mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-15 09:58:42 -07:00
Pull request: all: allow clientid in access settings
Updates #2624. Updates #3162. Squashed commit of the following: commit 68860da717a23a0bfeba14b7fe10b5e4ad38726d Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:41:33 2021 +0300 all: imp types, names commit ebd4ec26636853d0d58c4e331e6a78feede20813 Merge: 239eb72116e5e09c
Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:14:33 2021 +0300 Merge branch 'master' into 2624-clientid-access commit 239eb7215abc47e99a0300a0f4cf56002689b1a9 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 15:13:10 2021 +0300 all: fix client blocking check commit e6bece3ea8367b3cbe3d90702a3368c870ad4f13 Merge: 9935f2a39d1656b5
Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Jun 29 13:12:28 2021 +0300 Merge branch 'master' into 2624-clientid-access commit 9935f2a30bcfae2b853f3ef610c0ab7a56a8f448 Author: Ildar Kamalov <ik@adguard.com> Date: Tue Jun 29 11:26:51 2021 +0300 client: show block button for client id commit ed786a6a74a081cd89e9d67df3537a4fadd54831 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 15:56:23 2021 +0300 client: imp i18n commit 4fed21c68473ad408960c08a7d87624cabce1911 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 15:34:09 2021 +0300 all: imp i18n, docs commit 55e65c0d6b939560c53dcb834a4557eb3853d194 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Fri Jun 25 13:34:01 2021 +0300 all: fix cache, imp code, docs, tests commit c1e5a83e76deb44b1f92729bb9ddfcc6a96ac4a8 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Thu Jun 24 19:27:12 2021 +0300 all: allow clientid in access settings
This commit is contained in:
parent
16e5e09c2e
commit
e08a64ebe4
@ -15,6 +15,7 @@ and this project adheres to
|
||||
|
||||
### Added
|
||||
|
||||
- Blocking access using client IDs ([#2624], [#3162]).
|
||||
- `source` directives support in `/etc/network/interfaces` on Linux ([#3257]).
|
||||
- RFC 9000 support in DNS-over-QUIC.
|
||||
- Completely disabling statistics by setting the statistics interval to zero
|
||||
@ -80,9 +81,11 @@ released by then.
|
||||
[#2439]: https://github.com/AdguardTeam/AdGuardHome/issues/2439
|
||||
[#2441]: https://github.com/AdguardTeam/AdGuardHome/issues/2441
|
||||
[#2443]: https://github.com/AdguardTeam/AdGuardHome/issues/2443
|
||||
[#2624]: https://github.com/AdguardTeam/AdGuardHome/issues/2624
|
||||
[#2763]: https://github.com/AdguardTeam/AdGuardHome/issues/2763
|
||||
[#3013]: https://github.com/AdguardTeam/AdGuardHome/issues/3013
|
||||
[#3136]: https://github.com/AdguardTeam/AdGuardHome/issues/3136
|
||||
[#3162]: https://github.com/AdguardTeam/AdGuardHome/issues/3162
|
||||
[#3166]: https://github.com/AdguardTeam/AdGuardHome/issues/3166
|
||||
[#3172]: https://github.com/AdguardTeam/AdGuardHome/issues/3172
|
||||
[#3184]: https://github.com/AdguardTeam/AdGuardHome/issues/3184
|
||||
|
@ -159,8 +159,10 @@ attributes to make it work in Markdown renderers that strip "id". -->
|
||||
|
||||
* Minimize scope of variables as much as possible.
|
||||
|
||||
* No shadowing, since it can often lead to subtle bugs, especially with
|
||||
errors.
|
||||
* No name shadowing, including of predeclared identifiers, since it can often
|
||||
lead to subtle bugs, especially with errors. This rule does not apply to
|
||||
struct fields, since they are always used together with the name of the
|
||||
struct value, so there isn't any confusion.
|
||||
|
||||
* Prefer constants to variables where possible. Avoid global variables. Use
|
||||
[constant errors] instead of `errors.New`.
|
||||
|
@ -426,9 +426,9 @@
|
||||
"access_title": "Access settings",
|
||||
"access_desc": "Here you can configure access rules for the AdGuard Home DNS server.",
|
||||
"access_allowed_title": "Allowed clients",
|
||||
"access_allowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will accept requests from these IP addresses only.",
|
||||
"access_allowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will accept requests only from these clients.",
|
||||
"access_disallowed_title": "Disallowed clients",
|
||||
"access_disallowed_desc": "A list of CIDR or IP addresses. If configured, AdGuard Home will drop requests from these IP addresses.",
|
||||
"access_disallowed_desc": "A list of CIDRs, IP addresses, or client IDs. If configured, AdGuard Home will drop requests from these clients. If allowed clients are configured, this field is ignored.",
|
||||
"access_blocked_title": "Disallowed domains",
|
||||
"access_blocked_desc": "Not to be confused with filters. AdGuard Home drops DNS queries matching these domains, and these queries don't even appear in the query log. You can specify exact domain names, wildcards, or URL filter rules, e.g. \"example.org\", \"*.example.org\", or \"||example.org^\" correspondingly.",
|
||||
"access_settings_saved": "Access settings successfully saved",
|
||||
|
@ -9,7 +9,7 @@ import Card from '../ui/Card';
|
||||
import Cell from '../ui/Cell';
|
||||
|
||||
import { getPercent, sortIp } from '../../helpers/helpers';
|
||||
import { BLOCK_ACTIONS, R_CLIENT_ID, STATUS_COLORS } from '../../helpers/constants';
|
||||
import { BLOCK_ACTIONS, STATUS_COLORS } from '../../helpers/constants';
|
||||
import { toggleClientBlock } from '../../actions/access';
|
||||
import { renderFormattedClientCell } from '../../helpers/renderFormattedClientCell';
|
||||
import { getStats } from '../../actions/stats';
|
||||
@ -35,10 +35,6 @@ const CountCell = (row) => {
|
||||
};
|
||||
|
||||
const renderBlockingButton = (ip, disallowed, disallowed_rule) => {
|
||||
if (R_CLIENT_ID.test(ip)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const dispatch = useDispatch();
|
||||
const { t } = useTranslation();
|
||||
const processingSet = useSelector((state) => state.access.processingSet);
|
||||
|
2
go.mod
2
go.mod
@ -3,7 +3,7 @@ module github.com/AdguardTeam/AdGuardHome
|
||||
go 1.16
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.37.7
|
||||
github.com/AdguardTeam/dnsproxy v0.38.0
|
||||
github.com/AdguardTeam/golibs v0.8.0
|
||||
github.com/AdguardTeam/urlfilter v0.14.6
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
|
4
go.sum
4
go.sum
@ -9,8 +9,8 @@ dmitri.shuralyov.com/state v0.0.0-20180228185332-28bcc343414c/go.mod h1:0PRwlb0D
|
||||
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
|
||||
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf h1:gc042VRSIRSUzZ+Px6xQCRWNJZTaPkomisDfUZmoFNk=
|
||||
github.com/AdguardTeam/dhcp v0.0.0-20210519141215-51808c73c0bf/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
|
||||
github.com/AdguardTeam/dnsproxy v0.37.7 h1:yp0vEVYobf/1l8iY7es9yMqguw8BUEeC74OGA4G2v2A=
|
||||
github.com/AdguardTeam/dnsproxy v0.37.7/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M=
|
||||
github.com/AdguardTeam/dnsproxy v0.38.0 h1:7GyyNJOieIVOgdnhu47exqWjHPQro7wQhqzvQjaZt6M=
|
||||
github.com/AdguardTeam/dnsproxy v0.38.0/go.mod h1:xMfevPAwpK1ULoLO0CARg/OiUsPH92kfyliXhPTc62M=
|
||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
|
@ -27,10 +27,9 @@ type EtcHostsContainer struct {
|
||||
lock sync.RWMutex
|
||||
// table is the host-to-IPs map.
|
||||
table map[string][]net.IP
|
||||
// tableReverse is the IP-to-hosts map.
|
||||
//
|
||||
// TODO(a.garipov): Make better use of newtypes. Perhaps a custom map.
|
||||
tableReverse map[string][]string
|
||||
// tableReverse is the IP-to-hosts map. The type of the values in the
|
||||
// map is []string.
|
||||
tableReverse *IPMap
|
||||
|
||||
hostsFn string // path to the main hosts-file
|
||||
hostsDirs []string // paths to OS-specific directories with hosts-files
|
||||
@ -80,7 +79,7 @@ func (ehc *EtcHostsContainer) Init(hostsFn string) {
|
||||
var err error
|
||||
ehc.watcher, err = fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,7 +140,7 @@ func (ehc *EtcHostsContainer) Process(host string, qtype uint16) []net.IP {
|
||||
copy(ipsCopy, ips)
|
||||
}
|
||||
|
||||
log.Debug("etchostscontainer: answer: %s -> %v", host, ipsCopy)
|
||||
log.Debug("etchosts: answer: %s -> %v", host, ipsCopy)
|
||||
return ipsCopy
|
||||
}
|
||||
|
||||
@ -151,38 +150,40 @@ func (ehc *EtcHostsContainer) ProcessReverse(addr string, qtype uint16) (hosts [
|
||||
return nil
|
||||
}
|
||||
|
||||
ipReal := UnreverseAddr(addr)
|
||||
if ipReal == nil {
|
||||
ip := UnreverseAddr(addr)
|
||||
if ip == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipStr := ipReal.String()
|
||||
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
hosts = ehc.tableReverse[ipStr]
|
||||
|
||||
if len(hosts) == 0 {
|
||||
return nil // not found
|
||||
v, ok := ehc.tableReverse.Get(ip)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("etchostscontainer: reverse-lookup: %s -> %s", addr, hosts)
|
||||
hosts, ok = v.([]string)
|
||||
if !ok {
|
||||
log.Error("etchosts: bad type %T in tableReverse for %s", v, ip)
|
||||
|
||||
return nil
|
||||
} else if len(hosts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debug("etchosts: reverse-lookup: %s -> %s", addr, hosts)
|
||||
|
||||
return hosts
|
||||
}
|
||||
|
||||
// List returns an IP-to-hostnames table. It is safe for concurrent use.
|
||||
func (ehc *EtcHostsContainer) List() (ipToHosts map[string][]string) {
|
||||
// List returns an IP-to-hostnames table. The type of the values in the map is
|
||||
// []string. It is safe for concurrent use.
|
||||
func (ehc *EtcHostsContainer) List() (ipToHosts *IPMap) {
|
||||
ehc.lock.RLock()
|
||||
defer ehc.lock.RUnlock()
|
||||
|
||||
ipToHosts = make(map[string][]string, len(ehc.tableReverse))
|
||||
for k, v := range ehc.tableReverse {
|
||||
ipToHosts[k] = v
|
||||
}
|
||||
|
||||
return ipToHosts
|
||||
return ehc.tableReverse.ShallowClone()
|
||||
}
|
||||
|
||||
// update table
|
||||
@ -205,29 +206,31 @@ func (ehc *EtcHostsContainer) updateTable(table map[string][]net.IP, host string
|
||||
ok = true
|
||||
}
|
||||
if ok {
|
||||
log.Debug("etchostscontainer: added %s -> %s", ipAddr, host)
|
||||
log.Debug("etchosts: added %s -> %s", ipAddr, host)
|
||||
}
|
||||
}
|
||||
|
||||
// updateTableRev updates the reverse address table.
|
||||
func (ehc *EtcHostsContainer) updateTableRev(tableRev map[string][]string, newHost string, ipAddr net.IP) {
|
||||
ipStr := ipAddr.String()
|
||||
hosts, ok := tableRev[ipStr]
|
||||
func (ehc *EtcHostsContainer) updateTableRev(tableRev *IPMap, newHost string, ip net.IP) {
|
||||
v, ok := tableRev.Get(ip)
|
||||
if !ok {
|
||||
tableRev[ipStr] = []string{newHost}
|
||||
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
|
||||
tableRev.Set(ip, []string{newHost})
|
||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
hosts, _ := v.([]string)
|
||||
for _, host := range hosts {
|
||||
if host == newHost {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
tableRev[ipStr] = append(tableRev[ipStr], newHost)
|
||||
log.Debug("etchostscontainer: added reverse-address %s -> %s", ipStr, newHost)
|
||||
hosts = append(hosts, newHost)
|
||||
tableRev.Set(ip, hosts)
|
||||
|
||||
log.Debug("etchosts: added reverse-address %s -> %s", ip, newHost)
|
||||
}
|
||||
|
||||
// parseHostsLine parses hosts from the fields.
|
||||
@ -255,12 +258,12 @@ func parseHostsLine(fields []string) (hosts []string) {
|
||||
// line for one IP are supported.
|
||||
func (ehc *EtcHostsContainer) load(
|
||||
table map[string][]net.IP,
|
||||
tableRev map[string][]string,
|
||||
tableRev *IPMap,
|
||||
fn string,
|
||||
) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
log.Error("etchosts: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -268,11 +271,11 @@ func (ehc *EtcHostsContainer) load(
|
||||
defer func() {
|
||||
derr := f.Close()
|
||||
if derr != nil {
|
||||
log.Error("etchostscontainer: closing file: %s", err)
|
||||
log.Error("etchosts: closing file: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debug("etchostscontainer: loading hosts from file %s", fn)
|
||||
log.Debug("etchosts: loading hosts from file %s", fn)
|
||||
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
@ -296,7 +299,7 @@ func (ehc *EtcHostsContainer) load(
|
||||
|
||||
err = s.Err()
|
||||
if err != nil {
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -334,7 +337,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||
}
|
||||
|
||||
if event.Op&fsnotify.Write == fsnotify.Write {
|
||||
log.Debug("etchostscontainer: modified: %s", event.Name)
|
||||
log.Debug("etchosts: modified: %s", event.Name)
|
||||
ehc.updateHosts()
|
||||
}
|
||||
|
||||
@ -342,7 +345,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
log.Error("etchosts: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -350,7 +353,7 @@ func (ehc *EtcHostsContainer) watcherLoop() {
|
||||
// updateHosts - loads system hosts
|
||||
func (ehc *EtcHostsContainer) updateHosts() {
|
||||
table := make(map[string][]net.IP)
|
||||
tableRev := make(map[string][]string)
|
||||
tableRev := NewIPMap(0)
|
||||
|
||||
ehc.load(table, tableRev, ehc.hostsFn)
|
||||
|
||||
@ -358,7 +361,7 @@ func (ehc *EtcHostsContainer) updateHosts() {
|
||||
des, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
log.Error("etchostscontainer: Opening directory: %q: %s", dir, err)
|
||||
log.Error("etchosts: Opening directory: %q: %s", dir, err)
|
||||
}
|
||||
|
||||
continue
|
||||
|
@ -70,7 +70,7 @@ func TestEtcHostsContainerResolution(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("hosts_file", func(t *testing.T) {
|
||||
names, ok := ehc.List()["127.0.0.1"]
|
||||
names, ok := ehc.List().Get(net.IP{127, 0, 0, 1})
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, []string{"host", "localhost"}, names)
|
||||
})
|
||||
|
112
internal/aghnet/ipmap.go
Normal file
112
internal/aghnet/ipmap.go
Normal file
@ -0,0 +1,112 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ipArr is a representation of an IP address as an array of bytes.
|
||||
type ipArr [16]byte
|
||||
|
||||
// String implements the fmt.Stringer interface for ipArr.
|
||||
func (a ipArr) String() (s string) {
|
||||
return net.IP(a[:]).String()
|
||||
}
|
||||
|
||||
// IPMap is a map of IP addresses.
|
||||
type IPMap struct {
|
||||
m map[ipArr]interface{}
|
||||
}
|
||||
|
||||
// NewIPMap returns a new empty IP map using hint as a size hint for the
|
||||
// underlying map.
|
||||
func NewIPMap(hint int) (m *IPMap) {
|
||||
return &IPMap{
|
||||
m: make(map[ipArr]interface{}, hint),
|
||||
}
|
||||
}
|
||||
|
||||
// ipToArr converts a net.IP into an ipArr.
|
||||
//
|
||||
// TODO(a.garipov): Use the slice-to-array conversion in Go 1.17.
|
||||
func ipToArr(ip net.IP) (a ipArr) {
|
||||
copy(a[:], ip.To16())
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
// Del deletes ip from the map. Calling Del on a nil *IPMap has no effect, just
|
||||
// like delete on an empty map doesn't.
|
||||
func (m *IPMap) Del(ip net.IP) {
|
||||
if m != nil {
|
||||
delete(m.m, ipToArr(ip))
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the value from the map. Calling Get on a nil *IPMap returns nil
|
||||
// and false, just like indexing on an empty map does.
|
||||
func (m *IPMap) Get(ip net.IP) (v interface{}, ok bool) {
|
||||
if m != nil {
|
||||
v, ok = m.m[ipToArr(ip)]
|
||||
|
||||
return v, ok
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Len returns the length of the map. A nil *IPMap has a length of zero, just
|
||||
// like an empty map.
|
||||
func (m *IPMap) Len() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(m.m)
|
||||
}
|
||||
|
||||
// Range calls f for each key and value present in the map in an undefined
|
||||
// order. If cont is false, range stops the iteration. Calling Range on a nil
|
||||
// *IPMap has no effect, just like ranging over a nil map.
|
||||
func (m *IPMap) Range(f func(ip net.IP, v interface{}) (cont bool)) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for k, v := range m.m {
|
||||
if !f(net.IP(k[:]), v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the value. Set panics if the m is a nil *IPMap, just like a nil map
|
||||
// does.
|
||||
func (m *IPMap) Set(ip net.IP, v interface{}) {
|
||||
m.m[ipToArr(ip)] = v
|
||||
}
|
||||
|
||||
// ShallowClone returns a shallow clone of the map.
|
||||
func (m *IPMap) ShallowClone() (sclone *IPMap) {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sclone = NewIPMap(m.Len())
|
||||
m.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
sclone.Set(ip, v)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
return sclone
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface for *IPMap.
|
||||
func (m *IPMap) String() (s string) {
|
||||
if m == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
|
||||
return fmt.Sprint(m.m)
|
||||
}
|
142
internal/aghnet/ipmap_test.go
Normal file
142
internal/aghnet/ipmap_test.go
Normal file
@ -0,0 +1,142 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIPMap_allocs(t *testing.T) {
|
||||
ip4 := net.IP{1, 2, 3, 4}
|
||||
m := NewIPMap(0)
|
||||
m.Set(ip4, 42)
|
||||
|
||||
t.Run("get", func(t *testing.T) {
|
||||
var v interface{}
|
||||
var ok bool
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
v, ok = m.Get(ip4)
|
||||
})
|
||||
|
||||
require.True(t, ok)
|
||||
require.Equal(t, 42, v)
|
||||
|
||||
assert.Equal(t, float64(0), allocs)
|
||||
})
|
||||
|
||||
t.Run("len", func(t *testing.T) {
|
||||
var n int
|
||||
allocs := testing.AllocsPerRun(100, func() {
|
||||
n = m.Len()
|
||||
})
|
||||
|
||||
require.Equal(t, 1, n)
|
||||
|
||||
assert.Equal(t, float64(0), allocs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIPMap(t *testing.T) {
|
||||
ip4 := net.IP{1, 2, 3, 4}
|
||||
ip6 := net.IP{
|
||||
0x12, 0x34, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x56, 0x78,
|
||||
}
|
||||
|
||||
val := 42
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
var m *IPMap
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
m.Del(ip4)
|
||||
m.Del(ip6)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
v, ok := m.Get(ip4)
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
|
||||
v, ok = m.Get(ip6)
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
assert.Equal(t, 0, m.Len())
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
n := 0
|
||||
m.Range(func(_ net.IP, _ interface{}) (cont bool) {
|
||||
n++
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, 0, n)
|
||||
})
|
||||
|
||||
assert.Panics(t, func() {
|
||||
m.Set(ip4, val)
|
||||
})
|
||||
|
||||
assert.Panics(t, func() {
|
||||
m.Set(ip6, val)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
sclone := m.ShallowClone()
|
||||
assert.Nil(t, sclone)
|
||||
})
|
||||
})
|
||||
|
||||
testIPMap := func(t *testing.T, ip net.IP, s string) {
|
||||
m := NewIPMap(0)
|
||||
assert.Equal(t, 0, m.Len())
|
||||
|
||||
v, ok := m.Get(ip)
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
|
||||
m.Set(ip, val)
|
||||
v, ok = m.Get(ip)
|
||||
assert.Equal(t, val, v)
|
||||
assert.True(t, ok)
|
||||
|
||||
n := 0
|
||||
m.Range(func(ipKey net.IP, v interface{}) (cont bool) {
|
||||
assert.Equal(t, ip.To16(), ipKey)
|
||||
assert.Equal(t, val, v)
|
||||
|
||||
n++
|
||||
|
||||
return false
|
||||
})
|
||||
assert.Equal(t, 1, n)
|
||||
|
||||
sclone := m.ShallowClone()
|
||||
assert.Equal(t, m, sclone)
|
||||
|
||||
assert.Equal(t, s, m.String())
|
||||
|
||||
m.Del(ip)
|
||||
v, ok = m.Get(ip)
|
||||
assert.Nil(t, v)
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, m.Len())
|
||||
}
|
||||
|
||||
t.Run("ipv4", func(t *testing.T) {
|
||||
testIPMap(t, ip4, "map[1.2.3.4:42]")
|
||||
})
|
||||
|
||||
t.Run("ipv6", func(t *testing.T) {
|
||||
testIPMap(t, ip6, "map[1234::5678:42]")
|
||||
})
|
||||
}
|
@ -6,138 +6,163 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
)
|
||||
|
||||
// accessCtx controls IP and client blocking that takes place before all other
|
||||
// processing. An accessCtx is safe for concurrent use.
|
||||
type accessCtx struct {
|
||||
lock sync.Mutex
|
||||
allowedIPs *aghnet.IPMap
|
||||
blockedIPs *aghnet.IPMap
|
||||
|
||||
// allowedClients are the IP addresses of clients in the allowlist.
|
||||
allowedClients *aghstrings.Set
|
||||
allowedClientIDs *aghstrings.Set
|
||||
blockedClientIDs *aghstrings.Set
|
||||
|
||||
// disallowedClients are the IP addresses of clients in the blocklist.
|
||||
disallowedClients *aghstrings.Set
|
||||
blockedHostsEng *urlfilter.DNSEngine
|
||||
|
||||
allowedClientsIPNet []net.IPNet // CIDRs of whitelist clients
|
||||
disallowedClientsIPNet []net.IPNet // CIDRs of clients that should be blocked
|
||||
|
||||
blockedHostsEngine *urlfilter.DNSEngine // finds hosts that should be blocked
|
||||
// TODO(a.garipov): Create a type for a set of IP networks.
|
||||
// aghnet.IPNetSet?
|
||||
allowedNets []*net.IPNet
|
||||
blockedNets []*net.IPNet
|
||||
}
|
||||
|
||||
func newAccessCtx(allowedClients, disallowedClients, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedClients: aghstrings.NewSet(),
|
||||
disallowedClients: aghstrings.NewSet(),
|
||||
}
|
||||
// unit is a convenient alias for struct{}
|
||||
type unit = struct{}
|
||||
|
||||
err = processIPCIDRArray(a.allowedClients, &a.allowedClientsIPNet, allowedClients)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("processing allowed clients: %w", err)
|
||||
}
|
||||
// processAccessClients is a helper for processing a list of client strings,
|
||||
// which may be an IP address, a CIDR, or a ClientID.
|
||||
func processAccessClients(
|
||||
clientStrs []string,
|
||||
ips *aghnet.IPMap,
|
||||
nets *[]*net.IPNet,
|
||||
clientIDs *aghstrings.Set,
|
||||
) (err error) {
|
||||
for i, s := range clientStrs {
|
||||
if ip := net.ParseIP(s); ip != nil {
|
||||
ips.Set(ip, unit{})
|
||||
} else if cidrIP, ipnet, cidrErr := net.ParseCIDR(s); cidrErr == nil {
|
||||
ipnet.IP = cidrIP
|
||||
*nets = append(*nets, ipnet)
|
||||
} else {
|
||||
idErr := ValidateClientID(s)
|
||||
if idErr != nil {
|
||||
return fmt.Errorf(
|
||||
"value %q at index %d: bad ip, cidr, or clientid",
|
||||
s,
|
||||
i,
|
||||
)
|
||||
}
|
||||
|
||||
err = processIPCIDRArray(a.disallowedClients, &a.disallowedClientsIPNet, disallowedClients)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("processing disallowed clients: %w", err)
|
||||
}
|
||||
|
||||
b := &strings.Builder{}
|
||||
for _, s := range blockedHosts {
|
||||
aghstrings.WriteToBuilder(b, strings.ToLower(s), "\n")
|
||||
}
|
||||
|
||||
listArray := []filterlist.RuleList{}
|
||||
list := &filterlist.StringRuleList{
|
||||
ID: int(0),
|
||||
RulesText: b.String(),
|
||||
IgnoreCosmetic: true,
|
||||
}
|
||||
listArray = append(listArray, list)
|
||||
rulesStorage, err := filterlist.NewRuleStorage(listArray)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("filterlist.NewRuleStorage(): %w", err)
|
||||
}
|
||||
a.blockedHostsEngine = urlfilter.NewDNSEngine(rulesStorage)
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Split array of IP or CIDR into 2 containers for fast search
|
||||
func processIPCIDRArray(dst *aghstrings.Set, dstIPNet *[]net.IPNet, src []string) error {
|
||||
for _, s := range src {
|
||||
ip := net.ParseIP(s)
|
||||
if ip != nil {
|
||||
dst.Add(s)
|
||||
|
||||
continue
|
||||
clientIDs.Add(s)
|
||||
}
|
||||
|
||||
_, ipnet, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*dstIPNet = append(*dstIPNet, *ipnet)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsBlockedIP - return TRUE if this client should be blocked
|
||||
// Returns the item from the "disallowedClients" list that lead to blocking IP.
|
||||
// If it returns TRUE and an empty string, it means that the "allowedClients" is not empty,
|
||||
// but the ip does not belong to it.
|
||||
func (a *accessCtx) IsBlockedIP(ip net.IP) (bool, string) {
|
||||
ipStr := ip.String()
|
||||
// newAccessCtx creates a new accessCtx.
|
||||
func newAccessCtx(allowed, blocked, blockedHosts []string) (a *accessCtx, err error) {
|
||||
a = &accessCtx{
|
||||
allowedIPs: aghnet.NewIPMap(0),
|
||||
blockedIPs: aghnet.NewIPMap(0),
|
||||
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
|
||||
if a.allowedClients.Len() != 0 || len(a.allowedClientsIPNet) != 0 {
|
||||
if a.allowedClients.Has(ipStr) {
|
||||
return false, ""
|
||||
}
|
||||
|
||||
if len(a.allowedClientsIPNet) != 0 {
|
||||
for _, ipnet := range a.allowedClientsIPNet {
|
||||
if ipnet.Contains(ip) {
|
||||
return false, ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, ""
|
||||
allowedClientIDs: aghstrings.NewSet(),
|
||||
blockedClientIDs: aghstrings.NewSet(),
|
||||
}
|
||||
|
||||
if a.disallowedClients.Has(ipStr) {
|
||||
return true, ipStr
|
||||
err = processAccessClients(allowed, a.allowedIPs, &a.allowedNets, a.allowedClientIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding allowed: %w", err)
|
||||
}
|
||||
|
||||
if len(a.disallowedClientsIPNet) != 0 {
|
||||
for _, ipnet := range a.disallowedClientsIPNet {
|
||||
if ipnet.Contains(ip) {
|
||||
return true, ipnet.String()
|
||||
}
|
||||
}
|
||||
err = processAccessClients(blocked, a.blockedIPs, &a.blockedNets, a.blockedClientIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding blocked: %w", err)
|
||||
}
|
||||
|
||||
return false, ""
|
||||
b := &strings.Builder{}
|
||||
for _, h := range blockedHosts {
|
||||
aghstrings.WriteToBuilder(b, strings.ToLower(h), "\n")
|
||||
}
|
||||
|
||||
lists := []filterlist.RuleList{
|
||||
&filterlist.StringRuleList{
|
||||
ID: int(0),
|
||||
RulesText: b.String(),
|
||||
IgnoreCosmetic: true,
|
||||
},
|
||||
}
|
||||
|
||||
rulesStrg, err := filterlist.NewRuleStorage(lists)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("adding blocked hosts: %w", err)
|
||||
}
|
||||
|
||||
a.blockedHostsEng = urlfilter.NewDNSEngine(rulesStrg)
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// IsBlockedDomain - return TRUE if this domain should be blocked
|
||||
func (a *accessCtx) IsBlockedDomain(host string) (ok bool) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
// allowlistMode returns true if this *accessCtx is in the allowlist mode.
|
||||
func (a *accessCtx) allowlistMode() (ok bool) {
|
||||
return a.allowedIPs.Len() != 0 || a.allowedClientIDs.Len() != 0 || len(a.allowedNets) != 0
|
||||
}
|
||||
|
||||
_, ok = a.blockedHostsEngine.Match(strings.ToLower(host))
|
||||
// isBlockedClientID returns true if the ClientID should be blocked.
|
||||
func (a *accessCtx) isBlockedClientID(id string) (ok bool) {
|
||||
allowlistMode := a.allowlistMode()
|
||||
if id == "" {
|
||||
// In allowlist mode, consider requests without client IDs
|
||||
// blocked by default.
|
||||
return allowlistMode
|
||||
}
|
||||
|
||||
if allowlistMode {
|
||||
return !a.allowedClientIDs.Has(id)
|
||||
}
|
||||
|
||||
return a.blockedClientIDs.Has(id)
|
||||
}
|
||||
|
||||
// isBlockedHost returns true if host should be blocked.
|
||||
func (a *accessCtx) isBlockedHost(host string) (ok bool) {
|
||||
_, ok = a.blockedHostsEng.Match(strings.ToLower(host))
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// isBlockedIP returns the status of the IP address blocking as well as the rule
|
||||
// that blocked it.
|
||||
func (a *accessCtx) isBlockedIP(ip net.IP) (blocked bool, rule string) {
|
||||
blocked = true
|
||||
ips := a.blockedIPs
|
||||
ipnets := a.blockedNets
|
||||
|
||||
if a.allowlistMode() {
|
||||
// Enable allowlist mode and use the allowlist sets.
|
||||
blocked = false
|
||||
ips = a.allowedIPs
|
||||
ipnets = a.allowedNets
|
||||
}
|
||||
|
||||
if _, ok := ips.Get(ip); ok {
|
||||
return blocked, ip.String()
|
||||
}
|
||||
|
||||
for _, ipnet := range ipnets {
|
||||
if ipnet.Contains(ip) {
|
||||
return blocked, ipnet.String()
|
||||
}
|
||||
}
|
||||
|
||||
return !blocked, ""
|
||||
}
|
||||
|
||||
type accessListJSON struct {
|
||||
AllowedClients []string `json:"allowed_clients"`
|
||||
DisallowedClients []string `json:"disallowed_clients"`
|
||||
@ -161,62 +186,43 @@ func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err := json.NewEncoder(w).Encode(j)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusInternalServerError, "json.Encode: %s", err)
|
||||
httpError(r, w, http.StatusInternalServerError, "encoding response: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func checkIPCIDRArray(src []string) error {
|
||||
for _, s := range src {
|
||||
ip := net.ParseIP(s)
|
||||
if ip != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
_, _, err := net.ParseCIDR(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
j := accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&j)
|
||||
list := accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
return
|
||||
}
|
||||
httpError(r, w, http.StatusBadRequest, "decoding request: %s", err)
|
||||
|
||||
err = checkIPCIDRArray(j.AllowedClients)
|
||||
if err == nil {
|
||||
err = checkIPCIDRArray(j.DisallowedClients)
|
||||
}
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "%s", err)
|
||||
return
|
||||
}
|
||||
|
||||
var a *accessCtx
|
||||
a, err = newAccessCtx(j.AllowedClients, j.DisallowedClients, j.BlockedHosts)
|
||||
a, err = newAccessCtx(list.AllowedClients, list.DisallowedClients, list.BlockedHosts)
|
||||
if err != nil {
|
||||
httpError(r, w, http.StatusBadRequest, "creating access ctx: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer log.Debug("Access: updated lists: %d, %d, %d",
|
||||
len(j.AllowedClients), len(j.DisallowedClients), len(j.BlockedHosts))
|
||||
defer log.Debug(
|
||||
"access: updated lists: %d, %d, %d",
|
||||
len(list.AllowedClients),
|
||||
len(list.DisallowedClients),
|
||||
len(list.BlockedHosts),
|
||||
)
|
||||
|
||||
defer s.conf.ConfigModified()
|
||||
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.conf.AllowedClients = j.AllowedClients
|
||||
s.conf.DisallowedClients = j.DisallowedClients
|
||||
s.conf.BlockedHosts = j.BlockedHosts
|
||||
s.conf.AllowedClients = list.AllowedClients
|
||||
s.conf.DisallowedClients = list.DisallowedClients
|
||||
s.conf.BlockedHosts = list.BlockedHosts
|
||||
s.access = a
|
||||
}
|
||||
|
@ -8,99 +8,23 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
const (
|
||||
ip int = iota
|
||||
cidr
|
||||
)
|
||||
func TestIsBlockedClientID(t *testing.T) {
|
||||
clientID := "client-1"
|
||||
clients := []string{clientID}
|
||||
|
||||
rules := []string{
|
||||
ip: "1.1.1.1",
|
||||
cidr: "2.2.0.0/16",
|
||||
}
|
||||
a, err := newAccessCtx(clients, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
allowed bool
|
||||
ip net.IP
|
||||
wantDis bool
|
||||
wantRule string
|
||||
}{{
|
||||
name: "allow_ip",
|
||||
allowed: true,
|
||||
ip: net.IPv4(1, 1, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "disallow_ip",
|
||||
allowed: true,
|
||||
ip: net.IPv4(1, 1, 1, 2),
|
||||
wantDis: true,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_cidr",
|
||||
allowed: true,
|
||||
ip: net.IPv4(2, 2, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "disallow_cidr",
|
||||
allowed: true,
|
||||
ip: net.IPv4(2, 3, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_ip",
|
||||
allowed: false,
|
||||
ip: net.IPv4(1, 1, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: rules[ip],
|
||||
}, {
|
||||
name: "disallow_ip",
|
||||
allowed: false,
|
||||
ip: net.IPv4(1, 1, 1, 2),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}, {
|
||||
name: "allow_cidr",
|
||||
allowed: false,
|
||||
ip: net.IPv4(2, 2, 1, 1),
|
||||
wantDis: true,
|
||||
wantRule: rules[cidr],
|
||||
}, {
|
||||
name: "disallow_cidr",
|
||||
allowed: false,
|
||||
ip: net.IPv4(2, 3, 1, 1),
|
||||
wantDis: false,
|
||||
wantRule: "",
|
||||
}}
|
||||
assert.False(t, a.isBlockedClientID(clientID))
|
||||
|
||||
for _, tc := range testCases {
|
||||
prefix := "allowed_"
|
||||
if !tc.allowed {
|
||||
prefix = "disallowed_"
|
||||
}
|
||||
a, err = newAccessCtx(nil, clients, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run(prefix+tc.name, func(t *testing.T) {
|
||||
allowedRules := rules
|
||||
var disallowedRules []string
|
||||
|
||||
if !tc.allowed {
|
||||
allowedRules, disallowedRules = disallowedRules, allowedRules
|
||||
}
|
||||
|
||||
aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
disallowed, rule := aCtx.IsBlockedIP(tc.ip)
|
||||
assert.Equal(t, tc.wantDis, disallowed)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
})
|
||||
}
|
||||
assert.True(t, a.isBlockedClientID(clientID))
|
||||
}
|
||||
|
||||
func TestIsBlockedDomain(t *testing.T) {
|
||||
aCtx, err := newAccessCtx(nil, nil, []string{
|
||||
func TestIsBlockedHost(t *testing.T) {
|
||||
a, err := newAccessCtx(nil, nil, []string{
|
||||
"host1",
|
||||
"*.host.com",
|
||||
"||host3.com^",
|
||||
@ -108,50 +32,106 @@ func TestIsBlockedDomain(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
domain string
|
||||
want bool
|
||||
name string
|
||||
host string
|
||||
want bool
|
||||
}{{
|
||||
name: "plain_match",
|
||||
domain: "host1",
|
||||
want: true,
|
||||
name: "plain_match",
|
||||
host: "host1",
|
||||
want: true,
|
||||
}, {
|
||||
name: "plain_mismatch",
|
||||
domain: "host2",
|
||||
want: false,
|
||||
name: "plain_mismatch",
|
||||
host: "host2",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-1_match_short",
|
||||
domain: "asdf.host.com",
|
||||
want: true,
|
||||
name: "subdomain_match_short",
|
||||
host: "asdf.host.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-1_match_long",
|
||||
domain: "qwer.asdf.host.com",
|
||||
want: true,
|
||||
name: "subdomain_match_long",
|
||||
host: "qwer.asdf.host.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-1_mismatch_no-lead",
|
||||
domain: "host.com",
|
||||
want: false,
|
||||
name: "subdomain_mismatch_no_lead",
|
||||
host: "host.com",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-1_mismatch_bad-asterisk",
|
||||
domain: "asdf.zhost.com",
|
||||
want: false,
|
||||
name: "subdomain_mismatch_bad_asterisk",
|
||||
host: "asdf.zhost.com",
|
||||
want: false,
|
||||
}, {
|
||||
name: "wildcard_type-2_match_simple",
|
||||
domain: "host3.com",
|
||||
want: true,
|
||||
name: "rule_match_simple",
|
||||
host: "host3.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-2_match_complex",
|
||||
domain: "asdf.host3.com",
|
||||
want: true,
|
||||
name: "rule_match_complex",
|
||||
host: "asdf.host3.com",
|
||||
want: true,
|
||||
}, {
|
||||
name: "wildcard_type-2_mismatch",
|
||||
domain: ".host3.com",
|
||||
want: false,
|
||||
name: "rule_mismatch",
|
||||
host: ".host3.com",
|
||||
want: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain))
|
||||
assert.Equal(t, tc.want, a.isBlockedHost(tc.host))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBlockedIP(t *testing.T) {
|
||||
clients := []string{
|
||||
"1.2.3.4",
|
||||
"5.6.7.8/24",
|
||||
}
|
||||
|
||||
allowCtx, err := newAccessCtx(clients, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
blockCtx, err := newAccessCtx(nil, clients, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantRule string
|
||||
ip net.IP
|
||||
wantBlocked bool
|
||||
}{{
|
||||
name: "match_ip",
|
||||
wantRule: "1.2.3.4",
|
||||
ip: net.IP{1, 2, 3, 4},
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "match_cidr",
|
||||
wantRule: "5.6.7.8/24",
|
||||
ip: net.IP{5, 6, 7, 100},
|
||||
wantBlocked: true,
|
||||
}, {
|
||||
name: "no_match_ip",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 2, 3, 4},
|
||||
wantBlocked: false,
|
||||
}, {
|
||||
name: "no_match_cidr",
|
||||
wantRule: "",
|
||||
ip: net.IP{9, 6, 7, 100},
|
||||
wantBlocked: false,
|
||||
}}
|
||||
|
||||
t.Run("allow", func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
blocked, rule := allowCtx.isBlockedIP(tc.ip)
|
||||
assert.Equal(t, !tc.wantBlocked, blocked)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("block", func(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
blocked, rule := blockCtx.isBlockedIP(tc.ip)
|
||||
assert.Equal(t, tc.wantBlocked, blocked)
|
||||
assert.Equal(t, tc.wantRule, rule)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
@ -50,15 +51,15 @@ func clientIDFromClientServerName(hostSrvName, cliSrvName string, strict bool) (
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// processClientIDHTTPS extracts the client's ID from the path of the
|
||||
// clientIDFromDNSContextHTTPS extracts the client's ID from the path of the
|
||||
// client's DNS-over-HTTPS request.
|
||||
func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||
pctx := ctx.proxyCtx
|
||||
func clientIDFromDNSContextHTTPS(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
r := pctx.HTTPRequest
|
||||
if r == nil {
|
||||
ctx.err = fmt.Errorf("proxy ctx http request of proto %s is nil", pctx.Proto)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx http request of proto %s is nil",
|
||||
pctx.Proto,
|
||||
)
|
||||
}
|
||||
|
||||
origPath := r.URL.Path
|
||||
@ -68,34 +69,25 @@ func processClientIDHTTPS(ctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if len(parts) == 0 || parts[0] != "dns-query" {
|
||||
ctx.err = fmt.Errorf("client id check: invalid path %q", origPath)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: invalid path %q", origPath)
|
||||
}
|
||||
|
||||
clientID := ""
|
||||
switch len(parts) {
|
||||
case 1:
|
||||
// Just /dns-query, no client ID.
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
case 2:
|
||||
clientID = parts[1]
|
||||
default:
|
||||
ctx.err = fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: invalid path %q: extra parts", origPath)
|
||||
}
|
||||
|
||||
err := ValidateClientID(clientID)
|
||||
err = ValidateClientID(clientID)
|
||||
if err != nil {
|
||||
ctx.err = fmt.Errorf("client id check: %w", err)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: %w", err)
|
||||
}
|
||||
|
||||
ctx.clientID = clientID
|
||||
|
||||
return resultCodeSuccess
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// tlsConn is a narrow interface for *tls.Conn to simplify testing.
|
||||
@ -108,53 +100,73 @@ type quicSession interface {
|
||||
ConnectionState() (cs quic.ConnectionState)
|
||||
}
|
||||
|
||||
// processClientID extracts the client's ID from the server name of the client's
|
||||
// DoT or DoQ request or the path of the client's DoH.
|
||||
func processClientID(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
// clientIDFromDNSContext extracts the client's ID from the server name of the
|
||||
// client's DoT or DoQ request or the path of the client's DoH. If the protocol
|
||||
// is not one of these, clientID is an empty string and err is nil.
|
||||
func (s *Server) clientIDFromDNSContext(pctx *proxy.DNSContext) (clientID string, err error) {
|
||||
proto := pctx.Proto
|
||||
if proto == proxy.ProtoHTTPS {
|
||||
return processClientIDHTTPS(dctx)
|
||||
return clientIDFromDNSContextHTTPS(pctx)
|
||||
} else if proto != proxy.ProtoTLS && proto != proxy.ProtoQUIC {
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
}
|
||||
|
||||
srvConf := dctx.srv.conf
|
||||
hostSrvName := srvConf.TLSConfig.ServerName
|
||||
hostSrvName := s.conf.ServerName
|
||||
if hostSrvName == "" {
|
||||
return resultCodeSuccess
|
||||
return "", nil
|
||||
}
|
||||
|
||||
cliSrvName := ""
|
||||
if proto == proxy.ProtoTLS {
|
||||
switch proto {
|
||||
case proxy.ProtoTLS:
|
||||
conn := pctx.Conn
|
||||
tc, ok := conn.(tlsConn)
|
||||
if !ok {
|
||||
dctx.err = fmt.Errorf("proxy ctx conn of proto %s is %T, want *tls.Conn", proto, conn)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx conn of proto %s is %T, want *tls.Conn",
|
||||
proto,
|
||||
conn,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = tc.ConnectionState().ServerName
|
||||
} else if proto == proxy.ProtoQUIC {
|
||||
case proxy.ProtoQUIC:
|
||||
qs, ok := pctx.QUICSession.(quicSession)
|
||||
if !ok {
|
||||
dctx.err = fmt.Errorf("proxy ctx quic session of proto %s is %T, want quic.Session", proto, pctx.QUICSession)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf(
|
||||
"proxy ctx quic session of proto %s is %T, want quic.Session",
|
||||
proto,
|
||||
pctx.QUICSession,
|
||||
)
|
||||
}
|
||||
|
||||
cliSrvName = qs.ConnectionState().TLS.ServerName
|
||||
}
|
||||
|
||||
clientID, err := clientIDFromClientServerName(hostSrvName, cliSrvName, srvConf.StrictSNICheck)
|
||||
clientID, err = clientIDFromClientServerName(
|
||||
hostSrvName,
|
||||
cliSrvName,
|
||||
s.conf.StrictSNICheck,
|
||||
)
|
||||
if err != nil {
|
||||
dctx.err = fmt.Errorf("client id check: %w", err)
|
||||
|
||||
return resultCodeError
|
||||
return "", fmt.Errorf("client id check: %w", err)
|
||||
}
|
||||
|
||||
dctx.clientID = clientID
|
||||
return clientID, nil
|
||||
}
|
||||
|
||||
// processClientID puts the clientID into the DNS context, if there is one.
|
||||
func (s *Server) processClientID(dctx *dnsContext) (rc resultCode) {
|
||||
pctx := dctx.proxyCtx
|
||||
|
||||
var key [8]byte
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
clientIDData := s.clientIDCache.Get(key[:])
|
||||
if clientIDData == nil {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
dctx.clientID = string(clientIDData)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
@ -45,15 +45,14 @@ func (c testQUICSession) ConnectionState() (cs quic.ConnectionState) {
|
||||
return cs
|
||||
}
|
||||
|
||||
func TestProcessClientID(t *testing.T) {
|
||||
func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
proto proxy.Proto
|
||||
hostSrvName string
|
||||
cliSrvName string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
wantRes resultCode
|
||||
strictSNI bool
|
||||
}{{
|
||||
name: "udp",
|
||||
@ -62,7 +61,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_no_client_id",
|
||||
@ -71,7 +69,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "example.com",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_no_client_server_name",
|
||||
@ -81,7 +78,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_no_client_server_name_no_strict",
|
||||
@ -90,7 +86,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: false,
|
||||
}, {
|
||||
name: "tls_client_id",
|
||||
@ -99,7 +94,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_hostname_error",
|
||||
@ -109,7 +103,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: client server name "cli.example.net" ` +
|
||||
`doesn't match host server name "example.com"`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_invalid_client_id",
|
||||
@ -119,7 +112,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`invalid char '!' at index 0`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "tls_client_id_too_long",
|
||||
@ -131,7 +123,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
wantErrMsg: `client id check: invalid client id "abcdefghijklmno` +
|
||||
`pqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789": ` +
|
||||
`label is too long, max: 63`,
|
||||
wantRes: resultCodeError,
|
||||
strictSNI: true,
|
||||
}, {
|
||||
name: "quic_client_id",
|
||||
@ -140,7 +131,6 @@ func TestProcessClientID(t *testing.T) {
|
||||
cliSrvName: "cli.example.com",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
strictSNI: true,
|
||||
}}
|
||||
|
||||
@ -150,6 +140,7 @@ func TestProcessClientID(t *testing.T) {
|
||||
ServerName: tc.hostSrvName,
|
||||
StrictSNICheck: tc.strictSNI,
|
||||
}
|
||||
|
||||
srv := &Server{
|
||||
conf: ServerConfig{TLSConfig: tlsConf},
|
||||
}
|
||||
@ -168,79 +159,68 @@ func TestProcessClientID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
srv: srv,
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
},
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: tc.proto,
|
||||
Conn: conn,
|
||||
QUICSession: qs,
|
||||
}
|
||||
|
||||
res := processClientID(dctx)
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||
clientID, err := srv.clientIDFromDNSContext(pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, dctx.err)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, dctx.err)
|
||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessClientID_https(t *testing.T) {
|
||||
func TestClientIDFromDNSContextHTTPS(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
path string
|
||||
wantClientID string
|
||||
wantErrMsg string
|
||||
wantRes resultCode
|
||||
}{{
|
||||
name: "no_client_id",
|
||||
path: "/dns-query",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "no_client_id_slash",
|
||||
path: "/dns-query/",
|
||||
wantClientID: "",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "client_id",
|
||||
path: "/dns-query/cli",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "client_id_slash",
|
||||
path: "/dns-query/cli/",
|
||||
wantClientID: "cli",
|
||||
wantErrMsg: "",
|
||||
wantRes: resultCodeSuccess,
|
||||
}, {
|
||||
name: "bad_url",
|
||||
path: "/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/foo"`,
|
||||
wantRes: resultCodeError,
|
||||
}, {
|
||||
name: "extra",
|
||||
path: "/dns-query/cli/foo",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid path "/dns-query/cli/foo": extra parts`,
|
||||
wantRes: resultCodeError,
|
||||
}, {
|
||||
name: "invalid_client_id",
|
||||
path: "/dns-query/!!!",
|
||||
wantClientID: "",
|
||||
wantErrMsg: `client id check: invalid client id "!!!": ` +
|
||||
`invalid char '!' at index 0`,
|
||||
wantRes: resultCodeError,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@ -251,23 +231,20 @@ func TestProcessClientID_https(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: &proxy.DNSContext{
|
||||
Proto: proxy.ProtoHTTPS,
|
||||
HTTPRequest: r,
|
||||
},
|
||||
pctx := &proxy.DNSContext{
|
||||
Proto: proxy.ProtoHTTPS,
|
||||
HTTPRequest: r,
|
||||
}
|
||||
|
||||
res := processClientID(dctx)
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
assert.Equal(t, tc.wantClientID, dctx.clientID)
|
||||
clientID, err := clientIDFromDNSContextHTTPS(pctx)
|
||||
assert.Equal(t, tc.wantClientID, clientID)
|
||||
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.NoError(t, dctx.err)
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, dctx.err)
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, tc.wantErrMsg, dctx.err.Error())
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -331,7 +331,7 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
upstreams = aghstrings.FilterOut(upstreams, aghstrings.IsCommentOrEmpty)
|
||||
upstreamConfig, err := proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
},
|
||||
@ -342,10 +342,10 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
|
||||
if len(upstreamConfig.Upstreams) == 0 {
|
||||
log.Info("warning: no default upstream servers specified, using %v", defaultDNS)
|
||||
var uc proxy.UpstreamConfig
|
||||
var uc *proxy.UpstreamConfig
|
||||
uc, err = proxy.ParseUpstreamsConfig(
|
||||
defaultDNS,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: s.conf.BootstrapDNS,
|
||||
Timeout: s.conf.UpstreamTimeout,
|
||||
},
|
||||
@ -356,7 +356,8 @@ func (s *Server) prepareUpstreamSettings() error {
|
||||
upstreamConfig.Upstreams = uc.Upstreams
|
||||
}
|
||||
|
||||
s.conf.UpstreamConfig = &upstreamConfig
|
||||
s.conf.UpstreamConfig = upstreamConfig
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -89,7 +89,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, d *proxy.DNSContext) error {
|
||||
s.processInternalHosts,
|
||||
s.processRestrictLocal,
|
||||
s.processInternalIPAddrs,
|
||||
processClientID,
|
||||
s.processClientID,
|
||||
processFilteringBeforeRequest,
|
||||
s.processLocalPTR,
|
||||
s.processUpstream,
|
||||
@ -165,7 +165,7 @@ func (s *Server) setTableHostToIP(t hostToIPTable) {
|
||||
s.tableHostToIP = t
|
||||
}
|
||||
|
||||
func (s *Server) setTableIPToHost(t ipToHostTable) {
|
||||
func (s *Server) setTableIPToHost(t *aghnet.IPMap) {
|
||||
s.tableIPToHostLock.Lock()
|
||||
defer s.tableIPToHostLock.Unlock()
|
||||
|
||||
@ -188,13 +188,13 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
}
|
||||
|
||||
var hostToIP hostToIPTable
|
||||
var ipToHost ipToHostTable
|
||||
var ipToHost *aghnet.IPMap
|
||||
if add {
|
||||
hostToIP = make(hostToIPTable)
|
||||
ipToHost = make(ipToHostTable)
|
||||
|
||||
ll := s.dhcpServer.Leases(dhcpd.LeasesAll)
|
||||
|
||||
hostToIP = make(hostToIPTable, len(ll))
|
||||
ipToHost = aghnet.NewIPMap(len(ll))
|
||||
|
||||
for _, l := range ll {
|
||||
// TODO(a.garipov): Remove this after we're finished
|
||||
// with the client hostname validations in the DHCP
|
||||
@ -210,14 +210,14 @@ func (s *Server) onDHCPLeaseChanged(flags int) {
|
||||
|
||||
lowhost := strings.ToLower(l.Hostname)
|
||||
|
||||
ipToHost[l.IP.String()] = lowhost
|
||||
ipToHost.Set(l.IP, lowhost)
|
||||
|
||||
ip := make(net.IP, 4)
|
||||
copy(ip, l.IP.To4())
|
||||
hostToIP[lowhost] = ip
|
||||
}
|
||||
|
||||
log.Debug("dns: added %d A/PTR entries from DHCP", len(ipToHost))
|
||||
log.Debug("dns: added %d A/PTR entries from DHCP", ipToHost.Len())
|
||||
}
|
||||
|
||||
s.setTableHostToIP(hostToIP)
|
||||
@ -377,7 +377,15 @@ func (s *Server) ipToHost(ip net.IP) (host string, ok bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
host, ok = s.tableIPToHost[ip.String()]
|
||||
var v interface{}
|
||||
v, ok = s.tableIPToHost.Get(ip)
|
||||
|
||||
var typOK bool
|
||||
if host, typOK = v.(string); !typOK {
|
||||
log.Error("dns: bad type %T in tableIPToHost for %s", v, ip)
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
return host, ok
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
@ -26,6 +27,11 @@ import (
|
||||
// DefaultTimeout is the default upstream timeout
|
||||
const DefaultTimeout = 10 * time.Second
|
||||
|
||||
// defaultClientIDCacheCount is the default count of items in the LRU client ID
|
||||
// cache. The assumption here is that there won't be more than this many
|
||||
// requests between the BeforeRequestHandler stage and the actual processing.
|
||||
const defaultClientIDCacheCount = 1024
|
||||
|
||||
const (
|
||||
safeBrowsingBlockHost = "standard-block.dns.adguard.com"
|
||||
parentalBlockHost = "family-block.dns.adguard.com"
|
||||
@ -44,12 +50,6 @@ var webRegistered bool
|
||||
// hostToIPTable is an alias for the type of Server.tableHostToIP.
|
||||
type hostToIPTable = map[string]net.IP
|
||||
|
||||
// ipToHostTable is an alias for the type of Server.tableIPToHost.
|
||||
//
|
||||
// TODO(a.garipov): Define an IPMap type in aghnet and use here and in other
|
||||
// places?
|
||||
type ipToHostTable = map[string]string
|
||||
|
||||
// Server is the main way to start a DNS server.
|
||||
//
|
||||
// Example:
|
||||
@ -81,9 +81,13 @@ type Server struct {
|
||||
tableHostToIP hostToIPTable
|
||||
tableHostToIPLock sync.Mutex
|
||||
|
||||
tableIPToHost ipToHostTable
|
||||
tableIPToHost *aghnet.IPMap
|
||||
tableIPToHostLock sync.Mutex
|
||||
|
||||
// clientIDCache is a temporary storage for clientIDs that were
|
||||
// extracted during the BeforeRequestHandler stage.
|
||||
clientIDCache cache.Cache
|
||||
|
||||
// DNS proxy instance for internal usage
|
||||
// We don't Start() it and so no listen port is required.
|
||||
internalProxy *proxy.Proxy
|
||||
@ -152,6 +156,10 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
subnetDetector: p.SubnetDetector,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
recDetector: newRecursionDetector(recursionTTL, cachedRecurrentReqNum),
|
||||
clientIDCache: cache.New(cache.Config{
|
||||
EnableLRU: true,
|
||||
MaxCount: defaultClientIDCacheCount,
|
||||
}),
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Enable the refresher after the actual implementation
|
||||
@ -414,19 +422,22 @@ func (s *Server) setupResolvers(localAddrs []string) (err error) {
|
||||
|
||||
log.Debug("upstreams to resolve PTR for local addresses: %v", localAddrs)
|
||||
|
||||
var upsConfig proxy.UpstreamConfig
|
||||
upsConfig, err = proxy.ParseUpstreamsConfig(localAddrs, upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's ceritificates?
|
||||
})
|
||||
var upsConfig *proxy.UpstreamConfig
|
||||
upsConfig, err = proxy.ParseUpstreamsConfig(
|
||||
localAddrs,
|
||||
&upstream.Options{
|
||||
Bootstrap: bootstraps,
|
||||
Timeout: defaultLocalTimeout,
|
||||
// TODO(e.burkov): Should we verify server's ceritificates?
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parsing upstreams: %w", err)
|
||||
}
|
||||
|
||||
s.localResolvers = &proxy.Proxy{
|
||||
Config: proxy.Config{
|
||||
UpstreamConfig: &upsConfig,
|
||||
UpstreamConfig: upsConfig,
|
||||
},
|
||||
}
|
||||
|
||||
@ -577,11 +588,33 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// IsBlockedIP - return TRUE if this client should be blocked
|
||||
func (s *Server) IsBlockedIP(ip net.IP) (bool, string) {
|
||||
if ip == nil {
|
||||
return false, ""
|
||||
// IsBlockedClient returns true if the client is blocked by the current access
|
||||
// settings.
|
||||
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
allowlistMode := s.access.allowlistMode()
|
||||
blockedByIP, rule := s.access.isBlockedIP(ip)
|
||||
blockedByClientID := s.access.isBlockedClientID(clientID)
|
||||
|
||||
// Allow if at least one of the checks allows in allowlist mode, but
|
||||
// block if at least one of the checks blocks in blocklist mode.
|
||||
if allowlistMode && blockedByIP && blockedByClientID {
|
||||
log.Debug("client %s (id %q) is not in access allowlist", ip, clientID)
|
||||
|
||||
// Return now without substituting the empty rule for the
|
||||
// clientID because the rule can't be empty here.
|
||||
return true, rule
|
||||
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||
log.Debug("client %s (id %q) is in access blocklist", ip, clientID)
|
||||
|
||||
blocked = true
|
||||
}
|
||||
|
||||
return s.access.IsBlockedIP(ip)
|
||||
if rule == "" {
|
||||
rule = clientID
|
||||
}
|
||||
|
||||
return blocked, rule
|
||||
}
|
||||
|
@ -257,19 +257,22 @@ func TestServer(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
net string
|
||||
proto proxy.Proto
|
||||
}{{
|
||||
name: "message_over_udp",
|
||||
net: "",
|
||||
proto: proxy.ProtoUDP,
|
||||
}, {
|
||||
name: "message_over_tcp",
|
||||
net: "tcp",
|
||||
proto: proxy.ProtoTCP,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(tc.proto)
|
||||
client := dns.Client{Net: tc.proto}
|
||||
client := dns.Client{Net: tc.net}
|
||||
|
||||
reply, _, err := client.Exchange(createGoogleATestMessage(), addr.String())
|
||||
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
@ -324,7 +327,7 @@ func TestServerWithProtectionDisabled(t *testing.T) {
|
||||
// Message over UDP.
|
||||
req := createGoogleATestMessage()
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
client := dns.Client{Net: proxy.ProtoUDP}
|
||||
client := &dns.Client{}
|
||||
|
||||
reply, _, err := client.Exchange(req, addr.String())
|
||||
require.NoErrorf(t, err, "сouldn't talk to server %s: %s", addr, err)
|
||||
@ -376,7 +379,7 @@ func TestDoQServer(t *testing.T) {
|
||||
|
||||
// Create a DNS-over-QUIC upstream.
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoQUIC)
|
||||
opts := upstream.Options{InsecureSkipVerify: true}
|
||||
opts := &upstream.Options{InsecureSkipVerify: true}
|
||||
u, err := upstream.AddressToUpstream(fmt.Sprintf("%s://%s", proxy.ProtoQUIC, addr), opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -420,7 +423,7 @@ func TestServerRace(t *testing.T) {
|
||||
|
||||
// Message over UDP.
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
conn, err := dns.Dial(proxy.ProtoUDP, addr.String())
|
||||
conn, err := dns.Dial("udp", addr.String())
|
||||
require.NoErrorf(t, err, "cannot connect to the proxy: %s", err)
|
||||
|
||||
sendTestMessagesAsync(t, conn)
|
||||
@ -445,7 +448,7 @@ func TestSafeSearch(t *testing.T) {
|
||||
startDeferStop(t, s)
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP).String()
|
||||
client := dns.Client{Net: proxy.ProtoUDP}
|
||||
client := &dns.Client{}
|
||||
|
||||
yandexIP := net.IP{213, 180, 193, 56}
|
||||
googleIP, _ := resolver.HostToIPs("forcesafesearch.google.com")
|
||||
@ -507,7 +510,6 @@ func TestInvalidRequest(t *testing.T) {
|
||||
|
||||
// Send a DNS request without question.
|
||||
_, _, err := (&dns.Client{
|
||||
Net: proxy.ProtoUDP,
|
||||
Timeout: 500 * time.Millisecond,
|
||||
}).Exchange(&req, addr)
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -11,23 +12,39 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func (s *Server) beforeRequestHandler(_ *proxy.Proxy, d *proxy.DNSContext) (bool, error) {
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
disallowed, _ := s.access.IsBlockedIP(ip)
|
||||
if disallowed {
|
||||
log.Tracef("Client IP %s is blocked by settings", ip)
|
||||
// beforeRequestHandler is the handler that is called before any other
|
||||
// processing, including logs. It performs access checks and puts the client
|
||||
// ID, if there is one, into the server's cache.
|
||||
func (s *Server) beforeRequestHandler(
|
||||
_ *proxy.Proxy,
|
||||
pctx *proxy.DNSContext,
|
||||
) (reply bool, err error) {
|
||||
ip := aghnet.IPFromAddr(pctx.Addr)
|
||||
clientID, err := s.clientIDFromDNSContext(pctx)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("getting clientid: %w", err)
|
||||
}
|
||||
|
||||
blocked, _ := s.IsBlockedClient(ip, clientID)
|
||||
if blocked {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if len(d.Req.Question) == 1 {
|
||||
host := strings.TrimSuffix(d.Req.Question[0].Name, ".")
|
||||
if s.access.IsBlockedDomain(host) {
|
||||
log.Tracef("domain %s is blocked by access settings", host)
|
||||
if len(pctx.Req.Question) == 1 {
|
||||
host := strings.TrimSuffix(pctx.Req.Question[0].Name, ".")
|
||||
if s.access.isBlockedHost(host) {
|
||||
log.Debug("host %s is in access blocklist", host)
|
||||
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if clientID != "" {
|
||||
key := [8]byte{}
|
||||
binary.BigEndian.PutUint64(key[:], pctx.RequestID)
|
||||
s.clientIDCache.Set(key[:], []byte(clientID))
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
@ -167,7 +167,7 @@ func (req *dnsConfig) checkBootstrap() (string, error) {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: empty")
|
||||
}
|
||||
|
||||
if _, err := upstream.NewResolver(boot, upstream.Options{Timeout: 0}); err != nil {
|
||||
if _, err := upstream.NewResolver(boot, nil); err != nil {
|
||||
return boot, fmt.Errorf("invalid bootstrap server address: %w", err)
|
||||
}
|
||||
}
|
||||
@ -348,7 +348,7 @@ func ValidateUpstreams(upstreams []string) (err error) {
|
||||
|
||||
_, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: []string{},
|
||||
Timeout: DefaultTimeout,
|
||||
},
|
||||
@ -546,7 +546,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
|
||||
|
||||
log.Debug("checking if dns server %q works...", input)
|
||||
var u upstream.Upstream
|
||||
u, err = upstream.AddressToUpstream(input, upstream.Options{
|
||||
u, err = upstream.AddressToUpstream(input, &upstream.Options{
|
||||
Bootstrap: bootstrap,
|
||||
Timeout: timeout,
|
||||
})
|
||||
|
@ -46,7 +46,7 @@ func (l *testStats) Update(e stats.Entry) {
|
||||
func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
proto string
|
||||
proto proxy.Proto
|
||||
addr net.Addr
|
||||
clientID string
|
||||
wantLogProto querylog.ClientProto
|
||||
@ -156,7 +156,7 @@ func TestProcessQueryLogsAndStats(t *testing.T) {
|
||||
wantStatResult: stats.RParental,
|
||||
}}
|
||||
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", upstream.Options{})
|
||||
ups, err := upstream.AddressToUpstream("1.1.1.1", nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
@ -49,7 +49,7 @@ func (d *DNSFilter) initSecurityServices() error {
|
||||
var err error
|
||||
d.safeBrowsingServer = defaultSafebrowsingServer
|
||||
d.parentalServer = defaultParentalServer
|
||||
opts := upstream.Options{
|
||||
opts := &upstream.Options{
|
||||
Timeout: dnsTimeout,
|
||||
ServerIPAddrs: []net.IP{
|
||||
{94, 140, 14, 15},
|
||||
|
@ -78,10 +78,13 @@ type RuntimeClientWHOISInfo struct {
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for
|
||||
// different types (string, net.IP, and so on).
|
||||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // ID -> client
|
||||
ipToRC map[string]*RuntimeClient // IP -> runtime client
|
||||
lock sync.Mutex
|
||||
list map[string]*Client // name -> client
|
||||
idIndex map[string]*Client // ID -> client
|
||||
|
||||
// ipToRC is the IP address to *RuntimeClient map.
|
||||
ipToRC *aghnet.IPMap
|
||||
|
||||
lock sync.Mutex
|
||||
|
||||
allTags *aghstrings.Set
|
||||
|
||||
@ -109,7 +112,7 @@ func (clients *clientsContainer) Init(
|
||||
}
|
||||
clients.list = make(map[string]*Client)
|
||||
clients.idIndex = make(map[string]*Client)
|
||||
clients.ipToRC = make(map[string]*RuntimeClient)
|
||||
clients.ipToRC = aghnet.NewIPMap(0)
|
||||
|
||||
clients.allTags = aghstrings.NewSet(clientTags...)
|
||||
|
||||
@ -250,18 +253,17 @@ func (clients *clientsContainer) onHostsChanged() {
|
||||
clients.addFromHostsFile()
|
||||
}
|
||||
|
||||
// Exists checks if client with this ID already exists.
|
||||
func (clients *clientsContainer) Exists(id string, source clientSource) (ok bool) {
|
||||
// Exists checks if client with this IP address already exists.
|
||||
func (clients *clientsContainer) Exists(ip net.IP, source clientSource) (ok bool) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok = clients.findLocked(id)
|
||||
_, ok = clients.findLocked(ip.String())
|
||||
if ok {
|
||||
return true
|
||||
}
|
||||
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.ipToRC[id]
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@ -288,13 +290,14 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||
for _, id := range ids {
|
||||
var name string
|
||||
whois := &querylog.ClientWHOIS{}
|
||||
ip := net.ParseIP(id)
|
||||
|
||||
c, ok := clients.Find(id)
|
||||
if ok {
|
||||
name = c.Name
|
||||
} else {
|
||||
var rc RuntimeClient
|
||||
rc, ok = clients.FindRuntimeClient(id)
|
||||
} else if ip != nil {
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.FindRuntimeClient(ip)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
@ -303,8 +306,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
|
||||
whois = toQueryLogWHOIS(rc.WHOISInfo)
|
||||
}
|
||||
|
||||
ip := net.ParseIP(id)
|
||||
disallowed, disallowedRule := clients.dnsServer.IsBlockedIP(ip)
|
||||
disallowed, disallowedRule := clients.dnsServer.IsBlockedClient(ip, id)
|
||||
|
||||
return &querylog.Client{
|
||||
Name: name,
|
||||
@ -356,10 +358,10 @@ func (clients *clientsContainer) findUpstreams(
|
||||
return c.upstreamConfig, nil
|
||||
}
|
||||
|
||||
var conf proxy.UpstreamConfig
|
||||
var conf *proxy.UpstreamConfig
|
||||
conf, err = proxy.ParseUpstreamsConfig(
|
||||
upstreams,
|
||||
upstream.Options{
|
||||
&upstream.Options{
|
||||
Bootstrap: config.DNS.BootstrapDNS,
|
||||
Timeout: config.DNS.UpstreamTimeout.Duration,
|
||||
},
|
||||
@ -368,9 +370,9 @@ func (clients *clientsContainer) findUpstreams(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.upstreamConfig = &conf
|
||||
c.upstreamConfig = conf
|
||||
|
||||
return &conf, nil
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
// findLocked searches for a client by its ID. For internal use only.
|
||||
@ -423,22 +425,35 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// findRuntimeClientLocked finds a runtime client by their IP address. For
|
||||
// internal use only.
|
||||
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
var v interface{}
|
||||
v, ok = clients.ipToRC.Get(ip)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
rc, ok = v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return rc, true
|
||||
}
|
||||
|
||||
// FindRuntimeClient finds a runtime client by their IP.
|
||||
func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bool) {
|
||||
ipAddr := net.ParseIP(ip)
|
||||
if ipAddr == nil {
|
||||
return RuntimeClient{}, false
|
||||
func (clients *clientsContainer) FindRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
|
||||
if ip == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
if ok {
|
||||
return *rc, true
|
||||
}
|
||||
|
||||
return RuntimeClient{}, false
|
||||
return clients.findRuntimeClientLocked(ip)
|
||||
}
|
||||
|
||||
// check validates the client.
|
||||
@ -621,17 +636,17 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
|
||||
}
|
||||
|
||||
// SetWHOISInfo sets the WHOIS information for a client.
|
||||
func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISInfo) {
|
||||
func (clients *clientsContainer) SetWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
_, ok := clients.findLocked(ip)
|
||||
_, ok := clients.findLocked(ip.String())
|
||||
if ok {
|
||||
log.Debug("clients: client for %s is already created, ignore whois info", ip)
|
||||
return
|
||||
}
|
||||
|
||||
rc, ok := clients.ipToRC[ip]
|
||||
rc, ok := clients.findRuntimeClientLocked(ip)
|
||||
if ok {
|
||||
rc.WHOISInfo = wi
|
||||
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
|
||||
@ -646,14 +661,15 @@ func (clients *clientsContainer) SetWHOISInfo(ip string, wi *RuntimeClientWHOISI
|
||||
}
|
||||
|
||||
rc.WHOISInfo = wi
|
||||
clients.ipToRC[ip] = rc
|
||||
|
||||
clients.ipToRC.Set(ip, rc)
|
||||
|
||||
log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
|
||||
}
|
||||
|
||||
// AddHost adds a new IP-hostname pairing. The priorities of the sources is
|
||||
// taken into account. ok is true if the pairing was added.
|
||||
func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok bool, err error) {
|
||||
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
|
||||
clients.lock.Lock()
|
||||
defer clients.lock.Unlock()
|
||||
|
||||
@ -663,9 +679,9 @@ func (clients *clientsContainer) AddHost(ip, host string, src clientSource) (ok
|
||||
}
|
||||
|
||||
// addHostLocked adds a new IP-hostname pairing. For internal use only.
|
||||
func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource) (ok bool) {
|
||||
func (clients *clientsContainer) addHostLocked(ip net.IP, host string, src clientSource) (ok bool) {
|
||||
var rc *RuntimeClient
|
||||
rc, ok = clients.ipToRC[ip]
|
||||
rc, ok = clients.findRuntimeClientLocked(ip)
|
||||
if ok {
|
||||
if rc.Source > src {
|
||||
return false
|
||||
@ -679,10 +695,10 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
|
||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||
}
|
||||
|
||||
clients.ipToRC[ip] = rc
|
||||
clients.ipToRC.Set(ip, rc)
|
||||
}
|
||||
|
||||
log.Debug("clients: added %q -> %q [%d]", ip, host, len(clients.ipToRC))
|
||||
log.Debug("clients: added %s -> %q [%d]", ip, host, clients.ipToRC.Len())
|
||||
|
||||
return true
|
||||
}
|
||||
@ -690,12 +706,21 @@ func (clients *clientsContainer) addHostLocked(ip, host string, src clientSource
|
||||
// rmHostsBySrc removes all entries that match the specified source.
|
||||
func (clients *clientsContainer) rmHostsBySrc(src clientSource) {
|
||||
n := 0
|
||||
for k, v := range clients.ipToRC {
|
||||
if v.Source == src {
|
||||
delete(clients.ipToRC, k)
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("clients: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
if rc.Source == src {
|
||||
clients.ipToRC.Del(ip)
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: removed %d client aliases", n)
|
||||
}
|
||||
@ -715,16 +740,23 @@ func (clients *clientsContainer) addFromHostsFile() {
|
||||
clients.rmHostsBySrc(ClientSourceHostsFile)
|
||||
|
||||
n := 0
|
||||
for ip, names := range hosts {
|
||||
hosts.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
names, ok := v.([]string)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
ok := clients.addHostLocked(ip, name, ClientSourceHostsFile)
|
||||
ok = clients.addHostLocked(ip, name, ClientSourceHostsFile)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debug("Clients: added %d client aliases from system hosts-file", n)
|
||||
return true
|
||||
})
|
||||
|
||||
log.Debug("clients: added %d client aliases from system hosts-file", n)
|
||||
}
|
||||
|
||||
// addFromSystemARP adds the IP-hostname pairings from the output of the arp -a
|
||||
@ -752,15 +784,16 @@ func (clients *clientsContainer) addFromSystemARP() {
|
||||
// TODO(a.garipov): Rewrite to use bufio.Scanner.
|
||||
lines := strings.Split(string(data), "\n")
|
||||
for _, ln := range lines {
|
||||
open := strings.Index(ln, " (")
|
||||
close := strings.Index(ln, ") ")
|
||||
if open == -1 || close == -1 || open >= close {
|
||||
lparen := strings.Index(ln, " (")
|
||||
rparen := strings.Index(ln, ") ")
|
||||
if lparen == -1 || rparen == -1 || lparen >= rparen {
|
||||
continue
|
||||
}
|
||||
|
||||
host := ln[:open]
|
||||
ip := ln[open+2 : close]
|
||||
if aghnet.ValidateDomainName(host) != nil || net.ParseIP(ip) == nil {
|
||||
host := ln[:lparen]
|
||||
ipStr := ln[lparen+2 : rparen]
|
||||
ip := net.ParseIP(ipStr)
|
||||
if aghnet.ValidateDomainName(host) != nil || ip == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -796,7 +829,7 @@ func (clients *clientsContainer) updateFromDHCP(add bool) {
|
||||
continue
|
||||
}
|
||||
|
||||
ok := clients.addHostLocked(l.IP.String(), l.Hostname, ClientSourceDHCP)
|
||||
ok := clients.addHostLocked(l.IP, l.Hostname, ClientSourceDHCP)
|
||||
if ok {
|
||||
n++
|
||||
}
|
||||
|
@ -26,6 +26,7 @@ func TestClients(t *testing.T) {
|
||||
|
||||
ok, err := clients.Add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c = &Client{
|
||||
@ -35,23 +36,27 @@ func TestClients(t *testing.T) {
|
||||
|
||||
ok, err = clients.Add(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
c, ok = clients.Find("1.1.1.1")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.Find("1:2:3::4")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1", c.Name)
|
||||
|
||||
c, ok = clients.Find("2.2.2.2")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client2", c.Name)
|
||||
|
||||
assert.False(t, clients.Exists("1.2.3.4", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("2.2.2.2", ClientSourceHostsFile))
|
||||
assert.False(t, clients.Exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
@ -101,8 +106,8 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
assert.False(t, clients.Exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
||||
|
||||
err = clients.Update("client1", &Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
@ -113,21 +118,25 @@ func TestClients(t *testing.T) {
|
||||
|
||||
c, ok := clients.Find("1.1.1.2")
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, "client1-renamed", c.Name)
|
||||
assert.True(t, c.UseOwnSettings)
|
||||
|
||||
nilCli, ok := clients.list["client1"]
|
||||
require.False(t, ok)
|
||||
|
||||
assert.Nil(t, nilCli)
|
||||
|
||||
require.Len(t, c.IDs, 1)
|
||||
|
||||
assert.Equal(t, "1.1.1.2", c.IDs[0])
|
||||
})
|
||||
|
||||
t.Run("del_success", func(t *testing.T) {
|
||||
ok := clients.Del("client1-renamed")
|
||||
require.True(t, ok)
|
||||
assert.False(t, clients.Exists("1.1.1.2", ClientSourceHostsFile))
|
||||
|
||||
assert.False(t, clients.Exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("del_fail", func(t *testing.T) {
|
||||
@ -136,37 +145,44 @@ func TestClients(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("addhost_success", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceARP)
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
|
||||
ok, err := clients.AddHost(ip, "host", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost("1.1.1.1", "host2", ClientSourceARP)
|
||||
ok, err = clients.AddHost(ip, "host2", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
ok, err = clients.AddHost("1.1.1.1", "host3", ClientSourceHostsFile)
|
||||
ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.True(t, clients.Exists("1.1.1.1", ClientSourceHostsFile))
|
||||
assert.True(t, clients.Exists(ip, ClientSourceHostsFile))
|
||||
})
|
||||
|
||||
t.Run("dhcp_replaces_arp", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.2.3.4", "from_arp", ClientSourceARP)
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
|
||||
ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
assert.True(t, clients.Exists(ip, ClientSourceARP))
|
||||
|
||||
assert.True(t, clients.Exists("1.2.3.4", ClientSourceARP))
|
||||
|
||||
ok, err = clients.AddHost("1.2.3.4", "from_dhcp", ClientSourceDHCP)
|
||||
ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
assert.True(t, clients.Exists("1.2.3.4", ClientSourceDHCP))
|
||||
assert.True(t, ok)
|
||||
assert.True(t, clients.Exists(ip, ClientSourceDHCP))
|
||||
})
|
||||
|
||||
t.Run("addhost_fail", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host1", ClientSourceRDNS)
|
||||
ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
@ -183,31 +199,39 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("new_client", func(t *testing.T) {
|
||||
clients.SetWHOISInfo("1.1.1.255", whois)
|
||||
ip := net.IP{1, 1, 1, 255}
|
||||
clients.SetWHOISInfo(ip, whois)
|
||||
v, _ := clients.ipToRC.Get(ip)
|
||||
require.NotNil(t, v)
|
||||
|
||||
require.NotNil(t, clients.ipToRC["1.1.1.255"])
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
h := clients.ipToRC["1.1.1.255"]
|
||||
require.NotNil(t, h)
|
||||
|
||||
assert.Equal(t, h.WHOISInfo, whois)
|
||||
assert.Equal(t, rc.WHOISInfo, whois)
|
||||
})
|
||||
|
||||
t.Run("existing_auto-client", func(t *testing.T) {
|
||||
ok, err := clients.AddHost("1.1.1.1", "host", ClientSourceRDNS)
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
ok, err := clients.AddHost(ip, "host", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.SetWHOISInfo("1.1.1.1", whois)
|
||||
clients.SetWHOISInfo(ip, whois)
|
||||
v, _ := clients.ipToRC.Get(ip)
|
||||
require.NotNil(t, v)
|
||||
|
||||
require.NotNil(t, clients.ipToRC["1.1.1.1"])
|
||||
h := clients.ipToRC["1.1.1.1"]
|
||||
require.NotNil(t, h)
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
require.True(t, ok)
|
||||
require.NotNil(t, rc)
|
||||
|
||||
assert.Equal(t, h.WHOISInfo, whois)
|
||||
assert.Equal(t, rc.WHOISInfo, whois)
|
||||
})
|
||||
|
||||
t.Run("can't_set_manually-added", func(t *testing.T) {
|
||||
ip := net.IP{1, 1, 1, 2}
|
||||
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{"1.1.1.2"},
|
||||
Name: "client1",
|
||||
@ -215,8 +239,10 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
clients.SetWHOISInfo("1.1.1.2", whois)
|
||||
require.Nil(t, clients.ipToRC["1.1.1.2"])
|
||||
clients.SetWHOISInfo(ip, whois)
|
||||
v, _ := clients.ipToRC.Get(ip)
|
||||
require.Nil(t, v)
|
||||
|
||||
assert.True(t, clients.Del("client1"))
|
||||
})
|
||||
}
|
||||
@ -228,16 +254,18 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
clients.Init(nil, nil, nil)
|
||||
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
ip := net.IP{1, 1, 1, 1}
|
||||
|
||||
// Add a client.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
IDs: []string{ip.String(), "1:2:3::4", "aa:aa:aa:aa:aa:aa", "2.2.2.0/24"},
|
||||
Name: "client1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Now add an auto-client with the same IP.
|
||||
ok, err = clients.AddHost("1.1.1.1", "test", ClientSourceRDNS)
|
||||
ok, err = clients.AddHost(ip, "test", ClientSourceRDNS)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
})
|
||||
@ -245,7 +273,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
t.Run("complicated", func(t *testing.T) {
|
||||
var err error
|
||||
|
||||
testIP := net.IP{1, 2, 3, 4}
|
||||
ip := net.IP{1, 2, 3, 4}
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
config := dhcpd.ServerConfig{
|
||||
@ -267,7 +295,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
|
||||
err = clients.dhcpServer.AddStaticLease(&dhcpd.Lease{
|
||||
HWAddr: net.HardwareAddr{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA},
|
||||
IP: testIP,
|
||||
IP: ip,
|
||||
Hostname: "testhost",
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
})
|
||||
@ -275,7 +303,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.Add(&Client{
|
||||
IDs: []string{testIP.String()},
|
||||
IDs: []string{ip.String()},
|
||||
Name: "client2",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// clientJSON is a common structure used by several handlers to deal with
|
||||
@ -44,13 +46,13 @@ type clientJSON struct {
|
||||
type runtimeClientJSON struct {
|
||||
WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"`
|
||||
|
||||
IP string `json:"ip"`
|
||||
Name string `json:"name"`
|
||||
Source string `json:"source"`
|
||||
IP net.IP `json:"ip"`
|
||||
}
|
||||
|
||||
type clientListJSON struct {
|
||||
Clients []clientJSON `json:"clients"`
|
||||
Clients []*clientJSON `json:"clients"`
|
||||
RuntimeClients []runtimeClientJSON `json:"auto_clients"`
|
||||
Tags []string `json:"supported_tags"`
|
||||
}
|
||||
@ -66,11 +68,20 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
|
||||
cj := clientToJSON(c)
|
||||
data.Clients = append(data.Clients, cj)
|
||||
}
|
||||
for ip, rc := range clients.ipToRC {
|
||||
|
||||
clients.ipToRC.Range(func(ip net.IP, v interface{}) (cont bool) {
|
||||
rc, ok := v.(*RuntimeClient)
|
||||
if !ok {
|
||||
log.Error("dns: bad type %T in ipToRC for %s", v, ip)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
cj := runtimeClientJSON{
|
||||
IP: ip,
|
||||
Name: rc.Host,
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
|
||||
Name: rc.Host,
|
||||
IP: ip,
|
||||
}
|
||||
|
||||
cj.Source = "etc/hosts"
|
||||
@ -86,7 +97,9 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, _ *http
|
||||
}
|
||||
|
||||
data.RuntimeClients = append(data.RuntimeClients, cj)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
data.Tags = clientTags
|
||||
|
||||
@ -118,8 +131,8 @@ func jsonToClient(cj clientJSON) (c *Client) {
|
||||
}
|
||||
|
||||
// Convert Client object to JSON
|
||||
func clientToJSON(c *Client) clientJSON {
|
||||
cj := clientJSON{
|
||||
func clientToJSON(c *Client) (cj *clientJSON) {
|
||||
return &clientJSON{
|
||||
Name: c.Name,
|
||||
IDs: c.IDs,
|
||||
Tags: c.Tags,
|
||||
@ -134,19 +147,6 @@ func clientToJSON(c *Client) clientJSON {
|
||||
|
||||
Upstreams: c.Upstreams,
|
||||
}
|
||||
|
||||
return cj
|
||||
}
|
||||
|
||||
// runtimeClientToJSON converts a RuntimeClient into a JSON struct.
|
||||
func runtimeClientToJSON(ip string, rc RuntimeClient) (cj clientJSON) {
|
||||
cj = clientJSON{
|
||||
Name: rc.Host,
|
||||
IDs: []string{ip},
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
}
|
||||
|
||||
return cj
|
||||
}
|
||||
|
||||
// Add a new client
|
||||
@ -230,7 +230,7 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
// Get the list of clients by IP address list
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
data := []map[string]clientJSON{}
|
||||
data := []map[string]*clientJSON{}
|
||||
for i := 0; i < len(q); i++ {
|
||||
idStr := q.Get(fmt.Sprintf("ip%d", i))
|
||||
if idStr == "" {
|
||||
@ -239,20 +239,16 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
|
||||
ip := net.ParseIP(idStr)
|
||||
c, ok := clients.Find(idStr)
|
||||
var cj clientJSON
|
||||
var cj *clientJSON
|
||||
if !ok {
|
||||
var found bool
|
||||
cj, found = clients.findRuntime(ip, idStr)
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
cj = clients.findRuntime(ip, idStr)
|
||||
} else {
|
||||
cj = clientToJSON(c)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
}
|
||||
|
||||
data = append(data, map[string]clientJSON{
|
||||
data = append(data, map[string]*clientJSON{
|
||||
idStr: cj,
|
||||
})
|
||||
}
|
||||
@ -265,39 +261,37 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
}
|
||||
|
||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||
// /etc/hosts tables, DHCP leases, or blocklists.
|
||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj clientJSON, found bool) {
|
||||
if ip == nil {
|
||||
return cj, false
|
||||
}
|
||||
|
||||
rc, ok := clients.FindRuntimeClient(idStr)
|
||||
// /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be
|
||||
// non-nil.
|
||||
func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) {
|
||||
rc, ok := clients.FindRuntimeClient(ip)
|
||||
if !ok {
|
||||
// It is still possible that the IP used to be in the runtime
|
||||
// clients list, but then the server was reloaded. So, check
|
||||
// the DNS server's blocked IP list.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2428.
|
||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
||||
if rule == "" {
|
||||
return clientJSON{}, false
|
||||
}
|
||||
|
||||
cj = clientJSON{
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
cj = &clientJSON{
|
||||
IDs: []string{idStr},
|
||||
Disallowed: &disallowed,
|
||||
DisallowedRule: &rule,
|
||||
WHOISInfo: &RuntimeClientWHOISInfo{},
|
||||
}
|
||||
|
||||
return cj, true
|
||||
return cj
|
||||
}
|
||||
|
||||
cj = runtimeClientToJSON(idStr, rc)
|
||||
disallowed, rule := clients.dnsServer.IsBlockedIP(ip)
|
||||
cj = &clientJSON{
|
||||
Name: rc.Host,
|
||||
IDs: []string{idStr},
|
||||
WHOISInfo: rc.WHOISInfo,
|
||||
}
|
||||
|
||||
disallowed, rule := clients.dnsServer.IsBlockedClient(ip, idStr)
|
||||
cj.Disallowed, cj.DisallowedRule = &disallowed, &rule
|
||||
|
||||
return cj, true
|
||||
return cj
|
||||
}
|
||||
|
||||
// RegisterClientsHandlers registers HTTP handlers
|
||||
|
@ -105,8 +105,8 @@ func isRunning() bool {
|
||||
return Context.dnsServer != nil && Context.dnsServer.IsRunning()
|
||||
}
|
||||
|
||||
func onDNSRequest(d *proxy.DNSContext) {
|
||||
ip := aghnet.IPFromAddr(d.Addr)
|
||||
func onDNSRequest(pctx *proxy.DNSContext) {
|
||||
ip := aghnet.IPFromAddr(pctx.Addr)
|
||||
if ip == nil {
|
||||
// This would be quite weird if we get here.
|
||||
return
|
||||
|
@ -503,7 +503,7 @@ Please note, that this is crucial for a server to be able to use privileged port
|
||||
You have two options:
|
||||
1. Run AdGuard Home with root privileges
|
||||
2. On Linux you can grant the CAP_NET_BIND_SERVICE capability:
|
||||
https://github.com/AdguardTeam/AdGuardHome/internal/wiki/Getting-Started#running-without-superuser`
|
||||
https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#running-without-superuser`
|
||||
|
||||
log.Fatal(msg)
|
||||
}
|
||||
|
@ -102,12 +102,7 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) {
|
||||
func (r *RDNS) Begin(ip net.IP) {
|
||||
r.ensurePrivateCache()
|
||||
|
||||
if r.isCached(ip) {
|
||||
return
|
||||
}
|
||||
|
||||
id := ip.String()
|
||||
if r.clients.Exists(id, ClientSourceRDNS) {
|
||||
if r.isCached(ip) || r.clients.Exists(ip, ClientSourceRDNS) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -138,6 +133,6 @@ func (r *RDNS) workerLoop() {
|
||||
|
||||
// Don't handle any errors since AddHost doesn't return non-nil
|
||||
// errors for now.
|
||||
_, _ = r.clients.AddHost(ip.String(), host, ClientSourceRDNS)
|
||||
_, _ = r.clients.AddHost(ip, host, ClientSourceRDNS)
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
@ -84,7 +85,7 @@ func TestRDNS_Begin(t *testing.T) {
|
||||
clients: &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: tc.cliIDIndex,
|
||||
ipToRC: map[string]*RuntimeClient{},
|
||||
ipToRC: aghnet.NewIPMap(0),
|
||||
allTags: aghstrings.NewSet(),
|
||||
},
|
||||
}
|
||||
@ -204,7 +205,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
cc := &clientsContainer{
|
||||
list: map[string]*Client{},
|
||||
idIndex: map[string]*Client{},
|
||||
ipToRC: map[string]*RuntimeClient{},
|
||||
ipToRC: aghnet.NewIPMap(0),
|
||||
allTags: aghstrings.NewSet(),
|
||||
}
|
||||
ch := make(chan net.IP)
|
||||
@ -236,7 +237,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
assert.True(t, cc.Exists(tc.cliIP.String(), ClientSourceRDNS))
|
||||
assert.True(t, cc.Exists(tc.cliIP, ClientSourceRDNS))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -252,7 +252,6 @@ func (w *WHOIS) workerLoop() {
|
||||
continue
|
||||
}
|
||||
|
||||
id := ip.String()
|
||||
w.clients.SetWHOISInfo(id, info)
|
||||
w.clients.SetWHOISInfo(ip, info)
|
||||
}
|
||||
}
|
||||
|
@ -720,7 +720,10 @@ func (s *statsCtx) GetTopClientsIP(maxCount uint) []net.IP {
|
||||
a := convertMapToSlice(m, int(maxCount))
|
||||
d := []net.IP{}
|
||||
for _, it := range a {
|
||||
d = append(d, net.ParseIP(it.Name))
|
||||
ip := net.ParseIP(it.Name)
|
||||
if ip != nil {
|
||||
d = append(d, ip)
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
@ -4,6 +4,11 @@
|
||||
|
||||
## v0.107: API changes
|
||||
|
||||
### Client IDs in Access Settings
|
||||
|
||||
* The `POST /control/access/set` HTTP API now accepts client IDs in
|
||||
`"allowed_clients"` and `"disallowed_clients"` fields.
|
||||
|
||||
### The new field `"unicode_name"` in `DNSQuestion`
|
||||
|
||||
* The new optional field `"unicode_name"` is the Unicode representation of
|
||||
@ -17,7 +22,7 @@
|
||||
|
||||
### Disabling Statistics
|
||||
|
||||
* The API `POST /control/stats_config` HTTP API allows disabling statistics by
|
||||
* The `POST /control/stats_config` HTTP API allows disabling statistics by
|
||||
setting `"interval"` to `0`.
|
||||
|
||||
### `POST /control/dhcp/reset_leases`
|
||||
|
@ -1957,10 +1957,7 @@
|
||||
'disallowed_rule':
|
||||
'type': 'string'
|
||||
'description': >
|
||||
The rule due to which the client is disallowed. If disallowed is
|
||||
set to true, and this string is empty, then the client IP is
|
||||
disallowed by the "allowed IP list", that is it is not included in
|
||||
the allowed list.
|
||||
The rule due to which the client is allowed or blocked.
|
||||
'name':
|
||||
'description': >
|
||||
Persistent client's name or an empty string if this is a runtime
|
||||
@ -2352,17 +2349,19 @@
|
||||
'description': 'Client and host access list'
|
||||
'properties':
|
||||
'allowed_clients':
|
||||
'description': 'Allowlist of clients.'
|
||||
'description': >
|
||||
The allowlist of clients: IP addresses, CIDRs, or client IDs.
|
||||
'items':
|
||||
'type': 'string'
|
||||
'type': 'array'
|
||||
'disallowed_clients':
|
||||
'description': 'Blocklist of clients.'
|
||||
'description': >
|
||||
The blocklist of clients: IP addresses, CIDRs, or client IDs.
|
||||
'items':
|
||||
'type': 'string'
|
||||
'type': 'array'
|
||||
'blocked_hosts':
|
||||
'description': 'Blocklist of hosts.'
|
||||
'description': 'The blocklist of hosts.'
|
||||
'items':
|
||||
'type': 'string'
|
||||
'type': 'array'
|
||||
|
Loading…
Reference in New Issue
Block a user