diff --git a/dns.go b/dns.go index e637fa4a..4eebc110 100644 --- a/dns.go +++ b/dns.go @@ -4,16 +4,36 @@ import ( "fmt" "net" "os" + "strings" + "sync" + "time" "github.com/AdguardTeam/AdGuardHome/dnsfilter" "github.com/AdguardTeam/AdGuardHome/dnsforward" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/log" "github.com/joomcode/errorx" + "github.com/miekg/dns" ) var dnsServer *dnsforward.Server +const ( + rdnsTimeout = 3 * time.Second // max time to wait for rDNS response +) + +type dnsContext struct { + rdnsChannel chan string // pass data from DNS request handling thread to rDNS thread + // contains IP addresses of clients to be resolved by rDNS + // if IP address couldn't be resolved, it stays here forever to prevent further attempts to resolve the same IP + rdnsIP map[string]bool + rdnsLock sync.Mutex // synchronize access to rdnsIP + upstream upstream.Upstream // Upstream object for our own DNS server +} + +var dnsctx dnsContext + // initDNSServer creates an instance of the dnsforward.Server // Please note that we must do it even if we don't start it // so that we had access to the query log and the stats @@ -24,12 +44,121 @@ func initDNSServer(baseDir string) { } dnsServer = dnsforward.NewServer(baseDir) + + bindhost := config.DNS.BindHost + if config.DNS.BindHost == "0.0.0.0" { + bindhost = "127.0.0.1" + } + resolverAddress := fmt.Sprintf("%s:%d", bindhost, config.DNS.Port) + opts := upstream.Options{ + Timeout: rdnsTimeout, + } + dnsctx.upstream, err = upstream.AddressToUpstream(resolverAddress, opts) + if err != nil { + log.Error("upstream.AddressToUpstream: %s", err) + return + } + dnsctx.rdnsChannel = make(chan string, 256) + go asyncRDNSLoop() } func isRunning() bool { return dnsServer != nil && dnsServer.IsRunning() } +func beginAsyncRDNS(ip string) { + log.Tracef("Adding %s for rDNS resolve", ip) + select { + case dnsctx.rdnsChannel <- ip: + // + default: + log.Tracef("rDNS queue is full") + } +} + +// Use rDNS to get hostname by IP address +func resolveRDNS(ip string) string { + log.Tracef("Resolving host for %s", ip) + + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + { + Qtype: dns.TypePTR, + Qclass: dns.ClassINET, + }, + } + var err error + req.Question[0].Name, err = dns.ReverseAddr(ip) + if err != nil { + log.Error("dns.ReverseAddr: %s", err) + return "" + } + + resp, err := dnsctx.upstream.Exchange(&req) + if err != nil { + log.Error("upstream.Exchange: %s", err) + return "" + } + if len(resp.Answer) != 1 { + log.Error("len(resp.Answer) != 1") + return "" + } + ptr, ok := resp.Answer[0].(*dns.PTR) + if !ok { + log.Error("not a dns.PTR response") + return "" + } + + log.Tracef("PTR response: %s", ptr.String()) + if strings.HasSuffix(ptr.Ptr, ".") { + ptr.Ptr = ptr.Ptr[:len(ptr.Ptr)-1] + } + + return ptr.Ptr +} + +// Wait for a signal and then synchronously resolve hostname by IP address +// Add the hostname:IP pair to "Clients" array +func asyncRDNSLoop() { + for { + var ip string + ip = <-dnsctx.rdnsChannel + + host := resolveRDNS(ip) + if len(host) == 0 { + continue + } + + dnsctx.rdnsLock.Lock() + delete(dnsctx.rdnsIP, ip) + dnsctx.rdnsLock.Unlock() + + clientAddHost(ip, host, ClientSourceRDNS) + } +} + +func onDNSRequest(d *proxy.DNSContext) { + if d.Req.Question[0].Qtype == dns.TypeA { + ip, _, _ := net.SplitHostPort(d.Addr.String()) + if clientExists(ip) { + return + } + + // add IP to rdnsIP, if not exists + dnsctx.rdnsLock.Lock() + defer dnsctx.rdnsLock.Unlock() + _, ok := dnsctx.rdnsIP[ip] + if ok { + return + } + dnsctx.rdnsIP[ip] = true + + beginAsyncRDNS(ip) + } +} + func generateServerConfig() dnsforward.ServerConfig { filters := []dnsfilter.Filter{} userFilter := userFilter() @@ -71,6 +200,7 @@ func generateServerConfig() dnsforward.ServerConfig { newconfig.DomainsReservedUpstreams = upstreamConfig.DomainReservedUpstreams newconfig.AllServers = config.DNS.AllServers newconfig.FilterHandler = applyClientSettings + newconfig.OnDNSRequest = onDNSRequest return newconfig } diff --git a/dns_test.go b/dns_test.go new file mode 100644 index 00000000..6623d1be --- /dev/null +++ b/dns_test.go @@ -0,0 +1,9 @@ +package main + +import "testing" + +func TestResolveRDNS(t *testing.T) { + if r := resolveRDNS("1.1.1.1", "1.1.1.1"); r != "one.one.one.one" { + t.Errorf("resolveRDNS(): %s", r) + } +} diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index 34e6285c..bc8ed460 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -88,6 +88,7 @@ type ServerConfig struct { Upstreams []upstream.Upstream // Configured upstreams DomainsReservedUpstreams map[string][]upstream.Upstream // Map of domains and lists of configured upstreams Filters []dnsfilter.Filter // A list of filters to use + OnDNSRequest func(d *proxy.DNSContext) FilteringConfig TLSConfig @@ -324,6 +325,10 @@ func (s *Server) GetStatsHistory(timeUnit time.Duration, startTime time.Time, en func (s *Server) handleDNSRequest(p *proxy.Proxy, d *proxy.DNSContext) error { start := time.Now() + if s.conf.OnDNSRequest != nil { + s.conf.OnDNSRequest(d) + } + // use dnsfilter before cache -- changed settings or filters would require cache invalidation otherwise res, err := s.filterDNSRequest(d) if err != nil {