mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-15 09:58:42 -07:00
dnsforward: imp code
This commit is contained in:
parent
b27547ec80
commit
08bb7d43d2
@ -27,7 +27,14 @@ NOTE: Add new changes BELOW THIS COMMENT.
|
|||||||
|
|
||||||
- Support for comments in the ipset file ([#5345]).
|
- Support for comments in the ipset file ([#5345]).
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
|
||||||
|
- Subdomains of `in-addr.arpa` and `ip6.arpa` containing zero-length prefix
|
||||||
|
incorrectly considered invalid when specified for private RDNS upstream
|
||||||
|
servers ([#6854]).
|
||||||
|
|
||||||
[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345
|
[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345
|
||||||
|
[#6854]: https://github.com/AdguardTeam/AdGuardHome/issues/6854
|
||||||
|
|
||||||
<!--
|
<!--
|
||||||
NOTE: Add new changes ABOVE THIS COMMENT.
|
NOTE: Add new changes ABOVE THIS COMMENT.
|
||||||
|
53
internal/dnsforward/beforerequest.go
Normal file
53
internal/dnsforward/beforerequest.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package dnsforward
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||||
|
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||||
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
)
|
||||||
|
|
||||||
|
// type check
|
||||||
|
var _ proxy.BeforeRequestHandler = (*Server)(nil)
|
||||||
|
|
||||||
|
// HandleBefore is the handler that is called before any other processing,
|
||||||
|
// including logs. It performs access checks and puts the client ID, if there
|
||||||
|
// is one, into the server's cache.
|
||||||
|
//
|
||||||
|
// TODO(e.burkov): Write tests.
|
||||||
|
func (s *Server) HandleBefore(
|
||||||
|
_ *proxy.Proxy,
|
||||||
|
pctx *proxy.DNSContext,
|
||||||
|
) (err error) {
|
||||||
|
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("getting clientid: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
|
||||||
|
if blocked {
|
||||||
|
return s.preBlockedResponse(pctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pctx.Req.Question) == 1 {
|
||||||
|
q := pctx.Req.Question[0]
|
||||||
|
qt := q.Qtype
|
||||||
|
host := aghnet.NormalizeDomain(q.Name)
|
||||||
|
if s.access.isBlockedHost(host, qt) {
|
||||||
|
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
|
||||||
|
|
||||||
|
return s.preBlockedResponse(pctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if clientID != "" {
|
||||||
|
key := [8]byte{}
|
||||||
|
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||||
|
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -1,60 +1,17 @@
|
|||||||
package dnsforward
|
package dnsforward
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/AdguardTeam/urlfilter/rules"
|
"github.com/AdguardTeam/urlfilter/rules"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// type check
|
|
||||||
var _ proxy.BeforeRequestHandler = (*Server)(nil)
|
|
||||||
|
|
||||||
// HandleBefore is the handler that is called before any other processing,
|
|
||||||
// including logs. It performs access checks and puts the client ID, if there
|
|
||||||
// is one, into the server's cache.
|
|
||||||
func (s *Server) HandleBefore(
|
|
||||||
_ *proxy.Proxy,
|
|
||||||
pctx *proxy.DNSContext,
|
|
||||||
) (err error) {
|
|
||||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("getting clientid: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
|
|
||||||
if blocked {
|
|
||||||
return s.preBlockedResponse(pctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(pctx.Req.Question) == 1 {
|
|
||||||
q := pctx.Req.Question[0]
|
|
||||||
qt := q.Qtype
|
|
||||||
host := aghnet.NormalizeDomain(q.Name)
|
|
||||||
if s.access.isBlockedHost(host, qt) {
|
|
||||||
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
|
|
||||||
|
|
||||||
return s.preBlockedResponse(pctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if clientID != "" {
|
|
||||||
key := [8]byte{}
|
|
||||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
|
||||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// clientRequestFilteringSettings looks up client filtering settings using the
|
// clientRequestFilteringSettings looks up client filtering settings using the
|
||||||
// client's IP address and ID, if any, from dctx.
|
// client's IP address and ID, if any, from dctx.
|
||||||
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
|
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {
|
||||||
|
@ -261,48 +261,6 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkBootstrap returns an error if any bootstrap address is invalid.
|
|
||||||
func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
|
||||||
if req.Bootstraps == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var b string
|
|
||||||
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
|
|
||||||
|
|
||||||
for _, b = range *req.Bootstraps {
|
|
||||||
if b == "" {
|
|
||||||
return errors.Error("empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
var resolver *upstream.UpstreamResolver
|
|
||||||
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
|
||||||
// Don't wrap the error because it's informative enough as is.
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = resolver.Close(); err != nil {
|
|
||||||
return fmt.Errorf("closing %s: %w", b, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkFallbacks returns an error if any fallback address is invalid.
|
|
||||||
func (req *jsonDNSConfig) checkFallbacks() (err error) {
|
|
||||||
if req.Fallbacks == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("fallback servers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validate returns an error if any field of req is invalid.
|
// validate returns an error if any field of req is invalid.
|
||||||
//
|
//
|
||||||
// TODO(s.chzhen): Parse, don't validate.
|
// TODO(s.chzhen): Parse, don't validate.
|
||||||
@ -342,23 +300,68 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checkUpstreams returns an error if lines can't be parsed as an upstream
|
||||||
|
// configuration. If privateNets is not nil, it also checks that the domain
|
||||||
|
// specifications are strictly ARPA domains containing the prefixes within the
|
||||||
|
// set.
|
||||||
|
func checkUpstreams(lines []string, section string, privateNets netutil.SubnetSet) (err error) {
|
||||||
|
defer func() { err = errors.Annotate(err, "%s servers: %w", section) }()
|
||||||
|
|
||||||
|
uc, err := proxy.ParseUpstreamsConfig(lines, &upstream.Options{})
|
||||||
|
if err == nil {
|
||||||
|
defer func() { err = errors.WithDeferred(err, uc.Close()) }()
|
||||||
|
|
||||||
|
if privateNets != nil {
|
||||||
|
err = proxy.ValidatePrivateConfig(uc, privateNets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkBootstrap returns an error if any bootstrap address is invalid.
|
||||||
|
func (req *jsonDNSConfig) checkBootstrap() (err error) {
|
||||||
|
if req.Bootstraps == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var b string
|
||||||
|
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
|
||||||
|
|
||||||
|
for _, b = range *req.Bootstraps {
|
||||||
|
if b == "" {
|
||||||
|
return errors.Error("empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var resolver *upstream.UpstreamResolver
|
||||||
|
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
|
||||||
|
// Don't wrap the error because it's informative enough as is.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = resolver.Close(); err != nil {
|
||||||
|
return fmt.Errorf("closing %s: %w", b, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
// validateUpstreamDNSServers returns an error if any field of req is invalid.
|
||||||
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
|
||||||
if req.Upstreams != nil {
|
if req.Upstreams != nil {
|
||||||
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{})
|
err = checkUpstreams(*req.Upstreams, "upstream", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("upstream servers: %w", err)
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.LocalPTRUpstreams != nil {
|
if req.LocalPTRUpstreams != nil {
|
||||||
var uc *proxy.UpstreamConfig
|
err = checkUpstreams(*req.LocalPTRUpstreams, "private upstream", privateNets)
|
||||||
uc, err = proxy.ParseUpstreamsConfig(*req.LocalPTRUpstreams, &upstream.Options{})
|
|
||||||
if err == nil {
|
|
||||||
err = proxy.ValidatePrivateConfig(uc, privateNets)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("private upstream servers: %w", err)
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -368,10 +371,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = req.checkFallbacks()
|
if req.Fallbacks != nil {
|
||||||
if err != nil {
|
err = checkUpstreams(*req.Fallbacks, "fallback", nil)
|
||||||
// Don't wrap the error since it's informative enough as is.
|
if err != nil {
|
||||||
return err
|
// Don't wrap the error since it's informative enough as is.
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -12,8 +12,8 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO(e.burkov): Call all the other methods by a [proxy.MessageConstructor]
|
// TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
|
||||||
// template.
|
// template. Also extract all the methods to a separate entity.
|
||||||
|
|
||||||
// reply creates a DNS response for req.
|
// reply creates a DNS response for req.
|
||||||
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
|
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
|
||||||
|
Loading…
Reference in New Issue
Block a user