2018-08-30 07:25:33 -07:00
package dnsfilter
import (
"bufio"
"errors"
"fmt"
"log"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/coredns/coredns/core/dnsserver"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/coredns/coredns/plugin/pkg/dnstest"
"github.com/coredns/coredns/plugin/pkg/upstream"
"github.com/coredns/coredns/request"
"github.com/mholt/caddy"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
"github.com/AdguardTeam/AdguardDNS/dnsfilter"
"golang.org/x/net/context"
)
var defaultSOA = & dns . SOA {
// values copied from verisign's nonexistent .com domain
// their exact values are not important in our use case because they are used for domain transfers between primary/secondary DNS servers
Refresh : 1800 ,
Retry : 900 ,
Expire : 604800 ,
Minttl : 86400 ,
}
func init ( ) {
caddy . RegisterPlugin ( "dnsfilter" , caddy . Plugin {
ServerType : "dns" ,
Action : setup ,
} )
}
2018-09-14 06:50:56 -07:00
type plug struct {
2018-08-30 07:25:33 -07:00
d * dnsfilter . Dnsfilter
Next plugin . Handler
upstream upstream . Upstream
hosts map [ string ] net . IP
SafeBrowsingBlockHost string
ParentalBlockHost string
QueryLogEnabled bool
}
2018-09-14 06:50:56 -07:00
var defaultPlugin = plug {
2018-08-30 07:25:33 -07:00
SafeBrowsingBlockHost : "safebrowsing.block.dns.adguard.com" ,
ParentalBlockHost : "family.block.dns.adguard.com" ,
}
2018-09-14 06:50:56 -07:00
func newDNSCounter ( name string , help string ) prometheus . Counter {
2018-08-30 07:25:33 -07:00
return prometheus . NewCounter ( prometheus . CounterOpts {
Namespace : plugin . Namespace ,
Subsystem : "dnsfilter" ,
Name : name ,
Help : help ,
} )
}
var (
2018-09-14 06:50:56 -07:00
requests = newDNSCounter ( "requests_total" , "Count of requests seen by dnsfilter." )
filtered = newDNSCounter ( "filtered_total" , "Count of requests filtered by dnsfilter." )
filteredLists = newDNSCounter ( "filtered_lists_total" , "Count of requests filtered by dnsfilter using lists." )
filteredSafebrowsing = newDNSCounter ( "filtered_safebrowsing_total" , "Count of requests filtered by dnsfilter using safebrowsing." )
filteredParental = newDNSCounter ( "filtered_parental_total" , "Count of requests filtered by dnsfilter using parental." )
filteredInvalid = newDNSCounter ( "filtered_invalid_total" , "Count of requests filtered by dnsfilter because they were invalid." )
whitelisted = newDNSCounter ( "whitelisted_total" , "Count of requests not filtered by dnsfilter because they are whitelisted." )
safesearch = newDNSCounter ( "safesearch_total" , "Count of requests replaced by dnsfilter safesearch." )
errorsTotal = newDNSCounter ( "errors_total" , "Count of requests that dnsfilter couldn't process because of transitive errors." )
2018-08-30 07:25:33 -07:00
)
//
// coredns handling functions
//
2018-09-14 06:50:56 -07:00
func setupPlugin ( c * caddy . Controller ) ( * plug , error ) {
2018-08-30 07:25:33 -07:00
// create new Plugin and copy default values
2018-09-14 06:50:56 -07:00
var p = new ( plug )
* p = defaultPlugin
p . d = dnsfilter . New ( )
p . hosts = make ( map [ string ] net . IP )
2018-08-30 07:25:33 -07:00
var filterFileName string
for c . Next ( ) {
args := c . RemainingArgs ( )
if len ( args ) == 0 {
// must have at least one argument
return nil , c . ArgErr ( )
}
filterFileName = args [ 0 ]
for c . NextBlock ( ) {
switch c . Val ( ) {
case "safebrowsing" :
2018-09-14 06:50:56 -07:00
p . d . EnableSafeBrowsing ( )
2018-08-30 07:25:33 -07:00
if c . NextArg ( ) {
if len ( c . Val ( ) ) == 0 {
return nil , c . ArgErr ( )
}
2018-09-14 06:50:56 -07:00
p . d . SetSafeBrowsingServer ( c . Val ( ) )
2018-08-30 07:25:33 -07:00
}
case "safesearch" :
2018-09-14 06:50:56 -07:00
p . d . EnableSafeSearch ( )
2018-08-30 07:25:33 -07:00
case "parental" :
if ! c . NextArg ( ) {
return nil , c . ArgErr ( )
}
sensitivity , err := strconv . Atoi ( c . Val ( ) )
if err != nil {
return nil , c . ArgErr ( )
}
2018-09-14 06:50:56 -07:00
err = p . d . EnableParental ( sensitivity )
2018-08-30 07:25:33 -07:00
if err != nil {
return nil , c . ArgErr ( )
}
if c . NextArg ( ) {
if len ( c . Val ( ) ) == 0 {
return nil , c . ArgErr ( )
}
2018-09-14 06:50:56 -07:00
p . ParentalBlockHost = c . Val ( )
2018-08-30 07:25:33 -07:00
}
case "querylog" :
2018-09-14 06:50:56 -07:00
p . QueryLogEnabled = true
2018-09-05 11:21:46 -07:00
onceQueryLog . Do ( func ( ) {
2018-08-30 07:25:33 -07:00
go startQueryLogServer ( ) // TODO: how to handle errors?
} )
}
}
}
file , err := os . Open ( filterFileName )
if err != nil {
return nil , err
}
defer file . Close ( )
2018-09-05 16:07:23 -07:00
count := 0
2018-08-30 07:25:33 -07:00
scanner := bufio . NewScanner ( file )
for scanner . Scan ( ) {
text := scanner . Text ( )
2018-09-14 06:50:56 -07:00
if p . parseEtcHosts ( text ) {
2018-08-30 07:25:33 -07:00
continue
}
2018-09-14 06:50:56 -07:00
err = p . d . AddRule ( text , 0 )
2018-08-30 07:25:33 -07:00
if err == dnsfilter . ErrInvalidSyntax {
continue
}
if err != nil {
return nil , err
}
2018-09-05 16:07:23 -07:00
count ++
2018-08-30 07:25:33 -07:00
}
2018-09-05 16:07:23 -07:00
log . Printf ( "Added %d rules from %s" , count , filterFileName )
2018-08-30 07:25:33 -07:00
if err = scanner . Err ( ) ; err != nil {
return nil , err
}
2018-09-14 06:50:56 -07:00
p . upstream , err = upstream . New ( nil )
2018-08-30 07:25:33 -07:00
if err != nil {
return nil , err
}
2018-09-14 06:50:56 -07:00
return p , nil
2018-08-30 07:25:33 -07:00
}
func setup ( c * caddy . Controller ) error {
2018-09-14 06:50:56 -07:00
p , err := setupPlugin ( c )
2018-08-30 07:25:33 -07:00
if err != nil {
return err
}
config := dnsserver . GetConfig ( c )
config . AddPlugin ( func ( next plugin . Handler ) plugin . Handler {
2018-09-14 06:50:56 -07:00
p . Next = next
return p
2018-08-30 07:25:33 -07:00
} )
c . OnStartup ( func ( ) error {
2018-09-05 16:07:57 -07:00
m := dnsserver . GetConfig ( c ) . Handler ( "prometheus" )
if m == nil {
return nil
}
if x , ok := m . ( * metrics . Metrics ) ; ok {
x . MustRegister ( requests )
x . MustRegister ( filtered )
x . MustRegister ( filteredLists )
x . MustRegister ( filteredSafebrowsing )
x . MustRegister ( filteredParental )
x . MustRegister ( whitelisted )
x . MustRegister ( safesearch )
x . MustRegister ( errorsTotal )
2018-09-14 06:50:56 -07:00
x . MustRegister ( p )
2018-09-05 16:07:57 -07:00
}
2018-08-30 07:25:33 -07:00
return nil
} )
2018-09-14 06:50:56 -07:00
c . OnShutdown ( p . onShutdown )
2018-08-30 07:25:33 -07:00
return nil
}
2018-09-14 06:50:56 -07:00
func ( p * plug ) parseEtcHosts ( text string ) bool {
2018-08-30 07:25:33 -07:00
if pos := strings . IndexByte ( text , '#' ) ; pos != - 1 {
text = text [ 0 : pos ]
}
fields := strings . Fields ( text )
if len ( fields ) < 2 {
return false
}
addr := net . ParseIP ( fields [ 0 ] )
if addr == nil {
return false
}
for _ , host := range fields [ 1 : ] {
2018-09-14 06:50:56 -07:00
if val , ok := p . hosts [ host ] ; ok {
2018-08-30 07:25:33 -07:00
log . Printf ( "warning: host %s already has value %s, will overwrite it with %s" , host , val , addr )
}
2018-09-14 06:50:56 -07:00
p . hosts [ host ] = addr
2018-08-30 07:25:33 -07:00
}
return true
}
2018-09-14 06:50:56 -07:00
func ( p * plug ) onShutdown ( ) error {
p . d . Destroy ( )
p . d = nil
2018-08-30 07:25:33 -07:00
return nil
}
type statsFunc func ( ch interface { } , name string , text string , value float64 , valueType prometheus . ValueType )
func doDesc ( ch interface { } , name string , text string , value float64 , valueType prometheus . ValueType ) {
realch , ok := ch . ( chan <- * prometheus . Desc )
2018-09-14 06:50:56 -07:00
if ! ok {
2018-08-30 07:25:33 -07:00
log . Printf ( "Couldn't convert ch to chan<- *prometheus.Desc\n" )
return
}
realch <- prometheus . NewDesc ( name , text , nil , nil )
}
func doMetric ( ch interface { } , name string , text string , value float64 , valueType prometheus . ValueType ) {
realch , ok := ch . ( chan <- prometheus . Metric )
2018-09-14 06:50:56 -07:00
if ! ok {
2018-08-30 07:25:33 -07:00
log . Printf ( "Couldn't convert ch to chan<- prometheus.Metric\n" )
return
}
desc := prometheus . NewDesc ( name , text , nil , nil )
realch <- prometheus . MustNewConstMetric ( desc , valueType , value )
}
func gen ( ch interface { } , doFunc statsFunc , name string , text string , value float64 , valueType prometheus . ValueType ) {
doFunc ( ch , name , text , value , valueType )
}
2018-09-07 06:10:11 -07:00
func doStatsLookup ( ch interface { } , doFunc statsFunc , name string , lookupstats * dnsfilter . LookupStats ) {
gen ( ch , doFunc , fmt . Sprintf ( "coredns_dnsfilter_%s_requests" , name ) , fmt . Sprintf ( "Number of %s HTTP requests that were sent" , name ) , float64 ( lookupstats . Requests ) , prometheus . CounterValue )
gen ( ch , doFunc , fmt . Sprintf ( "coredns_dnsfilter_%s_cachehits" , name ) , fmt . Sprintf ( "Number of %s lookups that didn't need HTTP requests" , name ) , float64 ( lookupstats . CacheHits ) , prometheus . CounterValue )
gen ( ch , doFunc , fmt . Sprintf ( "coredns_dnsfilter_%s_pending" , name ) , fmt . Sprintf ( "Number of currently pending %s HTTP requests" , name ) , float64 ( lookupstats . Pending ) , prometheus . GaugeValue )
gen ( ch , doFunc , fmt . Sprintf ( "coredns_dnsfilter_%s_pending_max" , name ) , fmt . Sprintf ( "Maximum number of pending %s HTTP requests" , name ) , float64 ( lookupstats . PendingMax ) , prometheus . GaugeValue )
}
2018-09-14 06:50:56 -07:00
func ( p * plug ) doStats ( ch interface { } , doFunc statsFunc ) {
stats := p . d . GetStats ( )
2018-09-07 06:10:11 -07:00
doStatsLookup ( ch , doFunc , "safebrowsing" , & stats . Safebrowsing )
doStatsLookup ( ch , doFunc , "parental" , & stats . Parental )
2018-08-30 07:25:33 -07:00
}
2018-09-14 06:50:56 -07:00
// Describe is called by prometheus handler to know stat types
func ( p * plug ) Describe ( ch chan <- * prometheus . Desc ) {
p . doStats ( ch , doDesc )
2018-08-30 07:25:33 -07:00
}
2018-09-14 06:50:56 -07:00
// Collect is called by prometheus handler to collect stats
func ( p * plug ) Collect ( ch chan <- prometheus . Metric ) {
p . doStats ( ch , doMetric )
2018-08-30 07:25:33 -07:00
}
2018-09-14 06:50:56 -07:00
func ( p * plug ) replaceHostWithValAndReply ( ctx context . Context , w dns . ResponseWriter , r * dns . Msg , host string , val string , question dns . Question ) ( int , error ) {
2018-08-30 07:25:33 -07:00
// check if it's a domain name or IP address
addr := net . ParseIP ( val )
var records [ ] dns . RR
log . Println ( "Will give" , val , "instead of" , host )
if addr != nil {
// this is an IP address, return it
result , err := dns . NewRR ( host + " A " + val )
if err != nil {
log . Printf ( "Got error %s\n" , err )
return dns . RcodeServerFailure , fmt . Errorf ( "plugin/dnsfilter: %s" , err )
}
records = append ( records , result )
} else {
// this is a domain name, need to look it up
req := new ( dns . Msg )
req . SetQuestion ( dns . Fqdn ( val ) , question . Qtype )
req . RecursionDesired = true
reqstate := request . Request { W : w , Req : req , Context : ctx }
2018-09-14 06:50:56 -07:00
result , err := p . upstream . Lookup ( reqstate , dns . Fqdn ( val ) , reqstate . QType ( ) )
2018-08-30 07:25:33 -07:00
if err != nil {
log . Printf ( "Got error %s\n" , err )
return dns . RcodeServerFailure , fmt . Errorf ( "plugin/dnsfilter: %s" , err )
}
if result != nil {
for _ , answer := range result . Answer {
answer . Header ( ) . Name = question . Name
}
records = result . Answer
}
}
m := new ( dns . Msg )
m . SetReply ( r )
m . Authoritative , m . RecursionAvailable , m . Compress = true , true , true
m . Answer = append ( m . Answer , records ... )
state := request . Request { W : w , Req : r , Context : ctx }
state . SizeAndDo ( m )
err := state . W . WriteMsg ( m )
if err != nil {
log . Printf ( "Got error %s\n" , err )
return dns . RcodeServerFailure , fmt . Errorf ( "plugin/dnsfilter: %s" , err )
}
return dns . RcodeSuccess , nil
}
// generate SOA record that makes DNS clients cache NXdomain results
// the only value that is important is TTL in header, other values like refresh, retry, expire and minttl are irrelevant
func genSOA ( r * dns . Msg ) [ ] dns . RR {
zone := r . Question [ 0 ] . Name
header := dns . RR_Header { Name : zone , Rrtype : dns . TypeSOA , Ttl : 3600 , Class : dns . ClassINET }
Mbox := "hostmaster."
if zone [ 0 ] != '.' {
Mbox += zone
}
Ns := "fake-for-negative-caching.adguard.com."
soa := defaultSOA
soa . Hdr = header
soa . Mbox = Mbox
soa . Ns = Ns
soa . Serial = uint32 ( time . Now ( ) . Unix ( ) )
return [ ] dns . RR { soa }
}
func writeNXdomain ( ctx context . Context , w dns . ResponseWriter , r * dns . Msg ) ( int , error ) {
state := request . Request { W : w , Req : r , Context : ctx }
m := new ( dns . Msg )
m . SetRcode ( state . Req , dns . RcodeNameError )
m . Authoritative , m . RecursionAvailable , m . Compress = true , true , true
m . Ns = genSOA ( r )
state . SizeAndDo ( m )
err := state . W . WriteMsg ( m )
if err != nil {
log . Printf ( "Got error %s\n" , err )
return dns . RcodeServerFailure , err
}
return dns . RcodeNameError , nil
}
2018-09-14 06:50:56 -07:00
func ( p * plug ) serveDNSInternal ( ctx context . Context , w dns . ResponseWriter , r * dns . Msg ) ( int , dnsfilter . Result , error ) {
2018-08-30 07:25:33 -07:00
if len ( r . Question ) != 1 {
// google DNS, bind and others do the same
2018-09-14 06:50:56 -07:00
return dns . RcodeFormatError , dnsfilter . Result { } , fmt . Errorf ( "Got DNS request with != 1 questions" )
2018-08-30 07:25:33 -07:00
}
for _ , question := range r . Question {
host := strings . ToLower ( strings . TrimSuffix ( question . Name , "." ) )
// is it a safesearch domain?
2018-09-14 06:50:56 -07:00
if val , ok := p . d . SafeSearchDomain ( host ) ; ok {
rcode , err := p . replaceHostWithValAndReply ( ctx , w , r , host , val , question )
2018-08-30 07:25:33 -07:00
if err != nil {
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { } , err
2018-08-30 07:25:33 -07:00
}
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { Reason : dnsfilter . FilteredSafeSearch } , err
2018-08-30 07:25:33 -07:00
}
// is it in hosts?
2018-09-14 06:50:56 -07:00
if val , ok := p . hosts [ host ] ; ok {
2018-08-30 07:25:33 -07:00
// it is, if it's a loopback host, reply with NXDOMAIN
2018-09-25 08:34:01 -07:00
// TODO: research if it's better than 127.0.0.1
if false && val . IsLoopback ( ) {
2018-08-30 07:25:33 -07:00
rcode , err := writeNXdomain ( ctx , w , r )
if err != nil {
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { } , err
2018-08-30 07:25:33 -07:00
}
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { Reason : dnsfilter . FilteredInvalid } , err
2018-08-30 07:25:33 -07:00
}
// it's not a loopback host, replace it with value specified
2018-09-14 06:50:56 -07:00
rcode , err := p . replaceHostWithValAndReply ( ctx , w , r , host , val . String ( ) , question )
2018-08-30 07:25:33 -07:00
if err != nil {
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { } , err
2018-08-30 07:25:33 -07:00
}
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { Reason : dnsfilter . FilteredSafeSearch } , err
2018-08-30 07:25:33 -07:00
}
// needs to be filtered instead
2018-09-14 06:50:56 -07:00
result , err := p . d . CheckHost ( host )
2018-08-30 07:25:33 -07:00
if err != nil {
log . Printf ( "plugin/dnsfilter: %s\n" , err )
2018-09-14 06:50:56 -07:00
return dns . RcodeServerFailure , dnsfilter . Result { } , fmt . Errorf ( "plugin/dnsfilter: %s" , err )
2018-08-30 07:25:33 -07:00
}
2018-09-05 16:09:57 -07:00
if result . IsFiltered {
switch result . Reason {
case dnsfilter . FilteredSafeBrowsing :
// return cname safebrowsing.block.dns.adguard.com
2018-09-14 06:50:56 -07:00
val := p . SafeBrowsingBlockHost
rcode , err := p . replaceHostWithValAndReply ( ctx , w , r , host , val , question )
2018-09-05 16:09:57 -07:00
if err != nil {
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { } , err
2018-09-05 16:09:57 -07:00
}
2018-09-14 06:50:56 -07:00
return rcode , result , err
2018-09-05 16:09:57 -07:00
case dnsfilter . FilteredParental :
// return cname family.block.dns.adguard.com
2018-09-14 06:50:56 -07:00
val := p . ParentalBlockHost
rcode , err := p . replaceHostWithValAndReply ( ctx , w , r , host , val , question )
2018-09-05 16:09:57 -07:00
if err != nil {
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { } , err
2018-09-05 16:09:57 -07:00
}
2018-09-14 06:50:56 -07:00
return rcode , result , err
2018-09-05 16:09:57 -07:00
case dnsfilter . FilteredBlackList :
// return NXdomain
rcode , err := writeNXdomain ( ctx , w , r )
if err != nil {
2018-09-14 06:50:56 -07:00
return rcode , dnsfilter . Result { } , err
2018-09-05 16:09:57 -07:00
}
2018-09-14 06:50:56 -07:00
return rcode , result , err
2018-09-05 16:09:57 -07:00
default :
log . Printf ( "SHOULD NOT HAPPEN -- got unknown reason for filtering: %T %v %s" , result . Reason , result . Reason , result . Reason . String ( ) )
2018-08-30 07:25:33 -07:00
}
2018-09-05 16:09:57 -07:00
} else {
switch result . Reason {
case dnsfilter . NotFilteredWhiteList :
2018-09-14 06:50:56 -07:00
rcode , err := plugin . NextOrFailure ( p . Name ( ) , p . Next , ctx , w , r )
return rcode , result , err
2018-09-05 16:09:57 -07:00
case dnsfilter . NotFilteredNotFound :
// do nothing, pass through to lower code
default :
log . Printf ( "SHOULD NOT HAPPEN -- got unknown reason for not filtering: %T %v %s" , result . Reason , result . Reason , result . Reason . String ( ) )
2018-08-30 07:25:33 -07:00
}
}
}
2018-09-14 06:50:56 -07:00
rcode , err := plugin . NextOrFailure ( p . Name ( ) , p . Next , ctx , w , r )
return rcode , dnsfilter . Result { } , err
2018-08-30 07:25:33 -07:00
}
2018-09-14 06:50:56 -07:00
// ServeDNS handles the DNS request and refuses if it's in filterlists
func ( p * plug ) ServeDNS ( ctx context . Context , w dns . ResponseWriter , r * dns . Msg ) ( int , error ) {
2018-08-30 07:25:33 -07:00
start := time . Now ( )
requests . Inc ( )
2018-08-31 09:59:04 -07:00
state := request . Request { W : w , Req : r }
ip := state . IP ( )
2018-08-30 07:25:33 -07:00
// capture the written answer
rrw := dnstest . NewRecorder ( w )
2018-09-14 06:50:56 -07:00
rcode , result , err := p . serveDNSInternal ( ctx , rrw , r )
2018-08-30 07:25:33 -07:00
if rcode > 0 {
// actually send the answer if we have one
answer := new ( dns . Msg )
answer . SetRcode ( r , rcode )
state . SizeAndDo ( answer )
2018-09-14 06:50:56 -07:00
err = w . WriteMsg ( answer )
if err != nil {
return dns . RcodeServerFailure , err
}
2018-08-30 07:25:33 -07:00
}
// increment counters
switch {
case err != nil :
errorsTotal . Inc ( )
case result . Reason == dnsfilter . FilteredBlackList :
filtered . Inc ( )
filteredLists . Inc ( )
case result . Reason == dnsfilter . FilteredSafeBrowsing :
filtered . Inc ( )
filteredSafebrowsing . Inc ( )
case result . Reason == dnsfilter . FilteredParental :
filtered . Inc ( )
filteredParental . Inc ( )
case result . Reason == dnsfilter . FilteredInvalid :
filtered . Inc ( )
filteredInvalid . Inc ( )
case result . Reason == dnsfilter . FilteredSafeSearch :
// the request was passsed through but not filtered, don't increment filtered
safesearch . Inc ( )
case result . Reason == dnsfilter . NotFilteredWhiteList :
whitelisted . Inc ( )
case result . Reason == dnsfilter . NotFilteredNotFound :
// do nothing
case result . Reason == dnsfilter . NotFilteredError :
text := "SHOULD NOT HAPPEN: got DNSFILTER_NOTFILTERED_ERROR without err != nil!"
log . Println ( text )
err = errors . New ( text )
rcode = dns . RcodeServerFailure
}
// log
2018-09-14 06:50:56 -07:00
if p . QueryLogEnabled {
2018-09-05 11:21:46 -07:00
logRequest ( r , rrw . Msg , result , time . Since ( start ) , ip )
2018-08-30 07:25:33 -07:00
}
return rcode , err
}
2018-09-14 06:50:56 -07:00
// Name returns name of the plugin as seen in Corefile and plugin.cfg
func ( p * plug ) Name ( ) string { return "dnsfilter" }
2018-08-30 07:25:33 -07:00
2018-09-05 16:07:57 -07:00
var onceQueryLog sync . Once