dnsforward: imp code

This commit is contained in:
Eugene Burkov 2024-04-09 18:09:46 +03:00
parent b27547ec80
commit 08bb7d43d2
5 changed files with 121 additions and 99 deletions

View File

@ -27,7 +27,14 @@ NOTE: Add new changes BELOW THIS COMMENT.
- Support for comments in the ipset file ([#5345]). - Support for comments in the ipset file ([#5345]).
### Fixed
- Subdomains of `in-addr.arpa` and `ip6.arpa` containing zero-length prefix
incorrectly considered invalid when specified for private RDNS upstream
servers ([#6854]).
[#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345 [#5345]: https://github.com/AdguardTeam/AdGuardHome/issues/5345
[#6854]: https://github.com/AdguardTeam/AdGuardHome/issues/6854
<!-- <!--
NOTE: Add new changes ABOVE THIS COMMENT. NOTE: Add new changes ABOVE THIS COMMENT.

View File

@ -0,0 +1,53 @@
package dnsforward
import (
"encoding/binary"
"fmt"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
// type check
var _ proxy.BeforeRequestHandler = (*Server)(nil)
// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
//
// TODO(e.burkov): Write tests.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return nil
}

View File

@ -1,60 +1,17 @@
package dnsforward package dnsforward
import ( import (
"encoding/binary"
"fmt" "fmt"
"net" "net"
"slices" "slices"
"strings" "strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules" "github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// type check
var _ proxy.BeforeRequestHandler = (*Server)(nil)
// HandleBefore is the handler that is called before any other processing,
// including logs. It performs access checks and puts the client ID, if there
// is one, into the server's cache.
func (s *Server) HandleBefore(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (err error) {
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return fmt.Errorf("getting clientid: %w", err)
}
blocked, _ := s.IsBlockedClient(pctx.Addr.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
if len(pctx.Req.Question) == 1 {
q := pctx.Req.Question[0]
qt := q.Qtype
host := aghnet.NormalizeDomain(q.Name)
if s.access.isBlockedHost(host, qt) {
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
return s.preBlockedResponse(pctx)
}
}
if clientID != "" {
key := [8]byte{}
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
s.clientIDCache.Set(key[:], []byte(clientID))
}
return nil
}
// clientRequestFilteringSettings looks up client filtering settings using the // clientRequestFilteringSettings looks up client filtering settings using the
// client's IP address and ID, if any, from dctx. // client's IP address and ID, if any, from dctx.
func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) { func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filtering.Settings) {

View File

@ -261,48 +261,6 @@ func (req *jsonDNSConfig) checkUpstreamMode() (err error) {
} }
} }
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// checkFallbacks returns an error if any fallback address is invalid.
func (req *jsonDNSConfig) checkFallbacks() (err error) {
if req.Fallbacks == nil {
return nil
}
_, err = proxy.ParseUpstreamsConfig(*req.Fallbacks, &upstream.Options{})
if err != nil {
return fmt.Errorf("fallback servers: %w", err)
}
return nil
}
// validate returns an error if any field of req is invalid. // validate returns an error if any field of req is invalid.
// //
// TODO(s.chzhen): Parse, don't validate. // TODO(s.chzhen): Parse, don't validate.
@ -342,23 +300,68 @@ func (req *jsonDNSConfig) validate(privateNets netutil.SubnetSet) (err error) {
return nil return nil
} }
// checkUpstreams returns an error if lines can't be parsed as an upstream
// configuration. If privateNets is not nil, it also checks that the domain
// specifications are strictly ARPA domains containing the prefixes within the
// set.
func checkUpstreams(lines []string, section string, privateNets netutil.SubnetSet) (err error) {
defer func() { err = errors.Annotate(err, "%s servers: %w", section) }()
uc, err := proxy.ParseUpstreamsConfig(lines, &upstream.Options{})
if err == nil {
defer func() { err = errors.WithDeferred(err, uc.Close()) }()
if privateNets != nil {
err = proxy.ValidatePrivateConfig(uc, privateNets)
}
}
return err
}
// checkBootstrap returns an error if any bootstrap address is invalid.
func (req *jsonDNSConfig) checkBootstrap() (err error) {
if req.Bootstraps == nil {
return nil
}
var b string
defer func() { err = errors.Annotate(err, "checking bootstrap %s: %w", b) }()
for _, b = range *req.Bootstraps {
if b == "" {
return errors.Error("empty")
}
var resolver *upstream.UpstreamResolver
if resolver, err = upstream.NewUpstreamResolver(b, nil); err != nil {
// Don't wrap the error because it's informative enough as is.
return err
}
if err = resolver.Close(); err != nil {
return fmt.Errorf("closing %s: %w", b, err)
}
}
return nil
}
// validateUpstreamDNSServers returns an error if any field of req is invalid. // validateUpstreamDNSServers returns an error if any field of req is invalid.
func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) { func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetSet) (err error) {
if req.Upstreams != nil { if req.Upstreams != nil {
_, err = proxy.ParseUpstreamsConfig(*req.Upstreams, &upstream.Options{}) err = checkUpstreams(*req.Upstreams, "upstream", nil)
if err != nil { if err != nil {
return fmt.Errorf("upstream servers: %w", err) // Don't wrap the error since it's informative enough as is.
return err
} }
} }
if req.LocalPTRUpstreams != nil { if req.LocalPTRUpstreams != nil {
var uc *proxy.UpstreamConfig err = checkUpstreams(*req.LocalPTRUpstreams, "private upstream", privateNets)
uc, err = proxy.ParseUpstreamsConfig(*req.LocalPTRUpstreams, &upstream.Options{})
if err == nil {
err = proxy.ValidatePrivateConfig(uc, privateNets)
}
if err != nil { if err != nil {
return fmt.Errorf("private upstream servers: %w", err) // Don't wrap the error since it's informative enough as is.
return err
} }
} }
@ -368,10 +371,12 @@ func (req *jsonDNSConfig) validateUpstreamDNSServers(privateNets netutil.SubnetS
return err return err
} }
err = req.checkFallbacks() if req.Fallbacks != nil {
if err != nil { err = checkUpstreams(*req.Fallbacks, "fallback", nil)
// Don't wrap the error since it's informative enough as is. if err != nil {
return err // Don't wrap the error since it's informative enough as is.
return err
}
} }
return nil return nil

View File

@ -12,8 +12,8 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// TODO(e.burkov): Call all the other methods by a [proxy.MessageConstructor] // TODO(e.burkov): Name all the methods by a [proxy.MessageConstructor]
// template. // template. Also extract all the methods to a separate entity.
// reply creates a DNS response for req. // reply creates a DNS response for req.
func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) { func (*Server) reply(req *dns.Msg, code int) (resp *dns.Msg) {