diff --git a/internal/aghalg/aghalg.go b/internal/aghalg/aghalg.go index f0b71a09..b554a917 100644 --- a/internal/aghalg/aghalg.go +++ b/internal/aghalg/aghalg.go @@ -10,9 +10,14 @@ import ( "golang.org/x/exp/slices" ) -// Coalesce returns the first non-zero value. It is named after the function +// Coalesce returns the first non-zero value. It is named after function // COALESCE in SQL. If values or all its elements are empty, it returns a zero // value. +// +// T is comparable, because Go currently doesn't have a comparableWithZeroValue +// constraint. +// +// TODO(a.garipov): Think of ways to merge with [CoalesceSlice]. func Coalesce[T comparable](values ...T) (res T) { var zero T for _, v := range values { @@ -24,6 +29,20 @@ func Coalesce[T comparable](values ...T) (res T) { return zero } +// CoalesceSlice returns the first non-zero value. It is named after function +// COALESCE in SQL. If values or all its elements are empty, it returns nil. +// +// TODO(a.garipov): Think of ways to merge with [Coalesce]. +func CoalesceSlice[E any, S []E](values ...S) (res S) { + for _, v := range values { + if v != nil { + return v + } + } + + return nil +} + // UniqChecker allows validating uniqueness of comparable items. // // TODO(a.garipov): The Ordered constraint is only really necessary in Validate. diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 19ec710e..2b2ba2e8 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -337,7 +338,7 @@ func (s *Server) prepareUpstreamSettings() error { if s.conf.UpstreamDNSFileName != "" { data, err := os.ReadFile(s.conf.UpstreamDNSFileName) if err != nil { - return err + return fmt.Errorf("reading upstream from file: %w", err) } upstreams = stringutil.SplitTrimmed(string(data), "\n") @@ -356,7 +357,7 @@ func (s *Server) prepareUpstreamSettings() error { }, ) if err != nil { - return fmt.Errorf("dns: proxy.ParseUpstreamsConfig: %w", err) + return fmt.Errorf("parsing upstream config: %w", err) } if len(upstreamConfig.Upstreams) == 0 { @@ -370,8 +371,9 @@ func (s *Server) prepareUpstreamSettings() error { }, ) if err != nil { - return fmt.Errorf("dns: failed to parse default upstreams: %v", err) + return fmt.Errorf("parsing default upstreams: %w", err) } + upstreamConfig.Upstreams = uc.Upstreams } @@ -380,30 +382,6 @@ func (s *Server) prepareUpstreamSettings() error { return nil } -// prepareInternalProxy initializes the DNS proxy that is used for internal DNS -// queries, such at client PTR resolving and updater hostname resolving. -func (s *Server) prepareInternalProxy() { - conf := &proxy.Config{ - CacheEnabled: true, - CacheSizeBytes: 4096, - UpstreamConfig: s.conf.UpstreamConfig, - MaxGoroutines: int(s.conf.MaxGoroutines), - } - - srvConf := s.conf - setProxyUpstreamMode( - conf, - srvConf.AllServers, - srvConf.FastestAddr, - srvConf.FastestTimeout.Duration, - ) - - // TODO(a.garipov): Make a proper constructor for proxy.Proxy. - s.internalProxy = &proxy.Proxy{ - Config: *conf, - } -} - // setProxyUpstreamMode sets the upstream mode and related settings in conf // based on provided parameters. func setProxyUpstreamMode( @@ -432,13 +410,15 @@ func (s *Server) prepareTLS(proxyConfig *proxy.Config) error { return nil } - if s.conf.TLSListenAddrs != nil { - proxyConfig.TLSListenAddr = s.conf.TLSListenAddrs - } + proxyConfig.TLSListenAddr = aghalg.CoalesceSlice( + s.conf.TLSListenAddrs, + proxyConfig.TLSListenAddr, + ) - if s.conf.QUICListenAddrs != nil { - proxyConfig.QUICListenAddr = s.conf.QUICListenAddrs - } + proxyConfig.QUICListenAddr = aghalg.CoalesceSlice( + s.conf.QUICListenAddrs, + proxyConfig.QUICListenAddr, + ) var err error s.conf.cert, err = tls.X509KeyPair(s.conf.CertificateChainData, s.conf.PrivateKeyData) diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 63086d3d..33be1e83 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -434,65 +434,54 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) { return nil } -// Prepare the object -func (s *Server) Prepare(config *ServerConfig) error { - // Initialize the server configuration - // -- - if config != nil { - s.conf = *config - if s.conf.BlockingMode == "custom_ip" { - if s.conf.BlockingIPv4 == nil || s.conf.BlockingIPv6 == nil { - return fmt.Errorf("dns: invalid custom blocking IP address specified") - } - } +// Prepare initializes parameters of s using data from conf. conf must not be +// nil. +func (s *Server) Prepare(conf *ServerConfig) (err error) { + s.conf = *conf + + err = validateBlockingMode(s.conf.BlockingMode, s.conf.BlockingIPv4, s.conf.BlockingIPv6) + if err != nil { + return fmt.Errorf("checking blocking mode: %w", err) } - // Set default values in the case if nothing is configured - // -- s.initDefaultSettings() - // Initialize ipset configuration - // -- - err := s.ipset.init(s.conf.IpsetList) + err = s.ipset.init(s.conf.IpsetList) if err != nil { + // Don't wrap the error, because it's informative enough as is. return err } - log.Debug("inited ipset") - - // Prepare DNS servers settings - // -- err = s.prepareUpstreamSettings() if err != nil { - return err + return fmt.Errorf("preparing upstream settings: %w", err) } - // Create DNS proxy configuration - // -- var proxyConfig proxy.Config proxyConfig, err = s.createProxyConfig() if err != nil { - return err + return fmt.Errorf("preparing proxy: %w", err) } - // Prepare a DNS proxy instance that we use for internal DNS queries - // -- - s.prepareInternalProxy() - - s.access, err = newAccessCtx(s.conf.AllowedClients, s.conf.DisallowedClients, s.conf.BlockedHosts) + err = s.prepareInternalProxy() if err != nil { - return err + return fmt.Errorf("preparing internal proxy: %w", err) + } + + s.access, err = newAccessCtx( + s.conf.AllowedClients, + s.conf.DisallowedClients, + s.conf.BlockedHosts, + ) + if err != nil { + return fmt.Errorf("preparing access: %w", err) } - // Register web handlers if necessary - // -- if !webRegistered && s.conf.HTTPRegister != nil { webRegistered = true s.registerHandlers() } - // Create the main DNS proxy instance - // -- s.dnsProxy = &proxy.Proxy{Config: proxyConfig} err = s.setupResolvers(s.conf.LocalPTRResolvers) @@ -505,6 +494,61 @@ func (s *Server) Prepare(config *ServerConfig) error { return nil } +// validateBlockingMode returns an error if the blocking mode data aren't valid. +func validateBlockingMode(mode BlockingMode, blockingIPv4, blockingIPv6 net.IP) (err error) { + switch mode { + case + BlockingModeDefault, + BlockingModeNXDOMAIN, + BlockingModeREFUSED, + BlockingModeNullIP: + return nil + case BlockingModeCustomIP: + if blockingIPv4 == nil { + return fmt.Errorf("blocking_ipv4 must be set when blocking_mode is custom_ip") + } else if blockingIPv6 == nil { + return fmt.Errorf("blocking_ipv6 must be set when blocking_mode is custom_ip") + } + + return nil + default: + return fmt.Errorf("bad blocking mode %q", mode) + } +} + +// prepareInternalProxy initializes the DNS proxy that is used for internal DNS +// queries, such at client PTR resolving and updater hostname resolving. +func (s *Server) prepareInternalProxy() (err error) { + conf := &proxy.Config{ + CacheEnabled: true, + CacheSizeBytes: 4096, + UpstreamConfig: s.conf.UpstreamConfig, + MaxGoroutines: int(s.conf.MaxGoroutines), + } + + srvConf := s.conf + setProxyUpstreamMode( + conf, + srvConf.AllServers, + srvConf.FastestAddr, + srvConf.FastestTimeout.Duration, + ) + + // TODO(a.garipov): Make a proper constructor for proxy.Proxy. + p := &proxy.Proxy{ + Config: *conf, + } + + err = p.Init() + if err != nil { + return err + } + + s.internalProxy = p + + return nil +} + // Stop stops the DNS server. func (s *Server) Stop() error { s.serverLock.Lock() @@ -550,7 +594,7 @@ func (s *Server) proxy() (p *proxy.Proxy) { } // Reconfigure applies the new configuration to the DNS server. -func (s *Server) Reconfigure(config *ServerConfig) error { +func (s *Server) Reconfigure(conf *ServerConfig) error { s.serverLock.Lock() defer s.serverLock.Unlock() @@ -564,7 +608,12 @@ func (s *Server) Reconfigure(config *ServerConfig) error { // We wait for some time and hope that this fd will be closed. time.Sleep(100 * time.Millisecond) - err = s.Prepare(config) + // TODO(a.garipov): This whole piece of API is weird and needs to be remade. + if conf == nil { + conf = &s.conf + } + + err = s.Prepare(conf) if err != nil { return fmt.Errorf("could not reconfigure the server: %w", err) } diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 1da83ad4..a76d5a41 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -78,9 +78,11 @@ func createTestServer( }) require.NoError(t, err) - s.conf = forwardConf + if forwardConf.BlockingMode == "" { + forwardConf.BlockingMode = BlockingModeDefault + } - err = s.Prepare(nil) + err = s.Prepare(&forwardConf) require.NoError(t, err) s.serverLock.Lock() @@ -152,7 +154,7 @@ func createTestTLS(t *testing.T, tlsConf TLSConfig) (s *Server, certPem []byte) tlsConf.CertificateChainData, tlsConf.PrivateKeyData = certPem, keyPem s.conf.TLSConfig = tlsConf - err := s.Prepare(nil) + err := s.Prepare(&s.conf) require.NoErrorf(t, err, "failed to prepare server: %s", err) return s, certPem @@ -286,6 +288,9 @@ func TestServer_timeout(t *testing.T) { t.Run("custom", func(t *testing.T) { srvConf := &ServerConfig{ UpstreamTimeout: timeout, + FilteringConfig: FilteringConfig{ + BlockingMode: BlockingModeDefault, + }, } s, err := NewServer(DNSCreateParams{}) @@ -301,7 +306,8 @@ func TestServer_timeout(t *testing.T) { s, err := NewServer(DNSCreateParams{}) require.NoError(t, err) - err = s.Prepare(nil) + s.conf.FilteringConfig.BlockingMode = BlockingModeDefault + err = s.Prepare(&s.conf) require.NoError(t, err) assert.Equal(t, DefaultTimeout, s.conf.UpstreamTimeout) @@ -915,6 +921,7 @@ func TestRewrite(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{{}}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, + BlockingMode: BlockingModeDefault, UpstreamDNS: []string{"8.8.8.8:53"}, }, })) @@ -1026,9 +1033,10 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} - s.conf.ProtectionEnabled = true + s.conf.FilteringConfig.ProtectionEnabled = true + s.conf.FilteringConfig.BlockingMode = BlockingModeDefault - err = s.Prepare(nil) + err = s.Prepare(&s.conf) require.NoError(t, err) err = s.Start() @@ -1098,8 +1106,9 @@ func TestPTRResponseFromHosts(t *testing.T) { s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} + s.conf.FilteringConfig.BlockingMode = BlockingModeDefault - err = s.Prepare(nil) + err = s.Prepare(&s.conf) require.NoError(t, err) err = s.Start() diff --git a/internal/dnsforward/filter_test.go b/internal/dnsforward/filter_test.go index 69dcf8f9..e25d4037 100644 --- a/internal/dnsforward/filter_test.go +++ b/internal/dnsforward/filter_test.go @@ -45,8 +45,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { }) require.NoError(t, err) - s.conf = forwardConf - err = s.Prepare(nil) + err = s.Prepare(&forwardConf) require.NoError(t, err) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 49b65be4..5df74fe8 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -20,16 +20,14 @@ import ( "golang.org/x/exp/slices" ) -type dnsConfig struct { - Upstreams *[]string `json:"upstream_dns"` - UpstreamsFile *string `json:"upstream_dns_file"` - Bootstraps *[]string `json:"bootstrap_dns"` - +// jsonDNSConfig is the JSON representation of the DNS server configuration. +type jsonDNSConfig struct { + Upstreams *[]string `json:"upstream_dns"` + UpstreamsFile *string `json:"upstream_dns_file"` + Bootstraps *[]string `json:"bootstrap_dns"` ProtectionEnabled *bool `json:"protection_enabled"` RateLimit *uint32 `json:"ratelimit"` BlockingMode *BlockingMode `json:"blocking_mode"` - BlockingIPv4 net.IP `json:"blocking_ipv4"` - BlockingIPv6 net.IP `json:"blocking_ipv6"` EDNSCSEnabled *bool `json:"edns_cs_enabled"` DNSSECEnabled *bool `json:"dnssec_enabled"` DisableIPv6 *bool `json:"disable_ipv6"` @@ -41,9 +39,11 @@ type dnsConfig struct { ResolveClients *bool `json:"resolve_clients"` UsePrivateRDNS *bool `json:"use_private_ptr_resolvers"` LocalPTRUpstreams *[]string `json:"local_ptr_upstreams"` + BlockingIPv4 net.IP `json:"blocking_ipv4"` + BlockingIPv6 net.IP `json:"blocking_ipv6"` } -func (s *Server) getDNSConfig() (c *dnsConfig) { +func (s *Server) getDNSConfig() (c *jsonDNSConfig) { s.serverLock.RLock() defer s.serverLock.RUnlock() @@ -72,7 +72,7 @@ func (s *Server) getDNSConfig() (c *dnsConfig) { upstreamMode = "parallel" } - return &dnsConfig{ + return &jsonDNSConfig{ Upstreams: &upstreams, UpstreamsFile: &upstreamFile, Bootstraps: &bootstraps, @@ -102,13 +102,13 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { } resp := struct { - dnsConfig + jsonDNSConfig // DefautLocalPTRUpstreams is used to pass the addresses from // systemResolvers to the front-end. It's not a pointer to the slice // since there is no need to omit it while decoding from JSON. DefautLocalPTRUpstreams []string `json:"default_local_ptr_upstreams,omitempty"` }{ - dnsConfig: *s.getDNSConfig(), + jsonDNSConfig: *s.getDNSConfig(), DefautLocalPTRUpstreams: defLocalPTRUps, } @@ -121,31 +121,21 @@ func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { } } -func (req *dnsConfig) checkBlockingMode() bool { +func (req *jsonDNSConfig) checkBlockingMode() (err error) { if req.BlockingMode == nil { - return true + return nil } - switch bm := *req.BlockingMode; bm { - case BlockingModeDefault, - BlockingModeREFUSED, - BlockingModeNXDOMAIN, - BlockingModeNullIP: - return true - case BlockingModeCustomIP: - return req.BlockingIPv4.To4() != nil && req.BlockingIPv6 != nil - default: - return false - } + return validateBlockingMode(*req.BlockingMode, req.BlockingIPv4, req.BlockingIPv6) } -func (req *dnsConfig) checkUpstreamsMode() bool { +func (req *jsonDNSConfig) checkUpstreamsMode() bool { valid := []string{"", "fastest_addr", "parallel"} return req.UpstreamMode == nil || stringutil.InSlice(valid, *req.UpstreamMode) } -func (req *dnsConfig) checkBootstrap() (err error) { +func (req *jsonDNSConfig) checkBootstrap() (err error) { if req.Bootstraps == nil { return nil } @@ -167,7 +157,7 @@ func (req *dnsConfig) checkBootstrap() (err error) { } // validate returns an error if any field of req is invalid. -func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) { +func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) { if req.Upstreams != nil { err = ValidateUpstreams(*req.Upstreams) if err != nil { @@ -187,9 +177,12 @@ func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) { return err } + err = req.checkBlockingMode() + if err != nil { + return err + } + switch { - case !req.checkBlockingMode(): - return errors.Error("blocking_mode: incorrect value") case !req.checkUpstreamsMode(): return errors.Error("upstream_mode: incorrect value") case !req.checkCacheTTL(): @@ -199,7 +192,7 @@ func (req *dnsConfig) validate(privateNets netutil.SubnetSet) (err error) { } } -func (req *dnsConfig) checkCacheTTL() bool { +func (req *jsonDNSConfig) checkCacheTTL() bool { if req.CacheMinTTL == nil && req.CacheMaxTTL == nil { return true } @@ -216,7 +209,7 @@ func (req *dnsConfig) checkCacheTTL() bool { } func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { - req := &dnsConfig{} + req := &jsonDNSConfig{} err := json.NewDecoder(r.Body).Decode(req) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "decoding request: %s", err) @@ -242,100 +235,75 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) setConfigRestartable(dc *dnsConfig) (restart bool) { - if dc.Upstreams != nil { - s.conf.UpstreamDNS = *dc.Upstreams - restart = true - } - - if dc.LocalPTRUpstreams != nil { - s.conf.LocalPTRResolvers = *dc.LocalPTRUpstreams - restart = true - } - - if dc.UpstreamsFile != nil { - s.conf.UpstreamDNSFileName = *dc.UpstreamsFile - restart = true - } - - if dc.Bootstraps != nil { - s.conf.BootstrapDNS = *dc.Bootstraps - restart = true - } - - if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit { - s.conf.Ratelimit = *dc.RateLimit - restart = true - } - - if dc.EDNSCSEnabled != nil { - s.conf.EnableEDNSClientSubnet = *dc.EDNSCSEnabled - restart = true - } - - if dc.CacheSize != nil { - s.conf.CacheSize = *dc.CacheSize - restart = true - } - - if dc.CacheMinTTL != nil { - s.conf.CacheMinTTL = *dc.CacheMinTTL - restart = true - } - - if dc.CacheMaxTTL != nil { - s.conf.CacheMaxTTL = *dc.CacheMaxTTL - restart = true - } - - if dc.CacheOptimistic != nil { - s.conf.CacheOptimistic = *dc.CacheOptimistic - restart = true - } - - return restart -} - -func (s *Server) setConfig(dc *dnsConfig) (restart bool) { +// setConfigRestartable sets the server parameters. shouldRestart is true if +// the server should be restarted to apply changes. +func (s *Server) setConfig(dc *jsonDNSConfig) (shouldRestart bool) { s.serverLock.Lock() defer s.serverLock.Unlock() - if dc.ProtectionEnabled != nil { - s.conf.ProtectionEnabled = *dc.ProtectionEnabled - } - if dc.BlockingMode != nil { s.conf.BlockingMode = *dc.BlockingMode - if *dc.BlockingMode == "custom_ip" { + if *dc.BlockingMode == BlockingModeCustomIP { s.conf.BlockingIPv4 = dc.BlockingIPv4.To4() s.conf.BlockingIPv6 = dc.BlockingIPv6.To16() } } - if dc.DNSSECEnabled != nil { - s.conf.EnableDNSSEC = *dc.DNSSECEnabled - } - - if dc.DisableIPv6 != nil { - s.conf.AAAADisabled = *dc.DisableIPv6 - } - if dc.UpstreamMode != nil { s.conf.AllServers = *dc.UpstreamMode == "parallel" s.conf.FastestAddr = *dc.UpstreamMode == "fastest_addr" } - if dc.ResolveClients != nil { - s.conf.ResolveClients = *dc.ResolveClients - } - - if dc.UsePrivateRDNS != nil { - s.conf.UsePrivateRDNS = *dc.UsePrivateRDNS - } + setIfNotNil(&s.conf.ProtectionEnabled, dc.ProtectionEnabled) + setIfNotNil(&s.conf.EnableDNSSEC, dc.DNSSECEnabled) + setIfNotNil(&s.conf.AAAADisabled, dc.DisableIPv6) + setIfNotNil(&s.conf.ResolveClients, dc.ResolveClients) + setIfNotNil(&s.conf.UsePrivateRDNS, dc.UsePrivateRDNS) return s.setConfigRestartable(dc) } +// setIfNotNil sets the value pointed at by currentPtr to the value pointed at +// by newPtr if newPtr is not nil. currentPtr must not be nil. +func setIfNotNil[T any](currentPtr, newPtr *T) (hasSet bool) { + if newPtr == nil { + return false + } + + *currentPtr = *newPtr + + return true +} + +// setConfigRestartable sets the parameters which trigger a restart. +// shouldRestart is true if the server should be restarted to apply changes. +// s.serverLock is expected to be locked. +func (s *Server) setConfigRestartable(dc *jsonDNSConfig) (shouldRestart bool) { + for _, hasSet := range []bool{ + setIfNotNil(&s.conf.UpstreamDNS, dc.Upstreams), + setIfNotNil(&s.conf.LocalPTRResolvers, dc.LocalPTRUpstreams), + setIfNotNil(&s.conf.UpstreamDNSFileName, dc.UpstreamsFile), + setIfNotNil(&s.conf.BootstrapDNS, dc.Bootstraps), + setIfNotNil(&s.conf.EnableEDNSClientSubnet, dc.EDNSCSEnabled), + setIfNotNil(&s.conf.CacheSize, dc.CacheSize), + setIfNotNil(&s.conf.CacheMinTTL, dc.CacheMinTTL), + setIfNotNil(&s.conf.CacheMaxTTL, dc.CacheMaxTTL), + setIfNotNil(&s.conf.CacheOptimistic, dc.CacheOptimistic), + } { + shouldRestart = shouldRestart || hasSet + if shouldRestart { + break + } + } + + if dc.RateLimit != nil && s.conf.Ratelimit != *dc.RateLimit { + s.conf.Ratelimit = *dc.RateLimit + shouldRestart = true + } + + return shouldRestart +} + // upstreamJSON is a request body for handleTestUpstreamDNS endpoint. type upstreamJSON struct { Upstreams []string `json:"upstream_dns"` diff --git a/internal/dnsforward/http_test.go b/internal/dnsforward/http_test.go index 8863c6c6..1ae39455 100644 --- a/internal/dnsforward/http_test.go +++ b/internal/dnsforward/http_test.go @@ -62,6 +62,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, + BlockingMode: BlockingModeDefault, UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, }, ConfigModified: func() {}, @@ -135,6 +136,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { TCPListenAddrs: []*net.TCPAddr{}, FilteringConfig: FilteringConfig{ ProtectionEnabled: true, + BlockingMode: BlockingModeDefault, UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, }, ConfigModified: func() {}, @@ -164,7 +166,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { wantSet: "", }, { name: "blocking_mode_bad", - wantSet: "blocking_mode: incorrect value", + wantSet: "blocking_ipv4 must be set when blocking_mode is custom_ip", }, { name: "ratelimit", wantSet: "", diff --git a/internal/dnsforward/msg.go b/internal/dnsforward/msg.go index c5333ba2..e6bd178e 100644 --- a/internal/dnsforward/msg.go +++ b/internal/dnsforward/msg.go @@ -37,66 +37,73 @@ func ipsFromRules(resRules []*filtering.ResultRule) (ips []net.IP) { return ips } -// genDNSFilterMessage generates a DNS message corresponding to the filtering result -func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *filtering.Result) *dns.Msg { - m := d.Req - - if m.Question[0].Qtype != dns.TypeA && m.Question[0].Qtype != dns.TypeAAAA { +// genDNSFilterMessage generates a filtered response to req for the filtering +// result res. +func (s *Server) genDNSFilterMessage( + dctx *proxy.DNSContext, + res *filtering.Result, +) (resp *dns.Msg) { + req := dctx.Req + if qt := req.Question[0].Qtype; qt != dns.TypeA && qt != dns.TypeAAAA { if s.conf.BlockingMode == BlockingModeNullIP { - return s.makeResponse(m) + return s.makeResponse(req) } - return s.genNXDomain(m) + + return s.genNXDomain(req) } - ips := ipsFromRules(result.Rules) - switch result.Reason { + switch res.Reason { case filtering.FilteredSafeBrowsing: - return s.genBlockedHost(m, s.conf.SafeBrowsingBlockHost, d) + return s.genBlockedHost(req, s.conf.SafeBrowsingBlockHost, dctx) case filtering.FilteredParental: - return s.genBlockedHost(m, s.conf.ParentalBlockHost, d) + return s.genBlockedHost(req, s.conf.ParentalBlockHost, dctx) default: - // If the query was filtered by "Safe search", filtering also must return - // the IP address that must be used in response. - // In this case regardless of the filtering method, we should return it - if result.Reason == filtering.FilteredSafeSearch && len(ips) > 0 { - return s.genResponseWithIPs(m, ips) + // If the query was filtered by Safe Search, filtering also must return + // the IP addresses that must be used in response. Return them + // regardless of the filtering method. + ips := ipsFromRules(res.Rules) + if res.Reason == filtering.FilteredSafeSearch && len(ips) > 0 { + return s.genResponseWithIPs(req, ips) } - switch s.conf.BlockingMode { - case BlockingModeCustomIP: - switch m.Question[0].Qtype { - case dns.TypeA: - return s.genARecord(m, s.conf.BlockingIPv4) - case dns.TypeAAAA: - return s.genAAAARecord(m, s.conf.BlockingIPv6) - default: - // Generally shouldn't happen, since the types - // are checked above. - log.Error( - "dns: invalid msg type %d for blocking mode %s", - m.Question[0].Qtype, - s.conf.BlockingMode, - ) + return s.genForBlockingMode(req, ips) + } +} - return s.makeResponse(m) - } - case BlockingModeDefault: - if len(ips) > 0 { - return s.genResponseWithIPs(m, ips) - } - - return s.makeResponseNullIP(m) - case BlockingModeNullIP: - return s.makeResponseNullIP(m) - case BlockingModeNXDOMAIN: - return s.genNXDomain(m) - case BlockingModeREFUSED: - return s.makeResponseREFUSED(m) +// genForBlockingMode generates a filtered response to req based on the server's +// blocking mode. +func (s *Server) genForBlockingMode(req *dns.Msg, ips []net.IP) (resp *dns.Msg) { + qt := req.Question[0].Qtype + switch m := s.conf.BlockingMode; m { + case BlockingModeCustomIP: + switch qt { + case dns.TypeA: + return s.genARecord(req, s.conf.BlockingIPv4) + case dns.TypeAAAA: + return s.genAAAARecord(req, s.conf.BlockingIPv6) default: - log.Error("dns: invalid blocking mode %q", s.conf.BlockingMode) + // Generally shouldn't happen, since the types are checked in + // genDNSFilterMessage. + log.Error("dns: invalid msg type %s for blocking mode %s", dns.Type(qt), m) - return s.makeResponse(m) + return s.makeResponse(req) } + case BlockingModeDefault: + if len(ips) > 0 { + return s.genResponseWithIPs(req, ips) + } + + return s.makeResponseNullIP(req) + case BlockingModeNullIP: + return s.makeResponseNullIP(req) + case BlockingModeNXDOMAIN: + return s.genNXDomain(req) + case BlockingModeREFUSED: + return s.makeResponseREFUSED(req) + default: + log.Error("dns: invalid blocking mode %q", s.conf.BlockingMode) + + return s.makeResponse(req) } } diff --git a/internal/dnsforward/testdata/TestDNSForwardHTTP_handleGetConfig.json b/internal/dnsforward/testdata/TestDNSForwardHTTP_handleGetConfig.json index 538865c5..3ac6f2f5 100644 --- a/internal/dnsforward/testdata/TestDNSForwardHTTP_handleGetConfig.json +++ b/internal/dnsforward/testdata/TestDNSForwardHTTP_handleGetConfig.json @@ -13,7 +13,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -42,7 +42,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -71,7 +71,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -86,4 +86,4 @@ "use_private_ptr_resolvers": false, "local_ptr_upstreams": [] } -} \ No newline at end of file +} diff --git a/internal/dnsforward/testdata/TestDNSForwardHTTP_handleSetConfig.json b/internal/dnsforward/testdata/TestDNSForwardHTTP_handleSetConfig.json index 830bf491..f55359a9 100644 --- a/internal/dnsforward/testdata/TestDNSForwardHTTP_handleSetConfig.json +++ b/internal/dnsforward/testdata/TestDNSForwardHTTP_handleSetConfig.json @@ -20,7 +20,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -53,7 +53,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -121,7 +121,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -155,7 +155,7 @@ ], "protection_enabled": true, "ratelimit": 6, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -189,7 +189,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": true, @@ -223,7 +223,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -257,7 +257,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -291,7 +291,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -325,7 +325,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -361,7 +361,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -397,7 +397,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -432,7 +432,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -466,7 +466,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -502,7 +502,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -541,7 +541,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, @@ -575,7 +575,7 @@ ], "protection_enabled": true, "ratelimit": 0, - "blocking_mode": "", + "blocking_mode": "default", "blocking_ipv4": "", "blocking_ipv6": "", "edns_cs_enabled": false, diff --git a/internal/home/config.go b/internal/home/config.go index 5bd9b98b..7ac6470e 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -203,9 +203,9 @@ var config = &configuration{ Port: defaultPortDNS, StatsInterval: 1, FilteringConfig: dnsforward.FilteringConfig{ - ProtectionEnabled: true, // whether or not use any of filtering features - BlockingMode: "default", // mode how to answer filtered requests - BlockedResponseTTL: 10, // in seconds + ProtectionEnabled: true, // whether or not use any of filtering features + BlockingMode: dnsforward.BlockingModeDefault, + BlockedResponseTTL: 10, // in seconds Ratelimit: 20, RefuseAny: true, AllServers: false, diff --git a/scripts/make/go-lint.sh b/scripts/make/go-lint.sh index 63e32dd9..8e35efd5 100644 --- a/scripts/make/go-lint.sh +++ b/scripts/make/go-lint.sh @@ -221,8 +221,8 @@ exit_on_output gofumpt --extra -e -l . # Apply more lax standards to the code we haven't properly refactored yet. gocyclo --over 17 ./internal/querylog/ -gocyclo --over 16 ./internal/dnsforward/ gocyclo --over 15 ./internal/home/ ./internal/dhcpd +gocyclo --over 14 ./internal/dnsforward/ gocyclo --over 13 ./internal/filtering/ # Apply stricter standards to new or somewhat refactored code.