From 9200163f85e009680fe88457d4bc01a6b51c59de Mon Sep 17 00:00:00 2001 From: Ainar Garipov Date: Wed, 17 Aug 2022 18:23:30 +0300 Subject: [PATCH] all: sync with master --- client/src/__locales/da.json | 1 + client/src/__locales/es.json | 6 +- internal/aghalg/aghalg.go | 14 ++ internal/aghnet/hostscontainer_test.go | 2 +- internal/aghnet/net.go | 11 +- internal/aghnet/net_test.go | 4 +- internal/aghos/aghos_test.go | 2 +- internal/aghos/filewalker_internal_test.go | 57 +++++++ internal/aghos/filewalker_test.go | 52 +------ internal/aghtest/exchanger.go | 20 --- internal/aghtest/fswatcher.go | 23 --- internal/aghtest/interface.go | 135 ++++++++++++++++ internal/aghtest/interface_test.go | 9 ++ internal/aghtest/testfs.go | 46 ------ internal/aghtest/upstream.go | 154 +++++++++++-------- internal/dnsforward/config.go | 3 +- internal/dnsforward/dns.go | 97 ++++++++++++ internal/dnsforward/dns_test.go | 171 +++++++++++++++++++++ internal/dnsforward/dnsforward_test.go | 71 +++++---- internal/filtering/filtering_test.go | 92 +++++------ internal/filtering/safebrowsing.go | 4 +- internal/filtering/safebrowsing_test.go | 30 ++-- internal/home/config.go | 1 + internal/home/controlinstall.go | 2 +- internal/home/rdns_test.go | 59 ++++--- 25 files changed, 740 insertions(+), 326 deletions(-) create mode 100644 internal/aghos/filewalker_internal_test.go delete mode 100644 internal/aghtest/exchanger.go delete mode 100644 internal/aghtest/fswatcher.go create mode 100644 internal/aghtest/interface.go create mode 100644 internal/aghtest/interface_test.go delete mode 100644 internal/aghtest/testfs.go diff --git a/client/src/__locales/da.json b/client/src/__locales/da.json index 7e800e1c..32bec1e2 100644 --- a/client/src/__locales/da.json +++ b/client/src/__locales/da.json @@ -222,6 +222,7 @@ "updated_upstream_dns_toast": "Upstream-servere er gemt", "dns_test_ok_toast": "Angivne DNS-servere fungerer korrekt", "dns_test_not_ok_toast": "Server \"{{key}}\": Kunne ikke bruges. Tjek, at du har angivet den korrekt", + "dns_test_warning_toast": "Upstream \"{{key}}\" svarer ikke på testforespørgsler og fungerer muligvis ikke korrekt", "unblock": "Afblokering", "block": "Blokering", "disallow_this_client": "Afvis denne klient", diff --git a/client/src/__locales/es.json b/client/src/__locales/es.json index 56895b58..895d0c33 100644 --- a/client/src/__locales/es.json +++ b/client/src/__locales/es.json @@ -47,7 +47,7 @@ "form_error_server_name": "Nombre de servidor no válido", "form_error_subnet": "La subred \"{{cidr}}\" no contiene la dirección IP \"{{ip}}\"", "form_error_positive": "Debe ser mayor que 0", - "form_error_gateway_ip": "Asignación no puede tener la dirección IP de la puerta de enlace", + "form_error_gateway_ip": "La asignación no puede tener la dirección IP de la puerta de enlace", "out_of_range_error": "Debe estar fuera del rango \"{{start}}\"-\"{{end}}\"", "lower_range_start_error": "Debe ser inferior que el inicio de rango", "greater_range_start_error": "Debe ser mayor que el inicio de rango", @@ -222,7 +222,7 @@ "updated_upstream_dns_toast": "Servidores DNS de subida guardados correctamente", "dns_test_ok_toast": "Los servidores DNS especificados funcionan correctamente", "dns_test_not_ok_toast": "Servidor \"{{key}}\": no se puede utilizar, por favor revisa si lo has escrito correctamente", - "dns_test_warning_toast": "Upstream \"{{key}}\" no responde a las peticiones de prueba y es posible que no funcione correctamente", + "dns_test_warning_toast": "DNS de subida \"{{key}}\" no responde a las peticiones de prueba y es posible que no funcione correctamente", "unblock": "Desbloquear", "block": "Bloquear", "disallow_this_client": "No permitir a este cliente", @@ -364,7 +364,7 @@ "encryption_config_saved": "Configuración de cifrado guardado", "encryption_server": "Nombre del servidor", "encryption_server_enter": "Ingresa el nombre del dominio", - "encryption_server_desc": "Si se configura, AdGuard Home detecta los ClientID, responde a las consultas DDR y realiza validaciones de conexión adicionales. Si no se configura, estas funciones están deshabilitadas. Debe coincidir con uno de los nombres DNS del certificado.", + "encryption_server_desc": "Si se configura, AdGuard Home detecta los ID de clientes, responde a las consultas DDR y realiza validaciones de conexión adicionales. Si no se configura, estas funciones se deshabilitarán. Debe coincidir con uno de los nombres DNS del certificado.", "encryption_redirect": "Redireccionar a HTTPS automáticamente", "encryption_redirect_desc": "Si está marcado, AdGuard Home redireccionará automáticamente de HTTP a las direcciones HTTPS.", "encryption_https": "Puerto HTTPS", diff --git a/internal/aghalg/aghalg.go b/internal/aghalg/aghalg.go index 65e81cbc..f0b71a09 100644 --- a/internal/aghalg/aghalg.go +++ b/internal/aghalg/aghalg.go @@ -10,6 +10,20 @@ import ( "golang.org/x/exp/slices" ) +// Coalesce returns the first non-zero value. It is named after the function +// COALESCE in SQL. If values or all its elements are empty, it returns a zero +// value. +func Coalesce[T comparable](values ...T) (res T) { + var zero T + for _, v := range values { + if v != zero { + return v + } + } + + return zero +} + // UniqChecker allows validating uniqueness of comparable items. // // TODO(a.garipov): The Ordered constraint is only really necessary in Validate. diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 019c713e..1f75a3c9 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -470,7 +470,7 @@ func TestHostsContainer(t *testing.T) { }}, }, { req: &urlfilter.DNSRequest{ - Hostname: "nonexisting", + Hostname: "nonexistent.example", DNSType: dns.TypeA, }, name: "non-existing", diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 268380bd..2de9c630 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -154,10 +154,13 @@ func GetValidNetInterfacesForWeb() (netIfaces []*NetInterface, err error) { return netIfaces, nil } -// GetInterfaceByIP returns the name of interface containing provided ip. +// InterfaceByIP returns the name of the interface bound to ip. // -// TODO(e.burkov): See TODO on GetValidInterfacesForWeb. -func GetInterfaceByIP(ip net.IP) string { +// TODO(a.garipov, e.burkov): This function is technically incorrect, since one +// IP address can be shared by multiple interfaces in some configurations. +// +// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb. +func InterfaceByIP(ip net.IP) (ifaceName string) { ifaces, err := GetValidNetInterfacesForWeb() if err != nil { return "" @@ -177,7 +180,7 @@ func GetInterfaceByIP(ip net.IP) string { // GetSubnet returns pointer to net.IPNet for the specified interface or nil if // the search fails. // -// TODO(e.burkov): See TODO on GetValidInterfacesForWeb. +// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb. func GetSubnet(ifaceName string) *net.IPNet { netIfaces, err := GetValidNetInterfacesForWeb() if err != nil { diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index 40d395ba..d4ee59ee 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -132,7 +132,7 @@ func TestGatewayIP(t *testing.T) { } } -func TestGetInterfaceByIP(t *testing.T) { +func TestInterfaceByIP(t *testing.T) { ifaces, err := GetValidNetInterfacesForWeb() require.NoError(t, err) require.NotEmpty(t, ifaces) @@ -142,7 +142,7 @@ func TestGetInterfaceByIP(t *testing.T) { require.NotEmpty(t, iface.Addresses) for _, ip := range iface.Addresses { - ifaceName := GetInterfaceByIP(ip) + ifaceName := InterfaceByIP(ip) require.Equal(t, iface.Name, ifaceName) } }) diff --git a/internal/aghos/aghos_test.go b/internal/aghos/aghos_test.go index e68c26b7..684f646e 100644 --- a/internal/aghos/aghos_test.go +++ b/internal/aghos/aghos_test.go @@ -1,4 +1,4 @@ -package aghos +package aghos_test import ( "testing" diff --git a/internal/aghos/filewalker_internal_test.go b/internal/aghos/filewalker_internal_test.go new file mode 100644 index 00000000..bb162812 --- /dev/null +++ b/internal/aghos/filewalker_internal_test.go @@ -0,0 +1,57 @@ +package aghos + +import ( + "io/fs" + "path" + "testing" + "testing/fstest" + + "github.com/AdguardTeam/golibs/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// errFS is an fs.FS implementation, method Open of which always returns +// errFSOpen. +type errFS struct{} + +// errFSOpen is returned from errGlobFS.Open. +const errFSOpen errors.Error = "test open error" + +// Open implements the fs.FS interface for *errGlobFS. fsys is always nil and +// err is always errFSOpen. +func (efs *errFS) Open(name string) (fsys fs.File, err error) { + return nil, errFSOpen +} + +func TestWalkerFunc_CheckFile(t *testing.T) { + emptyFS := fstest.MapFS{} + + t.Run("non-existing", func(t *testing.T) { + _, ok, err := checkFile(emptyFS, nil, "lol") + require.NoError(t, err) + + assert.True(t, ok) + }) + + t.Run("invalid_argument", func(t *testing.T) { + _, ok, err := checkFile(&errFS{}, nil, "") + require.ErrorIs(t, err, errFSOpen) + + assert.False(t, ok) + }) + + t.Run("ignore_dirs", func(t *testing.T) { + const dirName = "dir" + + testFS := fstest.MapFS{ + path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}}, + } + + patterns, ok, err := checkFile(testFS, nil, dirName) + require.NoError(t, err) + + assert.Empty(t, patterns) + assert.True(t, ok) + }) +} diff --git a/internal/aghos/filewalker_test.go b/internal/aghos/filewalker_test.go index 97d1a845..94443831 100644 --- a/internal/aghos/filewalker_test.go +++ b/internal/aghos/filewalker_test.go @@ -1,13 +1,13 @@ -package aghos +package aghos_test import ( "bufio" "io" - "io/fs" "path" "testing" "testing/fstest" + "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -16,7 +16,7 @@ import ( func TestFileWalker_Walk(t *testing.T) { const attribute = `000` - makeFileWalker := func(_ string) (fw FileWalker) { + makeFileWalker := func(_ string) (fw aghos.FileWalker) { return func(r io.Reader) (patterns []string, cont bool, err error) { s := bufio.NewScanner(r) for s.Scan() { @@ -113,7 +113,7 @@ func TestFileWalker_Walk(t *testing.T) { f := fstest.MapFS{ filename: &fstest.MapFile{Data: []byte("[]")}, } - ok, err := FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { + ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, cont bool, err error) { s := bufio.NewScanner(r) for s.Scan() { patterns = append(patterns, s.Text()) @@ -134,7 +134,7 @@ func TestFileWalker_Walk(t *testing.T) { "mockfile.txt": &fstest.MapFile{Data: []byte(`mockdata`)}, } - ok, err := FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) { + ok, err := aghos.FileWalker(func(r io.Reader) (patterns []string, ok bool, err error) { return nil, true, rerr }).Walk(f, "*") require.ErrorIs(t, err, rerr) @@ -142,45 +142,3 @@ func TestFileWalker_Walk(t *testing.T) { assert.False(t, ok) }) } - -type errFS struct { - fs.GlobFS -} - -const errErrFSOpen errors.Error = "this error is always returned" - -func (efs *errFS) Open(name string) (fs.File, error) { - return nil, errErrFSOpen -} - -func TestWalkerFunc_CheckFile(t *testing.T) { - emptyFS := fstest.MapFS{} - - t.Run("non-existing", func(t *testing.T) { - _, ok, err := checkFile(emptyFS, nil, "lol") - require.NoError(t, err) - - assert.True(t, ok) - }) - - t.Run("invalid_argument", func(t *testing.T) { - _, ok, err := checkFile(&errFS{}, nil, "") - require.ErrorIs(t, err, errErrFSOpen) - - assert.False(t, ok) - }) - - t.Run("ignore_dirs", func(t *testing.T) { - const dirName = "dir" - - testFS := fstest.MapFS{ - path.Join(dirName, "file"): &fstest.MapFile{Data: []byte{}}, - } - - patterns, ok, err := checkFile(testFS, nil, dirName) - require.NoError(t, err) - - assert.Empty(t, patterns) - assert.True(t, ok) - }) -} diff --git a/internal/aghtest/exchanger.go b/internal/aghtest/exchanger.go deleted file mode 100644 index 2c617814..00000000 --- a/internal/aghtest/exchanger.go +++ /dev/null @@ -1,20 +0,0 @@ -package aghtest - -import ( - "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/miekg/dns" -) - -// Exchanger is a mock aghnet.Exchanger implementation for tests. -type Exchanger struct { - Ups upstream.Upstream -} - -// Exchange implements aghnet.Exchanger interface for *Exchanger. -func (e *Exchanger) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { - if e.Ups == nil { - e.Ups = &TestErrUpstream{} - } - - return e.Ups.Exchange(req) -} diff --git a/internal/aghtest/fswatcher.go b/internal/aghtest/fswatcher.go deleted file mode 100644 index 0df4470d..00000000 --- a/internal/aghtest/fswatcher.go +++ /dev/null @@ -1,23 +0,0 @@ -package aghtest - -// FSWatcher is a mock aghos.FSWatcher implementation to use in tests. -type FSWatcher struct { - OnEvents func() (e <-chan struct{}) - OnAdd func(name string) (err error) - OnClose func() (err error) -} - -// Events implements the aghos.FSWatcher interface for *FSWatcher. -func (w *FSWatcher) Events() (e <-chan struct{}) { - return w.OnEvents() -} - -// Add implements the aghos.FSWatcher interface for *FSWatcher. -func (w *FSWatcher) Add(name string) (err error) { - return w.OnAdd(name) -} - -// Close implements the aghos.FSWatcher interface for *FSWatcher. -func (w *FSWatcher) Close() (err error) { - return w.OnClose() -} diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go new file mode 100644 index 00000000..2de9d372 --- /dev/null +++ b/internal/aghtest/interface.go @@ -0,0 +1,135 @@ +package aghtest + +import ( + "io/fs" + "net" + + "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/miekg/dns" +) + +// Interface Mocks +// +// Keep entities in this file in alphabetic order. + +// Standard Library + +// type check +var _ fs.FS = &FS{} + +// FS is a mock [fs.FS] implementation for tests. +type FS struct { + OnOpen func(name string) (fs.File, error) +} + +// Open implements the [fs.FS] interface for *FS. +func (fsys *FS) Open(name string) (fs.File, error) { + return fsys.OnOpen(name) +} + +// type check +var _ fs.GlobFS = &GlobFS{} + +// GlobFS is a mock [fs.GlobFS] implementation for tests. +type GlobFS struct { + // FS is embedded here to avoid implementing all it's methods. + FS + OnGlob func(pattern string) ([]string, error) +} + +// Glob implements the [fs.GlobFS] interface for *GlobFS. +func (fsys *GlobFS) Glob(pattern string) ([]string, error) { + return fsys.OnGlob(pattern) +} + +// type check +var _ fs.StatFS = &StatFS{} + +// StatFS is a mock [fs.StatFS] implementation for tests. +type StatFS struct { + // FS is embedded here to avoid implementing all it's methods. + FS + OnStat func(name string) (fs.FileInfo, error) +} + +// Stat implements the [fs.StatFS] interface for *StatFS. +func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) { + return fsys.OnStat(name) +} + +// type check +var _ net.Listener = (*Listener)(nil) + +// Listener is a mock [net.Listener] implementation for tests. +type Listener struct { + OnAccept func() (conn net.Conn, err error) + OnAddr func() (addr net.Addr) + OnClose func() (err error) +} + +// Accept implements the [net.Listener] interface for *Listener. +func (l *Listener) Accept() (conn net.Conn, err error) { + return l.OnAccept() +} + +// Addr implements the [net.Listener] interface for *Listener. +func (l *Listener) Addr() (addr net.Addr) { + return l.OnAddr() +} + +// Close implements the [net.Listener] interface for *Listener. +func (l *Listener) Close() (err error) { + return l.OnClose() +} + +// Module dnsproxy + +// type check +var _ upstream.Upstream = (*UpstreamMock)(nil) + +// UpstreamMock is a mock [upstream.Upstream] implementation for tests. +// +// TODO(a.garipov): Replace with all uses of Upstream with UpstreamMock and +// rename it to just Upstream. +type UpstreamMock struct { + OnAddress func() (addr string) + OnExchange func(req *dns.Msg) (resp *dns.Msg, err error) +} + +// Address implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Address() (addr string) { + return u.OnAddress() +} + +// Exchange implements the [upstream.Upstream] interface for *UpstreamMock. +func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { + return u.OnExchange(req) +} + +// Module AdGuardHome + +// type check +var _ aghos.FSWatcher = (*FSWatcher)(nil) + +// FSWatcher is a mock [aghos.FSWatcher] implementation for tests. +type FSWatcher struct { + OnEvents func() (e <-chan struct{}) + OnAdd func(name string) (err error) + OnClose func() (err error) +} + +// Events implements the [aghos.FSWatcher] interface for *FSWatcher. +func (w *FSWatcher) Events() (e <-chan struct{}) { + return w.OnEvents() +} + +// Add implements the [aghos.FSWatcher] interface for *FSWatcher. +func (w *FSWatcher) Add(name string) (err error) { + return w.OnAdd(name) +} + +// Close implements the [aghos.FSWatcher] interface for *FSWatcher. +func (w *FSWatcher) Close() (err error) { + return w.OnClose() +} diff --git a/internal/aghtest/interface_test.go b/internal/aghtest/interface_test.go new file mode 100644 index 00000000..5a465c2c --- /dev/null +++ b/internal/aghtest/interface_test.go @@ -0,0 +1,9 @@ +package aghtest_test + +import ( + "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" +) + +// type check +var _ aghos.FSWatcher = (*aghtest.FSWatcher)(nil) diff --git a/internal/aghtest/testfs.go b/internal/aghtest/testfs.go deleted file mode 100644 index 88203fec..00000000 --- a/internal/aghtest/testfs.go +++ /dev/null @@ -1,46 +0,0 @@ -package aghtest - -import "io/fs" - -// type check -var _ fs.FS = &FS{} - -// FS is a mock fs.FS implementation to use in tests. -type FS struct { - OnOpen func(name string) (fs.File, error) -} - -// Open implements the fs.FS interface for *FS. -func (fsys *FS) Open(name string) (fs.File, error) { - return fsys.OnOpen(name) -} - -// type check -var _ fs.StatFS = &StatFS{} - -// StatFS is a mock fs.StatFS implementation to use in tests. -type StatFS struct { - // FS is embedded here to avoid implementing all it's methods. - FS - OnStat func(name string) (fs.FileInfo, error) -} - -// Stat implements the fs.StatFS interface for *StatFS. -func (fsys *StatFS) Stat(name string) (fs.FileInfo, error) { - return fsys.OnStat(name) -} - -// type check -var _ fs.GlobFS = &GlobFS{} - -// GlobFS is a mock fs.GlobFS implementation to use in tests. -type GlobFS struct { - // FS is embedded here to avoid implementing all it's methods. - FS - OnGlob func(pattern string) ([]string, error) -} - -// Glob implements the fs.GlobFS interface for *GlobFS. -func (fsys *GlobFS) Glob(pattern string) ([]string, error) { - return fsys.OnGlob(pattern) -} diff --git a/internal/aghtest/upstream.go b/internal/aghtest/upstream.go index 95d8f5ad..699c14b9 100644 --- a/internal/aghtest/upstream.go +++ b/internal/aghtest/upstream.go @@ -6,12 +6,18 @@ import ( "fmt" "net" "strings" - "sync" + "testing" + "github.com/AdguardTeam/golibs/errors" "github.com/miekg/dns" + "github.com/stretchr/testify/require" ) +// Additional Upstream Testing Utilities + // Upstream is a mock implementation of upstream.Upstream. +// +// TODO(a.garipov): Replace with UpstreamMock and rename it to just Upstream. type Upstream struct { // CName is a map of hostname to canonical name. CName map[string][]string @@ -25,6 +31,43 @@ type Upstream struct { Addr string } +// RespondTo returns a response with answer if req has class cl, question type +// qt, and target targ. +func RespondTo(t testing.TB, req *dns.Msg, cl, qt uint16, targ, answer string) (resp *dns.Msg) { + t.Helper() + + require.NotNil(t, req) + require.Len(t, req.Question, 1) + + q := req.Question[0] + targ = dns.Fqdn(targ) + if q.Qclass != cl || q.Qtype != qt || q.Name != targ { + return nil + } + + respHdr := dns.RR_Header{ + Name: targ, + Rrtype: qt, + Class: cl, + Ttl: 60, + } + + resp = new(dns.Msg).SetReply(req) + switch qt { + case dns.TypePTR: + resp.Answer = []dns.RR{ + &dns.PTR{ + Hdr: respHdr, + Ptr: answer, + }, + } + default: + t.Fatalf("unsupported question type: %s", dns.Type(qt)) + } + + return resp +} + // Exchange implements the upstream.Upstream interface for *Upstream. // // TODO(a.garipov): Split further into handlers. @@ -76,74 +119,57 @@ func (u *Upstream) Address() string { return u.Addr } -// TestBlockUpstream implements upstream.Upstream interface for replacing real -// upstream in tests. -type TestBlockUpstream struct { - Hostname string - - // lock protects reqNum. - lock sync.RWMutex - reqNum int - - Block bool -} - -// Exchange returns a message unique for TestBlockUpstream's Hostname-Block -// pair. -func (u *TestBlockUpstream) Exchange(r *dns.Msg) (*dns.Msg, error) { - u.lock.Lock() - defer u.lock.Unlock() - u.reqNum++ - - hash := sha256.Sum256([]byte(u.Hostname)) - hashToReturn := hex.EncodeToString(hash[:]) - if !u.Block { - hashToReturn = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28) +// NewBlockUpstream returns an [*UpstreamMock] that works like an upstream that +// supports hash-based safe-browsing/adult-blocking feature. If shouldBlock is +// true, hostname's actual hash is returned, blocking it. Otherwise, it returns +// a different hash. +func NewBlockUpstream(hostname string, shouldBlock bool) (u *UpstreamMock) { + hash := sha256.Sum256([]byte(hostname)) + hashStr := hex.EncodeToString(hash[:]) + if !shouldBlock { + hashStr = hex.EncodeToString(hash[:])[:2] + strings.Repeat("ab", 28) } - m := &dns.Msg{} - m.SetReply(r) - m.Answer = []dns.RR{ - &dns.TXT{ - Hdr: dns.RR_Header{ - Name: r.Question[0].Name, - }, - Txt: []string{ - hashToReturn, - }, + ans := &dns.TXT{ + Hdr: dns.RR_Header{ + Name: "", + Rrtype: dns.TypeTXT, + Class: dns.ClassINET, + Ttl: 60, + }, + Txt: []string{hashStr}, + } + respTmpl := &dns.Msg{ + Answer: []dns.RR{ans}, + } + + return &UpstreamMock{ + OnAddress: func() (addr string) { + return "sbpc.upstream.example" + }, + OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = respTmpl.Copy() + resp.SetReply(req) + resp.Answer[0].(*dns.TXT).Hdr.Name = req.Question[0].Name + + return resp, nil }, } - - return m, nil } -// Address always returns an empty string. -func (u *TestBlockUpstream) Address() string { - return "" -} +// ErrUpstream is the error returned from the [*UpstreamMock] created by +// [NewErrorUpstream]. +const ErrUpstream errors.Error = "test upstream error" -// RequestsCount returns the number of handled requests. It's safe for -// concurrent use. -func (u *TestBlockUpstream) RequestsCount() int { - u.lock.Lock() - defer u.lock.Unlock() - - return u.reqNum -} - -// TestErrUpstream implements upstream.Upstream interface for replacing real -// upstream in tests. -type TestErrUpstream struct { - // The error returned by Exchange may be unwrapped to the Err. - Err error -} - -// Exchange always returns nil Msg and non-nil error. -func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) { - return nil, fmt.Errorf("errupstream: %w", u.Err) -} - -// Address always returns an empty string. -func (u *TestErrUpstream) Address() string { - return "" +// NewErrorUpstream returns an [*UpstreamMock] that returns [ErrUpstream] from +// its Exchange method. +func NewErrorUpstream() (u *UpstreamMock) { + return &UpstreamMock{ + OnAddress: func() (addr string) { + return "error.upstream.example" + }, + OnExchange: func(_ *dns.Msg) (resp *dns.Msg, err error) { + return nil, errors.Error("test upstream error") + }, + } } diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index 7210ede9..63af9ed1 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -121,6 +121,7 @@ type FilteringConfig struct { EnableDNSSEC bool `yaml:"enable_dnssec"` // Set AD flag in outcoming DNS request EnableEDNSClientSubnet bool `yaml:"edns_client_subnet"` // Enable EDNS Client Subnet option MaxGoroutines uint32 `yaml:"max_goroutines"` // Max. number of parallel goroutines for processing incoming requests + HandleDDR bool `yaml:"handle_ddr"` // Handle DDR requests // IpsetList is the ipset configuration that allows AdGuard Home to add // IP addresses of the specified domain names to an ipset list. Syntax: @@ -151,7 +152,7 @@ type TLSConfig struct { PrivateKeyData []byte `yaml:"-" json:"-"` // ServerName is the hostname of the server. Currently, it is only being - // used for ClientID checking. + // used for ClientID checking and Discovery of Designated Resolvers (DDR). ServerName string `yaml:"-" json:"-"` cert tls.Certificate diff --git a/internal/dnsforward/dns.go b/internal/dnsforward/dns.go index 01e02480..f3c98361 100644 --- a/internal/dnsforward/dns.go +++ b/internal/dnsforward/dns.go @@ -76,6 +76,10 @@ const ( resultCodeError ) +// ddrHostFQDN is the FQDN used in Discovery of Designated Resolvers (DDR) requests. +// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html. +const ddrHostFQDN = "_dns.resolver.arpa." + // handleDNSRequest filters the incoming DNS requests and writes them to the query log func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { ctx := &dnsContext{ @@ -94,6 +98,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error { mods := []modProcessFunc{ s.processRecursion, s.processInitial, + s.processDDRQuery, s.processDetermineLocal, s.processDHCPHosts, s.processRestrictLocal, @@ -239,6 +244,98 @@ func (s *Server) onDHCPLeaseChanged(flags int) { s.setTableIPToHost(ipToHost) } +// processDDRQuery responds to SVCB query for a special use domain name +// ‘_dns.resolver.arpa’. The result contains different types of encryption +// supported by current user configuration. +// +// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html. +func (s *Server) processDDRQuery(ctx *dnsContext) (rc resultCode) { + d := ctx.proxyCtx + question := d.Req.Question[0] + + if !s.conf.HandleDDR { + return resultCodeSuccess + } + + if question.Name == ddrHostFQDN { + if s.dnsProxy.TLSListenAddr == nil && s.conf.HTTPSListenAddrs == nil && + s.dnsProxy.QUICListenAddr == nil || question.Qtype != dns.TypeSVCB { + d.Res = s.makeResponse(d.Req) + + return resultCodeFinish + } + + d.Res = s.makeDDRResponse(d.Req) + + return resultCodeFinish + } + + return resultCodeSuccess +} + +// makeDDRResponse creates DDR answer according to server configuration. The +// contructed SVCB resource records have the priority of 1 for each entry, +// similar to examples provided by https://www.ietf.org/archive/id/draft-ietf-add-ddr-06.html. +// +// TODO(a.meshkov): Consider setting the priority values based on the protocol. +func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) { + resp = s.makeResponse(req) + // TODO(e.burkov): Think about storing the FQDN version of the server's + // name somewhere. + domainName := dns.Fqdn(s.conf.ServerName) + + for _, addr := range s.conf.HTTPSListenAddrs { + values := []dns.SVCBKeyValue{ + &dns.SVCBAlpn{Alpn: []string{"h2"}}, + &dns.SVCBPort{Port: uint16(addr.Port)}, + &dns.SVCBDoHPath{Template: "/dns-query?dns"}, + } + + ans := &dns.SVCB{ + Hdr: s.hdr(req, dns.TypeSVCB), + Priority: 1, + Target: domainName, + Value: values, + } + + resp.Answer = append(resp.Answer, ans) + } + + for _, addr := range s.dnsProxy.TLSListenAddr { + values := []dns.SVCBKeyValue{ + &dns.SVCBAlpn{Alpn: []string{"dot"}}, + &dns.SVCBPort{Port: uint16(addr.Port)}, + } + + ans := &dns.SVCB{ + Hdr: s.hdr(req, dns.TypeSVCB), + Priority: 1, + Target: domainName, + Value: values, + } + + resp.Answer = append(resp.Answer, ans) + } + + for _, addr := range s.dnsProxy.QUICListenAddr { + values := []dns.SVCBKeyValue{ + &dns.SVCBAlpn{Alpn: []string{"doq"}}, + &dns.SVCBPort{Port: uint16(addr.Port)}, + } + + ans := &dns.SVCB{ + Hdr: s.hdr(req, dns.TypeSVCB), + Priority: 1, + Target: domainName, + Value: values, + } + + resp.Answer = append(resp.Answer, ans) + } + + return resp +} + // processDetermineLocal determines if the client's IP address is from // locally-served network and saves the result into the context. func (s *Server) processDetermineLocal(dctx *dnsContext) (rc resultCode) { diff --git a/internal/dnsforward/dns_test.go b/internal/dnsforward/dns_test.go index 27108a70..25e28afd 100644 --- a/internal/dnsforward/dns_test.go +++ b/internal/dnsforward/dns_test.go @@ -14,6 +14,177 @@ import ( "github.com/stretchr/testify/require" ) +const ( + ddrTestDomainName = "dns.example.net" + ddrTestFQDN = ddrTestDomainName + "." +) + +func TestServer_ProcessDDRQuery(t *testing.T) { + dohSVCB := &dns.SVCB{ + Priority: 1, + Target: ddrTestFQDN, + Value: []dns.SVCBKeyValue{ + &dns.SVCBAlpn{Alpn: []string{"h2"}}, + &dns.SVCBPort{Port: 8044}, + &dns.SVCBDoHPath{Template: "/dns-query?dns"}, + }, + } + + dotSVCB := &dns.SVCB{ + Priority: 1, + Target: ddrTestFQDN, + Value: []dns.SVCBKeyValue{ + &dns.SVCBAlpn{Alpn: []string{"dot"}}, + &dns.SVCBPort{Port: 8043}, + }, + } + + doqSVCB := &dns.SVCB{ + Priority: 1, + Target: ddrTestFQDN, + Value: []dns.SVCBKeyValue{ + &dns.SVCBAlpn{Alpn: []string{"doq"}}, + &dns.SVCBPort{Port: 8042}, + }, + } + + testCases := []struct { + name string + host string + want []*dns.SVCB + wantRes resultCode + portDoH int + portDoT int + portDoQ int + qtype uint16 + ddrEnabled bool + }{{ + name: "pass_host", + wantRes: resultCodeSuccess, + host: "example.net.", + qtype: dns.TypeSVCB, + ddrEnabled: true, + portDoH: 8043, + }, { + name: "pass_qtype", + wantRes: resultCodeFinish, + host: ddrHostFQDN, + qtype: dns.TypeA, + ddrEnabled: true, + portDoH: 8043, + }, { + name: "pass_disabled_tls", + wantRes: resultCodeFinish, + host: ddrHostFQDN, + qtype: dns.TypeSVCB, + ddrEnabled: true, + }, { + name: "pass_disabled_ddr", + wantRes: resultCodeSuccess, + host: ddrHostFQDN, + qtype: dns.TypeSVCB, + ddrEnabled: false, + portDoH: 8043, + }, { + name: "dot", + wantRes: resultCodeFinish, + want: []*dns.SVCB{dotSVCB}, + host: ddrHostFQDN, + qtype: dns.TypeSVCB, + ddrEnabled: true, + portDoT: 8043, + }, { + name: "doh", + wantRes: resultCodeFinish, + want: []*dns.SVCB{dohSVCB}, + host: ddrHostFQDN, + qtype: dns.TypeSVCB, + ddrEnabled: true, + portDoH: 8044, + }, { + name: "doq", + wantRes: resultCodeFinish, + want: []*dns.SVCB{doqSVCB}, + host: ddrHostFQDN, + qtype: dns.TypeSVCB, + ddrEnabled: true, + portDoQ: 8042, + }, { + name: "dot_doh", + wantRes: resultCodeFinish, + want: []*dns.SVCB{dotSVCB, dohSVCB}, + host: ddrHostFQDN, + qtype: dns.TypeSVCB, + ddrEnabled: true, + portDoT: 8043, + portDoH: 8044, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s := prepareTestServer(t, tc.portDoH, tc.portDoT, tc.portDoQ, tc.ddrEnabled) + + req := createTestMessageWithType(tc.host, tc.qtype) + + dctx := &dnsContext{ + proxyCtx: &proxy.DNSContext{ + Req: req, + }, + } + + res := s.processDDRQuery(dctx) + require.Equal(t, tc.wantRes, res) + + if tc.wantRes != resultCodeFinish { + return + } + + msg := dctx.proxyCtx.Res + require.NotNil(t, msg) + + for _, v := range tc.want { + v.Hdr = s.hdr(req, dns.TypeSVCB) + } + + assert.ElementsMatch(t, tc.want, msg.Answer) + }) + } +} + +func prepareTestServer(t *testing.T, portDoH, portDoT, portDoQ int, ddrEnabled bool) (s *Server) { + t.Helper() + + proxyConf := proxy.Config{} + + if portDoT > 0 { + proxyConf.TLSListenAddr = []*net.TCPAddr{{Port: portDoT}} + } + + if portDoQ > 0 { + proxyConf.QUICListenAddr = []*net.UDPAddr{{Port: portDoQ}} + } + + s = &Server{ + dnsProxy: &proxy.Proxy{ + Config: proxyConf, + }, + conf: ServerConfig{ + FilteringConfig: FilteringConfig{ + HandleDDR: ddrEnabled, + }, + TLSConfig: TLSConfig{ + ServerName: ddrTestDomainName, + }, + }, + } + + if portDoH > 0 { + s.conf.TLSConfig.HTTPSListenAddrs = []*net.TCPAddr{{Port: portDoH}} + } + + return s +} + func TestServer_ProcessDetermineLocal(t *testing.T) { s := &Server{ privateNets: netutil.SubnetSetFunc(netutil.IsLocallyServed), diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index e4d61b48..5144680b 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -17,13 +17,13 @@ import ( "testing/fstest" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/dhcpd" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" - "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/timeutil" @@ -853,10 +853,7 @@ func TestBlockedByHosts(t *testing.T) { func TestBlockedBySafeBrowsing(t *testing.T) { const hostname = "wmconvirus.narod.ru" - sbUps := &aghtest.TestBlockUpstream{ - Hostname: hostname, - Block: true, - } + sbUps := aghtest.NewBlockUpstream(hostname, true) ans4, _ := (&aghtest.TestResolver{}).HostToIPs(hostname) filterConf := &filtering.Config{ @@ -1029,7 +1026,7 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { s.conf.UDPListenAddrs = []*net.UDPAddr{{}} s.conf.TCPListenAddrs = []*net.TCPAddr{{}} s.conf.UpstreamDNS = []string{"127.0.0.1:53"} - s.conf.FilteringConfig.ProtectionEnabled = true + s.conf.ProtectionEnabled = true err = s.Prepare(nil) require.NoError(t, err) @@ -1177,25 +1174,48 @@ func TestNewServer(t *testing.T) { } func TestServer_Exchange(t *testing.T) { - extUpstream := &aghtest.Upstream{ - Reverse: map[string][]string{ - "1.1.1.1.in-addr.arpa.": {"one.one.one.one"}, + const ( + onesHost = "one.one.one.one" + localDomainHost = "local.domain" + ) + + var ( + onesIP = net.IP{1, 1, 1, 1} + localIP = net.IP{192, 168, 1, 1} + ) + + revExtIPv4, err := netutil.IPToReversedAddr(onesIP) + require.NoError(t, err) + + extUpstream := &aghtest.UpstreamMock{ + OnAddress: func() (addr string) { return "external.upstream.example" }, + OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = aghalg.Coalesce( + aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revExtIPv4, onesHost), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ) + + return resp, nil }, } - locUpstream := &aghtest.Upstream{ - Reverse: map[string][]string{ - "1.1.168.192.in-addr.arpa.": {"local.domain"}, - "2.1.168.192.in-addr.arpa.": {}, + + revLocIPv4, err := netutil.IPToReversedAddr(localIP) + require.NoError(t, err) + + locUpstream := &aghtest.UpstreamMock{ + OnAddress: func() (addr string) { return "local.upstream.example" }, + OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = aghalg.Coalesce( + aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revLocIPv4, localDomainHost), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ) + + return resp, nil }, } - upstreamErr := errors.Error("upstream error") - errUpstream := &aghtest.TestErrUpstream{ - Err: upstreamErr, - } - nonPtrUpstream := &aghtest.TestBlockUpstream{ - Hostname: "some-host", - Block: true, - } + + errUpstream := aghtest.NewErrorUpstream() + nonPtrUpstream := aghtest.NewBlockUpstream("some-host", true) srv := NewCustomServer(&proxy.Proxy{ Config: proxy.Config{ @@ -1209,7 +1229,6 @@ func TestServer_Exchange(t *testing.T) { srv.privateNets = netutil.SubnetSetFunc(netutil.IsLocallyServed) - localIP := net.IP{192, 168, 1, 1} testCases := []struct { name string want string @@ -1218,20 +1237,20 @@ func TestServer_Exchange(t *testing.T) { req net.IP }{{ name: "external_good", - want: "one.one.one.one", + want: onesHost, wantErr: nil, locUpstream: nil, - req: net.IP{1, 1, 1, 1}, + req: onesIP, }, { name: "local_good", - want: "local.domain", + want: localDomainHost, wantErr: nil, locUpstream: locUpstream, req: localIP, }, { name: "upstream_error", want: "", - wantErr: upstreamErr, + wantErr: aghtest.ErrUpstream, locUpstream: errUpstream, req: localIP, }, { diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go index 79c4d040..95554b07 100644 --- a/internal/filtering/filtering_test.go +++ b/internal/filtering/filtering_test.go @@ -21,6 +21,11 @@ func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } +const ( + sbBlocked = "wmconvirus.narod.ru" + pcBlocked = "pornhub.com" +) + var setts = Settings{ ProtectionEnabled: true, } @@ -173,43 +178,37 @@ func TestSafeBrowsing(t *testing.T) { d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) - const matching = "wmconvirus.narod.ru" - d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ - Hostname: matching, - Block: true, - }) - d.checkMatch(t, matching) - require.Contains(t, logOutput.String(), "SafeBrowsing lookup for "+matching) + d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) + d.checkMatch(t, sbBlocked) - d.checkMatch(t, "test."+matching) + require.Contains(t, logOutput.String(), fmt.Sprintf("safebrowsing lookup for %q", sbBlocked)) + + d.checkMatch(t, "test."+sbBlocked) d.checkMatchEmpty(t, "yandex.ru") - d.checkMatchEmpty(t, "pornhub.com") + d.checkMatchEmpty(t, pcBlocked) // Cached result. d.safeBrowsingServer = "127.0.0.1" - d.checkMatch(t, matching) - d.checkMatchEmpty(t, "pornhub.com") + d.checkMatch(t, sbBlocked) + d.checkMatchEmpty(t, pcBlocked) d.safeBrowsingServer = defaultSafebrowsingServer } func TestParallelSB(t *testing.T) { d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) - const matching = "wmconvirus.narod.ru" - d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ - Hostname: matching, - Block: true, - }) + + d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) t.Run("group", func(t *testing.T) { for i := 0; i < 100; i++ { t.Run(fmt.Sprintf("aaa%d", i), func(t *testing.T) { t.Parallel() - d.checkMatch(t, matching) - d.checkMatch(t, "test."+matching) + d.checkMatch(t, sbBlocked) + d.checkMatch(t, "test."+sbBlocked) d.checkMatchEmpty(t, "yandex.ru") - d.checkMatchEmpty(t, "pornhub.com") + d.checkMatchEmpty(t, pcBlocked) }) } }) @@ -382,23 +381,19 @@ func TestParentalControl(t *testing.T) { d := newForTest(t, &Config{ParentalEnabled: true}, nil) t.Cleanup(d.Close) - const matching = "pornhub.com" - d.SetParentalUpstream(&aghtest.TestBlockUpstream{ - Hostname: matching, - Block: true, - }) - d.checkMatch(t, matching) - require.Contains(t, logOutput.String(), "Parental lookup for "+matching) + d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true)) + d.checkMatch(t, pcBlocked) + require.Contains(t, logOutput.String(), fmt.Sprintf("parental lookup for %q", pcBlocked)) - d.checkMatch(t, "www."+matching) + d.checkMatch(t, "www."+pcBlocked) d.checkMatchEmpty(t, "www.yandex.ru") d.checkMatchEmpty(t, "yandex.ru") d.checkMatchEmpty(t, "api.jquery.com") // Test cached result. d.parentalServer = "127.0.0.1" - d.checkMatch(t, matching) + d.checkMatch(t, pcBlocked) d.checkMatchEmpty(t, "yandex.ru") } @@ -445,7 +440,7 @@ func TestMatching(t *testing.T) { }, { name: "sanity", rules: "||doubleclick.net^", - host: "wmconvirus.narod.ru", + host: sbBlocked, wantIsFiltered: false, wantReason: NotFilteredNotFound, wantDNSType: dns.TypeA, @@ -765,14 +760,9 @@ func TestClientSettings(t *testing.T) { }}, ) t.Cleanup(d.Close) - d.SetParentalUpstream(&aghtest.TestBlockUpstream{ - Hostname: "pornhub.com", - Block: true, - }) - d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ - Hostname: "wmconvirus.narod.ru", - Block: true, - }) + + d.SetParentalUpstream(aghtest.NewBlockUpstream(pcBlocked, true)) + d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) type testCase struct { name string @@ -787,12 +777,12 @@ func TestClientSettings(t *testing.T) { wantReason: FilteredBlockList, }, { name: "parental", - host: "pornhub.com", + host: pcBlocked, before: true, wantReason: FilteredParental, }, { name: "safebrowsing", - host: "wmconvirus.narod.ru", + host: sbBlocked, before: false, wantReason: FilteredSafeBrowsing, }, { @@ -836,33 +826,29 @@ func TestClientSettings(t *testing.T) { func BenchmarkSafeBrowsing(b *testing.B) { d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) b.Cleanup(d.Close) - blocked := "wmconvirus.narod.ru" - d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ - Hostname: blocked, - Block: true, - }) + + d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) + for n := 0; n < b.N; n++ { - res, err := d.CheckHost(blocked, dns.TypeA, &setts) + res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts) require.NoError(b, err) - assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) + assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked) } } func BenchmarkSafeBrowsingParallel(b *testing.B) { d := newForTest(b, &Config{SafeBrowsingEnabled: true}, nil) b.Cleanup(d.Close) - blocked := "wmconvirus.narod.ru" - d.SetSafeBrowsingUpstream(&aghtest.TestBlockUpstream{ - Hostname: blocked, - Block: true, - }) + + d.SetSafeBrowsingUpstream(aghtest.NewBlockUpstream(sbBlocked, true)) + b.RunParallel(func(pb *testing.PB) { for pb.Next() { - res, err := d.CheckHost(blocked, dns.TypeA, &setts) + res, err := d.CheckHost(sbBlocked, dns.TypeA, &setts) require.NoError(b, err) - assert.True(b, res.IsFiltered, "Expected hostname %s to match", blocked) + assert.Truef(b, res.IsFiltered, "expected hostname %q to match", sbBlocked) } }) } diff --git a/internal/filtering/safebrowsing.go b/internal/filtering/safebrowsing.go index 046756ac..e49c4070 100644 --- a/internal/filtering/safebrowsing.go +++ b/internal/filtering/safebrowsing.go @@ -314,7 +314,7 @@ func (d *DNSFilter) checkSafeBrowsing( if log.GetLevel() >= log.DEBUG { timer := log.StartTimer() - defer timer.LogElapsed("SafeBrowsing lookup for %s", host) + defer timer.LogElapsed("safebrowsing lookup for %q", host) } sctx := &sbCtx{ @@ -348,7 +348,7 @@ func (d *DNSFilter) checkParental( if log.GetLevel() >= log.DEBUG { timer := log.StartTimer() - defer timer.LogElapsed("Parental lookup for %s", host) + defer timer.LogElapsed("parental lookup for %q", host) } sctx := &sbCtx{ diff --git a/internal/filtering/safebrowsing_test.go b/internal/filtering/safebrowsing_test.go index 2dec3668..f2cc846c 100644 --- a/internal/filtering/safebrowsing_test.go +++ b/internal/filtering/safebrowsing_test.go @@ -74,21 +74,20 @@ func TestSafeBrowsingCache(t *testing.T) { c.hashToHost[hash] = "sub.host.com" assert.Equal(t, -1, c.getCached()) - // match "sub.host.com" from cache, - // but another hash for "nonexisting.com" is not in cache - // which means that we must get data from server for it + // Match "sub.host.com" from cache. Another hash for "host.example" is not + // in the cache, so get data for it from the server. c.hashToHost = make(map[[32]byte]string) hash = sha256.Sum256([]byte("sub.host.com")) c.hashToHost[hash] = "sub.host.com" - hash = sha256.Sum256([]byte("nonexisting.com")) - c.hashToHost[hash] = "nonexisting.com" + hash = sha256.Sum256([]byte("host.example")) + c.hashToHost[hash] = "host.example" assert.Empty(t, c.getCached()) hash = sha256.Sum256([]byte("sub.host.com")) _, ok := c.hashToHost[hash] assert.False(t, ok) - hash = sha256.Sum256([]byte("nonexisting.com")) + hash = sha256.Sum256([]byte("host.example")) _, ok = c.hashToHost[hash] assert.True(t, ok) @@ -111,8 +110,7 @@ func TestSBPC_checkErrorUpstream(t *testing.T) { d := newForTest(t, &Config{SafeBrowsingEnabled: true}, nil) t.Cleanup(d.Close) - ups := &aghtest.TestErrUpstream{} - + ups := aghtest.NewErrorUpstream() d.SetSafeBrowsingUpstream(ups) d.SetParentalUpstream(ups) @@ -170,10 +168,16 @@ func TestSBPC(t *testing.T) { for _, tc := range testCases { // Prepare the upstream. - ups := &aghtest.TestBlockUpstream{ - Hostname: hostname, - Block: tc.block, + ups := aghtest.NewBlockUpstream(hostname, tc.block) + + var numReq int + onExchange := ups.OnExchange + ups.OnExchange = func(req *dns.Msg) (resp *dns.Msg, err error) { + numReq++ + + return onExchange(req) } + d.SetSafeBrowsingUpstream(ups) d.SetParentalUpstream(ups) @@ -196,7 +200,7 @@ func TestSBPC(t *testing.T) { assert.Equal(t, hits, tc.testCache.Stats().Hit) // There was one request to an upstream. - assert.Equal(t, 1, ups.RequestsCount()) + assert.Equal(t, 1, numReq) // Now make the same request to check the cache was used. res, err = tc.testFunc(hostname, dns.TypeA, setts) @@ -214,7 +218,7 @@ func TestSBPC(t *testing.T) { assert.Equal(t, hits+1, tc.testCache.Stats().Hit) // Check that there were no additional requests. - assert.Equal(t, 1, ups.RequestsCount()) + assert.Equal(t, 1, numReq) }) purgeCaches(d) diff --git a/internal/home/config.go b/internal/home/config.go index 05397426..5bd9b98b 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -209,6 +209,7 @@ var config = &configuration{ Ratelimit: 20, RefuseAny: true, AllServers: false, + HandleDDR: true, FastestTimeout: timeutil.Duration{ Duration: fastip.DefaultPingWaitTimeout, }, diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 5304b794..c46f3459 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -216,7 +216,7 @@ func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) func handleStaticIP(ip net.IP, set bool) staticIPJSON { resp := staticIPJSON{} - interfaceName := aghnet.GetInterfaceByIP(ip) + interfaceName := aghnet.InterfaceByIP(ip) resp.Static = "no" if len(interfaceName) == 0 { diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 08f4f013..870d0f04 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -3,15 +3,16 @@ package home import ( "bytes" "encoding/binary" + "fmt" "net" "sync" "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/cache" - "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" @@ -80,8 +81,10 @@ func TestRDNS_Begin(t *testing.T) { binary.BigEndian.PutUint64(ttl, uint64(time.Now().Add(100*time.Hour).Unix())) rdns := &RDNS{ - ipCache: ipCache, - exchanger: &rDNSExchanger{}, + ipCache: ipCache, + exchanger: &rDNSExchanger{ + ex: aghtest.NewErrorUpstream(), + }, clients: &clientsContainer{ list: map[string]*Client{}, idIndex: tc.cliIDIndex, @@ -108,16 +111,22 @@ func TestRDNS_Begin(t *testing.T) { // rDNSExchanger is a mock dnsforward.RDNSExchanger implementation for tests. type rDNSExchanger struct { - ex aghtest.Exchanger + ex upstream.Upstream usePrivate bool } // Exchange implements dnsforward.RDNSExchanger interface for *RDNSExchanger. func (e *rDNSExchanger) Exchange(ip net.IP) (host string, err error) { + rev, err := netutil.IPToReversedAddr(ip) + if err != nil { + return "", fmt.Errorf("reversing ip: %w", err) + } + req := &dns.Msg{ Question: []dns.Question{{ - Name: ip.String(), - Qtype: dns.TypePTR, + Name: dns.Fqdn(rev), + Qclass: dns.ClassINET, + Qtype: dns.TypePTR, }}, } @@ -146,7 +155,9 @@ func TestRDNS_ensurePrivateCache(t *testing.T) { MaxCount: defaultRDNSCacheSize, }) - ex := &rDNSExchanger{} + ex := &rDNSExchanger{ + ex: aghtest.NewErrorUpstream(), + } rdns := &RDNS{ ipCache: ipCache, @@ -167,15 +178,27 @@ func TestRDNS_WorkerLoop(t *testing.T) { w := &bytes.Buffer{} aghtest.ReplaceLogWriter(t, w) - locUpstream := &aghtest.Upstream{ - Reverse: map[string][]string{ - "192.168.1.1": {"local.domain"}, - "2a00:1450:400c:c06::93": {"ipv6.domain"}, + localIP := net.IP{192, 168, 1, 1} + revIPv4, err := netutil.IPToReversedAddr(localIP) + require.NoError(t, err) + + revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93")) + require.NoError(t, err) + + locUpstream := &aghtest.UpstreamMock{ + OnAddress: func() (addr string) { return "local.upstream.example" }, + OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) { + resp = aghalg.Coalesce( + aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv4, "local.domain"), + aghtest.RespondTo(t, req, dns.ClassINET, dns.TypePTR, revIPv6, "ipv6.domain"), + new(dns.Msg).SetRcode(req, dns.RcodeNameError), + ) + + return resp, nil }, } - errUpstream := &aghtest.TestErrUpstream{ - Err: errors.Error("1234"), - } + + errUpstream := aghtest.NewErrorUpstream() testCases := []struct { ups upstream.Upstream @@ -186,10 +209,10 @@ func TestRDNS_WorkerLoop(t *testing.T) { ups: locUpstream, wantLog: "", name: "all_good", - cliIP: net.IP{192, 168, 1, 1}, + cliIP: localIP, }, { ups: errUpstream, - wantLog: `rdns: resolving "192.168.1.2": errupstream: 1234`, + wantLog: `rdns: resolving "192.168.1.2": test upstream error`, name: "resolve_error", cliIP: net.IP{192, 168, 1, 2}, }, { @@ -211,9 +234,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { ch := make(chan net.IP) rdns := &RDNS{ exchanger: &rDNSExchanger{ - ex: aghtest.Exchanger{ - Ups: tc.ups, - }, + ex: tc.ups, }, clients: cc, ipCh: ch,