mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-16 10:28:29 -07:00
1149 lines
28 KiB
Go
1149 lines
28 KiB
Go
package dnsfilter
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/AdguardTeam/dnsproxy/upstream"
|
|
"github.com/AdguardTeam/golibs/log"
|
|
"github.com/bluele/gcache"
|
|
"golang.org/x/net/publicsuffix"
|
|
)
|
|
|
|
const defaultCacheSize = 64 * 1024 // in number of elements
|
|
const defaultCacheTime = 30 * time.Minute
|
|
|
|
const defaultHTTPTimeout = 5 * time.Minute
|
|
const defaultHTTPMaxIdleConnections = 100
|
|
|
|
const defaultSafebrowsingServer = "sb.adtidy.org"
|
|
const defaultSafebrowsingURL = "http://%s/safebrowsing-lookup-hash.html?prefixes=%s"
|
|
const defaultParentalServer = "pctrl.adguard.com"
|
|
const defaultParentalURL = "http://%s/check-parental-control-hash?prefixes=%s&sensitivity=%d"
|
|
const maxDialCacheSize = 2 // the number of host names for safebrowsing and parental control
|
|
|
|
// ErrInvalidSyntax is returned by AddRule when the rule is invalid
|
|
var ErrInvalidSyntax = errors.New("dnsfilter: invalid rule syntax")
|
|
|
|
// ErrAlreadyExists is returned by AddRule when the rule was already added to the filter
|
|
var ErrAlreadyExists = errors.New("dnsfilter: rule was already added")
|
|
|
|
const shortcutLength = 6 // used for rule search optimization, 6 hits the sweet spot
|
|
|
|
const enableFastLookup = true // flag for debugging, must be true in production for faster performance
|
|
const enableDelayedCompilation = true // flag for debugging, must be true in production for faster performance
|
|
|
|
// Config allows you to configure DNS filtering with New() or just change variables directly.
|
|
type Config struct {
|
|
ParentalSensitivity int `yaml:"parental_sensitivity"` // must be either 3, 10, 13 or 17
|
|
ParentalEnabled bool `yaml:"parental_enabled"`
|
|
SafeSearchEnabled bool `yaml:"safesearch_enabled"`
|
|
SafeBrowsingEnabled bool `yaml:"safebrowsing_enabled"`
|
|
ResolverAddress string // DNS server address
|
|
}
|
|
|
|
type privateConfig struct {
|
|
parentalServer string // access via methods
|
|
safeBrowsingServer string // access via methods
|
|
}
|
|
|
|
type rule struct {
|
|
text string // text without @@ decorators or $ options
|
|
shortcut string // for speeding up lookup
|
|
originalText string // original text for reporting back to applications
|
|
ip net.IP // IP address (for the case when we're matching a hosts file)
|
|
|
|
// options
|
|
options []string // optional options after $
|
|
|
|
// parsed options
|
|
apps []string
|
|
isWhitelist bool
|
|
isImportant bool
|
|
|
|
// user-supplied data
|
|
listID int64
|
|
|
|
// suffix matching
|
|
isSuffix bool
|
|
suffix string
|
|
|
|
// compiled regexp
|
|
compiled *regexp.Regexp
|
|
|
|
sync.RWMutex
|
|
}
|
|
|
|
// LookupStats store stats collected during safebrowsing or parental checks
|
|
type LookupStats struct {
|
|
Requests uint64 // number of HTTP requests that were sent
|
|
CacheHits uint64 // number of lookups that didn't need HTTP requests
|
|
Pending int64 // number of currently pending HTTP requests
|
|
PendingMax int64 // maximum number of pending HTTP requests
|
|
}
|
|
|
|
// Stats store LookupStats for safebrowsing, parental and safesearch
|
|
type Stats struct {
|
|
Safebrowsing LookupStats
|
|
Parental LookupStats
|
|
Safesearch LookupStats
|
|
}
|
|
|
|
// Dnsfilter holds added rules and performs hostname matches against the rules
|
|
type Dnsfilter struct {
|
|
storage map[string]bool // rule storage, not used for matching, just for filtering out duplicates
|
|
storageMutex sync.RWMutex
|
|
|
|
// rules are checked against these lists in the order defined here
|
|
important *rulesTable // more important than whitelist and is checked first
|
|
whiteList *rulesTable // more important than blacklist
|
|
blackList *rulesTable
|
|
|
|
// HTTP lookups for safebrowsing and parental
|
|
client http.Client // handle for http client -- single instance as recommended by docs
|
|
transport *http.Transport // handle for http transport used by http client
|
|
|
|
Config // for direct access by library users, even a = assignment
|
|
privateConfig
|
|
}
|
|
|
|
// Filter represents a filter list
|
|
type Filter struct {
|
|
ID int64 `json:"id"` // auto-assigned when filter is added (see nextFilterID), json by default keeps ID uppercase but we need lowercase
|
|
Rules []string `json:"-" yaml:"-"` // not in yaml or json
|
|
}
|
|
|
|
//go:generate stringer -type=Reason
|
|
|
|
// Reason holds an enum detailing why it was filtered or not filtered
|
|
type Reason int
|
|
|
|
const (
|
|
// reasons for not filtering
|
|
|
|
// NotFilteredNotFound - host was not find in any checks, default value for result
|
|
NotFilteredNotFound Reason = iota
|
|
// NotFilteredWhiteList - the host is explicitly whitelisted
|
|
NotFilteredWhiteList
|
|
// NotFilteredError - there was a transitive error during check
|
|
NotFilteredError
|
|
|
|
// reasons for filtering
|
|
|
|
// FilteredBlackList - the host was matched to be advertising host
|
|
FilteredBlackList
|
|
// FilteredSafeBrowsing - the host was matched to be malicious/phishing
|
|
FilteredSafeBrowsing
|
|
// FilteredParental - the host was matched to be outside of parental control settings
|
|
FilteredParental
|
|
// FilteredInvalid - the request was invalid and was not processed
|
|
FilteredInvalid
|
|
// FilteredSafeSearch - the host was replaced with safesearch variant
|
|
FilteredSafeSearch
|
|
)
|
|
|
|
// these variables need to survive coredns reload
|
|
var (
|
|
stats Stats
|
|
dialCache gcache.Cache // "host" -> "IP" cache for safebrowsing and parental control servers
|
|
safebrowsingCache gcache.Cache
|
|
parentalCache gcache.Cache
|
|
safeSearchCache gcache.Cache
|
|
)
|
|
|
|
// Result holds state of hostname check
|
|
type Result struct {
|
|
IsFiltered bool `json:",omitempty"` // True if the host name is filtered
|
|
Reason Reason `json:",omitempty"` // Reason for blocking / unblocking
|
|
Rule string `json:",omitempty"` // Original rule text
|
|
IP net.IP `json:",omitempty"` // Not nil only in the case of a hosts file syntax
|
|
FilterID int64 `json:",omitempty"` // Filter ID the rule belongs to
|
|
}
|
|
|
|
// Matched can be used to see if any match at all was found, no matter filtered or not
|
|
func (r Reason) Matched() bool {
|
|
return r != NotFilteredNotFound
|
|
}
|
|
|
|
// CheckHost tries to match host against rules, then safebrowsing and parental if they are enabled
|
|
func (d *Dnsfilter) CheckHost(host string) (Result, error) {
|
|
// sometimes DNS clients will try to resolve ".", which is a request to get root servers
|
|
if host == "" {
|
|
return Result{Reason: NotFilteredNotFound}, nil
|
|
}
|
|
host = strings.ToLower(host)
|
|
// prevent recursion
|
|
if host == d.parentalServer || host == d.safeBrowsingServer {
|
|
return Result{}, nil
|
|
}
|
|
|
|
// try filter lists first
|
|
result, err := d.matchHost(host)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
if result.Reason.Matched() {
|
|
return result, nil
|
|
}
|
|
|
|
// check safeSearch if no match
|
|
if d.SafeSearchEnabled {
|
|
result, err = d.checkSafeSearch(host)
|
|
if err != nil {
|
|
log.Printf("Failed to safesearch HTTP lookup, ignoring check: %v", err)
|
|
return Result{}, nil
|
|
}
|
|
|
|
if result.Reason.Matched() {
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
// check safebrowsing if no match
|
|
if d.SafeBrowsingEnabled {
|
|
result, err = d.checkSafeBrowsing(host)
|
|
if err != nil {
|
|
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
|
|
log.Printf("Failed to do safebrowsing HTTP lookup, ignoring check: %v", err)
|
|
return Result{}, nil
|
|
}
|
|
if result.Reason.Matched() {
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
// check parental if no match
|
|
if d.ParentalEnabled {
|
|
result, err = d.checkParental(host)
|
|
if err != nil {
|
|
// failed to do HTTP lookup -- treat it as if we got empty response, but don't save cache
|
|
log.Printf("Failed to do parental HTTP lookup, ignoring check: %v", err)
|
|
return Result{}, nil
|
|
}
|
|
if result.Reason.Matched() {
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
// nothing matched, return nothing
|
|
return Result{}, nil
|
|
}
|
|
|
|
//
|
|
// rules table
|
|
//
|
|
|
|
type rulesTable struct {
|
|
rulesByHost map[string]*rule
|
|
rulesByShortcut map[string][]*rule
|
|
rulesLeftovers []*rule
|
|
sync.RWMutex
|
|
}
|
|
|
|
func newRulesTable() *rulesTable {
|
|
return &rulesTable{
|
|
rulesByHost: make(map[string]*rule),
|
|
rulesByShortcut: make(map[string][]*rule),
|
|
rulesLeftovers: make([]*rule, 0),
|
|
}
|
|
}
|
|
|
|
func (r *rulesTable) Add(rule *rule) {
|
|
r.Lock()
|
|
if rule.ip != nil {
|
|
// Hosts syntax
|
|
r.rulesByHost[rule.text] = rule
|
|
} else if len(rule.shortcut) == shortcutLength && enableFastLookup {
|
|
// Adblock syntax with a shortcut
|
|
r.rulesByShortcut[rule.shortcut] = append(r.rulesByShortcut[rule.shortcut], rule)
|
|
} else {
|
|
// Adblock syntax -- too short to have a shortcut
|
|
r.rulesLeftovers = append(r.rulesLeftovers, rule)
|
|
}
|
|
r.Unlock()
|
|
}
|
|
|
|
func (r *rulesTable) matchByHost(host string) (Result, error) {
|
|
// First: examine the hosts-syntax rules
|
|
res, err := r.searchByHost(host)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
if res.Reason.Matched() {
|
|
return res, nil
|
|
}
|
|
|
|
// Second: examine the adblock-syntax rules with shortcuts
|
|
res, err = r.searchShortcuts(host)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
if res.Reason.Matched() {
|
|
return res, nil
|
|
}
|
|
|
|
// Third: examine the others
|
|
res, err = r.searchLeftovers(host)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
if res.Reason.Matched() {
|
|
return res, nil
|
|
}
|
|
|
|
return Result{}, nil
|
|
}
|
|
|
|
func (r *rulesTable) searchByHost(host string) (Result, error) {
|
|
rule, ok := r.rulesByHost[host]
|
|
|
|
if ok {
|
|
return rule.match(host)
|
|
}
|
|
|
|
return Result{}, nil
|
|
}
|
|
|
|
func (r *rulesTable) searchShortcuts(host string) (Result, error) {
|
|
// check in shortcuts first
|
|
for i := 0; i < len(host); i++ {
|
|
shortcut := host[i:]
|
|
if len(shortcut) > shortcutLength {
|
|
shortcut = shortcut[:shortcutLength]
|
|
}
|
|
if len(shortcut) != shortcutLength {
|
|
continue
|
|
}
|
|
rules, ok := r.rulesByShortcut[shortcut]
|
|
if !ok {
|
|
continue
|
|
}
|
|
for _, rule := range rules {
|
|
res, err := rule.match(host)
|
|
// error? stop search
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
// matched? stop search
|
|
if res.Reason.Matched() {
|
|
return res, err
|
|
}
|
|
// continue otherwise
|
|
}
|
|
}
|
|
return Result{}, nil
|
|
}
|
|
|
|
func (r *rulesTable) searchLeftovers(host string) (Result, error) {
|
|
for _, rule := range r.rulesLeftovers {
|
|
res, err := rule.match(host)
|
|
// error? stop search
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
// matched? stop search
|
|
if res.Reason.Matched() {
|
|
return res, err
|
|
}
|
|
// continue otherwise
|
|
}
|
|
return Result{}, nil
|
|
}
|
|
|
|
func findOptionIndex(text string) int {
|
|
for i, r := range text {
|
|
// ignore non-$
|
|
if r != '$' {
|
|
continue
|
|
}
|
|
// ignore `\$`
|
|
if i > 0 && text[i-1] == '\\' {
|
|
continue
|
|
}
|
|
// ignore `$/`
|
|
if i > len(text) && text[i+1] == '/' {
|
|
continue
|
|
}
|
|
return i + 1
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func (rule *rule) extractOptions() error {
|
|
optIndex := findOptionIndex(rule.text)
|
|
if optIndex == 0 { // starts with $
|
|
return ErrInvalidSyntax
|
|
}
|
|
if optIndex == len(rule.text) { // ends with $
|
|
return ErrInvalidSyntax
|
|
}
|
|
if optIndex < 0 {
|
|
return nil
|
|
}
|
|
|
|
optionsStr := rule.text[optIndex:]
|
|
rule.text = rule.text[:optIndex-1] // remove options from text
|
|
|
|
begin := 0
|
|
i := 0
|
|
for i = 0; i < len(optionsStr); i++ {
|
|
switch optionsStr[i] {
|
|
case ',':
|
|
if i > 0 {
|
|
// it might be escaped, if so, ignore
|
|
if optionsStr[i-1] == '\\' {
|
|
break // from switch, not for loop
|
|
}
|
|
}
|
|
rule.options = append(rule.options, optionsStr[begin:i])
|
|
begin = i + 1
|
|
}
|
|
}
|
|
if begin != i {
|
|
// there's still an option remaining
|
|
rule.options = append(rule.options, optionsStr[begin:])
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rule *rule) parseOptions() error {
|
|
err := rule.extractOptions()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, option := range rule.options {
|
|
switch {
|
|
case option == "important":
|
|
rule.isImportant = true
|
|
case strings.HasPrefix(option, "app="):
|
|
option = strings.TrimPrefix(option, "app=")
|
|
rule.apps = strings.Split(option, "|")
|
|
default:
|
|
return ErrInvalidSyntax
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (rule *rule) extractShortcut() {
|
|
// regex rules have no shortcuts
|
|
if rule.text[0] == '/' && rule.text[len(rule.text)-1] == '/' {
|
|
return
|
|
}
|
|
|
|
fields := strings.FieldsFunc(rule.text, func(r rune) bool {
|
|
switch r {
|
|
case '*', '^', '|':
|
|
return true
|
|
}
|
|
return false
|
|
})
|
|
longestField := ""
|
|
for _, field := range fields {
|
|
if len(field) > len(longestField) {
|
|
longestField = field
|
|
}
|
|
}
|
|
if len(longestField) > shortcutLength {
|
|
longestField = longestField[:shortcutLength]
|
|
}
|
|
rule.shortcut = strings.ToLower(longestField)
|
|
}
|
|
|
|
func (rule *rule) compile() error {
|
|
rule.RLock()
|
|
isCompiled := rule.isSuffix || rule.compiled != nil
|
|
rule.RUnlock()
|
|
if isCompiled {
|
|
return nil
|
|
}
|
|
|
|
isSuffix, suffix := getSuffix(rule.text)
|
|
if isSuffix {
|
|
rule.Lock()
|
|
rule.isSuffix = isSuffix
|
|
rule.suffix = suffix
|
|
rule.Unlock()
|
|
return nil
|
|
}
|
|
|
|
expr, err := ruleToRegexp(rule.text)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
compiled, err := regexp.Compile(expr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rule.Lock()
|
|
rule.compiled = compiled
|
|
rule.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Checks if the rule matches the specified host and returns a corresponding Result object
|
|
func (rule *rule) match(host string) (Result, error) {
|
|
res := Result{}
|
|
|
|
if rule.ip != nil && rule.text == host {
|
|
// This is a hosts-syntax rule -- just check that the hostname matches and return the result
|
|
return Result{
|
|
IsFiltered: true,
|
|
Reason: FilteredBlackList,
|
|
Rule: rule.originalText,
|
|
IP: rule.ip,
|
|
FilterID: rule.listID,
|
|
}, nil
|
|
}
|
|
|
|
err := rule.compile()
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
rule.RLock()
|
|
matched := false
|
|
if rule.isSuffix {
|
|
if host == rule.suffix {
|
|
matched = true
|
|
} else if strings.HasSuffix(host, "."+rule.suffix) {
|
|
matched = true
|
|
}
|
|
} else {
|
|
matched = rule.compiled.MatchString(host)
|
|
}
|
|
rule.RUnlock()
|
|
if matched {
|
|
res.Reason = FilteredBlackList
|
|
res.IsFiltered = true
|
|
res.FilterID = rule.listID
|
|
res.Rule = rule.originalText
|
|
if rule.isWhitelist {
|
|
res.Reason = NotFilteredWhiteList
|
|
res.IsFiltered = false
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func getCachedReason(cache gcache.Cache, host string) (result Result, isFound bool, err error) {
|
|
isFound = false // not found yet
|
|
|
|
// get raw value
|
|
rawValue, err := cache.Get(host)
|
|
if err == gcache.KeyNotFoundError {
|
|
// not a real error, just not found
|
|
err = nil
|
|
return
|
|
}
|
|
if err != nil {
|
|
// real error
|
|
return
|
|
}
|
|
|
|
// since it can be something else, validate that it belongs to proper type
|
|
cachedValue, ok := rawValue.(Result)
|
|
if !ok {
|
|
// this is not our type -- error
|
|
text := "SHOULD NOT HAPPEN: entry with invalid type was found in lookup cache"
|
|
log.Println(text)
|
|
err = errors.New(text)
|
|
return
|
|
}
|
|
isFound = ok
|
|
return cachedValue, isFound, err
|
|
}
|
|
|
|
// for each dot, hash it and add it to string
|
|
func hostnameToHashParam(host string, addslash bool) (string, map[string]bool) {
|
|
var hashparam bytes.Buffer
|
|
hashes := map[string]bool{}
|
|
tld, icann := publicsuffix.PublicSuffix(host)
|
|
if !icann {
|
|
// private suffixes like cloudfront.net
|
|
tld = ""
|
|
}
|
|
curhost := host
|
|
for {
|
|
if curhost == "" {
|
|
// we've reached end of string
|
|
break
|
|
}
|
|
if tld != "" && curhost == tld {
|
|
// we've reached the TLD, don't hash it
|
|
break
|
|
}
|
|
tohash := []byte(curhost)
|
|
if addslash {
|
|
tohash = append(tohash, '/')
|
|
}
|
|
sum := sha256.Sum256(tohash)
|
|
hexhash := fmt.Sprintf("%X", sum)
|
|
hashes[hexhash] = true
|
|
hashparam.WriteString(fmt.Sprintf("%02X%02X%02X%02X/", sum[0], sum[1], sum[2], sum[3]))
|
|
pos := strings.IndexByte(curhost, byte('.'))
|
|
if pos < 0 {
|
|
break
|
|
}
|
|
curhost = curhost[pos+1:]
|
|
}
|
|
return hashparam.String(), hashes
|
|
}
|
|
|
|
func (d *Dnsfilter) checkSafeSearch(host string) (Result, error) {
|
|
if log.GetLevel() >= log.DEBUG {
|
|
timer := log.StartTimer()
|
|
defer timer.LogElapsed("SafeSearch HTTP lookup for %s", host)
|
|
}
|
|
|
|
if safeSearchCache == nil {
|
|
safeSearchCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
|
}
|
|
|
|
// Check cache. Return cached result if it was found
|
|
cachedValue, isFound, err := getCachedReason(safeSearchCache, host)
|
|
if isFound {
|
|
atomic.AddUint64(&stats.Safesearch.CacheHits, 1)
|
|
log.Tracef("%s: found in SafeSearch cache", host)
|
|
return cachedValue, nil
|
|
}
|
|
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
|
|
safeHost, ok := d.SafeSearchDomain(host)
|
|
if !ok {
|
|
return Result{}, nil
|
|
}
|
|
|
|
res := Result{IsFiltered: true, Reason: FilteredSafeSearch}
|
|
if ip := net.ParseIP(safeHost); ip != nil {
|
|
res.IP = ip
|
|
err = safeSearchCache.Set(host, res)
|
|
if err != nil {
|
|
return Result{}, nil
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// TODO this address should be resolved with upstream that was configured in dnsforward
|
|
addrs, err := net.LookupIP(safeHost)
|
|
if err != nil {
|
|
log.Tracef("SafeSearchDomain for %s was found but failed to lookup for %s cause %s", host, safeHost, err)
|
|
return Result{}, err
|
|
}
|
|
|
|
for _, i := range addrs {
|
|
if ipv4 := i.To4(); ipv4 != nil {
|
|
res.IP = ipv4
|
|
break
|
|
}
|
|
}
|
|
|
|
if len(res.IP) == 0 {
|
|
return Result{}, fmt.Errorf("no ipv4 addresses in safe search response for %s", safeHost)
|
|
}
|
|
|
|
// Cache result
|
|
err = safeSearchCache.Set(host, res)
|
|
if err != nil {
|
|
return Result{}, nil
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func (d *Dnsfilter) checkSafeBrowsing(host string) (Result, error) {
|
|
if log.GetLevel() >= log.DEBUG {
|
|
timer := log.StartTimer()
|
|
defer timer.LogElapsed("SafeBrowsing HTTP lookup for %s", host)
|
|
}
|
|
|
|
format := func(hashparam string) string {
|
|
url := fmt.Sprintf(defaultSafebrowsingURL, d.safeBrowsingServer, hashparam)
|
|
return url
|
|
}
|
|
handleBody := func(body []byte, hashes map[string]bool) (Result, error) {
|
|
result := Result{}
|
|
scanner := bufio.NewScanner(strings.NewReader(string(body)))
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
splitted := strings.Split(line, ":")
|
|
if len(splitted) < 3 {
|
|
continue
|
|
}
|
|
hash := splitted[2]
|
|
if _, ok := hashes[hash]; ok {
|
|
// it's in the hash
|
|
result.IsFiltered = true
|
|
result.Reason = FilteredSafeBrowsing
|
|
result.Rule = splitted[0]
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
// error, don't save cache
|
|
return Result{}, err
|
|
}
|
|
return result, nil
|
|
}
|
|
if safebrowsingCache == nil {
|
|
safebrowsingCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
|
}
|
|
result, err := d.lookupCommon(host, &stats.Safebrowsing, safebrowsingCache, true, format, handleBody)
|
|
return result, err
|
|
}
|
|
|
|
func (d *Dnsfilter) checkParental(host string) (Result, error) {
|
|
if log.GetLevel() >= log.DEBUG {
|
|
timer := log.StartTimer()
|
|
defer timer.LogElapsed("Parental HTTP lookup for %s", host)
|
|
}
|
|
|
|
format := func(hashparam string) string {
|
|
url := fmt.Sprintf(defaultParentalURL, d.parentalServer, hashparam, d.ParentalSensitivity)
|
|
return url
|
|
}
|
|
handleBody := func(body []byte, hashes map[string]bool) (Result, error) {
|
|
// parse json
|
|
var m []struct {
|
|
Blocked bool `json:"blocked"`
|
|
ClientTTL int `json:"clientTtl"`
|
|
Reason string `json:"reason"`
|
|
Hash string `json:"hash"`
|
|
}
|
|
err := json.Unmarshal(body, &m)
|
|
if err != nil {
|
|
// error, don't save cache
|
|
log.Printf("Couldn't parse json '%s': %s", body, err)
|
|
return Result{}, err
|
|
}
|
|
|
|
result := Result{}
|
|
|
|
for i := range m {
|
|
if !hashes[m[i].Hash] {
|
|
continue
|
|
}
|
|
if m[i].Blocked {
|
|
result.IsFiltered = true
|
|
result.Reason = FilteredParental
|
|
result.Rule = fmt.Sprintf("parental %s", m[i].Reason)
|
|
break
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
if parentalCache == nil {
|
|
parentalCache = gcache.New(defaultCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
|
}
|
|
result, err := d.lookupCommon(host, &stats.Parental, parentalCache, false, format, handleBody)
|
|
return result, err
|
|
}
|
|
|
|
type formatHandler func(hashparam string) string
|
|
type bodyHandler func(body []byte, hashes map[string]bool) (Result, error)
|
|
|
|
// real implementation of lookup/check
|
|
func (d *Dnsfilter) lookupCommon(host string, lookupstats *LookupStats, cache gcache.Cache, hashparamNeedSlash bool, format formatHandler, handleBody bodyHandler) (Result, error) {
|
|
// if host ends with a dot, trim it
|
|
host = strings.ToLower(strings.Trim(host, "."))
|
|
|
|
// check cache
|
|
cachedValue, isFound, err := getCachedReason(cache, host)
|
|
if isFound {
|
|
atomic.AddUint64(&lookupstats.CacheHits, 1)
|
|
log.Tracef("%s: found in the lookup cache", host)
|
|
return cachedValue, nil
|
|
}
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
|
|
// convert hostname to hash parameters
|
|
hashparam, hashes := hostnameToHashParam(host, hashparamNeedSlash)
|
|
|
|
// format URL with our hashes
|
|
url := format(hashparam)
|
|
|
|
// do HTTP request
|
|
atomic.AddUint64(&lookupstats.Requests, 1)
|
|
atomic.AddInt64(&lookupstats.Pending, 1)
|
|
updateMax(&lookupstats.Pending, &lookupstats.PendingMax)
|
|
resp, err := d.client.Get(url)
|
|
atomic.AddInt64(&lookupstats.Pending, -1)
|
|
if resp != nil && resp.Body != nil {
|
|
defer resp.Body.Close()
|
|
}
|
|
if err != nil {
|
|
// error, don't save cache
|
|
return Result{}, err
|
|
}
|
|
|
|
// get body text
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
// error, don't save cache
|
|
return Result{}, err
|
|
}
|
|
|
|
// handle status code
|
|
switch {
|
|
case resp.StatusCode == 204:
|
|
// empty result, save cache
|
|
err = cache.Set(host, Result{})
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
return Result{}, nil
|
|
case resp.StatusCode != 200:
|
|
// error, don't save cache
|
|
return Result{}, nil
|
|
}
|
|
|
|
result, err := handleBody(body, hashes)
|
|
if err != nil {
|
|
// error, don't save cache
|
|
return Result{}, err
|
|
}
|
|
|
|
err = cache.Set(host, result)
|
|
if err != nil {
|
|
return Result{}, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
//
|
|
// Adding rule and matching against the rules
|
|
//
|
|
|
|
// AddRules is a convinience function to add an array of filters in one call
|
|
func (d *Dnsfilter) AddRules(filters []Filter) error {
|
|
for _, f := range filters {
|
|
for _, rule := range f.Rules {
|
|
err := d.AddRule(rule, f.ID)
|
|
if err == ErrAlreadyExists || err == ErrInvalidSyntax {
|
|
continue
|
|
}
|
|
if err != nil {
|
|
log.Printf("Cannot add rule %s: %s", rule, err)
|
|
// Just ignore invalid rules
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// AddRule adds a rule, checking if it is a valid rule first and if it wasn't added already
|
|
func (d *Dnsfilter) AddRule(input string, filterListID int64) error {
|
|
input = strings.TrimSpace(input)
|
|
d.storageMutex.RLock()
|
|
_, exists := d.storage[input]
|
|
d.storageMutex.RUnlock()
|
|
if exists {
|
|
// already added
|
|
return ErrAlreadyExists
|
|
}
|
|
|
|
if !isValidRule(input) {
|
|
return ErrInvalidSyntax
|
|
}
|
|
|
|
// First, check if this is a hosts-syntax rule
|
|
if d.parseEtcHosts(input, filterListID) {
|
|
// This is a valid hosts-syntax rule, no need for further parsing
|
|
return nil
|
|
}
|
|
|
|
// Start parsing the rule
|
|
r := rule{
|
|
text: input, // will be modified
|
|
originalText: input,
|
|
listID: filterListID,
|
|
}
|
|
|
|
// Mark rule as whitelist if it starts with @@
|
|
if strings.HasPrefix(r.text, "@@") {
|
|
r.isWhitelist = true
|
|
r.text = r.text[2:]
|
|
}
|
|
|
|
err := r.parseOptions()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
r.extractShortcut()
|
|
|
|
if !enableDelayedCompilation {
|
|
err := r.compile()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
destination := d.blackList
|
|
if r.isImportant {
|
|
destination = d.important
|
|
} else if r.isWhitelist {
|
|
destination = d.whiteList
|
|
}
|
|
|
|
d.storageMutex.Lock()
|
|
d.storage[input] = true
|
|
d.storageMutex.Unlock()
|
|
destination.Add(&r)
|
|
return nil
|
|
}
|
|
|
|
// Parses the hosts-syntax rules. Returns false if the input string is not of hosts-syntax.
|
|
func (d *Dnsfilter) parseEtcHosts(input string, filterListID int64) bool {
|
|
// Strip the trailing comment
|
|
ruleText := input
|
|
if pos := strings.IndexByte(ruleText, '#'); pos != -1 {
|
|
ruleText = ruleText[0:pos]
|
|
}
|
|
fields := strings.Fields(ruleText)
|
|
if len(fields) < 2 {
|
|
return false
|
|
}
|
|
addr := net.ParseIP(fields[0])
|
|
if addr == nil {
|
|
return false
|
|
}
|
|
|
|
d.storageMutex.Lock()
|
|
d.storage[input] = true
|
|
d.storageMutex.Unlock()
|
|
|
|
for _, host := range fields[1:] {
|
|
r := rule{
|
|
text: host,
|
|
originalText: input,
|
|
listID: filterListID,
|
|
ip: addr,
|
|
}
|
|
d.blackList.Add(&r)
|
|
}
|
|
return true
|
|
}
|
|
|
|
// matchHost is a low-level way to check only if hostname is filtered by rules, skipping expensive safebrowsing and parental lookups
|
|
func (d *Dnsfilter) matchHost(host string) (Result, error) {
|
|
lists := []*rulesTable{
|
|
d.important,
|
|
d.whiteList,
|
|
d.blackList,
|
|
}
|
|
|
|
for _, table := range lists {
|
|
res, err := table.matchByHost(host)
|
|
if err != nil {
|
|
return res, err
|
|
}
|
|
if res.Reason.Matched() {
|
|
return res, nil
|
|
}
|
|
}
|
|
return Result{}, nil
|
|
}
|
|
|
|
//
|
|
// lifecycle helper functions
|
|
//
|
|
|
|
// Return TRUE if this host's IP should be cached
|
|
func (d *Dnsfilter) shouldBeInDialCache(host string) bool {
|
|
return host == d.safeBrowsingServer ||
|
|
host == d.parentalServer
|
|
}
|
|
|
|
// Search for an IP address by host name
|
|
func searchInDialCache(host string) string {
|
|
rawValue, err := dialCache.Get(host)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
|
|
ip, _ := rawValue.(string)
|
|
log.Debug("Found in cache: %s -> %s", host, ip)
|
|
return ip
|
|
}
|
|
|
|
// Add "hostname" -> "IP address" entry to cache
|
|
func addToDialCache(host, ip string) {
|
|
dialCache.Set(host, ip)
|
|
log.Debug("Added to cache: %s -> %s", host, ip)
|
|
}
|
|
|
|
type dialFunctionType func(ctx context.Context, network, addr string) (net.Conn, error)
|
|
|
|
// Connect to a remote server resolving hostname using our own DNS server
|
|
func (d *Dnsfilter) createCustomDialContext(resolverAddr string) dialFunctionType {
|
|
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
log.Tracef("network:%v addr:%v", network, addr)
|
|
|
|
host, port, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dialer := &net.Dialer{
|
|
Timeout: time.Minute * 5,
|
|
}
|
|
|
|
if net.ParseIP(host) != nil {
|
|
con, err := dialer.DialContext(ctx, network, addr)
|
|
return con, err
|
|
}
|
|
|
|
cache := d.shouldBeInDialCache(host)
|
|
if cache {
|
|
ip := searchInDialCache(host)
|
|
if len(ip) != 0 {
|
|
addr = fmt.Sprintf("%s:%s", ip, port)
|
|
return dialer.DialContext(ctx, network, addr)
|
|
}
|
|
}
|
|
|
|
r := upstream.NewResolver(resolverAddr, 30*time.Second)
|
|
addrs, e := r.LookupIPAddr(ctx, host)
|
|
log.Tracef("LookupIPAddr: %s: %v", host, addrs)
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
|
|
var firstErr error
|
|
firstErr = nil
|
|
for _, a := range addrs {
|
|
addr = fmt.Sprintf("%s:%s", a.String(), port)
|
|
con, err := dialer.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
if firstErr == nil {
|
|
firstErr = err
|
|
}
|
|
continue
|
|
}
|
|
|
|
if cache {
|
|
addToDialCache(host, a.String())
|
|
}
|
|
|
|
return con, err
|
|
}
|
|
return nil, firstErr
|
|
}
|
|
}
|
|
|
|
// New creates properly initialized DNS Filter that is ready to be used
|
|
func New(c *Config) *Dnsfilter {
|
|
d := new(Dnsfilter)
|
|
|
|
d.storage = make(map[string]bool)
|
|
d.important = newRulesTable()
|
|
d.whiteList = newRulesTable()
|
|
d.blackList = newRulesTable()
|
|
|
|
// Customize the Transport to have larger connection pool,
|
|
// We are not (re)using http.DefaultTransport because of race conditions found by tests
|
|
d.transport = &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
MaxIdleConns: defaultHTTPMaxIdleConnections, // default 100
|
|
MaxIdleConnsPerHost: defaultHTTPMaxIdleConnections, // default 2
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
if c != nil && len(c.ResolverAddress) != 0 {
|
|
dialCache = gcache.New(maxDialCacheSize).LRU().Expiration(defaultCacheTime).Build()
|
|
d.transport.DialContext = d.createCustomDialContext(c.ResolverAddress)
|
|
}
|
|
d.client = http.Client{
|
|
Transport: d.transport,
|
|
Timeout: defaultHTTPTimeout,
|
|
}
|
|
d.safeBrowsingServer = defaultSafebrowsingServer
|
|
d.parentalServer = defaultParentalServer
|
|
if c != nil {
|
|
d.Config = *c
|
|
}
|
|
|
|
return d
|
|
}
|
|
|
|
// Destroy is optional if you want to tidy up goroutines without waiting for them to die off
|
|
// right now it closes idle HTTP connections if there are any
|
|
func (d *Dnsfilter) Destroy() {
|
|
if d != nil && d.transport != nil {
|
|
d.transport.CloseIdleConnections()
|
|
}
|
|
}
|
|
|
|
//
|
|
// config manipulation helpers
|
|
//
|
|
|
|
// SetSafeBrowsingServer lets you optionally change hostname of safesearch lookup
|
|
func (d *Dnsfilter) SetSafeBrowsingServer(host string) {
|
|
if len(host) == 0 {
|
|
d.safeBrowsingServer = defaultSafebrowsingServer
|
|
} else {
|
|
d.safeBrowsingServer = host
|
|
}
|
|
}
|
|
|
|
// SetHTTPTimeout lets you optionally change timeout during lookups
|
|
func (d *Dnsfilter) SetHTTPTimeout(t time.Duration) {
|
|
d.client.Timeout = t
|
|
}
|
|
|
|
// ResetHTTPTimeout resets lookup timeouts
|
|
func (d *Dnsfilter) ResetHTTPTimeout() {
|
|
d.client.Timeout = defaultHTTPTimeout
|
|
}
|
|
|
|
// SafeSearchDomain returns replacement address for search engine
|
|
func (d *Dnsfilter) SafeSearchDomain(host string) (string, bool) {
|
|
if d.SafeSearchEnabled {
|
|
val, ok := safeSearchDomains[host]
|
|
return val, ok
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
//
|
|
// stats
|
|
//
|
|
|
|
// GetStats return dns filtering stats since startup
|
|
func (d *Dnsfilter) GetStats() Stats {
|
|
return stats
|
|
}
|
|
|
|
// Count returns number of rules added to filter
|
|
func (d *Dnsfilter) Count() int {
|
|
return len(d.storage)
|
|
}
|