mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-15 18:08:30 -07:00
Pull request: dhcpd: fix ip ranges
Updates #2541. Squashed commit of the following: commit c81299991876f48836d24872d9145331a0bc9e6e Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Mar 16 18:10:07 2021 +0300 agherr: imp docs commit f43a5f5cde0ea16dd38dd533e16e415a1d306cb2 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Mar 16 17:35:59 2021 +0300 all: imp err handling, fix code commit ed26ad0ff53882725f7747264f8094e6fb9b0423 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Tue Mar 16 12:24:17 2021 +0300 dhcpd: fix ip ranges
This commit is contained in:
parent
e6a8fe452c
commit
9736123483
1
go.mod
1
go.mod
@ -32,6 +32,7 @@ require (
|
||||
github.com/stretchr/testify v1.6.1
|
||||
github.com/ti-mo/netfilter v0.4.0
|
||||
github.com/u-root/u-root v7.0.0+incompatible
|
||||
github.com/willf/bitset v1.1.11
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110
|
||||
|
2
go.sum
2
go.sum
@ -425,6 +425,8 @@ github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGr
|
||||
github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0=
|
||||
github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU=
|
||||
github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM=
|
||||
github.com/willf/bitset v1.1.11 h1:N7Z7E9UvjW+sGsEl7k/SJrvY2reP1A07MrGuCjIOjRE=
|
||||
github.com/willf/bitset v1.1.11/go.mod h1:83CECat5yLh5zVOf4P1ErAgKA5UDvKtgyUABdr3+MjI=
|
||||
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
|
||||
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
|
||||
go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU=
|
||||
|
@ -1,5 +1,4 @@
|
||||
// Package agherr contains the extended error type, and the function for
|
||||
// wrapping several errors.
|
||||
// Package agherr contains AdGuard Home's error handling helpers.
|
||||
package agherr
|
||||
|
||||
import (
|
||||
@ -23,8 +22,10 @@ type manyError struct {
|
||||
}
|
||||
|
||||
// Many wraps several errors and returns a single error.
|
||||
func Many(message string, underlying ...error) error {
|
||||
err := &manyError{
|
||||
//
|
||||
// TODO(a.garipov): Add formatting to message.
|
||||
func Many(message string, underlying ...error) (err error) {
|
||||
err = &manyError{
|
||||
message: message,
|
||||
underlying: underlying,
|
||||
}
|
||||
@ -33,7 +34,7 @@ func Many(message string, underlying ...error) error {
|
||||
}
|
||||
|
||||
// Error implements the error interface for *manyError.
|
||||
func (e *manyError) Error() string {
|
||||
func (e *manyError) Error() (msg string) {
|
||||
switch len(e.underlying) {
|
||||
case 0:
|
||||
return e.message
|
||||
@ -58,7 +59,7 @@ func (e *manyError) Error() string {
|
||||
}
|
||||
|
||||
// Unwrap implements the hidden errors.wrapper interface for *manyError.
|
||||
func (e *manyError) Unwrap() error {
|
||||
func (e *manyError) Unwrap() (err error) {
|
||||
if len(e.underlying) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -71,3 +72,38 @@ func (e *manyError) Unwrap() error {
|
||||
type wrapper interface {
|
||||
Unwrap() error
|
||||
}
|
||||
|
||||
// Annotate annotates the error with the message, unless the error is nil. This
|
||||
// is a helper function to simplify code like this:
|
||||
//
|
||||
// func (f *foo) doStuff(s string) (err error) {
|
||||
// defer func() {
|
||||
// if err != nil {
|
||||
// err = fmt.Errorf("bad foo string %q: %w", s, err)
|
||||
// }
|
||||
// }()
|
||||
//
|
||||
// // …
|
||||
// }
|
||||
//
|
||||
// Instead, write:
|
||||
//
|
||||
// func (f *foo) doStuff(s string) (err error) {
|
||||
// defer agherr.Annotate("bad foo string %q: %w", &err, s)
|
||||
//
|
||||
// // …
|
||||
// }
|
||||
//
|
||||
// msg must contain the final ": %w" verb.
|
||||
func Annotate(msg string, errPtr *error, args ...interface{}) {
|
||||
if errPtr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := *errPtr
|
||||
if err != nil {
|
||||
args = append(args, err)
|
||||
|
||||
*errPtr = fmt.Errorf(msg, args...)
|
||||
}
|
||||
}
|
||||
|
@ -6,30 +6,32 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestError_Error(t *testing.T) {
|
||||
testCases := []struct {
|
||||
err error
|
||||
name string
|
||||
want string
|
||||
err error
|
||||
}{{
|
||||
err: Many("a"),
|
||||
name: "simple",
|
||||
want: "a",
|
||||
err: Many("a"),
|
||||
}, {
|
||||
err: Many("a", errors.New("b")),
|
||||
name: "wrapping",
|
||||
want: "a: b",
|
||||
err: Many("a", errors.New("b")),
|
||||
}, {
|
||||
err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")),
|
||||
name: "wrapping several",
|
||||
want: "a: b (hidden: c, d)",
|
||||
err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")),
|
||||
}, {
|
||||
err: Many("a", Many("b", errors.New("c"), errors.New("d"))),
|
||||
name: "wrapping wrapper",
|
||||
want: "a: b: c (hidden: d)",
|
||||
err: Many("a", Many("b", errors.New("c"), errors.New("d"))),
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
assert.Equal(t, tc.want, tc.err.Error(), tc.name)
|
||||
}
|
||||
@ -43,33 +45,78 @@ func TestError_Unwrap(t *testing.T) {
|
||||
errWrapped
|
||||
errNil
|
||||
)
|
||||
|
||||
errs := []error{
|
||||
errSimple: errors.New("a"),
|
||||
errWrapped: fmt.Errorf("err: %w", errors.New("nested")),
|
||||
errNil: nil,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
want error
|
||||
wrapped error
|
||||
name string
|
||||
}{{
|
||||
name: "simple",
|
||||
want: errs[errSimple],
|
||||
wrapped: Many("a", errs[errSimple]),
|
||||
name: "simple",
|
||||
}, {
|
||||
name: "nested",
|
||||
want: errs[errWrapped],
|
||||
wrapped: Many("b", errs[errWrapped]),
|
||||
name: "nested",
|
||||
}, {
|
||||
name: "nil passed",
|
||||
want: errs[errNil],
|
||||
wrapped: Many("c", errs[errNil]),
|
||||
name: "nil passed",
|
||||
}, {
|
||||
name: "nil not passed",
|
||||
want: nil,
|
||||
wrapped: Many("d"),
|
||||
name: "nil not passed",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnotate(t *testing.T) {
|
||||
const s = "1234"
|
||||
const wantMsg = `bad string "1234": test`
|
||||
|
||||
// Don't use const, because we can't take a pointer of a constant.
|
||||
var errTest error = Error("test")
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
var errPtr *error
|
||||
assert.NotPanics(t, func() {
|
||||
Annotate("bad string %q: %w", errPtr, s)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("non_nil", func(t *testing.T) {
|
||||
errPtr := &errTest
|
||||
assert.NotPanics(t, func() {
|
||||
Annotate("bad string %q: %w", errPtr, s)
|
||||
})
|
||||
|
||||
require.NotNil(t, errPtr)
|
||||
|
||||
err := *errPtr
|
||||
require.NotNil(t, err)
|
||||
|
||||
assert.Equal(t, wantMsg, err.Error())
|
||||
})
|
||||
|
||||
t.Run("defer", func(t *testing.T) {
|
||||
f := func() (err error) {
|
||||
defer Annotate("bad string %q: %w", &errTest, s)
|
||||
|
||||
return errTest
|
||||
}
|
||||
|
||||
err := f()
|
||||
require.NotNil(t, err)
|
||||
|
||||
assert.Equal(t, wantMsg, err.Error())
|
||||
})
|
||||
}
|
||||
|
99
internal/dhcpd/iprange.go
Normal file
99
internal/dhcpd/iprange.go
Normal file
@ -0,0 +1,99 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/big"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
)
|
||||
|
||||
// ipRange is an inclusive range of IP addresses.
|
||||
//
|
||||
// It is safe for concurrent use.
|
||||
//
|
||||
// TODO(a.garipov): Perhaps create an optimised version with uint32 for
|
||||
// IPv4 ranges? Or use one of uint128 packages?
|
||||
type ipRange struct {
|
||||
start *big.Int
|
||||
end *big.Int
|
||||
}
|
||||
|
||||
// maxRangeLen is the maximum IP range length. The bitsets used in servers only
|
||||
// accept uints, which can have the size of 32 bit.
|
||||
const maxRangeLen = math.MaxUint32
|
||||
|
||||
// newIPRange creates a new IP address range. start must be less than end. The
|
||||
// resulting range must not be greater than maxRangeLen.
|
||||
func newIPRange(start, end net.IP) (r *ipRange, err error) {
|
||||
defer agherr.Annotate("invalid ip range: %w", &err)
|
||||
|
||||
// Make sure that both are 16 bytes long to simplify handling in
|
||||
// methods.
|
||||
start, end = start.To16(), end.To16()
|
||||
|
||||
startInt := (&big.Int{}).SetBytes(start)
|
||||
endInt := (&big.Int{}).SetBytes(end)
|
||||
diff := (&big.Int{}).Sub(endInt, startInt)
|
||||
|
||||
if diff.Sign() <= 0 {
|
||||
return nil, fmt.Errorf("start is greater than or equal to end")
|
||||
} else if !diff.IsUint64() || diff.Uint64() > maxRangeLen {
|
||||
return nil, fmt.Errorf("range is too large")
|
||||
}
|
||||
|
||||
r = &ipRange{
|
||||
start: startInt,
|
||||
end: endInt,
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// contains returns true if r contains ip.
|
||||
func (r *ipRange) contains(ip net.IP) (ok bool) {
|
||||
ipInt := (&big.Int{}).SetBytes(ip.To16())
|
||||
|
||||
return r.containsInt(ipInt)
|
||||
}
|
||||
|
||||
// containsInt returns true if r contains ipInt.
|
||||
func (r *ipRange) containsInt(ipInt *big.Int) (ok bool) {
|
||||
return ipInt.Cmp(r.start) >= 0 && ipInt.Cmp(r.end) <= 0
|
||||
}
|
||||
|
||||
// ipPredicate is a function that is called on every IP address in
|
||||
// (*ipRange).find. ip is given in the 16-byte form.
|
||||
type ipPredicate func(ip net.IP) (ok bool)
|
||||
|
||||
// find finds the first IP address in r for which p returns true. ip is in the
|
||||
// 16-byte form.
|
||||
func (r *ipRange) find(p ipPredicate) (ip net.IP) {
|
||||
ip = make(net.IP, net.IPv6len)
|
||||
_1 := big.NewInt(1)
|
||||
for i := (&big.Int{}).Set(r.start); i.Cmp(r.end) <= 0; i.Add(i, _1) {
|
||||
i.FillBytes(ip)
|
||||
if p(ip) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// offset returns the offset of ip from the beginning of r. It returns 0 and
|
||||
// false if ip is not in r.
|
||||
func (r *ipRange) offset(ip net.IP) (offset uint, ok bool) {
|
||||
ip = ip.To16()
|
||||
ipInt := (&big.Int{}).SetBytes(ip)
|
||||
if !r.containsInt(ipInt) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
offsetInt := (&big.Int{}).Sub(ipInt, r.start)
|
||||
|
||||
// Assume that the range was checked against maxRangeLen during
|
||||
// construction.
|
||||
return uint(offsetInt.Uint64()), true
|
||||
}
|
154
internal/dhcpd/iprange_test.go
Normal file
154
internal/dhcpd/iprange_test.go
Normal file
@ -0,0 +1,154 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewIPRange(t *testing.T) {
|
||||
start4 := net.IP{0, 0, 0, 1}
|
||||
end4 := net.IP{0, 0, 0, 3}
|
||||
start6 := net.IP{
|
||||
0x01, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x01,
|
||||
}
|
||||
end6 := net.IP{
|
||||
0x01, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x03,
|
||||
}
|
||||
end6Large := net.IP{
|
||||
0x02, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x00,
|
||||
0x00, 0x00, 0x00, 0x03,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
wantErrMsg string
|
||||
start net.IP
|
||||
end net.IP
|
||||
}{{
|
||||
name: "success_ipv4",
|
||||
wantErrMsg: "",
|
||||
start: start4,
|
||||
end: end4,
|
||||
}, {
|
||||
name: "success_ipv6",
|
||||
wantErrMsg: "",
|
||||
start: start6,
|
||||
end: end6,
|
||||
}, {
|
||||
name: "start_gt_end",
|
||||
wantErrMsg: "invalid ip range: start is greater than or equal to end",
|
||||
start: end4,
|
||||
end: start4,
|
||||
}, {
|
||||
name: "start_eq_end",
|
||||
wantErrMsg: "invalid ip range: start is greater than or equal to end",
|
||||
start: start4,
|
||||
end: start4,
|
||||
}, {
|
||||
name: "too_large",
|
||||
wantErrMsg: "invalid ip range: range is too large",
|
||||
start: start6,
|
||||
end: end6Large,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
r, err := newIPRange(tc.start, tc.end)
|
||||
if tc.wantErrMsg == "" {
|
||||
assert.Nil(t, err)
|
||||
assert.NotNil(t, r)
|
||||
} else {
|
||||
require.NotNil(t, err)
|
||||
assert.Equal(t, tc.wantErrMsg, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPRange_Contains(t *testing.T) {
|
||||
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 3}
|
||||
r, err := newIPRange(start, end)
|
||||
require.Nil(t, err)
|
||||
|
||||
assert.True(t, r.contains(start))
|
||||
assert.True(t, r.contains(net.IP{0, 0, 0, 2}))
|
||||
assert.True(t, r.contains(end))
|
||||
|
||||
assert.False(t, r.contains(net.IP{0, 0, 0, 0}))
|
||||
assert.False(t, r.contains(net.IP{0, 0, 0, 4}))
|
||||
}
|
||||
|
||||
func TestIPRange_Find(t *testing.T) {
|
||||
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5}
|
||||
r, err := newIPRange(start, end)
|
||||
require.Nil(t, err)
|
||||
|
||||
want := net.IPv4(0, 0, 0, 2)
|
||||
got := r.find(func(ip net.IP) (ok bool) {
|
||||
return ip[len(ip)-1]%2 == 0
|
||||
})
|
||||
|
||||
assert.Equal(t, want, got)
|
||||
|
||||
got = r.find(func(ip net.IP) (ok bool) {
|
||||
return ip[len(ip)-1]%10 == 0
|
||||
})
|
||||
assert.Nil(t, got)
|
||||
}
|
||||
|
||||
func TestIPRange_Offset(t *testing.T) {
|
||||
start, end := net.IP{0, 0, 0, 1}, net.IP{0, 0, 0, 5}
|
||||
r, err := newIPRange(start, end)
|
||||
require.Nil(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
in net.IP
|
||||
wantOffset uint
|
||||
wantOK bool
|
||||
}{{
|
||||
name: "in",
|
||||
in: net.IP{0, 0, 0, 2},
|
||||
wantOffset: 1,
|
||||
wantOK: true,
|
||||
}, {
|
||||
name: "in_start",
|
||||
in: start,
|
||||
wantOffset: 0,
|
||||
wantOK: true,
|
||||
}, {
|
||||
name: "in_end",
|
||||
in: end,
|
||||
wantOffset: 4,
|
||||
wantOK: true,
|
||||
}, {
|
||||
name: "out_after",
|
||||
in: net.IP{0, 0, 0, 6},
|
||||
wantOffset: 0,
|
||||
wantOK: false,
|
||||
}, {
|
||||
name: "out_before",
|
||||
in: net.IP{0, 0, 0, 0},
|
||||
wantOffset: 0,
|
||||
wantOK: false,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
offset, ok := r.offset(tc.in)
|
||||
assert.Equal(t, tc.wantOffset, offset)
|
||||
assert.Equal(t, tc.wantOK, ok)
|
||||
})
|
||||
}
|
||||
}
|
@ -100,11 +100,7 @@ func newDHCPOptionParser() (p *dhcpOptionParser) {
|
||||
|
||||
// parse parses an option. See the handlers' documentation for more info.
|
||||
func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid option string %q: %w", s, err)
|
||||
}
|
||||
}()
|
||||
defer agherr.Annotate("invalid option string %q: %w", &err, s)
|
||||
|
||||
s = strings.TrimSpace(s)
|
||||
parts := strings.SplitN(s, " ", 3)
|
||||
|
@ -8,7 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func TestDHCPOptionParser(t *testing.T) {
|
||||
testCasesA := []struct {
|
||||
testCases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantErrMsg string
|
||||
@ -92,7 +92,7 @@ func TestDHCPOptionParser(t *testing.T) {
|
||||
|
||||
p := newDHCPOptionParser()
|
||||
|
||||
for _, tc := range testCasesA {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
code, data, err := p.parse(tc.in)
|
||||
if tc.wantErrMsg == "" {
|
||||
|
@ -60,8 +60,8 @@ type V4ServerConf struct {
|
||||
// DEC_CODE ip IP_ADDR
|
||||
Options []string `yaml:"options" json:"-"`
|
||||
|
||||
ipStart net.IP // starting IP address for dynamic leases
|
||||
ipEnd net.IP // ending IP address for dynamic leases
|
||||
ipRange *ipRange
|
||||
|
||||
leaseTime time.Duration // the time during which a dynamic lease is considered valid
|
||||
dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses
|
||||
routerIP net.IP // value for Option Router
|
||||
|
@ -4,7 +4,6 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
@ -14,19 +13,25 @@ import (
|
||||
"github.com/go-ping/ping"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/server4"
|
||||
"github.com/willf/bitset"
|
||||
)
|
||||
|
||||
// v4Server is a DHCPv4 server.
|
||||
//
|
||||
// TODO(a.garipov): Think about unifying this and v6Server.
|
||||
type v4Server struct {
|
||||
srv *server4.Server
|
||||
leasesLock sync.Mutex
|
||||
leases []*Lease
|
||||
// TODO(e.burkov): This field type should be a normal bitmap.
|
||||
ipAddrs [256]byte
|
||||
|
||||
conf V4ServerConf
|
||||
srv *server4.Server
|
||||
|
||||
// leasedOffsets contains offsets from conf.ipRange.start that have been
|
||||
// leased.
|
||||
leasedOffsets *bitset.BitSet
|
||||
|
||||
// leases contains all dynamic and static leases.
|
||||
leases []*Lease
|
||||
|
||||
// leasesLock protects leases and leasedOffsets.
|
||||
leasesLock sync.Mutex
|
||||
}
|
||||
|
||||
// WriteDiskConfig4 - write configuration
|
||||
@ -38,27 +43,14 @@ func (s *v4Server) WriteDiskConfig4(c *V4ServerConf) {
|
||||
func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) {
|
||||
}
|
||||
|
||||
// Return TRUE if IP address is within range [start..stop]
|
||||
func ip4InRange(start, stop, ip net.IP) bool {
|
||||
if len(start) != 4 || len(stop) != 4 {
|
||||
return false
|
||||
}
|
||||
from := binary.BigEndian.Uint32(start)
|
||||
to := binary.BigEndian.Uint32(stop)
|
||||
check := binary.BigEndian.Uint32(ip)
|
||||
return from <= check && check <= to
|
||||
}
|
||||
|
||||
// ResetLeases - reset leases
|
||||
func (s *v4Server) ResetLeases(leases []*Lease) {
|
||||
s.leases = nil
|
||||
|
||||
for _, l := range leases {
|
||||
if l.Expiry.Unix() != leaseExpireStatic && !s.conf.ipRange.contains(l.IP) {
|
||||
log.Debug("dhcpv4: skipping a lease with ip %v: not within current ip range", l.IP)
|
||||
|
||||
if l.Expiry.Unix() != leaseExpireStatic &&
|
||||
!ip4InRange(s.conf.ipStart, s.conf.ipEnd, l.IP) {
|
||||
|
||||
log.Debug("dhcpv4: skipping a lease with IP %v: not within current IP range", l.IP)
|
||||
continue
|
||||
}
|
||||
|
||||
@ -127,16 +119,18 @@ func (s *v4Server) blacklistLease(lease *Lease) {
|
||||
lease.Expiry = time.Now().Add(s.conf.leaseTime)
|
||||
}
|
||||
|
||||
// Remove (swap) lease by index
|
||||
func (s *v4Server) leaseRemoveSwapByIndex(i int) {
|
||||
s.ipAddrs[s.leases[i].IP[3]] = 0
|
||||
log.Debug("dhcpv4: removed lease %s", s.leases[i].HWAddr)
|
||||
// rmLeaseByIndex removes a lease by its index in the leases slice.
|
||||
func (s *v4Server) rmLeaseByIndex(i int) {
|
||||
l := s.leases[i]
|
||||
s.leases = append(s.leases[:i], s.leases[i+1:]...)
|
||||
|
||||
n := len(s.leases)
|
||||
if i != n-1 {
|
||||
s.leases[i] = s.leases[n-1] // swap with the last element
|
||||
r := s.conf.ipRange
|
||||
offset, ok := r.offset(l.IP)
|
||||
if ok {
|
||||
s.leasedOffsets.Clear(offset)
|
||||
}
|
||||
s.leases = s.leases[:n-1]
|
||||
|
||||
log.Debug("dhcpv4: removed lease %s (%s)", l.IP, l.HWAddr)
|
||||
}
|
||||
|
||||
// Remove a dynamic lease with the same properties
|
||||
@ -146,51 +140,61 @@ func (s *v4Server) rmDynamicLease(lease Lease) error {
|
||||
l := s.leases[i]
|
||||
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) {
|
||||
|
||||
if l.Expiry.Unix() == leaseExpireStatic {
|
||||
return fmt.Errorf("static lease already exists")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
s.rmLeaseByIndex(i)
|
||||
if i == len(s.leases) {
|
||||
break
|
||||
}
|
||||
|
||||
l = s.leases[i]
|
||||
}
|
||||
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
|
||||
if l.Expiry.Unix() == leaseExpireStatic {
|
||||
return fmt.Errorf("static lease already exists")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
s.rmLeaseByIndex(i)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Add a lease
|
||||
// addLease adds a lease.
|
||||
func (s *v4Server) addLease(l *Lease) {
|
||||
r := s.conf.ipRange
|
||||
offset, ok := r.offset(l.IP)
|
||||
if !ok {
|
||||
// TODO(a.garipov): Better error handling.
|
||||
log.Debug("dhcpv4: lease %s (%s) out of range, not adding", l.IP, l.HWAddr)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
s.leases = append(s.leases, l)
|
||||
s.ipAddrs[l.IP[3]] = 1
|
||||
log.Debug("dhcpv4: added lease %s <-> %s", l.IP, l.HWAddr)
|
||||
s.leasedOffsets.Set(uint(offset))
|
||||
|
||||
log.Debug("dhcpv4: added lease %s (%s)", l.IP, l.HWAddr)
|
||||
}
|
||||
|
||||
// Remove a lease with the same properties
|
||||
func (s *v4Server) rmLease(lease Lease) error {
|
||||
for i, l := range s.leases {
|
||||
if net.IP.Equal(l.IP, lease.IP) {
|
||||
|
||||
if l.IP.Equal(lease.IP) {
|
||||
if !bytes.Equal(l.HWAddr, lease.HWAddr) ||
|
||||
l.Hostname != lease.Hostname {
|
||||
return fmt.Errorf("lease not found")
|
||||
}
|
||||
|
||||
s.leaseRemoveSwapByIndex(i)
|
||||
s.rmLeaseByIndex(i)
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("lease not found")
|
||||
}
|
||||
|
||||
@ -258,7 +262,7 @@ func (s *v4Server) addrAvailable(target net.IP) bool {
|
||||
pinger.Timeout = time.Duration(s.conf.ICMPTimeout) * time.Millisecond
|
||||
pinger.Count = 1
|
||||
reply := false
|
||||
pinger.OnRecv = func(pkt *ping.Packet) {
|
||||
pinger.OnRecv = func(_ *ping.Packet) {
|
||||
reply = true
|
||||
}
|
||||
log.Debug("dhcpv4: Sending ICMP Echo to %v", target)
|
||||
@ -278,30 +282,31 @@ func (s *v4Server) addrAvailable(target net.IP) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Find lease by MAC
|
||||
func (s *v4Server) findLease(mac net.HardwareAddr) *Lease {
|
||||
for i := range s.leases {
|
||||
if bytes.Equal(mac, s.leases[i].HWAddr) {
|
||||
return s.leases[i]
|
||||
// findLease finds a lease by its MAC-address.
|
||||
func (s *v4Server) findLease(mac net.HardwareAddr) (l *Lease) {
|
||||
for _, l = range s.leases {
|
||||
if bytes.Equal(mac, l.HWAddr) {
|
||||
return l
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get next free IP
|
||||
func (s *v4Server) findFreeIP() net.IP {
|
||||
for i := s.conf.ipStart[3]; ; i++ {
|
||||
if s.ipAddrs[i] == 0 {
|
||||
ip := make([]byte, 4)
|
||||
copy(ip, s.conf.ipStart)
|
||||
ip[3] = i
|
||||
return ip
|
||||
// nextIP generates a new free IP.
|
||||
func (s *v4Server) nextIP() (ip net.IP) {
|
||||
r := s.conf.ipRange
|
||||
ip = r.find(func(next net.IP) (ok bool) {
|
||||
offset, ok := r.offset(next)
|
||||
if !ok {
|
||||
// Shouldn't happen.
|
||||
return false
|
||||
}
|
||||
if i == s.conf.ipEnd[3] {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
return !s.leasedOffsets.Test(uint(offset))
|
||||
})
|
||||
|
||||
return ip.To4()
|
||||
}
|
||||
|
||||
// Find an expired lease and return its index or -1
|
||||
@ -316,24 +321,30 @@ func (s *v4Server) findExpiredLease() int {
|
||||
return -1
|
||||
}
|
||||
|
||||
// Reserve lease for MAC
|
||||
func (s *v4Server) reserveLease(mac net.HardwareAddr) *Lease {
|
||||
l := Lease{}
|
||||
l.HWAddr = make([]byte, 6)
|
||||
// reserveLease reserves a lease for a client by its MAC-address. It returns
|
||||
// nil if it couldn't allocate a new lease.
|
||||
func (s *v4Server) reserveLease(mac net.HardwareAddr) (l *Lease) {
|
||||
l = &Lease{
|
||||
HWAddr: make([]byte, 6),
|
||||
}
|
||||
|
||||
copy(l.HWAddr, mac)
|
||||
|
||||
l.IP = s.findFreeIP()
|
||||
l.IP = s.nextIP()
|
||||
if l.IP == nil {
|
||||
i := s.findExpiredLease()
|
||||
if i < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
copy(s.leases[i].HWAddr, mac)
|
||||
|
||||
return s.leases[i]
|
||||
}
|
||||
|
||||
s.addLease(&l)
|
||||
return &l
|
||||
s.addLease(l)
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (s *v4Server) commitLease(l *Lease) {
|
||||
@ -650,22 +661,12 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) {
|
||||
s.conf.subnetMask = make([]byte, 4)
|
||||
copy(s.conf.subnetMask, s.conf.SubnetMask.To4())
|
||||
|
||||
s.conf.ipStart, err = tryTo4(conf.RangeStart)
|
||||
if s.conf.ipStart == nil {
|
||||
s.conf.ipRange, err = newIPRange(conf.RangeStart, conf.RangeEnd)
|
||||
if err != nil {
|
||||
return s, fmt.Errorf("dhcpv4: %w", err)
|
||||
}
|
||||
if s.conf.ipStart[0] == 0 {
|
||||
return s, fmt.Errorf("dhcpv4: invalid range start IP")
|
||||
}
|
||||
|
||||
s.conf.ipEnd, err = tryTo4(conf.RangeEnd)
|
||||
if s.conf.ipEnd == nil {
|
||||
return s, fmt.Errorf("dhcpv4: %w", err)
|
||||
}
|
||||
if !net.IP.Equal(s.conf.ipStart[:3], s.conf.ipEnd[:3]) ||
|
||||
s.conf.ipStart[3] > s.conf.ipEnd[3] {
|
||||
return s, fmt.Errorf("dhcpv4: range end IP should match range start IP")
|
||||
}
|
||||
s.leasedOffsets = &bitset.BitSet{}
|
||||
|
||||
if conf.LeaseDuration == 0 {
|
||||
s.conf.leaseTime = time.Hour * 24
|
||||
|
@ -212,18 +212,26 @@ func TestV4DynamicLease_Get(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 1, s.process(req, resp))
|
||||
})
|
||||
|
||||
// Don't continue if we got any errors in the previous subtest.
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("offer", func(t *testing.T) {
|
||||
assert.Equal(t, dhcpv4.MessageTypeOffer, resp.MessageType())
|
||||
assert.Equal(t, mac, resp.ClientHWAddr)
|
||||
assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0]))
|
||||
assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier()))
|
||||
|
||||
assert.Equal(t, s.conf.RangeStart, resp.YourIPAddr)
|
||||
assert.Equal(t, s.conf.GatewayIP, resp.ServerIdentifier())
|
||||
|
||||
router := resp.Router()
|
||||
require.Len(t, router, 1)
|
||||
assert.Equal(t, s.conf.GatewayIP, router[0])
|
||||
|
||||
assert.Equal(t, s.conf.subnetMask, resp.SubnetMask())
|
||||
assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds())
|
||||
assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)])
|
||||
assert.True(t, net.IP{1, 2, 3, 4}.Equal(net.IP(resp.Options[uint8(dhcpv4.OptionRelayAgentInformation)])))
|
||||
|
||||
assert.Equal(t, net.IP{1, 2, 3, 4}, net.IP(resp.RelayAgentInfo().ToBytes()))
|
||||
})
|
||||
|
||||
t.Run("request", func(t *testing.T) {
|
||||
@ -260,31 +268,3 @@ func TestV4DynamicLease_Get(t *testing.T) {
|
||||
assert.Equal(t, mac, ls[0].HWAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIP4InRange(t *testing.T) {
|
||||
start := net.IP{192, 168, 10, 100}
|
||||
stop := net.IP{192, 168, 10, 200}
|
||||
|
||||
testCases := []struct {
|
||||
ip net.IP
|
||||
want bool
|
||||
}{{
|
||||
ip: net.IP{192, 168, 10, 99},
|
||||
want: false,
|
||||
}, {
|
||||
ip: net.IP{192, 168, 11, 100},
|
||||
want: false,
|
||||
}, {
|
||||
ip: net.IP{192, 168, 11, 201},
|
||||
want: false,
|
||||
}, {
|
||||
ip: start,
|
||||
want: true,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.ip.String(), func(t *testing.T) {
|
||||
assert.Equal(t, tc.want, ip4InRange(start, stop, tc.ip))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -231,8 +231,17 @@ func TestClientsAddExisting(t *testing.T) {
|
||||
|
||||
// First, init a DHCP server with a single static lease.
|
||||
config := dhcpd.ServerConfig{
|
||||
Enabled: true,
|
||||
DBFilePath: "leases.db",
|
||||
Conf4: dhcpd.V4ServerConf{
|
||||
Enabled: true,
|
||||
GatewayIP: net.IP{1, 2, 3, 1},
|
||||
SubnetMask: net.IP{255, 255, 255, 0},
|
||||
RangeStart: net.IP{1, 2, 3, 2},
|
||||
RangeEnd: net.IP{1, 2, 3, 10},
|
||||
},
|
||||
}
|
||||
|
||||
clients.dhcpServer = dhcpd.Create(config)
|
||||
t.Cleanup(func() { _ = os.Remove("leases.db") })
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
checks = ["all"]
|
||||
initialisms = [
|
||||
# See https://github.com/dominikh/go-tools/blob/master/config/config.go.
|
||||
#
|
||||
# Do not add "PTR" since we use "Ptr" as a suffix.
|
||||
"inherit"
|
||||
, "DHCP"
|
||||
, "DOH"
|
||||
@ -8,7 +10,6 @@ initialisms = [
|
||||
, "DOT"
|
||||
, "EDNS"
|
||||
, "MX"
|
||||
, "PTR"
|
||||
, "QUIC"
|
||||
, "RA"
|
||||
, "SDNS"
|
||||
|
Loading…
Reference in New Issue
Block a user