diff --git a/dnsforward/dnsforward.go b/dnsforward/dnsforward.go index c3c05ad7..ccfa410b 100644 --- a/dnsforward/dnsforward.go +++ b/dnsforward/dnsforward.go @@ -635,8 +635,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error d.Res = s.genDNSFilterMessage(d, &res) } else if res.Reason == dnsfilter.ReasonRewrite && len(res.IPList) != 0 { - resp := dns.Msg{} - resp.SetReply(req) + resp := s.makeResponse(req) name := host if len(res.CanonName) != 0 { @@ -657,7 +656,7 @@ func (s *Server) filterDNSRequest(d *proxy.DNSContext) (*dnsfilter.Result, error } } - d.Res = &resp + d.Res = resp } return &res, err @@ -711,6 +710,15 @@ func (s *Server) filterResponse(d *proxy.DNSContext) (*dnsfilter.Result, error) return nil, nil } +// Create a DNS response by DNS request and set necessary flags +func (s *Server) makeResponse(req *dns.Msg) *dns.Msg { + resp := dns.Msg{} + resp.SetReply(req) + resp.RecursionAvailable = true + resp.Compress = true + return &resp +} + // genDNSFilterMessage generates a DNS message corresponding to the filtering result func (s *Server) genDNSFilterMessage(d *proxy.DNSContext, result *dnsfilter.Result) *dns.Msg { m := d.Req @@ -758,17 +766,15 @@ func (s *Server) genServerFailure(request *dns.Msg) *dns.Msg { } func (s *Server) genARecord(request *dns.Msg, ip net.IP) *dns.Msg { - resp := dns.Msg{} - resp.SetReply(request) + resp := s.makeResponse(request) resp.Answer = append(resp.Answer, s.genAAnswer(request, ip)) - return &resp + return resp } func (s *Server) genAAAARecord(request *dns.Msg, ip net.IP) *dns.Msg { - resp := dns.Msg{} - resp.SetReply(request) + resp := s.makeResponse(request) resp.Answer = append(resp.Answer, s.genAAAAAnswer(request, ip)) - return &resp + return resp } func (s *Server) genAAnswer(req *dns.Msg, ip net.IP) *dns.A { @@ -804,9 +810,8 @@ func (s *Server) genResponseWithIP(req *dns.Msg, ip net.IP) *dns.Msg { } // empty response - resp := dns.Msg{} - resp.SetReply(req) - return &resp + resp := s.makeResponse(req) + return resp } func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg { @@ -834,9 +839,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo return s.genServerFailure(request) } - resp := dns.Msg{} - resp.SetReply(request) - resp.Authoritative, resp.RecursionAvailable = true, true + resp := s.makeResponse(request) if newContext.Res != nil { for _, answer := range newContext.Res.Answer { answer.Header().Name = request.Question[0].Name @@ -844,7 +847,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo } } - return &resp + return resp } // Make a CNAME response