diff --git a/dns.go b/dns.go index 54e56184..6ad93d3c 100644 --- a/dns.go +++ b/dns.go @@ -33,6 +33,7 @@ func generateServerConfig() dnsforward.ServerConfig { newconfig := dnsforward.ServerConfig{ UDPListenAddr: &net.UDPAddr{Port: config.DNS.Port}, + TCPListenAddr: &net.TCPAddr{Port: config.DNS.Port}, FilteringConfig: config.DNS.FilteringConfig, Filters: filters, } diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index a242f460..4075e614 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -61,6 +61,7 @@ type FilteringConfig struct { // The zero ServerConfig is empty and ready for use. type ServerConfig struct { UDPListenAddr *net.UDPAddr // UDP listen address + TCPListenAddr *net.TCPAddr // TCP listen address Upstreams []upstream.Upstream // Configured upstreams Filters []dnsfilter.Filter // A list of filters to use @@ -70,6 +71,7 @@ type ServerConfig struct { // if any of ServerConfig values are zero, then default values from below are used var defaultValues = ServerConfig{ UDPListenAddr: &net.UDPAddr{Port: 53}, + TCPListenAddr: &net.TCPAddr{Port: 53}, FilteringConfig: FilteringConfig{BlockedResponseTTL: 3600}, } @@ -120,9 +122,9 @@ func (s *Server) startInternal(config *ServerConfig) error { go statsRotator() }) - // TODO: Add TCPListenAddr proxyConfig := proxy.Config{ UDPListenAddr: s.UDPListenAddr, + TCPListenAddr: s.TCPListenAddr, Ratelimit: s.Ratelimit, RatelimitWhitelist: s.RatelimitWhitelist, RefuseAny: s.RefuseAny, @@ -135,6 +137,10 @@ func (s *Server) startInternal(config *ServerConfig) error { proxyConfig.UDPListenAddr = defaultValues.UDPListenAddr } + if proxyConfig.TCPListenAddr == nil { + proxyConfig.TCPListenAddr = defaultValues.TCPListenAddr + } + if len(proxyConfig.Upstreams) == 0 { proxyConfig.Upstreams = defaultValues.Upstreams } diff --git a/dnsforward/dnsforward_test.go b/dnsforward/dnsforward_test.go index fe638c42..0edde88b 100644 --- a/dnsforward/dnsforward_test.go +++ b/dnsforward/dnsforward_test.go @@ -13,34 +13,31 @@ import ( func TestServer(t *testing.T) { s := Server{} s.UDPListenAddr = &net.UDPAddr{Port: 0} + s.TCPListenAddr = &net.TCPAddr{Port: 0} err := s.Start(nil) if err != nil { t.Fatalf("Failed to start server: %s", err) } - // server is running, send a message + // message over UDP + req := createTestMessage() addr := s.dnsProxy.Addr("udp") - req := dns.Msg{} - req.Id = dns.Id() - req.RecursionDesired = true - req.Question = []dns.Question{ - {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, - } - - reply, err := dns.Exchange(&req, addr.String()) + client := dns.Client{Net: "udp"} + reply, _, err := client.Exchange(req, addr.String()) if err != nil { t.Fatalf("Couldn't talk to server %s: %s", addr, err) } - if len(reply.Answer) != 1 { - t.Fatalf("DNS server %s returned reply with wrong number of answers - %d", addr, len(reply.Answer)) - } - if a, ok := reply.Answer[0].(*dns.A); ok { - if !net.IPv4(8, 8, 8, 8).Equal(a.A) { - t.Fatalf("DNS server %s returned wrong answer instead of 8.8.8.8: %v", addr, a.A) - } - } else { - t.Fatalf("DNS server %s returned wrong answer type instead of A: %v", addr, reply.Answer[0]) + assertResponse(t, reply) + + // message over TCP + req = createTestMessage() + addr = s.dnsProxy.Addr("tcp") + client = dns.Client{Net: "tcp"} + reply, _, err = client.Exchange(req, addr.String()) + if err != nil { + t.Fatalf("Couldn't talk to server %s: %s", addr, err) } + assertResponse(t, reply) err = s.Stop() if err != nil { @@ -51,6 +48,7 @@ func TestServer(t *testing.T) { func TestInvalidRequest(t *testing.T) { s := Server{} s.UDPListenAddr = &net.UDPAddr{Port: 0} + s.TCPListenAddr = &net.TCPAddr{Port: 0} err := s.Start(nil) if err != nil { t.Fatalf("Failed to start server: %s", err) @@ -199,6 +197,7 @@ func TestBlockedBySafeBrowsing(t *testing.T) { func createTestServer() *Server { s := Server{} s.UDPListenAddr = &net.UDPAddr{Port: 0} + s.TCPListenAddr = &net.TCPAddr{Port: 0} s.FilteringConfig.FilteringEnabled = true s.FilteringConfig.ProtectionEnabled = true s.FilteringConfig.SafeBrowsingEnabled = true @@ -212,3 +211,26 @@ func createTestServer() *Server { s.Filters = append(s.Filters, filter) return &s } + +func createTestMessage() *dns.Msg { + req := dns.Msg{} + req.Id = dns.Id() + req.RecursionDesired = true + req.Question = []dns.Question{ + {Name: "google-public-dns-a.google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}, + } + return &req +} + +func assertResponse(t *testing.T, reply *dns.Msg) { + if len(reply.Answer) != 1 { + t.Fatalf("DNS server returned reply with wrong number of answers - %d", len(reply.Answer)) + } + if a, ok := reply.Answer[0].(*dns.A); ok { + if !net.IPv4(8, 8, 8, 8).Equal(a.A) { + t.Fatalf("DNS server returned wrong answer instead of 8.8.8.8: %v", a.A) + } + } else { + t.Fatalf("DNS server returned wrong answer type instead of A: %v", reply.Answer[0]) + } +}