all: remove private rdns logic

This commit is contained in:
Eugene Burkov 2024-04-03 17:10:33 +03:00
parent 5cc05e2c4b
commit 2046156580
12 changed files with 230 additions and 965 deletions

5
go.mod
View File

@ -3,8 +3,8 @@ module github.com/AdguardTeam/AdGuardHome
go 1.21.8 go 1.21.8
require ( require (
github.com/AdguardTeam/dnsproxy v0.67.0 github.com/AdguardTeam/dnsproxy v0.67.1-0.20240403090357-a2c0e321a217
github.com/AdguardTeam/golibs v0.21.0 github.com/AdguardTeam/golibs v0.22.0
github.com/AdguardTeam/urlfilter v0.18.0 github.com/AdguardTeam/urlfilter v0.18.0
github.com/NYTimes/gziphandler v1.1.1 github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.2.7 github.com/ameshkov/dnscrypt/v2 v2.2.7
@ -49,6 +49,7 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect
github.com/jessevdk/go-flags v1.5.0 // indirect
github.com/mdlayher/socket v0.5.0 // indirect github.com/mdlayher/socket v0.5.0 // indirect
github.com/onsi/ginkgo/v2 v2.16.0 // indirect github.com/onsi/ginkgo/v2 v2.16.0 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect

7
go.sum
View File

@ -1,7 +1,11 @@
github.com/AdguardTeam/dnsproxy v0.67.0 h1:7oKfcA8sm9d1N4qvhsNmQWBX4+fs3sX4cAnERmBXEbw= github.com/AdguardTeam/dnsproxy v0.67.0 h1:7oKfcA8sm9d1N4qvhsNmQWBX4+fs3sX4cAnERmBXEbw=
github.com/AdguardTeam/dnsproxy v0.67.0/go.mod h1:XLfD6IpSplUZZ+f5vhWSJW1mp4wm+KkHWiMo9w7U1Ls= github.com/AdguardTeam/dnsproxy v0.67.0/go.mod h1:XLfD6IpSplUZZ+f5vhWSJW1mp4wm+KkHWiMo9w7U1Ls=
github.com/AdguardTeam/dnsproxy v0.67.1-0.20240403090357-a2c0e321a217 h1:ryczFRf8y6PEzCjgy/S3Ptg4Ea1TUYFyiEZnoEyEV7s=
github.com/AdguardTeam/dnsproxy v0.67.1-0.20240403090357-a2c0e321a217/go.mod h1:5wIQueGTDX1Uk4GYevRh7HCtsCUR/U9lxf478+STOZI=
github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE= github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE=
github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI= github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI=
github.com/AdguardTeam/golibs v0.22.0 h1:wvT/UFIT8XIBfMabnD3LcDRiorx8J0lc3A/bzD6OX7c=
github.com/AdguardTeam/golibs v0.22.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI=
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ= github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s= github.com/AdguardTeam/urlfilter v0.18.0/go.mod h1:IXxBwedLiZA2viyHkaFxY/8mjub0li2PXRg8a3d9Z1s=
github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I= github.com/NYTimes/gziphandler v1.1.1 h1:ZUDjpQae29j0ryrS0u/B8HZfJBtBQHjqw2rQ2cqUQ3I=
@ -58,6 +62,8 @@ github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Go
github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8 h1:V3plQrMHRWOB5zMm3yNqvBxDQVW1+/wHBSok5uPdmVs= github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8 h1:V3plQrMHRWOB5zMm3yNqvBxDQVW1+/wHBSok5uPdmVs=
github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8/go.mod h1:izxuNQZeFrbx2nK2fAyN5iNUB34Fe9j0nK4PwLzAkKw= github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8/go.mod h1:izxuNQZeFrbx2nK2fAyN5iNUB34Fe9j0nK4PwLzAkKw=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk= github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk=
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8= github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8=
@ -158,6 +164,7 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -235,9 +235,18 @@ type DNSCryptConfig struct {
// ServerConfig represents server configuration. // ServerConfig represents server configuration.
// The zero ServerConfig is empty and ready for use. // The zero ServerConfig is empty and ready for use.
type ServerConfig struct { type ServerConfig struct {
UDPListenAddrs []*net.UDPAddr // UDP listen address // UDPListenAddrs is the list of addresses to listen for DNS-over-UDP.
TCPListenAddrs []*net.TCPAddr // TCP listen address UDPListenAddrs []*net.UDPAddr
UpstreamConfig *proxy.UpstreamConfig // Upstream DNS servers config
// TCPListenAddrs is the list of addresses to listen for DNS-over-TCP.
TCPListenAddrs []*net.TCPAddr
// UpstreamConfig is the general configuration of upstream DNS servers.
UpstreamConfig *proxy.UpstreamConfig
// PrivateRDNSUpstreamConfig is the configuration of upstream DNS servers
// for private reverse DNS.
PrivateRDNSUpstreamConfig *proxy.UpstreamConfig
// AddrProcConf defines the configuration for the client IP processor. // AddrProcConf defines the configuration for the client IP processor.
// If nil, [client.EmptyAddrProc] is used. // If nil, [client.EmptyAddrProc] is used.
@ -306,24 +315,28 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies) trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
conf = &proxy.Config{ conf = &proxy.Config{
HTTP3: srvConf.ServeHTTP3, HTTP3: srvConf.ServeHTTP3,
Ratelimit: int(srvConf.Ratelimit), Ratelimit: int(srvConf.Ratelimit),
RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4, RatelimitSubnetLenIPv4: srvConf.RatelimitSubnetLenIPv4,
RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6, RatelimitSubnetLenIPv6: srvConf.RatelimitSubnetLenIPv6,
RatelimitWhitelist: srvConf.RatelimitWhitelist, RatelimitWhitelist: srvConf.RatelimitWhitelist,
RefuseAny: srvConf.RefuseAny, RefuseAny: srvConf.RefuseAny,
TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes), TrustedProxies: netutil.SliceSubnetSet(trustedPrefixes),
CacheMinTTL: srvConf.CacheMinTTL, CacheMinTTL: srvConf.CacheMinTTL,
CacheMaxTTL: srvConf.CacheMaxTTL, CacheMaxTTL: srvConf.CacheMaxTTL,
CacheOptimistic: srvConf.CacheOptimistic, CacheOptimistic: srvConf.CacheOptimistic,
UpstreamConfig: srvConf.UpstreamConfig, UpstreamConfig: srvConf.UpstreamConfig,
BeforeRequestHandler: s.beforeRequestHandler, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
RequestHandler: s.handleDNSRequest, BeforeRequestHandler: s.beforeRequestHandler,
HTTPSServerName: aghhttp.UserAgent(), RequestHandler: s.handleDNSRequest,
EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled, HTTPSServerName: aghhttp.UserAgent(),
MaxGoroutines: srvConf.MaxGoroutines, EnableEDNSClientSubnet: srvConf.EDNSClientSubnet.Enabled,
UseDNS64: srvConf.UseDNS64, MaxGoroutines: srvConf.MaxGoroutines,
DNS64Prefs: srvConf.DNS64Prefixes, UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets,
// TODO(e.burkov): !! add message constructor
} }
if srvConf.EDNSClientSubnet.UseCustom { if srvConf.EDNSClientSubnet.UseCustom {

View File

@ -63,6 +63,7 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns
return rr return rr
} }
// TODO(e.burkov): !! add ErrNoUptreams case
func TestServer_HandleDNSRequest_dns64(t *testing.T) { func TestServer_HandleDNSRequest_dns64(t *testing.T) {
const ( const (
ipv4Domain = "ipv4.only." ipv4Domain = "ipv4.only."

View File

@ -135,12 +135,6 @@ type Server struct {
// WHOIS, etc. // WHOIS, etc.
addrProc client.AddressProcessor addrProc client.AddressProcessor
// localResolvers is a DNS proxy instance used to resolve PTR records for
// addresses considered private as per the [privateNets].
//
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
localResolvers *proxy.Proxy
// sysResolvers used to fetch system resolvers to use by default for private // sysResolvers used to fetch system resolvers to use by default for private
// PTR resolving. // PTR resolving.
sysResolvers SystemResolvers sysResolvers SystemResolvers
@ -158,12 +152,6 @@ type Server struct {
// [upstream.Resolver] interface. // [upstream.Resolver] interface.
bootResolvers []*upstream.UpstreamResolver bootResolvers []*upstream.UpstreamResolver
// recDetector is a cache for recursive requests. It is used to detect and
// prevent recursive requests only for private upstreams.
//
// See https://github.com/adguardTeam/adGuardHome/issues/3185#issuecomment-851048135.
recDetector *recursionDetector
// dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major // dns64Pref is the NAT64 prefix used for DNS64 response mapping. The major
// part of DNS64 happens inside the [proxy] package, but there still are // part of DNS64 happens inside the [proxy] package, but there still are
// some places where response mapping is needed (e.g. DHCP). // some places where response mapping is needed (e.g. DHCP).
@ -212,14 +200,6 @@ type DNSCreateParams struct {
LocalDomain string LocalDomain string
} }
const (
// recursionTTL is the time recursive request is cached for.
recursionTTL = 1 * time.Second
// cachedRecurrentReqNum is the maximum number of cached recurrent
// requests.
cachedRecurrentReqNum = 1000
)
// NewServer creates a new instance of the dnsforward.Server // NewServer creates a new instance of the dnsforward.Server
// Note: this function must be called only once // Note: this function must be called only once
// //
@ -256,7 +236,6 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// TODO(e.burkov): Use some case-insensitive string comparison. // TODO(e.burkov): Use some case-insensitive string comparison.
localDomainSuffix: strings.ToLower(localDomainSuffix), localDomainSuffix: strings.ToLower(localDomainSuffix),
etcHosts: etcHosts, etcHosts: etcHosts,
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
clientIDCache: cache.New(cache.Config{ clientIDCache: cache.New(cache.Config{
EnableLRU: true, EnableLRU: true,
MaxCount: defaultClientIDCacheCount, MaxCount: defaultClientIDCacheCount,
@ -366,6 +345,7 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
s.serverLock.RLock() s.serverLock.RLock()
defer s.serverLock.RUnlock() defer s.serverLock.RUnlock()
// TODO(e.burkov): Migrate to netip.Addr already.
arpa, err := netutil.IPToReversedAddr(ip.AsSlice()) arpa, err := netutil.IPToReversedAddr(ip.AsSlice())
if err != nil { if err != nil {
return "", 0, fmt.Errorf("reversing ip: %w", err) return "", 0, fmt.Errorf("reversing ip: %w", err)
@ -386,25 +366,23 @@ func (s *Server) Exchange(ip netip.Addr) (host string, ttl time.Duration, err er
} }
dctx := &proxy.DNSContext{ dctx := &proxy.DNSContext{
Proto: "udp", Proto: "udp",
Req: req, Req: req,
IsPrivateClient: true,
} }
var resolver *proxy.Proxy
var errMsg string var errMsg string
if s.privateNets.Contains(ip) { if s.privateNets.Contains(ip) {
if !s.conf.UsePrivateRDNS { if !s.conf.UsePrivateRDNS {
return "", 0, nil return "", 0, nil
} }
resolver = s.localResolvers
errMsg = "resolving a private address: %w" errMsg = "resolving a private address: %w"
s.recDetector.add(*req) dctx.RequestedPrivateRDNS = netip.PrefixFrom(ip, ip.BitLen())
} else { } else {
resolver = s.internalProxy
errMsg = "resolving an address: %w" errMsg = "resolving an address: %w"
} }
if err = resolver.Resolve(dctx); err != nil { if err = s.internalProxy.Resolve(dctx); err != nil {
return "", 0, fmt.Errorf(errMsg, err) return "", 0, fmt.Errorf(errMsg, err)
} }
@ -541,35 +519,6 @@ func (err *LocalResolversError) Unwrap() error {
return err.Err return err.Err
} }
// setupLocalResolvers initializes and sets the resolvers for local addresses.
// It assumes s.serverLock is locked or s not running. It returns the upstream
// configuration used for private PTR resolving, or nil if it's disabled. Note,
// that it's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
func (s *Server) setupLocalResolvers(boot upstream.Resolver) (uc *proxy.UpstreamConfig, err error) {
if !s.conf.UsePrivateRDNS {
// It's safe to put nil into [proxy.Config.PrivateRDNSUpstreamConfig].
return nil, nil
}
uc, err = s.prepareLocalResolvers(boot)
if err != nil {
// Don't wrap the error because it's informative enough as is.
return nil, err
}
localResolvers, err := proxy.New(&proxy.Config{
UpstreamConfig: uc,
})
if err != nil {
return nil, &LocalResolversError{Err: err}
}
s.localResolvers = localResolvers
// TODO(e.burkov): Should we also consider the DNS64 usage?
return uc, nil
}
// Prepare initializes parameters of s using data from conf. conf must not be // Prepare initializes parameters of s using data from conf. conf must not be
// nil. // nil.
func (s *Server) Prepare(conf *ServerConfig) (err error) { func (s *Server) Prepare(conf *ServerConfig) (err error) {
@ -586,7 +535,7 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.initDefaultSettings() s.initDefaultSettings()
boot, err := s.prepareInternalDNS() err = s.prepareInternalDNS()
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return err return err
@ -608,12 +557,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
return fmt.Errorf("preparing access: %w", err) return fmt.Errorf("preparing access: %w", err)
} }
// TODO(e.burkov): Remove once the local resolvers logic moved to dnsproxy.
proxyConfig.PrivateRDNSUpstreamConfig, err = s.setupLocalResolvers(boot)
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
proxyConfig.Fallbacks, err = s.setupFallbackDNS() proxyConfig.Fallbacks, err = s.setupFallbackDNS()
if err != nil { if err != nil {
return fmt.Errorf("setting up fallback dns servers: %w", err) return fmt.Errorf("setting up fallback dns servers: %w", err)
@ -626,8 +569,6 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.dnsProxy = dnsProxy s.dnsProxy = dnsProxy
s.recDetector.clear()
s.setupAddrProc() s.setupAddrProc()
s.registerHandlers() s.registerHandlers()
@ -638,10 +579,10 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
// prepareInternalDNS initializes the internal state of s before initializing // prepareInternalDNS initializes the internal state of s before initializing
// the primary DNS proxy instance. It assumes s.serverLock is locked or the // the primary DNS proxy instance. It assumes s.serverLock is locked or the
// Server not running. // Server not running.
func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) { func (s *Server) prepareInternalDNS() (err error) {
err = s.prepareIpsetListSettings() err = s.prepareIpsetListSettings()
if err != nil { if err != nil {
return nil, fmt.Errorf("preparing ipset settings: %w", err) return fmt.Errorf("preparing ipset settings: %w", err)
} }
s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{ s.bootstrap, s.bootResolvers, err = s.createBootstrap(s.conf.BootstrapDNS, &upstream.Options{
@ -650,21 +591,28 @@ func (s *Server) prepareInternalDNS() (boot upstream.Resolver, err error) {
}) })
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return nil, err return err
} }
err = s.prepareUpstreamSettings(s.bootstrap) err = s.prepareUpstreamSettings(s.bootstrap)
if err != nil { if err != nil {
// Don't wrap the error, because it's informative enough as is. // Don't wrap the error, because it's informative enough as is.
return s.bootstrap, err return err
}
if s.conf.UsePrivateRDNS {
s.conf.PrivateRDNSUpstreamConfig, err = s.prepareLocalResolvers(s.bootstrap)
if err != nil {
return fmt.Errorf("setting up resolvers: %w", err)
}
} }
err = s.prepareInternalProxy() err = s.prepareInternalProxy()
if err != nil { if err != nil {
return s.bootstrap, fmt.Errorf("preparing internal proxy: %w", err) return fmt.Errorf("preparing internal proxy: %w", err)
} }
return s.bootstrap, nil return nil
} }
// setupFallbackDNS initializes the fallback DNS servers. // setupFallbackDNS initializes the fallback DNS servers.
@ -743,10 +691,15 @@ func validateBlockingMode(
func (s *Server) prepareInternalProxy() (err error) { func (s *Server) prepareInternalProxy() (err error) {
srvConf := s.conf srvConf := s.conf
conf := &proxy.Config{ conf := &proxy.Config{
CacheEnabled: true, CacheEnabled: true,
CacheSizeBytes: 4096, CacheSizeBytes: 4096,
UpstreamConfig: srvConf.UpstreamConfig, PrivateRDNSUpstreamConfig: srvConf.PrivateRDNSUpstreamConfig,
MaxGoroutines: s.conf.MaxGoroutines, UpstreamConfig: srvConf.UpstreamConfig,
MaxGoroutines: srvConf.MaxGoroutines,
UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS,
// TODO(e.burkov): !! add message constructor
} }
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)
@ -782,11 +735,6 @@ func (s *Server) stopLocked() (err error) {
} }
} }
logCloserErr(s.internalProxy.UpstreamConfig, "dnsforward: closing internal resolvers: %s")
if s.localResolvers != nil {
logCloserErr(s.localResolvers.UpstreamConfig, "dnsforward: closing local resolvers: %s")
}
for _, b := range s.bootResolvers { for _, b := range s.bootResolvers {
logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address()) logCloserErr(b, "dnsforward: closing bootstrap %s: %s", b.Address())
} }

View File

@ -352,7 +352,11 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
} }
if req.LocalPTRUpstreams != nil { if req.LocalPTRUpstreams != nil {
err = ValidateUpstreamsPrivate(*req.LocalPTRUpstreams, privateNets) var uc *proxy.UpstreamConfig
uc, err = proxy.ParseUpstreamsConfig(*req.LocalPTRUpstreams, &upstream.Options{})
if err == nil {
err = proxy.ValidatePrivateConfig(uc, privateNets)
}
if err != nil { if err != nil {
return fmt.Errorf("private upstream servers: %w", err) return fmt.Errorf("private upstream servers: %w", err)
} }

View File

@ -245,9 +245,8 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
wantSet: "", wantSet: "",
}, { }, {
name: "local_ptr_upstreams_bad", name: "local_ptr_upstreams_bad",
wantSet: `validating dns config: ` + wantSet: `validating dns config: private upstream servers: ` +
`private upstream servers: checking domain-specific upstreams: ` + `bad arpa domain name "non.arpa": not a reversed ip network`,
`bad arpa domain name "non.arpa.": not a reversed ip network`,
}, { }, {
name: "local_ptr_upstreams_null", name: "local_ptr_upstreams_null",
wantSet: "", wantSet: "",
@ -318,58 +317,6 @@ func TestIsCommentOrEmpty(t *testing.T) {
} }
} }
func TestValidateUpstreamsPrivate(t *testing.T) {
ss := netutil.SubnetSetFunc(netutil.IsLocallyServed)
testCases := []struct {
name string
wantErr string
u string
}{{
name: "success_address",
wantErr: ``,
u: "[/1.0.0.127.in-addr.arpa/]#",
}, {
name: "success_subnet",
wantErr: ``,
u: "[/127.in-addr.arpa/]#",
}, {
name: "not_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`bad arpa domain name "hello.world.": not a reversed ip network`,
u: "[/hello.world/]#",
}, {
name: "non-private_arpa_address",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network`,
u: "[/1.2.3.4.in-addr.arpa/]#",
}, {
name: "non-private_arpa_subnet",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "128.in-addr.arpa." should point to a locally-served network`,
u: "[/128.in-addr.arpa/]#",
}, {
name: "several_bad",
wantErr: `checking domain-specific upstreams: ` +
`arpa domain "1.2.3.4.in-addr.arpa." should point to a locally-served network` + "\n" +
`bad arpa domain name "non.arpa.": not a reversed ip network`,
u: "[/non.arpa/1.2.3.4.in-addr.arpa/127.in-addr.arpa/]#",
}, {
name: "partial_good",
wantErr: "",
u: "[/a.1.2.3.10.in-addr.arpa/a.10.in-addr.arpa/]#",
}}
for _, tc := range testCases {
set := []string{"192.168.0.1", tc.u}
t.Run(tc.name, func(t *testing.T) {
err := ValidateUpstreamsPrivate(set, ss)
testutil.AssertErrorMsg(t, tc.wantErr, err)
})
}
}
func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) { func newLocalUpstreamListener(t *testing.T, port uint16, handler dns.Handler) (real netip.AddrPort) {
t.Helper() t.Helper()

View File

@ -4,14 +4,11 @@ import (
"encoding/binary" "encoding/binary"
"net" "net"
"net/netip" "net/netip"
"strconv"
"strings" "strings"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
@ -34,11 +31,6 @@ type dnsContext struct {
// response is modified by filters. // response is modified by filters.
origResp *dns.Msg origResp *dns.Msg
// unreversedReqIP stores an IP address obtained from a PTR request if it
// was parsed successfully and belongs to one of the locally served IP
// ranges.
unreversedReqIP netip.Addr
// err is the error returned from a processing function. // err is the error returned from a processing function.
err error err error
@ -63,10 +55,6 @@ type dnsContext struct {
// responseAD shows if the response had the AD bit set. // responseAD shows if the response had the AD bit set.
responseAD bool responseAD bool
// isLocalClient shows if client's IP address is from locally served
// network.
isLocalClient bool
// isDHCPHost is true if the request for a local domain name and the DHCP is // isDHCPHost is true if the request for a local domain name and the DHCP is
// available for this request. // available for this request.
isDHCPHost bool isDHCPHost bool
@ -109,15 +97,11 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
// (*proxy.Proxy).handleDNSRequest method performs it before calling the // (*proxy.Proxy).handleDNSRequest method performs it before calling the
// appropriate handler. // appropriate handler.
mods := []modProcessFunc{ mods := []modProcessFunc{
s.processRecursion,
s.processInitial, s.processInitial,
s.processDDRQuery, s.processDDRQuery,
s.processDetermineLocal,
s.processDHCPHosts, s.processDHCPHosts,
s.processRestrictLocal,
s.processDHCPAddrs, s.processDHCPAddrs,
s.processFilteringBeforeRequest, s.processFilteringBeforeRequest,
s.processLocalPTR,
s.processUpstream, s.processUpstream,
s.processFilteringAfterResponse, s.processFilteringAfterResponse,
s.ipset.process, s.ipset.process,
@ -145,24 +129,6 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
return nil return nil
} }
// processRecursion checks the incoming request and halts its handling by
// answering NXDOMAIN if s has tried to resolve it recently.
func (s *Server) processRecursion(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing recursion")
defer log.Debug("dnsforward: finished processing recursion")
pctx := dctx.proxyCtx
if msg := pctx.Req; msg != nil && s.recDetector.check(*msg) {
log.Debug("dnsforward: recursion detected resolving %q", msg.Question[0].Name)
pctx.Res = s.genNXDomain(pctx.Req)
return resultCodeFinish
}
return resultCodeSuccess
}
// mozillaFQDN is the domain used to signal the Firefox browser to not use its // mozillaFQDN is the domain used to signal the Firefox browser to not use its
// own DoH server. // own DoH server.
// //
@ -339,19 +305,6 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
return resp return resp
} }
// processDetermineLocal determines if the client's IP address is from locally
// served network and saves the result into the context.
func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local detection")
defer log.Debug("dnsforward: finished processing local detection")
rc = resultCodeSuccess
dctx.isLocalClient = s.privateNets.Contains(dctx.proxyCtx.Addr.Addr())
return rc
}
// processDHCPHosts respond to A requests if the target hostname is known to // processDHCPHosts respond to A requests if the target hostname is known to
// the server. It responds with a mapped IP address if the DNS64 is enabled and // the server. It responds with a mapped IP address if the DNS64 is enabled and
// the request is for AAAA. // the request is for AAAA.
@ -370,7 +323,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
if !dctx.isLocalClient { if !pctx.IsPrivateClient {
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost) log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
pctx.Res = s.genNXDomain(req) pctx.Res = s.genNXDomain(req)
@ -416,141 +369,6 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// indexFirstV4Label returns the index at which the reversed IPv4 address
// starts, assuming the domain is pre-validated ARPA domain having in-addr and
// arpa labels removed.
func indexFirstV4Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv4len && idx > 0; labelsNum++ {
curIdx := strings.LastIndexByte(domain[:idx-1], '.') + 1
_, parseErr := strconv.ParseUint(domain[curIdx:idx-1], 10, 8)
if parseErr != nil {
return idx
}
idx = curIdx
}
return idx
}
// indexFirstV6Label returns the index at which the reversed IPv6 address
// starts, assuming the domain is pre-validated ARPA domain having ip6 and arpa
// labels removed.
func indexFirstV6Label(domain string) (idx int) {
idx = len(domain)
for labelsNum := 0; labelsNum < net.IPv6len*2 && idx > 0; labelsNum++ {
curIdx := idx - len("a.")
if curIdx > 1 && domain[curIdx-1] != '.' {
return idx
}
nibble := domain[curIdx]
if (nibble < '0' || nibble > '9') && (nibble < 'a' || nibble > 'f') {
return idx
}
idx = curIdx
}
return idx
}
// extractARPASubnet tries to convert a reversed ARPA address being a part of
// domain to an IP network. domain must be an FQDN.
//
// TODO(e.burkov): Move to golibs.
func extractARPASubnet(domain string) (pref netip.Prefix, err error) {
err = netutil.ValidateDomainName(strings.TrimSuffix(domain, "."))
if err != nil {
// Don't wrap the error since it's informative enough as is.
return netip.Prefix{}, err
}
const (
v4Suffix = "in-addr.arpa."
v6Suffix = "ip6.arpa."
)
domain = strings.ToLower(domain)
var idx int
switch {
case strings.HasSuffix(domain, v4Suffix):
idx = indexFirstV4Label(domain[:len(domain)-len(v4Suffix)])
case strings.HasSuffix(domain, v6Suffix):
idx = indexFirstV6Label(domain[:len(domain)-len(v6Suffix)])
default:
return netip.Prefix{}, &netutil.AddrError{
Err: netutil.ErrNotAReversedSubnet,
Kind: netutil.AddrKindARPA,
Addr: domain,
}
}
return netutil.PrefixFromReversedAddr(domain[idx:])
}
// processRestrictLocal responds with NXDOMAIN to PTR requests for IP addresses
// in locally served network from external clients.
func (s *Server) processRestrictLocal(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local restriction")
defer log.Debug("dnsforward: finished processing local restriction")
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
if q.Qtype != dns.TypePTR {
// No need for restriction.
return resultCodeSuccess
}
subnet, err := extractARPASubnet(q.Name)
if err != nil {
if errors.Is(err, netutil.ErrNotAReversedSubnet) {
log.Debug("dnsforward: request is not for arpa domain")
return resultCodeSuccess
}
log.Debug("dnsforward: parsing reversed addr: %s", err)
return resultCodeError
}
// Restrict an access to local addresses for external clients. We also
// assume that all the DHCP leases we give are locally served or at least
// shouldn't be accessible externally.
subnetAddr := subnet.Addr()
if !s.privateNets.Contains(subnetAddr) {
return resultCodeSuccess
}
log.Debug("dnsforward: addr %s is from locally served network", subnetAddr)
if !dctx.isLocalClient {
log.Debug("dnsforward: %q requests an internal ip", pctx.Addr)
pctx.Res = s.genNXDomain(req)
// Do not even put into query log.
return resultCodeFinish
}
// Do not perform unreversing ever again.
dctx.unreversedReqIP = subnetAddr
// There is no need to filter request from external addresses since this
// code is only executed when the request is for locally served ARPA
// hostname so disable redundant filters.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
// Nothing to restrict.
return resultCodeSuccess
}
// processDHCPAddrs responds to PTR requests if the target IP is leased by the // processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server. // DHCP server.
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
@ -562,17 +380,18 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
ipAddr := dctx.unreversedReqIP pref := pctx.RequestedPrivateRDNS
if ipAddr == (netip.Addr{}) { if pref == (netip.Prefix{}) {
return resultCodeSuccess return resultCodeSuccess
} }
host := s.dhcpServer.HostByIP(ipAddr) addr := pref.Addr()
host := s.dhcpServer.HostByIP(addr)
if host == "" { if host == "" {
return resultCodeSuccess return resultCodeSuccess
} }
log.Debug("dnsforward: dhcp client %s is %q", ipAddr, host) log.Debug("dnsforward: dhcp client %s is %q", addr, host)
req := pctx.Req req := pctx.Req
resp := s.makeResponse(req) resp := s.makeResponse(req)
@ -593,62 +412,22 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
return resultCodeSuccess return resultCodeSuccess
} }
// processLocalPTR responds to PTR requests if the target IP is detected to be
// inside the local network and the query was not answered from DHCP.
func (s *Server) processLocalPTR(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing local ptr")
defer log.Debug("dnsforward: finished processing local ptr")
pctx := dctx.proxyCtx
if pctx.Res != nil {
return resultCodeSuccess
}
ip := dctx.unreversedReqIP
if ip == (netip.Addr{}) {
return resultCodeSuccess
}
s.serverLock.RLock()
defer s.serverLock.RUnlock()
if s.conf.UsePrivateRDNS {
s.recDetector.add(*pctx.Req)
if err := s.localResolvers.Resolve(pctx); err != nil {
log.Debug("dnsforward: resolving private address: %s", err)
// Generate the server failure if the private upstream configuration
// is empty.
//
// This is a crutch, see TODO at [Server.localResolvers].
if errors.Is(err, upstream.ErrNoUpstreams) {
pctx.Res = s.genServerFailure(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
dctx.err = err
return resultCodeError
}
}
if pctx.Res == nil {
pctx.Res = s.genNXDomain(pctx.Req)
// Do not even put into query log.
return resultCodeFinish
}
return resultCodeSuccess
}
// Apply filtering logic // Apply filtering logic
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) { func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: started processing filtering before req") log.Debug("dnsforward: started processing filtering before req")
defer log.Debug("dnsforward: finished processing filtering before req") defer log.Debug("dnsforward: finished processing filtering before req")
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
// There is no need to filter request for locally served ARPA hostname
// so disable redundant filters.
//
// TODO(e.burkov): !! check the above comment.
dctx.setts.ParentalEnabled = false
dctx.setts.SafeBrowsingEnabled = false
dctx.setts.SafeSearchEnabled = false
dctx.setts.ServicesRules = nil
}
if dctx.proxyCtx.Res != nil { if dctx.proxyCtx.Res != nil {
// Go on since the response is already set. // Go on since the response is already set.
return resultCodeSuccess return resultCodeSuccess
@ -713,18 +492,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
} }
if err := prx.Resolve(pctx); err != nil { if err := prx.Resolve(pctx); err != nil {
if errors.Is(err, upstream.ErrNoUpstreams) { // TODO(e.burkov): !! try getting error from here.
// Do not even put into querylog. Currently this happens either
// when the private resolvers enabled and the request is DNS64 PTR,
// or when the client isn't considered local by prx.
//
// TODO(e.burkov): Make proxy detect local client the same way as
// AGH does.
pctx.Res = s.genNXDomain(req)
return resultCodeFinish
}
dctx.err = err dctx.err = err
return resultCodeError return resultCodeError

View File

@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
@ -379,44 +380,6 @@ func createTestDNSFilter(t *testing.T) (f *filtering.DNSFilter) {
return f return f
} }
func TestServer_ProcessDetermineLocal(t *testing.T) {
s := &Server{
privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed),
}
testCases := []struct {
want assert.BoolAssertionFunc
name string
cliAddr netip.AddrPort
}{{
want: assert.True,
name: "local",
cliAddr: netip.MustParseAddrPort("192.168.0.1:1"),
}, {
want: assert.False,
name: "external",
cliAddr: netip.MustParseAddrPort("250.249.0.1:1"),
}, {
want: assert.False,
name: "invalid",
cliAddr: netip.AddrPort{},
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
proxyCtx := &proxy.DNSContext{
Addr: tc.cliAddr,
}
dctx := &dnsContext{
proxyCtx: proxyCtx,
}
s.processDetermineLocal(dctx)
tc.want(t, dctx.isLocalClient)
})
}
}
func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
const ( const (
localDomainSuffix = "lan" localDomainSuffix = "lan"
@ -486,9 +449,9 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
dctx := &dnsContext{ dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{ proxyCtx: &proxy.DNSContext{
Req: req, Req: req,
IsPrivateClient: tc.isLocalCli,
}, },
isLocalClient: tc.isLocalCli,
} }
res := s.processDHCPHosts(dctx) res := s.processDHCPHosts(dctx)
@ -621,9 +584,9 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
dctx := &dnsContext{ dctx := &dnsContext{
proxyCtx: &proxy.DNSContext{ proxyCtx: &proxy.DNSContext{
Req: req, Req: req,
IsPrivateClient: true,
}, },
isLocalClient: true,
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
@ -658,19 +621,25 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
} }
} }
func TestServer_ProcessRestrictLocal(t *testing.T) { func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
intAddr := netip.MustParseAddr("192.168.1.1")
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
require.NoError(t, err)
extAddr := netip.MustParseAddr("254.253.252.1")
extPTRQuestion, err := netutil.IPToReversedAddr(extAddr.AsSlice())
require.NoError(t, err)
const ( const (
extPTRQuestion = "251.252.253.254.in-addr.arpa." extPTRAnswer = "host1.example.net."
extPTRAnswer = "host1.example.net." intPTRAnswer = "some.local-client."
intPTRQuestion = "1.1.168.192.in-addr.arpa."
intPTRAnswer = "some.local-client."
) )
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce( resp := aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, extPTRQuestion, extPTRAnswer),
aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer), aghtest.MatchedResponse(req, dns.TypePTR, intPTRQuestion, intPTRAnswer),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), (&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
) )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
@ -696,123 +665,162 @@ func TestServer_ProcessRestrictLocal(t *testing.T) {
startDeferStop(t, s) startDeferStop(t, s)
testCases := []struct { testCases := []struct {
name string name string
want string question string
question net.IP wantErr error
cliAddr netip.AddrPort wantAns []dns.RR
wantLen int isPrivate bool
}{{ }{{
name: "from_local_to_external", name: "from_local_for_external",
want: "host1.example.net.", question: extPTRQuestion,
question: net.IP{254, 253, 252, 251}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("192.168.10.10:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: true,
}, { }, {
name: "from_external_for_local", name: "from_external_for_local",
want: "", question: intPTRQuestion,
question: net.IP{192, 168, 1, 1}, wantErr: upstream.ErrNoUpstreams,
cliAddr: netip.MustParseAddrPort("254.253.252.251:1"), wantAns: nil,
wantLen: 0, isPrivate: false,
}, { }, {
name: "from_local_for_local", name: "from_local_for_local",
want: "some.local-client.", question: intPTRQuestion,
question: net.IP{192, 168, 1, 1}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("192.168.1.2:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(intPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(intPTRAnswer) + 1),
},
Ptr: dns.Fqdn(intPTRAnswer),
}},
isPrivate: true,
}, { }, {
name: "from_external_for_external", name: "from_external_for_external",
want: "host1.example.net.", question: extPTRQuestion,
question: net.IP{254, 253, 252, 251}, wantErr: nil,
cliAddr: netip.MustParseAddrPort("254.253.252.255:1"), wantAns: []dns.RR{&dns.PTR{
wantLen: 1, Hdr: dns.RR_Header{
Name: dns.Fqdn(extPTRQuestion),
Rrtype: dns.TypePTR,
Class: dns.ClassINET,
Ttl: 60,
Rdlength: uint16(len(extPTRAnswer) + 1),
},
Ptr: dns.Fqdn(extPTRAnswer),
}},
isPrivate: false,
}} }}
for _, tc := range testCases { for _, tc := range testCases {
reqAddr, err := dns.ReverseAddr(tc.question.String()) pref, extErr := netutil.ExtractReversedAddr(tc.question)
require.NoError(t, err) require.NoError(t, extErr)
req := createTestMessageWithType(reqAddr, dns.TypePTR)
req := createTestMessageWithType(dns.Fqdn(tc.question), dns.TypePTR)
pctx := &proxy.DNSContext{ pctx := &proxy.DNSContext{
Proto: proxy.ProtoTCP, Req: req,
Req: req, IsPrivateClient: tc.isPrivate,
Addr: tc.cliAddr, }
// TODO(e.burkov): Configure the subnet set properly.
if netutil.IsLocallyServed(pref.Addr()) {
pctx.RequestedPrivateRDNS = pref
} }
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
err = s.handleDNSRequest(nil, pctx) err = s.handleDNSRequest(s.dnsProxy, pctx)
require.NoError(t, err) require.ErrorIs(t, err, tc.wantErr)
require.NotNil(t, pctx.Res)
require.Len(t, pctx.Res.Answer, tc.wantLen)
if tc.wantLen > 0 { require.NotNil(t, pctx.Res)
assert.Equal(t, tc.want, pctx.Res.Answer[0].(*dns.PTR).Ptr) assert.Equal(t, tc.wantAns, pctx.Res.Answer)
}
}) })
} }
} }
func TestServer_ProcessLocalPTR_usingResolvers(t *testing.T) { func TestServer_ProcessUpstream_localPTR(t *testing.T) {
const locDomain = "some.local." const locDomain = "some.local."
const reqAddr = "1.1.168.192.in-addr.arpa." const reqAddr = "1.1.168.192.in-addr.arpa."
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
resp := aghalg.Coalesce( resp := aghalg.Coalesce(
aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain), aghtest.MatchedResponse(req, dns.TypePTR, reqAddr, locDomain),
new(dns.Msg).SetRcode(req, dns.RcodeNameError), (&dns.Msg{}).SetRcode(req, dns.RcodeNameError),
) )
require.NoError(testutil.PanicT{}, w.WriteMsg(resp)) require.NoError(testutil.PanicT{}, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
s := createTestServer( newPrxCtx := func() (prxCtx *proxy.DNSContext) {
t, return &proxy.DNSContext{
&filtering.Config{ Addr: testClientAddrPort,
BlockingMode: filtering.BlockingModeDefault, Req: createTestMessageWithType(reqAddr, dns.TypePTR),
}, IsPrivateClient: true,
ServerConfig{ RequestedPrivateRDNS: netip.MustParsePrefix("192.168.1.1/32"),
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
var proxyCtx *proxy.DNSContext
var dnsCtx *dnsContext
setup := func(use bool) {
proxyCtx = &proxy.DNSContext{
Addr: testClientAddrPort,
Req: createTestMessageWithType(reqAddr, dns.TypePTR),
} }
dnsCtx = &dnsContext{
proxyCtx: proxyCtx,
unreversedReqIP: netip.MustParseAddr("192.168.1.1"),
}
s.conf.UsePrivateRDNS = use
} }
t.Run("enabled", func(t *testing.T) { t.Run("enabled", func(t *testing.T) {
setup(true) s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: true,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx) rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeSuccess, rc) require.Equal(t, resultCodeSuccess, rc)
require.NotEmpty(t, proxyCtx.Res.Answer) require.NotEmpty(t, pctx.Res.Answer)
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
assert.Equal(t, locDomain, proxyCtx.Res.Answer[0].(*dns.PTR).Ptr) assert.Equal(t, locDomain, ptr.Ptr)
}) })
t.Run("disabled", func(t *testing.T) { t.Run("disabled", func(t *testing.T) {
setup(false) s := createTestServer(
t,
&filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
},
ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
},
)
pctx := newPrxCtx()
rc := s.processLocalPTR(dnsCtx) rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeFinish, rc) require.Equal(t, resultCodeError, rc)
require.Empty(t, proxyCtx.Res.Answer) require.Empty(t, pctx.Res.Answer)
}) })
} }
@ -830,129 +838,3 @@ func TestIPStringFromAddr(t *testing.T) {
assert.Empty(t, ipStringFromAddr(nil)) assert.Empty(t, ipStringFromAddr(nil))
}) })
} }
// TODO(e.burkov): Add fuzzing when moving to golibs.
func TestExtractARPASubnet(t *testing.T) {
const (
v4Suf = `in-addr.arpa.`
v4Part = `2.1.` + v4Suf
v4Whole = `4.3.` + v4Part
v6Suf = `ip6.arpa.`
v6Part = `4.3.2.1.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Suf
v6Whole = `f.e.d.c.0.0.0.0.0.0.0.0.0.0.0.0.` + v6Part
)
v4Pref := netip.MustParsePrefix("1.2.3.4/32")
v4PrefPart := netip.MustParsePrefix("1.2.0.0/16")
v6Pref := netip.MustParsePrefix("::1234:0:0:0:cdef/128")
v6PrefPart := netip.MustParsePrefix("0:0:0:1234::/64")
testCases := []struct {
want netip.Prefix
name string
domain string
wantErr string
}{{
want: netip.Prefix{},
name: "not_an_arpa",
domain: "some.domain.name.",
wantErr: `bad arpa domain name "some.domain.name.": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "bad_domain_name",
domain: "abc.123.",
wantErr: `bad domain name "abc.123": ` +
`bad top-level domain name label "123": all octets are numeric`,
}, {
want: v4Pref,
name: "whole_v4",
domain: v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4",
domain: v4Part,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_within_domain",
domain: "a." + v4Whole,
wantErr: "",
}, {
want: v4Pref,
name: "whole_v4_additional_label",
domain: "5." + v4Whole,
wantErr: "",
}, {
want: v4PrefPart,
name: "partial_v4_within_domain",
domain: "a." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4",
domain: "256." + v4Part,
wantErr: "",
}, {
want: v4PrefPart,
name: "overflow_v4_within_domain",
domain: "a.256." + v4Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v4",
domain: v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v4_within_domain",
domain: "a." + v4Suf,
wantErr: `bad arpa domain name "in-addr.arpa": ` +
`not a reversed ip network`,
}, {
want: v6Pref,
name: "whole_v6",
domain: v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6",
domain: v6Part,
}, {
want: v6Pref,
name: "whole_v6_within_domain",
domain: "g." + v6Whole,
wantErr: "",
}, {
want: v6Pref,
name: "whole_v6_additional_label",
domain: "1." + v6Whole,
wantErr: "",
}, {
want: v6PrefPart,
name: "partial_v6_within_domain",
domain: "label." + v6Part,
wantErr: "",
}, {
want: netip.Prefix{},
name: "empty_v6",
domain: v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}, {
want: netip.Prefix{},
name: "empty_v6_within_domain",
domain: "g." + v6Suf,
wantErr: `bad arpa domain name "ip6.arpa": not a reversed ip network`,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
subnet, err := extractARPASubnet(tc.domain)
testutil.AssertErrorMsg(t, tc.wantErr, err)
assert.Equal(t, tc.want, subnet)
})
}
}

View File

@ -1,115 +0,0 @@
package dnsforward
import (
"bytes"
"encoding/binary"
"time"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
)
// uint* sizes in bytes to improve readability.
//
// TODO(e.burkov): Remove when there will be a more regardful way to define
// those. See https://github.com/golang/go/issues/29982.
const (
uint16sz = 2
uint64sz = 8
)
// recursionDetector detects recursion in DNS forwarding.
type recursionDetector struct {
recentRequests cache.Cache
ttl time.Duration
}
// check checks if the passed req was already sent by the server.
func (rd *recursionDetector) check(msg dns.Msg) (ok bool) {
if len(msg.Question) == 0 {
return false
}
key := msgToSignature(msg)
expireData := rd.recentRequests.Get(key)
if expireData == nil {
return false
}
expire := time.Unix(0, int64(binary.BigEndian.Uint64(expireData)))
return time.Now().Before(expire)
}
// add caches the msg if it has anything in the questions section.
func (rd *recursionDetector) add(msg dns.Msg) {
now := time.Now()
if len(msg.Question) == 0 {
return
}
key := msgToSignature(msg)
expire64 := uint64(now.Add(rd.ttl).UnixNano())
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, expire64)
rd.recentRequests.Set(key, expire)
}
// clear clears the recent requests cache.
func (rd *recursionDetector) clear() {
rd.recentRequests.Clear()
}
// newRecursionDetector returns the initialized *recursionDetector.
func newRecursionDetector(ttl time.Duration, suspectsNum uint) (rd *recursionDetector) {
return &recursionDetector{
recentRequests: cache.New(cache.Config{
EnableLRU: true,
MaxCount: suspectsNum,
}),
ttl: ttl,
}
}
// msgToSignature converts msg into it's signature represented in bytes.
func msgToSignature(msg dns.Msg) (sig []byte) {
sig = make([]byte, uint16sz*2+netutil.MaxDomainNameLen)
// The binary.BigEndian byte order is used everywhere except when the real
// machine's endianness is needed.
byteOrder := binary.BigEndian
byteOrder.PutUint16(sig[0:], msg.Id)
q := msg.Question[0]
byteOrder.PutUint16(sig[uint16sz:], q.Qtype)
copy(sig[2*uint16sz:], []byte(q.Name))
return sig
}
// msgToSignatureSlow converts msg into it's signature represented in bytes in
// the less efficient way.
//
// See BenchmarkMsgToSignature.
func msgToSignatureSlow(msg dns.Msg) (sig []byte) {
type msgSignature struct {
name [netutil.MaxDomainNameLen]byte
id uint16
qtype uint16
}
b := bytes.NewBuffer(sig)
q := msg.Question[0]
signature := msgSignature{
id: msg.Id,
qtype: q.Qtype,
}
copy(signature.name[:], q.Name)
if err := binary.Write(b, binary.BigEndian, signature); err != nil {
log.Debug("writing message signature: %s", err)
}
return b.Bytes()
}

View File

@ -1,148 +0,0 @@
package dnsforward
import (
"encoding/binary"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
func TestRecursionDetector_Check(t *testing.T) {
rd := newRecursionDetector(0, 2)
const (
recID = 1234
recTTL = time.Hour * 100
)
const nonRecID = recID * 2
sampleQuestion := dns.Question{
Name: "some.domain",
Qtype: dns.TypeAAAA,
}
sampleMsg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: recID,
},
Question: []dns.Question{sampleQuestion},
}
// Manually add the message with big ttl.
key := msgToSignature(sampleMsg)
expire := make([]byte, uint64sz)
binary.BigEndian.PutUint64(expire, uint64(time.Now().Add(recTTL).UnixNano()))
rd.recentRequests.Set(key, expire)
// Add an expired message.
sampleMsg.Id = nonRecID
rd.add(sampleMsg)
testCases := []struct {
name string
questions []dns.Question
id uint16
want bool
}{{
name: "recurrent",
questions: []dns.Question{sampleQuestion},
id: recID,
want: true,
}, {
name: "not_suspected",
questions: []dns.Question{sampleQuestion},
id: recID + 1,
want: false,
}, {
name: "expired",
questions: []dns.Question{sampleQuestion},
id: nonRecID,
want: false,
}, {
name: "empty",
questions: []dns.Question{},
id: nonRecID,
want: false,
}}
for _, tc := range testCases {
sampleMsg.Id = tc.id
sampleMsg.Question = tc.questions
t.Run(tc.name, func(t *testing.T) {
detected := rd.check(sampleMsg)
assert.Equal(t, tc.want, detected)
})
}
}
func TestRecursionDetector_Suspect(t *testing.T) {
rd := newRecursionDetector(0, 1)
testCases := []struct {
name string
msg dns.Msg
want int
}{{
name: "simple",
msg: dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: "some.domain",
Qtype: dns.TypeA,
}},
},
want: 1,
}, {
name: "unencumbered",
msg: dns.Msg{},
want: 0,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(rd.clear)
rd.add(tc.msg)
assert.Equal(t, tc.want, rd.recentRequests.Stats().Count)
})
}
}
var sink []byte
func BenchmarkMsgToSignature(b *testing.B) {
const name = "some.not.very.long.host.name"
msg := dns.Msg{
MsgHdr: dns.MsgHdr{
Id: 1234,
},
Question: []dns.Question{{
Name: name,
Qtype: dns.TypeAAAA,
}},
}
b.Run("efficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignature(msg)
}
assert.NotEmpty(b, sink)
})
b.Run("inefficient", func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
sink = msgToSignatureSlow(msg)
}
assert.NotEmpty(b, sink)
})
}

View File

@ -2,19 +2,14 @@ package dnsforward
import ( import (
"fmt" "fmt"
"net/netip"
"os" "os"
"slices"
"time" "time"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/golibs/stringutil"
"golang.org/x/exp/maps"
) )
// loadUpstreams parses upstream DNS servers from the configured file or from // loadUpstreams parses upstream DNS servers from the configured file or from
@ -174,41 +169,3 @@ func (s *Server) createBootstrap(
func IsCommentOrEmpty(s string) (ok bool) { func IsCommentOrEmpty(s string) (ok bool) {
return len(s) == 0 || s[0] == '#' return len(s) == 0 || s[0] == '#'
} }
// ValidateUpstreamsPrivate validates each upstream and returns an error if any
// upstream is invalid or if there are no default upstreams specified. It also
// checks each domain of domain-specific upstreams for being ARPA pointing to
// a locally-served network. privateNets must not be nil.
func ValidateUpstreamsPrivate(upstreams []string, privateNets netutil.SubnetSet) (err error) {
conf, err := proxy.ParseUpstreamsConfig(upstreams, &upstream.Options{})
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
if conf == nil {
return nil
}
keys := maps.Keys(conf.DomainReservedUpstreams)
slices.Sort(keys)
var errs []error
for _, domain := range keys {
var subnet netip.Prefix
subnet, err = extractARPASubnet(domain)
if err != nil {
errs = append(errs, err)
continue
}
if !privateNets.Contains(subnet.Addr()) {
errs = append(
errs,
fmt.Errorf("arpa domain %q should point to a locally-served network", domain),
)
}
}
return errors.Annotate(errors.Join(errs...), "checking domain-specific upstreams: %w")
}