mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-15 18:08:30 -07:00
Pull request 2138: AG-27492-client-persistent-storage
Squashed commit of the following: commit37e33ec761
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 14 15:25:25 2024 +0300 aghalg: imp code commit6b2f09a442
Merge:b8ea924aa
37736289e
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 14 15:04:59 2024 +0300 Merge branch 'master' into AG-27492-client-persistent-storage commitb8ea924aa7
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Feb 13 19:07:52 2024 +0300 home: imp tests commitaa6fec03b1
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Feb 13 14:54:28 2024 +0300 home: imp docs commit10637fdec4
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Feb 8 20:16:11 2024 +0300 all: imp code commitb45c7d868d
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Feb 7 19:15:11 2024 +0300 aghalg: add tests commit7abe33dbaa
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Feb 6 20:50:22 2024 +0300 all: imp code, tests commit4a44e993c9
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Feb 1 14:59:11 2024 +0300 all: persistent client index commit66b16e216e
Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Jan 31 15:06:05 2024 +0300 aghalg: ordered map
This commit is contained in:
parent
37736289e2
commit
fede297942
86
internal/aghalg/sortedmap.go
Normal file
86
internal/aghalg/sortedmap.go
Normal file
@ -0,0 +1,86 @@
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"slices"
|
||||
)
|
||||
|
||||
// SortedMap is a map that keeps elements in order with internal sorting
|
||||
// function. Must be initialised by the [NewSortedMap].
|
||||
type SortedMap[K comparable, V any] struct {
|
||||
vals map[K]V
|
||||
cmp func(a, b K) (res int)
|
||||
keys []K
|
||||
}
|
||||
|
||||
// NewSortedMap initializes the new instance of sorted map. cmp is a sort
|
||||
// function to keep elements in order.
|
||||
//
|
||||
// TODO(s.chzhen): Use cmp.Compare in Go 1.21.
|
||||
func NewSortedMap[K comparable, V any](cmp func(a, b K) (res int)) SortedMap[K, V] {
|
||||
return SortedMap[K, V]{
|
||||
vals: map[K]V{},
|
||||
cmp: cmp,
|
||||
}
|
||||
}
|
||||
|
||||
// Set adds val with key to the sorted map. It panics if the m is nil.
|
||||
func (m *SortedMap[K, V]) Set(key K, val V) {
|
||||
m.vals[key] = val
|
||||
|
||||
i, has := slices.BinarySearchFunc(m.keys, key, m.cmp)
|
||||
if has {
|
||||
m.keys[i] = key
|
||||
} else {
|
||||
m.keys = slices.Insert(m.keys, i, key)
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns val by key from the sorted map.
|
||||
func (m *SortedMap[K, V]) Get(key K) (val V, ok bool) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
val, ok = m.vals[key]
|
||||
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// Del removes the value by key from the sorted map.
|
||||
func (m *SortedMap[K, V]) Del(key K) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, has := m.vals[key]; !has {
|
||||
return
|
||||
}
|
||||
|
||||
delete(m.vals, key)
|
||||
i, _ := slices.BinarySearchFunc(m.keys, key, m.cmp)
|
||||
m.keys = slices.Delete(m.keys, i, i+1)
|
||||
}
|
||||
|
||||
// Clear removes all elements from the sorted map.
|
||||
func (m *SortedMap[K, V]) Clear() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
m.keys = nil
|
||||
clear(m.vals)
|
||||
}
|
||||
|
||||
// Range calls cb for each element of the map, sorted by m.cmp. If cb returns
|
||||
// false it stops.
|
||||
func (m *SortedMap[K, V]) Range(cb func(K, V) (cont bool)) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, k := range m.keys {
|
||||
if !cb(k, m.vals[k]) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
95
internal/aghalg/sortedmap_test.go
Normal file
95
internal/aghalg/sortedmap_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package aghalg
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewSortedMap(t *testing.T) {
|
||||
var m SortedMap[string, int]
|
||||
|
||||
letters := []string{}
|
||||
for i := 0; i < 10; i++ {
|
||||
r := string('a' + rune(i))
|
||||
letters = append(letters, r)
|
||||
}
|
||||
|
||||
t.Run("create_and_fill", func(t *testing.T) {
|
||||
m = NewSortedMap[string, int](strings.Compare)
|
||||
|
||||
nums := []int{}
|
||||
for i, r := range letters {
|
||||
m.Set(r, i)
|
||||
nums = append(nums, i)
|
||||
}
|
||||
|
||||
gotLetters := []string{}
|
||||
gotNums := []int{}
|
||||
m.Range(func(k string, v int) bool {
|
||||
gotLetters = append(gotLetters, k)
|
||||
gotNums = append(gotNums, v)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Equal(t, letters, gotLetters)
|
||||
assert.Equal(t, nums, gotNums)
|
||||
|
||||
n, ok := m.Get(letters[0])
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, nums[0], n)
|
||||
})
|
||||
|
||||
t.Run("clear", func(t *testing.T) {
|
||||
lastLetter := letters[len(letters)-1]
|
||||
m.Del(lastLetter)
|
||||
|
||||
_, ok := m.Get(lastLetter)
|
||||
assert.False(t, ok)
|
||||
|
||||
m.Clear()
|
||||
|
||||
gotLetters := []string{}
|
||||
m.Range(func(k string, _ int) bool {
|
||||
gotLetters = append(gotLetters, k)
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
assert.Len(t, gotLetters, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewSortedMap_nil(t *testing.T) {
|
||||
const (
|
||||
key = "key"
|
||||
val = "val"
|
||||
)
|
||||
|
||||
var m SortedMap[string, string]
|
||||
|
||||
assert.Panics(t, func() {
|
||||
m.Set(key, val)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
_, ok := m.Get(key)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
m.Range(func(_, _ string) (cont bool) {
|
||||
return true
|
||||
})
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
m.Del(key)
|
||||
})
|
||||
|
||||
assert.NotPanics(t, func() {
|
||||
m.Clear()
|
||||
})
|
||||
}
|
@ -30,6 +30,16 @@ func NewUID() (uid UID, err error) {
|
||||
return UID(uuidv7), err
|
||||
}
|
||||
|
||||
// MustNewUID is a wrapper around [NewUID] that panics if there is an error.
|
||||
func MustNewUID() (uid UID) {
|
||||
uid, err := NewUID()
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("unexpected uuidv7 error: %w", err))
|
||||
}
|
||||
|
||||
return uid
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ encoding.TextMarshaler = UID{}
|
||||
|
||||
|
249
internal/home/clientindex.go
Normal file
249
internal/home/clientindex.go
Normal file
@ -0,0 +1,249 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
)
|
||||
|
||||
// macKey contains MAC as byte array of 6, 8, or 20 bytes.
|
||||
type macKey any
|
||||
|
||||
// macToKey converts mac into key of type macKey, which is used as the key of
|
||||
// the [clientIndex.macToUID]. mac must be valid MAC address.
|
||||
func macToKey(mac net.HardwareAddr) (key macKey) {
|
||||
switch len(mac) {
|
||||
case 6:
|
||||
return [6]byte(mac)
|
||||
case 8:
|
||||
return [8]byte(mac)
|
||||
case 20:
|
||||
return [20]byte(mac)
|
||||
default:
|
||||
panic(fmt.Errorf("invalid mac address %#v", mac))
|
||||
}
|
||||
}
|
||||
|
||||
// clientIndex stores all information about persistent clients.
|
||||
type clientIndex struct {
|
||||
// clientIDToUID maps client ID to UID.
|
||||
clientIDToUID map[string]UID
|
||||
|
||||
// ipToUID maps IP address to UID.
|
||||
ipToUID map[netip.Addr]UID
|
||||
|
||||
// macToUID maps MAC address to UID.
|
||||
macToUID map[macKey]UID
|
||||
|
||||
// uidToClient maps UID to the persistent client.
|
||||
uidToClient map[UID]*persistentClient
|
||||
|
||||
// subnetToUID maps subnet to UID.
|
||||
subnetToUID aghalg.SortedMap[netip.Prefix, UID]
|
||||
}
|
||||
|
||||
// NewClientIndex initializes the new instance of client index.
|
||||
func NewClientIndex() (ci *clientIndex) {
|
||||
return &clientIndex{
|
||||
clientIDToUID: map[string]UID{},
|
||||
ipToUID: map[netip.Addr]UID{},
|
||||
subnetToUID: aghalg.NewSortedMap[netip.Prefix, UID](subnetCompare),
|
||||
macToUID: map[macKey]UID{},
|
||||
uidToClient: map[UID]*persistentClient{},
|
||||
}
|
||||
}
|
||||
|
||||
// add stores information about a persistent client in the index. c must be
|
||||
// non-nil and contain UID.
|
||||
func (ci *clientIndex) add(c *persistentClient) {
|
||||
if (c.UID == UID{}) {
|
||||
panic("client must contain uid")
|
||||
}
|
||||
|
||||
for _, id := range c.ClientIDs {
|
||||
ci.clientIDToUID[id] = c.UID
|
||||
}
|
||||
|
||||
for _, ip := range c.IPs {
|
||||
ci.ipToUID[ip] = c.UID
|
||||
}
|
||||
|
||||
for _, pref := range c.Subnets {
|
||||
ci.subnetToUID.Set(pref, c.UID)
|
||||
}
|
||||
|
||||
for _, mac := range c.MACs {
|
||||
k := macToKey(mac)
|
||||
ci.macToUID[k] = c.UID
|
||||
}
|
||||
|
||||
ci.uidToClient[c.UID] = c
|
||||
}
|
||||
|
||||
// clashes returns an error if the index contains a different persistent client
|
||||
// with at least a single identifier contained by c. c must be non-nil.
|
||||
func (ci *clientIndex) clashes(c *persistentClient) (err error) {
|
||||
for _, id := range c.ClientIDs {
|
||||
existing, ok := ci.clientIDToUID[id]
|
||||
if ok && existing != c.UID {
|
||||
p := ci.uidToClient[existing]
|
||||
|
||||
return fmt.Errorf("another client %q uses the same ID %q", p.Name, id)
|
||||
}
|
||||
}
|
||||
|
||||
p, ip := ci.clashesIP(c)
|
||||
if p != nil {
|
||||
return fmt.Errorf("another client %q uses the same IP %q", p.Name, ip)
|
||||
}
|
||||
|
||||
p, s := ci.clashesSubnet(c)
|
||||
if p != nil {
|
||||
return fmt.Errorf("another client %q uses the same subnet %q", p.Name, s)
|
||||
}
|
||||
|
||||
p, mac := ci.clashesMAC(c)
|
||||
if p != nil {
|
||||
return fmt.Errorf("another client %q uses the same MAC %q", p.Name, mac)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// clashesIP returns a previous client with the same IP address as c. c must be
|
||||
// non-nil.
|
||||
func (ci *clientIndex) clashesIP(c *persistentClient) (p *persistentClient, ip netip.Addr) {
|
||||
for _, ip := range c.IPs {
|
||||
existing, ok := ci.ipToUID[ip]
|
||||
if ok && existing != c.UID {
|
||||
return ci.uidToClient[existing], ip
|
||||
}
|
||||
}
|
||||
|
||||
return nil, netip.Addr{}
|
||||
}
|
||||
|
||||
// clashesSubnet returns a previous client with the same subnet as c. c must be
|
||||
// non-nil.
|
||||
func (ci *clientIndex) clashesSubnet(c *persistentClient) (p *persistentClient, s netip.Prefix) {
|
||||
for _, s = range c.Subnets {
|
||||
var existing UID
|
||||
var ok bool
|
||||
|
||||
ci.subnetToUID.Range(func(p netip.Prefix, uid UID) (cont bool) {
|
||||
if s == p {
|
||||
existing = uid
|
||||
ok = true
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if ok && existing != c.UID {
|
||||
return ci.uidToClient[existing], s
|
||||
}
|
||||
}
|
||||
|
||||
return nil, netip.Prefix{}
|
||||
}
|
||||
|
||||
// clashesMAC returns a previous client with the same MAC address as c. c must
|
||||
// be non-nil.
|
||||
func (ci *clientIndex) clashesMAC(c *persistentClient) (p *persistentClient, mac net.HardwareAddr) {
|
||||
for _, mac = range c.MACs {
|
||||
k := macToKey(mac)
|
||||
existing, ok := ci.macToUID[k]
|
||||
if ok && existing != c.UID {
|
||||
return ci.uidToClient[existing], mac
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// find finds persistent client by string representation of the client ID, IP
|
||||
// address, or MAC.
|
||||
func (ci *clientIndex) find(id string) (c *persistentClient, ok bool) {
|
||||
uid, found := ci.clientIDToUID[id]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(id)
|
||||
if err == nil {
|
||||
// MAC addresses can be successfully parsed as IP addresses.
|
||||
c, found = ci.findByIP(ip)
|
||||
if found {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
|
||||
mac, err := net.ParseMAC(id)
|
||||
if err == nil {
|
||||
return ci.findByMAC(mac)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find finds persistent client by IP address.
|
||||
func (ci *clientIndex) findByIP(ip netip.Addr) (c *persistentClient, found bool) {
|
||||
uid, found := ci.ipToUID[ip]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
ci.subnetToUID.Range(func(pref netip.Prefix, id UID) (cont bool) {
|
||||
if pref.Contains(ip) {
|
||||
uid, found = id, true
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// find finds persistent client by MAC.
|
||||
func (ci *clientIndex) findByMAC(mac net.HardwareAddr) (c *persistentClient, found bool) {
|
||||
k := macToKey(mac)
|
||||
uid, found := ci.macToUID[k]
|
||||
if found {
|
||||
return ci.uidToClient[uid], true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// del removes information about persistent client from the index. c must be
|
||||
// non-nil.
|
||||
func (ci *clientIndex) del(c *persistentClient) {
|
||||
for _, id := range c.ClientIDs {
|
||||
delete(ci.clientIDToUID, id)
|
||||
}
|
||||
|
||||
for _, ip := range c.IPs {
|
||||
delete(ci.ipToUID, ip)
|
||||
}
|
||||
|
||||
for _, pref := range c.Subnets {
|
||||
ci.subnetToUID.Del(pref)
|
||||
}
|
||||
|
||||
for _, mac := range c.MACs {
|
||||
k := macToKey(mac)
|
||||
delete(ci.macToUID, k)
|
||||
}
|
||||
|
||||
delete(ci.uidToClient, c.UID)
|
||||
}
|
210
internal/home/clientindex_internal_test.go
Normal file
210
internal/home/clientindex_internal_test.go
Normal file
@ -0,0 +1,210 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClientIndex(t *testing.T) {
|
||||
const (
|
||||
cliIPNone = "1.2.3.4"
|
||||
cliIP1 = "1.1.1.1"
|
||||
cliIP2 = "2.2.2.2"
|
||||
|
||||
cliIPv6 = "1:2:3::4"
|
||||
|
||||
cliSubnet = "2.2.2.0/24"
|
||||
cliSubnetIP = "2.2.2.222"
|
||||
|
||||
cliID = "client-id"
|
||||
cliMAC = "11:11:11:11:11:11"
|
||||
)
|
||||
|
||||
clients := []*persistentClient{{
|
||||
Name: "client1",
|
||||
IPs: []netip.Addr{
|
||||
netip.MustParseAddr(cliIP1),
|
||||
netip.MustParseAddr(cliIPv6),
|
||||
},
|
||||
}, {
|
||||
Name: "client2",
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP2)},
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||
}, {
|
||||
Name: "client_with_mac",
|
||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||
}, {
|
||||
Name: "client_with_id",
|
||||
ClientIDs: []string{cliID},
|
||||
}}
|
||||
|
||||
ci := newIDIndex(clients)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ids []string
|
||||
want *persistentClient
|
||||
}{{
|
||||
name: "ipv4_ipv6",
|
||||
ids: []string{cliIP1, cliIPv6},
|
||||
want: clients[0],
|
||||
}, {
|
||||
name: "ipv4_subnet",
|
||||
ids: []string{cliIP2, cliSubnetIP},
|
||||
want: clients[1],
|
||||
}, {
|
||||
name: "mac",
|
||||
ids: []string{cliMAC},
|
||||
want: clients[2],
|
||||
}, {
|
||||
name: "client_id",
|
||||
ids: []string{cliID},
|
||||
want: clients[3],
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
for _, id := range tc.ids {
|
||||
c, ok := ci.find(id)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Equal(t, tc.want, c)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("not_found", func(t *testing.T) {
|
||||
_, ok := ci.find(cliIPNone)
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClientIndex_Clashes(t *testing.T) {
|
||||
const (
|
||||
cliIP1 = "1.1.1.1"
|
||||
cliSubnet = "2.2.2.0/24"
|
||||
cliSubnetIP = "2.2.2.222"
|
||||
cliID = "client-id"
|
||||
cliMAC = "11:11:11:11:11:11"
|
||||
)
|
||||
|
||||
clients := []*persistentClient{{
|
||||
Name: "client_with_ip",
|
||||
IPs: []netip.Addr{netip.MustParseAddr(cliIP1)},
|
||||
}, {
|
||||
Name: "client_with_subnet",
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix(cliSubnet)},
|
||||
}, {
|
||||
Name: "client_with_mac",
|
||||
MACs: []net.HardwareAddr{mustParseMAC(cliMAC)},
|
||||
}, {
|
||||
Name: "client_with_id",
|
||||
ClientIDs: []string{cliID},
|
||||
}}
|
||||
|
||||
ci := newIDIndex(clients)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
client *persistentClient
|
||||
}{{
|
||||
name: "ipv4",
|
||||
client: clients[0],
|
||||
}, {
|
||||
name: "subnet",
|
||||
client: clients[1],
|
||||
}, {
|
||||
name: "mac",
|
||||
client: clients[2],
|
||||
}, {
|
||||
name: "client_id",
|
||||
client: clients[3],
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
clone := tc.client.shallowClone()
|
||||
clone.UID = MustNewUID()
|
||||
|
||||
err := ci.clashes(clone)
|
||||
require.Error(t, err)
|
||||
|
||||
ci.del(tc.client)
|
||||
err = ci.clashes(clone)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mustParseMAC is wrapper around [net.ParseMAC] that panics if there is an
|
||||
// error.
|
||||
func mustParseMAC(s string) (mac net.HardwareAddr) {
|
||||
mac, err := net.ParseMAC(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return mac
|
||||
}
|
||||
|
||||
func TestMACToKey(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
in string
|
||||
want any
|
||||
}{{
|
||||
name: "column6",
|
||||
in: "00:00:5e:00:53:01",
|
||||
want: [6]byte(mustParseMAC("00:00:5e:00:53:01")),
|
||||
}, {
|
||||
name: "column8",
|
||||
in: "02:00:5e:10:00:00:00:01",
|
||||
want: [8]byte(mustParseMAC("02:00:5e:10:00:00:00:01")),
|
||||
}, {
|
||||
name: "column20",
|
||||
in: "00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01",
|
||||
want: [20]byte(mustParseMAC("00:00:00:00:fe:80:00:00:00:00:00:00:02:00:5e:10:00:00:00:01")),
|
||||
}, {
|
||||
name: "hyphen6",
|
||||
in: "00-00-5e-00-53-01",
|
||||
want: [6]byte(mustParseMAC("00-00-5e-00-53-01")),
|
||||
}, {
|
||||
name: "hyphen8",
|
||||
in: "02-00-5e-10-00-00-00-01",
|
||||
want: [8]byte(mustParseMAC("02-00-5e-10-00-00-00-01")),
|
||||
}, {
|
||||
name: "hyphen20",
|
||||
in: "00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01",
|
||||
want: [20]byte(mustParseMAC("00-00-00-00-fe-80-00-00-00-00-00-00-02-00-5e-10-00-00-00-01")),
|
||||
}, {
|
||||
name: "dot6",
|
||||
in: "0000.5e00.5301",
|
||||
want: [6]byte(mustParseMAC("0000.5e00.5301")),
|
||||
}, {
|
||||
name: "dot8",
|
||||
in: "0200.5e10.0000.0001",
|
||||
want: [8]byte(mustParseMAC("0200.5e10.0000.0001")),
|
||||
}, {
|
||||
name: "dot20",
|
||||
in: "0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001",
|
||||
want: [20]byte(mustParseMAC("0000.0000.fe80.0000.0000.0000.0200.5e10.0000.0001")),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
mac := mustParseMAC(tc.in)
|
||||
|
||||
key := macToKey(mac)
|
||||
assert.Equal(t, tc.want, key)
|
||||
})
|
||||
}
|
||||
|
||||
assert.Panics(t, func() {
|
||||
mac := net.HardwareAddr([]byte{1, 2, 3})
|
||||
_ = macToKey(mac)
|
||||
})
|
||||
}
|
@ -47,8 +47,9 @@ type DHCP interface {
|
||||
type clientsContainer struct {
|
||||
// TODO(a.garipov): Perhaps use a number of separate indices for different
|
||||
// types (string, netip.Addr, and so on).
|
||||
list map[string]*persistentClient // name -> client
|
||||
idIndex map[string]*persistentClient // ID -> client
|
||||
list map[string]*persistentClient // name -> client
|
||||
|
||||
clientIndex *clientIndex
|
||||
|
||||
// ipToRC maps IP addresses to runtime client information.
|
||||
ipToRC map[netip.Addr]*client.Runtime
|
||||
@ -103,9 +104,10 @@ func (clients *clientsContainer) Init(
|
||||
}
|
||||
|
||||
clients.list = map[string]*persistentClient{}
|
||||
clients.idIndex = map[string]*persistentClient{}
|
||||
clients.ipToRC = map[netip.Addr]*client.Runtime{}
|
||||
|
||||
clients.clientIndex = NewClientIndex()
|
||||
|
||||
clients.allTags = stringutil.NewSet(clientTags...)
|
||||
|
||||
// TODO(e.burkov): Use [dhcpsvc] implementation when it's ready.
|
||||
@ -517,7 +519,7 @@ func (clients *clientsContainer) UpstreamConfigByID(
|
||||
// findLocked searches for a client by its ID. clients.lock is expected to be
|
||||
// locked.
|
||||
func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok bool) {
|
||||
c, ok = clients.idIndex[id]
|
||||
c, ok = clients.clientIndex.find(id)
|
||||
if ok {
|
||||
return c, true
|
||||
}
|
||||
@ -527,14 +529,6 @@ func (clients *clientsContainer) findLocked(id string) (c *persistentClient, ok
|
||||
return nil, false
|
||||
}
|
||||
|
||||
for _, c = range clients.list {
|
||||
for _, subnet := range c.Subnets {
|
||||
if subnet.Contains(ip) {
|
||||
return c, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(e.burkov): Iterate through clients.list only once.
|
||||
return clients.findDHCP(ip)
|
||||
}
|
||||
@ -638,18 +632,15 @@ func (clients *clientsContainer) add(c *persistentClient) (ok bool, err error) {
|
||||
}
|
||||
|
||||
// check ID index
|
||||
ids := c.ids()
|
||||
for _, id := range ids {
|
||||
var c2 *persistentClient
|
||||
c2, ok = clients.idIndex[id]
|
||||
if ok {
|
||||
return false, fmt.Errorf("another client uses the same ID (%q): %q", id, c2.Name)
|
||||
}
|
||||
err = clients.clientIndex.clashes(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return false, err
|
||||
}
|
||||
|
||||
clients.addLocked(c)
|
||||
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, ids, len(clients.list))
|
||||
log.Debug("clients: added %q: ID:%q [%d]", c.Name, c.ids(), len(clients.list))
|
||||
|
||||
return true, nil
|
||||
}
|
||||
@ -660,9 +651,7 @@ func (clients *clientsContainer) addLocked(c *persistentClient) {
|
||||
clients.list[c.Name] = c
|
||||
|
||||
// update ID index
|
||||
for _, id := range c.ids() {
|
||||
clients.idIndex[id] = c
|
||||
}
|
||||
clients.clientIndex.add(c)
|
||||
}
|
||||
|
||||
// remove removes a client. ok is false if there is no such client.
|
||||
@ -692,9 +681,7 @@ func (clients *clientsContainer) removeLocked(c *persistentClient) {
|
||||
delete(clients.list, c.Name)
|
||||
|
||||
// Update the ID index.
|
||||
for _, id := range c.ids() {
|
||||
delete(clients.idIndex, id)
|
||||
}
|
||||
clients.clientIndex.del(c)
|
||||
}
|
||||
|
||||
// update updates a client by its name.
|
||||
@ -724,11 +711,10 @@ func (clients *clientsContainer) update(prev, c *persistentClient) (err error) {
|
||||
}
|
||||
|
||||
// Check the ID index.
|
||||
for _, id := range c.ids() {
|
||||
existing, ok := clients.idIndex[id]
|
||||
if ok && existing != prev {
|
||||
return fmt.Errorf("id %q is used by client with name %q", id, existing.Name)
|
||||
}
|
||||
err = clients.clientIndex.clashes(c)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
clients.removeLocked(prev)
|
||||
|
@ -68,6 +68,7 @@ func TestClients(t *testing.T) {
|
||||
|
||||
c := &persistentClient{
|
||||
Name: "client1",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{cli1IP, cliIPv6},
|
||||
}
|
||||
|
||||
@ -78,6 +79,7 @@ func TestClients(t *testing.T) {
|
||||
|
||||
c = &persistentClient{
|
||||
Name: "client2",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{cli2IP},
|
||||
}
|
||||
|
||||
@ -111,6 +113,7 @@ func TestClients(t *testing.T) {
|
||||
t.Run("add_fail_name", func(t *testing.T) {
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client1",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.2.3.5")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -120,6 +123,7 @@ func TestClients(t *testing.T) {
|
||||
t.Run("add_fail_ip", func(t *testing.T) {
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client3",
|
||||
UID: MustNewUID(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.False(t, ok)
|
||||
@ -128,6 +132,7 @@ func TestClients(t *testing.T) {
|
||||
t.Run("update_fail_ip", func(t *testing.T) {
|
||||
err := clients.update(&persistentClient{Name: "client1"}, &persistentClient{
|
||||
Name: "client1",
|
||||
UID: MustNewUID(),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
@ -145,6 +150,7 @@ func TestClients(t *testing.T) {
|
||||
|
||||
err := clients.update(prev, &persistentClient{
|
||||
Name: "client1",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -159,6 +165,7 @@ func TestClients(t *testing.T) {
|
||||
|
||||
err = clients.update(prev, &persistentClient{
|
||||
Name: "client1-renamed",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{cliNewIP},
|
||||
UseOwnSettings: true,
|
||||
})
|
||||
@ -260,6 +267,7 @@ func TestClientsWHOIS(t *testing.T) {
|
||||
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client1",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -282,6 +290,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
// Add a client.
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client1",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{ip, netip.MustParseAddr("1:2:3::4")},
|
||||
Subnets: []netip.Prefix{netip.MustParsePrefix("2.2.2.0/24")},
|
||||
MACs: []net.HardwareAddr{{0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA}},
|
||||
@ -332,6 +341,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
// Add a new client with the same IP as for a client with MAC.
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client2",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{ip},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -340,6 +350,7 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
// Add a new client with the IP from the first client's IP range.
|
||||
ok, err = clients.add(&persistentClient{
|
||||
Name: "client3",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("2.2.2.2")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -353,6 +364,7 @@ func TestClientsCustomUpstream(t *testing.T) {
|
||||
// Add client with upstreams.
|
||||
ok, err := clients.add(&persistentClient{
|
||||
Name: "client1",
|
||||
UID: MustNewUID(),
|
||||
IPs: []netip.Addr{netip.MustParseAddr("1.1.1.1"), netip.MustParseAddr("1:2:3::4")},
|
||||
Upstreams: []string{
|
||||
"1.1.1.1",
|
||||
|
@ -12,6 +12,19 @@ import (
|
||||
|
||||
var testIPv4 = netip.AddrFrom4([4]byte{1, 2, 3, 4})
|
||||
|
||||
// newIDIndex is a helper function that returns a client index filled with
|
||||
// persistent clients from the m. It also generates a UID for each client.
|
||||
func newIDIndex(m []*persistentClient) (ci *clientIndex) {
|
||||
ci = NewClientIndex()
|
||||
|
||||
for _, c := range m {
|
||||
c.UID = MustNewUID()
|
||||
ci.add(c)
|
||||
}
|
||||
|
||||
return ci
|
||||
}
|
||||
|
||||
func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
var err error
|
||||
|
||||
@ -22,29 +35,28 @@ func TestApplyAdditionalFiltering(t *testing.T) {
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*persistentClient{
|
||||
"default": {
|
||||
UseOwnSettings: false,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
|
||||
FilteringEnabled: false,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
},
|
||||
"custom_filtering": {
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
},
|
||||
"partial_custom_filtering": {
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
},
|
||||
}
|
||||
Context.clients.clientIndex = newIDIndex([]*persistentClient{{
|
||||
ClientIDs: []string{"default"},
|
||||
UseOwnSettings: false,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: false},
|
||||
FilteringEnabled: false,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
}, {
|
||||
ClientIDs: []string{"custom_filtering"},
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: true,
|
||||
ParentalEnabled: true,
|
||||
}, {
|
||||
ClientIDs: []string{"partial_custom_filtering"},
|
||||
UseOwnSettings: true,
|
||||
safeSearchConf: filtering.SafeSearchConfig{Enabled: true},
|
||||
FilteringEnabled: true,
|
||||
SafeBrowsingEnabled: false,
|
||||
ParentalEnabled: false,
|
||||
}})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -108,38 +120,37 @@ func TestApplyAdditionalFiltering_blockedServices(t *testing.T) {
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
Context.clients.idIndex = map[string]*persistentClient{
|
||||
"default": {
|
||||
UseOwnBlockedServices: false,
|
||||
Context.clients.clientIndex = newIDIndex([]*persistentClient{{
|
||||
ClientIDs: []string{"default"},
|
||||
UseOwnBlockedServices: false,
|
||||
}, {
|
||||
ClientIDs: []string{"no_services"},
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
"no_services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
UseOwnBlockedServices: true,
|
||||
}, {
|
||||
ClientIDs: []string{"services"},
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
"services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
UseOwnBlockedServices: true,
|
||||
}, {
|
||||
ClientIDs: []string{"invalid_services"},
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: invalidBlockedServices,
|
||||
},
|
||||
"invalid_services": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.EmptyWeekly(),
|
||||
IDs: invalidBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
UseOwnBlockedServices: true,
|
||||
}, {
|
||||
ClientIDs: []string{"allow_all"},
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.FullWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
"allow_all": {
|
||||
BlockedServices: &filtering.BlockedServices{
|
||||
Schedule: schedule.FullWeekly(),
|
||||
IDs: clientBlockedServices,
|
||||
},
|
||||
UseOwnBlockedServices: true,
|
||||
},
|
||||
}
|
||||
UseOwnBlockedServices: true,
|
||||
}})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
|
Loading…
Reference in New Issue
Block a user