mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-17 10:58:29 -07:00
208 lines
3.6 KiB
Go
208 lines
3.6 KiB
Go
|
package home
|
||
|
|
||
|
import (
|
||
|
"net"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
)
|
||
|
|
||
|
func TestAuthRateLimiter_Cleanup(t *testing.T) {
|
||
|
const key = "some-key"
|
||
|
now := time.Now()
|
||
|
|
||
|
testCases := []struct {
|
||
|
name string
|
||
|
att failedAuth
|
||
|
wantExp bool
|
||
|
}{{
|
||
|
name: "expired",
|
||
|
att: failedAuth{
|
||
|
until: now.Add(-100 * time.Hour),
|
||
|
},
|
||
|
wantExp: true,
|
||
|
}, {
|
||
|
name: "nope_yet",
|
||
|
att: failedAuth{
|
||
|
until: now.Add(failedAuthTTL / 2),
|
||
|
},
|
||
|
wantExp: false,
|
||
|
}, {
|
||
|
name: "blocked",
|
||
|
att: failedAuth{
|
||
|
until: now.Add(100 * time.Hour),
|
||
|
},
|
||
|
wantExp: false,
|
||
|
}}
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
ab := &authRateLimiter{
|
||
|
failedAuths: map[string]failedAuth{
|
||
|
key: tc.att,
|
||
|
},
|
||
|
}
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
ab.cleanupLocked(now)
|
||
|
if tc.wantExp {
|
||
|
assert.Empty(t, ab.failedAuths)
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
require.Len(t, ab.failedAuths, 1)
|
||
|
|
||
|
_, ok := ab.failedAuths[key]
|
||
|
require.True(t, ok)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAuthRateLimiter_Check(t *testing.T) {
|
||
|
key := string(net.IP{127, 0, 0, 1})
|
||
|
const maxAtt = 1
|
||
|
now := time.Now()
|
||
|
|
||
|
testCases := []struct {
|
||
|
until time.Time
|
||
|
name string
|
||
|
num uint
|
||
|
wantExp bool
|
||
|
}{{
|
||
|
until: now.Add(-100 * time.Hour),
|
||
|
name: "expired",
|
||
|
num: 0,
|
||
|
wantExp: true,
|
||
|
}, {
|
||
|
until: now.Add(failedAuthTTL),
|
||
|
name: "not_blocked_but_tracked",
|
||
|
num: 0,
|
||
|
wantExp: true,
|
||
|
}, {
|
||
|
until: now,
|
||
|
name: "expired_but_stayed",
|
||
|
num: 2,
|
||
|
wantExp: true,
|
||
|
}, {
|
||
|
until: now.Add(100 * time.Hour),
|
||
|
name: "blocked",
|
||
|
num: 2,
|
||
|
wantExp: false,
|
||
|
}}
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
failedAuths := map[string]failedAuth{
|
||
|
key: {
|
||
|
num: tc.num,
|
||
|
until: tc.until,
|
||
|
},
|
||
|
}
|
||
|
ab := &authRateLimiter{
|
||
|
maxAttempts: maxAtt,
|
||
|
failedAuths: failedAuths,
|
||
|
}
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
until := ab.check(key)
|
||
|
|
||
|
if tc.wantExp {
|
||
|
assert.LessOrEqual(t, until, time.Duration(0))
|
||
|
} else {
|
||
|
assert.Greater(t, until, time.Duration(0))
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
|
||
|
t.Run("non-existent", func(t *testing.T) {
|
||
|
ab := &authRateLimiter{
|
||
|
failedAuths: map[string]failedAuth{
|
||
|
key + "smthng": {},
|
||
|
},
|
||
|
}
|
||
|
|
||
|
until := ab.check(key)
|
||
|
|
||
|
assert.Zero(t, until)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestAuthRateLimiter_Inc(t *testing.T) {
|
||
|
ip := net.IP{127, 0, 0, 1}
|
||
|
key := string(ip)
|
||
|
now := time.Now()
|
||
|
const maxAtt = 2
|
||
|
const blockDur = 15 * time.Minute
|
||
|
|
||
|
testCases := []struct {
|
||
|
until time.Time
|
||
|
wantUntil time.Time
|
||
|
name string
|
||
|
num uint
|
||
|
wantNum uint
|
||
|
}{{
|
||
|
name: "only_inc",
|
||
|
until: now,
|
||
|
wantUntil: now,
|
||
|
num: maxAtt - 1,
|
||
|
wantNum: maxAtt,
|
||
|
}, {
|
||
|
name: "inc_and_block",
|
||
|
until: now,
|
||
|
wantUntil: now.Add(failedAuthTTL),
|
||
|
num: maxAtt,
|
||
|
wantNum: maxAtt + 1,
|
||
|
}}
|
||
|
|
||
|
for _, tc := range testCases {
|
||
|
failedAuths := map[string]failedAuth{
|
||
|
key: {
|
||
|
num: tc.num,
|
||
|
until: tc.until,
|
||
|
},
|
||
|
}
|
||
|
ab := &authRateLimiter{
|
||
|
blockDur: blockDur,
|
||
|
maxAttempts: maxAtt,
|
||
|
failedAuths: failedAuths,
|
||
|
}
|
||
|
t.Run(tc.name, func(t *testing.T) {
|
||
|
ab.inc(key)
|
||
|
|
||
|
a, ok := ab.failedAuths[key]
|
||
|
require.True(t, ok)
|
||
|
|
||
|
assert.Equal(t, tc.wantNum, a.num)
|
||
|
assert.LessOrEqual(t, tc.wantUntil.Unix(), a.until.Unix())
|
||
|
})
|
||
|
}
|
||
|
|
||
|
t.Run("non-existent", func(t *testing.T) {
|
||
|
ab := &authRateLimiter{
|
||
|
blockDur: blockDur,
|
||
|
maxAttempts: maxAtt,
|
||
|
failedAuths: map[string]failedAuth{},
|
||
|
}
|
||
|
|
||
|
ab.inc(key)
|
||
|
|
||
|
a, ok := ab.failedAuths[key]
|
||
|
require.True(t, ok)
|
||
|
assert.EqualValues(t, 1, a.num)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func TestAuthRateLimiter_Remove(t *testing.T) {
|
||
|
const key = "some-key"
|
||
|
|
||
|
failedAuths := map[string]failedAuth{
|
||
|
key: {},
|
||
|
}
|
||
|
ab := &authRateLimiter{
|
||
|
failedAuths: failedAuths,
|
||
|
}
|
||
|
|
||
|
ab.remove(key)
|
||
|
|
||
|
assert.Empty(t, ab.failedAuths)
|
||
|
}
|