package dnsforward import ( "net" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestIsBlockedIP(t *testing.T) { const ( ip int = iota cidr ) rules := []string{ ip: "1.1.1.1", cidr: "2.2.0.0/16", } testCases := []struct { name string allowed bool ip net.IP wantDis bool wantRule string }{{ name: "allow_ip", allowed: true, ip: net.IPv4(1, 1, 1, 1), wantDis: false, wantRule: "", }, { name: "disallow_ip", allowed: true, ip: net.IPv4(1, 1, 1, 2), wantDis: true, wantRule: "", }, { name: "allow_cidr", allowed: true, ip: net.IPv4(2, 2, 1, 1), wantDis: false, wantRule: "", }, { name: "disallow_cidr", allowed: true, ip: net.IPv4(2, 3, 1, 1), wantDis: true, wantRule: "", }, { name: "allow_ip", allowed: false, ip: net.IPv4(1, 1, 1, 1), wantDis: true, wantRule: rules[ip], }, { name: "disallow_ip", allowed: false, ip: net.IPv4(1, 1, 1, 2), wantDis: false, wantRule: "", }, { name: "allow_cidr", allowed: false, ip: net.IPv4(2, 2, 1, 1), wantDis: true, wantRule: rules[cidr], }, { name: "disallow_cidr", allowed: false, ip: net.IPv4(2, 3, 1, 1), wantDis: false, wantRule: "", }} for _, tc := range testCases { prefix := "allowed_" if !tc.allowed { prefix = "disallowed_" } t.Run(prefix+tc.name, func(t *testing.T) { allowedRules := rules var disallowedRules []string if !tc.allowed { allowedRules, disallowedRules = disallowedRules, allowedRules } aCtx, err := newAccessCtx(allowedRules, disallowedRules, nil) require.NoError(t, err) disallowed, rule := aCtx.IsBlockedIP(tc.ip) assert.Equal(t, tc.wantDis, disallowed) assert.Equal(t, tc.wantRule, rule) }) } } func TestIsBlockedDomain(t *testing.T) { aCtx, err := newAccessCtx(nil, nil, []string{ "host1", "*.host.com", "||host3.com^", }) require.NoError(t, err) testCases := []struct { name string domain string want bool }{{ name: "plain_match", domain: "host1", want: true, }, { name: "plain_mismatch", domain: "host2", want: false, }, { name: "wildcard_type-1_match_short", domain: "asdf.host.com", want: true, }, { name: "wildcard_type-1_match_long", domain: "qwer.asdf.host.com", want: true, }, { name: "wildcard_type-1_mismatch_no-lead", domain: "host.com", want: false, }, { name: "wildcard_type-1_mismatch_bad-asterisk", domain: "asdf.zhost.com", want: false, }, { name: "wildcard_type-2_match_simple", domain: "host3.com", want: true, }, { name: "wildcard_type-2_match_complex", domain: "asdf.host3.com", want: true, }, { name: "wildcard_type-2_mismatch", domain: ".host3.com", want: false, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { assert.Equal(t, tc.want, aCtx.IsBlockedDomain(tc.domain)) }) } }