all: imp code

This commit is contained in:
Eugene Burkov 2024-04-03 19:25:27 +03:00
parent 2046156580
commit d8108e6516
10 changed files with 146 additions and 116 deletions

1
go.mod
View File

@ -49,7 +49,6 @@ require (
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7 // indirect
github.com/jessevdk/go-flags v1.5.0 // indirect
github.com/mdlayher/socket v0.5.0 // indirect github.com/mdlayher/socket v0.5.0 // indirect
github.com/onsi/ginkgo/v2 v2.16.0 // indirect github.com/onsi/ginkgo/v2 v2.16.0 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect

7
go.sum
View File

@ -1,9 +1,5 @@
github.com/AdguardTeam/dnsproxy v0.67.0 h1:7oKfcA8sm9d1N4qvhsNmQWBX4+fs3sX4cAnERmBXEbw=
github.com/AdguardTeam/dnsproxy v0.67.0/go.mod h1:XLfD6IpSplUZZ+f5vhWSJW1mp4wm+KkHWiMo9w7U1Ls=
github.com/AdguardTeam/dnsproxy v0.67.1-0.20240403090357-a2c0e321a217 h1:ryczFRf8y6PEzCjgy/S3Ptg4Ea1TUYFyiEZnoEyEV7s= github.com/AdguardTeam/dnsproxy v0.67.1-0.20240403090357-a2c0e321a217 h1:ryczFRf8y6PEzCjgy/S3Ptg4Ea1TUYFyiEZnoEyEV7s=
github.com/AdguardTeam/dnsproxy v0.67.1-0.20240403090357-a2c0e321a217/go.mod h1:5wIQueGTDX1Uk4GYevRh7HCtsCUR/U9lxf478+STOZI= github.com/AdguardTeam/dnsproxy v0.67.1-0.20240403090357-a2c0e321a217/go.mod h1:5wIQueGTDX1Uk4GYevRh7HCtsCUR/U9lxf478+STOZI=
github.com/AdguardTeam/golibs v0.21.0 h1:0swWyNaHTmT7aMwffKd9d54g4wBd8Oaj0fl+5l/PRdE=
github.com/AdguardTeam/golibs v0.21.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI=
github.com/AdguardTeam/golibs v0.22.0 h1:wvT/UFIT8XIBfMabnD3LcDRiorx8J0lc3A/bzD6OX7c= github.com/AdguardTeam/golibs v0.22.0 h1:wvT/UFIT8XIBfMabnD3LcDRiorx8J0lc3A/bzD6OX7c=
github.com/AdguardTeam/golibs v0.22.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI= github.com/AdguardTeam/golibs v0.22.0/go.mod h1:/votX6WK1PdcZ3T2kBOPjPCGmfhlKixhI6ljYrFRPvI=
github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ= github.com/AdguardTeam/urlfilter v0.18.0 h1:ZZzwODC/ADpjJSODxySrrUnt/fvOCfGFaCW6j+wsGfQ=
@ -62,8 +58,6 @@ github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Go
github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8 h1:V3plQrMHRWOB5zMm3yNqvBxDQVW1+/wHBSok5uPdmVs= github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8 h1:V3plQrMHRWOB5zMm3yNqvBxDQVW1+/wHBSok5uPdmVs=
github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8/go.mod h1:izxuNQZeFrbx2nK2fAyN5iNUB34Fe9j0nK4PwLzAkKw= github.com/insomniacslk/dhcp v0.0.0-20240227161007-c728f5dd21c8/go.mod h1:izxuNQZeFrbx2nK2fAyN5iNUB34Fe9j0nK4PwLzAkKw=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc=
github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4=
github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk= github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk=
github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8= github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8=
@ -164,7 +158,6 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201015000850-e3ed0017c211/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@ -336,7 +336,7 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
DNS64Prefs: srvConf.DNS64Prefixes, DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS, UsePrivateRDNS: srvConf.UsePrivateRDNS,
PrivateSubnets: s.privateNets, PrivateSubnets: s.privateNets,
// TODO(e.burkov): !! add message constructor MessageConstructor: s,
} }
if srvConf.EDNSClientSubnet.UseCustom { if srvConf.EDNSClientSubnet.UseCustom {

View File

@ -3,7 +3,6 @@ package dnsforward
import ( import (
"net" "net"
"testing" "testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
@ -11,6 +10,7 @@ import (
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -63,8 +63,9 @@ func newRR(t *testing.T, name string, qtype uint16, ttl uint32, val any) (rr dns
return rr return rr
} }
// TODO(e.burkov): !! add ErrNoUptreams case
func TestServer_HandleDNSRequest_dns64(t *testing.T) { func TestServer_HandleDNSRequest_dns64(t *testing.T) {
t.Parallel()
const ( const (
ipv4Domain = "ipv4.only." ipv4Domain = "ipv4.only."
ipv6Domain = "ipv6.only." ipv6Domain = "ipv6.only."
@ -253,33 +254,35 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) { localUpsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
require.Len(pt, m.Question, 1) require.Len(pt, m.Question, 1)
require.Equal(pt, m.Question[0].Name, ptr64Domain) require.Equal(pt, m.Question[0].Name, ptr64Domain)
resp := (&dns.Msg{
Answer: []dns.RR{localRR}, resp := (&dns.Msg{}).SetReply(m)
}).SetReply(m) resp.Answer = []dns.RR{localRR}
require.NoError(t, w.WriteMsg(resp)) require.NoError(t, w.WriteMsg(resp))
}) })
localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String() localUpsAddr := aghtest.StartLocalhostUpstream(t, localUpsHdlr).String()
client := &dns.Client{ client := &dns.Client{
Net: "tcp", Net: string(proxy.ProtoTCP),
Timeout: 1 * time.Second, Timeout: testTimeout,
} }
for _, tc := range testCases { for _, tc := range testCases {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel()
upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) { upsHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, req *dns.Msg) {
q := req.Question[0] q := req.Question[0]
require.Contains(pt, tc.upsAns, q.Qtype)
require.Contains(pt, tc.upsAns, q.Qtype)
answer := tc.upsAns[q.Qtype] answer := tc.upsAns[q.Qtype]
resp := (&dns.Msg{ resp := (&dns.Msg{}).SetReply(req)
Answer: answer[sectionAnswer], resp.Answer = answer[sectionAnswer]
Ns: answer[sectionAuthority], resp.Ns = answer[sectionAuthority]
Extra: answer[sectionAdditional], resp.Extra = answer[sectionAdditional]
}).SetReply(req)
require.NoError(pt, w.WriteMsg(resp)) require.NoError(pt, w.WriteMsg(resp))
}) })
@ -309,10 +312,54 @@ func TestServer_HandleDNSRequest_dns64(t *testing.T) {
req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype) req := (&dns.Msg{}).SetQuestion(tc.qname, tc.qtype)
resp, _, excErr := client.Exchange(req, s.dnsProxy.Addr(proxy.ProtoTCP).String()) resp, _, excErr := client.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, excErr) require.NoError(t, excErr)
require.Equal(t, tc.wantAns, resp.Answer) require.Equal(t, tc.wantAns, resp.Answer)
}) })
} }
} }
func TestServer_dns64WithDisabledRDNS(t *testing.T) {
t.Parallel()
// Shouldn't go to upstream at all.
panicHdlr := dns.HandlerFunc(func(w dns.ResponseWriter, m *dns.Msg) {
panic("not implemented")
})
upsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
localUpsAddr := aghtest.StartLocalhostUpstream(t, panicHdlr).String()
s := createTestServer(t, &filtering.Config{
BlockingMode: filtering.BlockingModeDefault,
}, ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
UseDNS64: true,
Config: Config{
UpstreamMode: UpstreamModeLoadBalance,
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
UpstreamDNS: []string{upsAddr},
},
UsePrivateRDNS: false,
LocalPTRResolvers: []string{localUpsAddr},
ServePlainDNS: true,
})
startDeferStop(t, s)
mappedIPv6 := net.ParseIP("64:ff9b::102:304")
arpa, err := netutil.IPToReversedAddr(mappedIPv6)
require.NoError(t, err)
req := (&dns.Msg{}).SetQuestion(dns.Fqdn(arpa), dns.TypePTR)
cli := &dns.Client{
Net: string(proxy.ProtoTCP),
Timeout: testTimeout,
}
resp, _, err := cli.Exchange(req, s.proxy().Addr(proxy.ProtoTCP).String())
require.NoError(t, err)
assert.Equal(t, dns.RcodeNameError, resp.Rcode)
}

View File

@ -496,29 +496,6 @@ func (s *Server) prepareLocalResolvers(
return uc, nil return uc, nil
} }
// LocalResolversError is an error type for errors during local resolvers setup.
// This is only needed to distinguish these errors from errors returned by
// creating the proxy.
type LocalResolversError struct {
Err error
}
// type check
var _ error = (*LocalResolversError)(nil)
// Error implements the error interface for *LocalResolversError.
func (err *LocalResolversError) Error() (s string) {
return fmt.Sprintf("creating local resolvers: %s", err.Err)
}
// type check
var _ errors.Wrapper = (*LocalResolversError)(nil)
// Unwrap implements the [errors.Wrapper] interface for *LocalResolversError.
func (err *LocalResolversError) Unwrap() error {
return err.Err
}
// Prepare initializes parameters of s using data from conf. conf must not be // Prepare initializes parameters of s using data from conf. conf must not be
// nil. // nil.
func (s *Server) Prepare(conf *ServerConfig) (err error) { func (s *Server) Prepare(conf *ServerConfig) (err error) {
@ -699,7 +676,7 @@ func (s *Server) prepareInternalProxy() (err error) {
UseDNS64: srvConf.UseDNS64, UseDNS64: srvConf.UseDNS64,
DNS64Prefs: srvConf.DNS64Prefixes, DNS64Prefs: srvConf.DNS64Prefixes,
UsePrivateRDNS: srvConf.UsePrivateRDNS, UsePrivateRDNS: srvConf.UsePrivateRDNS,
// TODO(e.burkov): !! add message constructor MessageConstructor: s,
} }
err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration) err = setProxyUpstreamMode(conf, srvConf.UpstreamMode, srvConf.FastestTimeout.Duration)

View File

@ -143,7 +143,7 @@ func (s *Server) filterDNSRewrite(
res *filtering.Result, res *filtering.Result,
pctx *proxy.DNSContext, pctx *proxy.DNSContext,
) (err error) { ) (err error) {
resp := s.makeResponse(req) resp := s.replyCompressed(req)
dnsrr := res.DNSRewriteResult dnsrr := res.DNSRewriteResult
if dnsrr == nil { if dnsrr == nil {
return errors.Error("no dns rewrite rule content") return errors.Error("no dns rewrite rule content")

View File

@ -11,17 +11,21 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// makeResponse creates a DNS response by req and sets necessary flags. It also // TODO(e.burkov): Call all the other methods by a [proxy.MessageConstructor]
// guarantees that req.Question will be not empty. // template.
func (s *Server) makeResponse(req *dns.Msg) (resp *dns.Msg) {
resp = &dns.Msg{
MsgHdr: dns.MsgHdr{
RecursionAvailable: true,
},
Compress: true,
}
resp.SetReply(req) // reply creates a DNS response for req.
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, code)
resp.RecursionAvailable = true
return resp
}
// replyCompressed creates a DNS response for req and sets the compress flag.
func (s *Server) replyCompressed(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeSuccess)
resp.Compress = true
return resp return resp
} }
@ -51,7 +55,7 @@ func (s *Server) genDNSFilterMessage(
if qt != dns.TypeA && qt != dns.TypeAAAA { if qt != dns.TypeA && qt != dns.TypeAAAA {
m, _, _ := s.dnsFilter.BlockingMode() m, _, _ := s.dnsFilter.BlockingMode()
if m == filtering.BlockingModeNullIP { if m == filtering.BlockingModeNullIP {
return s.makeResponse(req) return s.replyCompressed(req)
} }
return s.newMsgNODATA(req) return s.newMsgNODATA(req)
@ -75,7 +79,7 @@ func (s *Server) genDNSFilterMessage(
// getCNAMEWithIPs generates a filtered response to req for with CNAME record // getCNAMEWithIPs generates a filtered response to req for with CNAME record
// and provided ips. // and provided ips.
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) { func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
resp = s.makeResponse(req) resp = s.replyCompressed(req)
originalName := req.Question[0].Name originalName := req.Question[0].Name
@ -121,13 +125,13 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
case filtering.BlockingModeNullIP: case filtering.BlockingModeNullIP:
return s.makeResponseNullIP(req) return s.makeResponseNullIP(req)
case filtering.BlockingModeNXDOMAIN: case filtering.BlockingModeNXDOMAIN:
return s.genNXDomain(req) return s.NewMsgNXDOMAIN(req)
case filtering.BlockingModeREFUSED: case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req) return s.makeResponseREFUSED(req)
default: default:
log.Error("dnsforward: invalid blocking mode %q", mode) log.Error("dnsforward: invalid blocking mode %q", mode)
return s.makeResponse(req) return s.replyCompressed(req)
} }
} }
@ -148,25 +152,18 @@ func (s *Server) makeResponseCustomIP(
// genDNSFilterMessage. // genDNSFilterMessage.
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt)) log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
return s.makeResponse(req) return s.replyCompressed(req)
} }
} }
func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeServerFailure)
resp.RecursionAvailable = true
return &resp
}
func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { func (s *Server) genARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request) resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerA(request, ip)) resp.Answer = append(resp.Answer, s.genAnswerA(request, ip))
return resp return resp
} }
func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg { func (s *Server) genAAAARecord(request *dns.Msg, ip netip.Addr) *dns.Msg {
resp := s.makeResponse(request) resp := s.replyCompressed(request)
resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip)) resp.Answer = append(resp.Answer, s.genAnswerAAAA(request, ip))
return resp return resp
} }
@ -252,7 +249,7 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
// Go on and return an empty response. // Go on and return an empty response.
} }
resp = s.makeResponse(req) resp = s.replyCompressed(req)
resp.Answer = ans resp.Answer = ans
return resp return resp
@ -288,7 +285,7 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
case dns.TypeAAAA: case dns.TypeAAAA:
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()}) resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
default: default:
resp = s.makeResponse(req) resp = s.replyCompressed(req)
} }
return resp return resp
@ -298,7 +295,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if newAddr == "" { if newAddr == "" {
log.Info("dnsforward: block host is not specified") log.Info("dnsforward: block host is not specified")
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
ip, err := netip.ParseAddr(newAddr) ip, err := netip.ParseAddr(newAddr)
@ -321,17 +318,17 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
if prx == nil { if prx == nil {
log.Debug("dnsforward: %s", srvClosedErr) log.Debug("dnsforward: %s", srvClosedErr)
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
err = prx.Resolve(newContext) err = prx.Resolve(newContext)
if err != nil { if err != nil {
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err) log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
return s.genServerFailure(request) return s.NewMsgSERVFAIL(request)
} }
resp := s.makeResponse(request) resp := s.replyCompressed(request)
if newContext.Res != nil { if newContext.Res != nil {
for _, answer := range newContext.Res.Answer { for _, answer := range newContext.Res.Answer {
answer.Header().Name = request.Question[0].Name answer.Header().Name = request.Question[0].Name
@ -357,32 +354,20 @@ func (s *Server) preBlockedResponse(pctx *proxy.DNSContext) (reply bool, err err
} }
// Create REFUSED DNS response // Create REFUSED DNS response
func (s *Server) makeResponseREFUSED(request *dns.Msg) *dns.Msg { func (s *Server) makeResponseREFUSED(req *dns.Msg) *dns.Msg {
resp := dns.Msg{} return s.reply(req, dns.RcodeRefused)
resp.SetRcode(request, dns.RcodeRefused)
resp.RecursionAvailable = true
return &resp
} }
// newMsgNODATA returns a properly initialized NODATA response. // newMsgNODATA returns a properly initialized NODATA response.
// //
// See https://www.rfc-editor.org/rfc/rfc2308#section-2.2. // See https://www.rfc-editor.org/rfc/rfc2308#section-2.2.
func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) { func (s *Server) newMsgNODATA(req *dns.Msg) (resp *dns.Msg) {
resp = (&dns.Msg{}).SetRcode(req, dns.RcodeSuccess) resp = s.reply(req, dns.RcodeSuccess)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(req) resp.Ns = s.genSOA(req)
return resp return resp
} }
func (s *Server) genNXDomain(request *dns.Msg) *dns.Msg {
resp := dns.Msg{}
resp.SetRcode(request, dns.RcodeNameError)
resp.RecursionAvailable = true
resp.Ns = s.genSOA(request)
return &resp
}
func (s *Server) genSOA(request *dns.Msg) []dns.RR { func (s *Server) genSOA(request *dns.Msg) []dns.RR {
zone := "" zone := ""
if len(request.Question) > 0 { if len(request.Question) > 0 {
@ -414,5 +399,43 @@ func (s *Server) genSOA(request *dns.Msg) []dns.RR {
if len(zone) > 0 && zone[0] != '.' { if len(zone) > 0 && zone[0] != '.' {
soa.Mbox += zone soa.Mbox += zone
} }
return []dns.RR{&soa} return []dns.RR{&soa}
} }
// type check
var _ proxy.MessageConstructor = (*Server)(nil)
// NewMsgNXDOMAIN implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNXDOMAIN(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNameError)
resp.Ns = s.genSOA(req)
return resp
}
// NewMsgSERVFAIL implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgSERVFAIL(req *dns.Msg) (resp *dns.Msg) {
return s.reply(req, dns.RcodeServerFailure)
}
// NewMsgNOTIMPLEMENTED implements the [proxy.MessageConstructor] interface for
// *Server.
func (s *Server) NewMsgNOTIMPLEMENTED(req *dns.Msg) (resp *dns.Msg) {
resp = s.reply(req, dns.RcodeNotImplemented)
// Most of the Internet and especially the inner core has an MTU of at least
// 1500 octets. Maximum DNS/UDP payload size for IPv6 on MTU 1500 ethernet
// is 1452 (1500 minus 40 (IPv6 header size) minus 8 (UDP header size)).
//
// See appendix A of https://datatracker.ietf.org/doc/draft-ietf-dnsop-avoid-fragmentation/17.
const maxUDPPayload = 1452
// NOTIMPLEMENTED without EDNS is treated as 'we don't support EDNS', so
// explicitly set it.
resp.SetEdns0(maxUDPPayload, false)
return resp
}

View File

@ -165,14 +165,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
} }
if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN { if (qt == dns.TypeA || qt == dns.TypeAAAA) && q.Name == mozillaFQDN {
pctx.Res = s.genNXDomain(pctx.Req) pctx.Res = s.NewMsgNXDOMAIN(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
if q.Name == healthcheckFQDN { if q.Name == healthcheckFQDN {
// Generate a NODATA negative response to make nslookup exit with 0. // Generate a NODATA negative response to make nslookup exit with 0.
pctx.Res = s.makeResponse(pctx.Req) pctx.Res = s.replyCompressed(pctx.Req)
return resultCodeFinish return resultCodeFinish
} }
@ -238,7 +238,7 @@ func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
// //
// [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html. // [draft standard]: https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) { func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
resp = s.makeResponse(req) resp = s.replyCompressed(req)
if req.Question[0].Qtype != dns.TypeSVCB { if req.Question[0].Qtype != dns.TypeSVCB {
return resp return resp
} }
@ -325,7 +325,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
if !pctx.IsPrivateClient { if !pctx.IsPrivateClient {
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost) log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
pctx.Res = s.genNXDomain(req) pctx.Res = s.NewMsgNXDOMAIN(req)
// Do not even put into query log. // Do not even put into query log.
return resultCodeFinish return resultCodeFinish
@ -342,7 +342,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip) log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
resp := s.makeResponse(req) resp := s.replyCompressed(req)
switch q.Qtype { switch q.Qtype {
case dns.TypeA: case dns.TypeA:
a := &dns.A{ a := &dns.A{
@ -394,7 +394,7 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
log.Debug("dnsforward: dhcp client %s is %q", addr, host) log.Debug("dnsforward: dhcp client %s is %q", addr, host)
req := pctx.Req req := pctx.Req
resp := s.makeResponse(req) resp := s.replyCompressed(req)
ptr := &dns.PTR{ ptr := &dns.PTR{
Hdr: dns.RR_Header{ Hdr: dns.RR_Header{
Name: req.Question[0].Name, Name: req.Question[0].Name,
@ -474,7 +474,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
// local domain name if there is one. // local domain name if there is one.
name := req.Question[0].Name name := req.Question[0].Name
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1]) log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
pctx.Res = s.genNXDomain(req) pctx.Res = s.NewMsgNXDOMAIN(req)
return resultCodeFinish return resultCodeFinish
} }
@ -491,10 +491,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
return resultCodeError return resultCodeError
} }
if err := prx.Resolve(pctx); err != nil { if dctx.err = prx.Resolve(pctx); dctx.err != nil {
// TODO(e.burkov): !! try getting error from here.
dctx.err = err
return resultCodeError return resultCodeError
} }

View File

@ -621,6 +621,9 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
} }
} }
// TODO(e.burkov): Rewrite this test to use the whole server instead of just
// testing the [handleDNSRequest] method. See comment on
// "from_external_for_local" test case.
func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) { func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
intAddr := netip.MustParseAddr("192.168.1.1") intAddr := netip.MustParseAddr("192.168.1.1")
intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice()) intPTRQuestion, err := netutil.IPToReversedAddr(intAddr.AsSlice())
@ -686,6 +689,9 @@ func TestServer_HandleDNSRequest_restrictLocal(t *testing.T) {
}}, }},
isPrivate: true, isPrivate: true,
}, { }, {
// In theory this case is not reproducible because [proxy.Proxy] should
// respond to such queries with NXDOMAIN before they reach
// [Server.handleDNSRequest].
name: "from_external_for_local", name: "from_external_for_local",
question: intPTRQuestion, question: intPTRQuestion,
wantErr: upstream.ErrNoUpstreams, wantErr: upstream.ErrNoUpstreams,

View File

@ -18,7 +18,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil"
@ -158,17 +157,6 @@ func initDNSServer(
} }
err = Context.dnsServer.Prepare(dnsConf) err = Context.dnsServer.Prepare(dnsConf)
// TODO(e.burkov): Recreate the server with private RDNS disabled. This
// should go away once the private RDNS resolution is moved to the proxy.
var locResErr *dnsforward.LocalResolversError
if errors.As(err, &locResErr) && errors.Is(locResErr.Err, upstream.ErrNoUpstreams) {
log.Info("WARNING: no local resolvers configured while private RDNS " +
"resolution enabled, trying to disable")
dnsConf.UsePrivateRDNS = false
err = Context.dnsServer.Prepare(dnsConf)
}
if err != nil { if err != nil {
return fmt.Errorf("dnsServer.Prepare: %w", err) return fmt.Errorf("dnsServer.Prepare: %w", err)
} }