package querylog import ( "fmt" "math/rand" "net" "sort" "testing" "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/dnsfilter" "github.com/AdguardTeam/dnsproxy/proxyutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMain(m *testing.M) { aghtest.DiscardLogOutput(m) } // TestQueryLog tests adding and loading (with filtering) entries from disk and // memory. func TestQueryLog(t *testing.T) { l := newQueryLog(Config{ Enabled: true, FileEnabled: true, Interval: 1, MemSize: 100, BaseDir: aghtest.PrepareTestDir(t), }) // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // Write to disk (first file). require.Nil(t, l.flushLogBuffer(true)) // Start writing to the second file. require.Nil(t, l.rotate()) // Add disk entries. addEntry(l, "example.org", net.IPv4(1, 1, 1, 2), net.IPv4(2, 2, 2, 2)) // Write to disk. require.Nil(t, l.flushLogBuffer(true)) // Add memory entries. addEntry(l, "test.example.org", net.IPv4(1, 1, 1, 3), net.IPv4(2, 2, 2, 3)) addEntry(l, "example.com", net.IPv4(1, 1, 1, 4), net.IPv4(2, 2, 2, 4)) type tcAssertion struct { num int host string answer, client net.IP } testCases := []struct { name string sCr []searchCriteria want []tcAssertion }{{ name: "all", sCr: []searchCriteria{}, want: []tcAssertion{ {num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)}, {num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)}, {num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)}, {num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)}, }, }, { name: "by_domain_strict", sCr: []searchCriteria{{ criteriaType: ctDomainOrClient, strict: true, value: "TEST.example.org", }}, want: []tcAssertion{{ num: 0, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3), }}, }, { name: "by_domain_non-strict", sCr: []searchCriteria{{ criteriaType: ctDomainOrClient, strict: false, value: "example.ORG", }}, want: []tcAssertion{ {num: 0, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)}, {num: 1, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)}, {num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)}, }, }, { name: "by_client_ip_strict", sCr: []searchCriteria{{ criteriaType: ctDomainOrClient, strict: true, value: "2.2.2.2", }}, want: []tcAssertion{{ num: 0, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2), }}, }, { name: "by_client_ip_non-strict", sCr: []searchCriteria{{ criteriaType: ctDomainOrClient, strict: false, value: "2.2.2", }}, want: []tcAssertion{ {num: 0, host: "example.com", answer: net.IPv4(1, 1, 1, 4), client: net.IPv4(2, 2, 2, 4)}, {num: 1, host: "test.example.org", answer: net.IPv4(1, 1, 1, 3), client: net.IPv4(2, 2, 2, 3)}, {num: 2, host: "example.org", answer: net.IPv4(1, 1, 1, 2), client: net.IPv4(2, 2, 2, 2)}, {num: 3, host: "example.org", answer: net.IPv4(1, 1, 1, 1), client: net.IPv4(2, 2, 2, 1)}, }, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { params := newSearchParams() params.searchCriteria = tc.sCr entries, _ := l.search(params) require.Len(t, entries, len(tc.want)) for _, want := range tc.want { assertLogEntry(t, entries[want.num], want.host, want.answer, want.client) } }) } } func TestQueryLogOffsetLimit(t *testing.T) { l := newQueryLog(Config{ Enabled: true, Interval: 1, MemSize: 100, BaseDir: aghtest.PrepareTestDir(t), }) const ( entNum = 10 firstPageDomain = "first.example.org" secondPageDomain = "second.example.org" ) // Add entries to the log. for i := 0; i < entNum; i++ { addEntry(l, secondPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to the first file. require.Nil(t, l.flushLogBuffer(true)) // Add more to the in-memory part of log. for i := 0; i < entNum; i++ { addEntry(l, firstPageDomain, net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } params := newSearchParams() testCases := []struct { name string offset int limit int wantLen int want string }{{ name: "page_1", offset: 0, limit: 10, wantLen: 10, want: firstPageDomain, }, { name: "page_2", offset: 10, limit: 10, wantLen: 10, want: secondPageDomain, }, { name: "page_2.5", offset: 15, limit: 10, wantLen: 5, want: secondPageDomain, }, { name: "page_3", offset: 20, limit: 10, wantLen: 0, }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { params.offset = tc.offset params.limit = tc.limit entries, _ := l.search(params) require.Len(t, entries, tc.wantLen) if tc.wantLen > 0 { assert.Equal(t, entries[0].QHost, tc.want) assert.Equal(t, entries[tc.wantLen-1].QHost, tc.want) } }) } } func TestQueryLogMaxFileScanEntries(t *testing.T) { l := newQueryLog(Config{ Enabled: true, FileEnabled: true, Interval: 1, MemSize: 100, BaseDir: aghtest.PrepareTestDir(t), }) const entNum = 10 // Add entries to the log. for i := 0; i < entNum; i++ { addEntry(l, "example.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) } // Write them to disk. require.Nil(t, l.flushLogBuffer(true)) params := newSearchParams() for _, maxFileScanEntries := range []int{5, 0} { t.Run(fmt.Sprintf("limit_%d", maxFileScanEntries), func(t *testing.T) { params.maxFileScanEntries = maxFileScanEntries entries, _ := l.search(params) assert.Len(t, entries, entNum-maxFileScanEntries) }) } } func TestQueryLogFileDisabled(t *testing.T) { l := newQueryLog(Config{ Enabled: true, FileEnabled: false, Interval: 1, MemSize: 2, BaseDir: aghtest.PrepareTestDir(t), }) addEntry(l, "example1.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) addEntry(l, "example2.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) // The oldest entry is going to be removed from memory buffer. addEntry(l, "example3.org", net.IPv4(1, 1, 1, 1), net.IPv4(2, 2, 2, 1)) params := newSearchParams() ll, _ := l.search(params) require.Len(t, ll, 2) assert.Equal(t, "example3.org", ll[0].QHost) assert.Equal(t, "example2.org", ll[1].QHost) } func addEntry(l *queryLog, host string, answerStr, client net.IP) { q := dns.Msg{ Question: []dns.Question{{ Name: host + ".", Qtype: dns.TypeA, Qclass: dns.ClassINET, }}, } a := dns.Msg{ Question: q.Question, Answer: []dns.RR{&dns.A{ Hdr: dns.RR_Header{ Name: q.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, }, A: answerStr, }}, } res := dnsfilter.Result{ IsFiltered: true, Reason: dnsfilter.Rewritten, ServiceName: "SomeService", Rules: []*dnsfilter.ResultRule{{ FilterListID: 1, Text: "SomeRule", }}, } params := AddParams{ Question: &q, Answer: &a, OrigAnswer: &a, Result: &res, ClientIP: client, Upstream: "upstream", } l.Add(params) } func assertLogEntry(t *testing.T, entry *logEntry, host string, answer, client net.IP) { t.Helper() require.NotNil(t, entry) assert.Equal(t, host, entry.QHost) assert.Equal(t, client, entry.IP) assert.Equal(t, "A", entry.QType) assert.Equal(t, "IN", entry.QClass) msg := &dns.Msg{} require.Nil(t, msg.Unpack(entry.Answer)) require.Len(t, msg.Answer, 1) ip := proxyutil.GetIPFromDNSRecord(msg.Answer[0]).To16() assert.Equal(t, answer, ip) } func testEntries() (entries []*logEntry) { rsrc := rand.NewSource(time.Now().UnixNano()) rgen := rand.New(rsrc) entries = make([]*logEntry, 1000) for i := range entries { min := rgen.Intn(60) sec := rgen.Intn(60) entries[i] = &logEntry{ Time: time.Date(2020, 1, 1, 0, min, sec, 0, time.UTC), } } return entries } // logEntriesByTimeDesc is a wrapper over []*logEntry for sorting. // // NOTE(a.garipov): Weirdly enough, on my machine this gets consistently // outperformed by sort.Slice, see the benchmark below. I'm leaving this // implementation here, in tests, in case we want to make sure it outperforms on // most machines, but for now this is unused in the actual code. type logEntriesByTimeDesc []*logEntry // Len implements the sort.Interface interface for logEntriesByTimeDesc. func (les logEntriesByTimeDesc) Len() (n int) { return len(les) } // Less implements the sort.Interface interface for logEntriesByTimeDesc. func (les logEntriesByTimeDesc) Less(i, j int) (less bool) { return les[i].Time.After(les[j].Time) } // Swap implements the sort.Interface interface for logEntriesByTimeDesc. func (les logEntriesByTimeDesc) Swap(i, j int) { les[i], les[j] = les[j], les[i] } func BenchmarkLogEntry_sort(b *testing.B) { b.Run("methods", func(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() entries := testEntries() b.StartTimer() sort.Stable(logEntriesByTimeDesc(entries)) } }) b.Run("reflect", func(b *testing.B) { for i := 0; i < b.N; i++ { b.StopTimer() entries := testEntries() b.StartTimer() sort.SliceStable(entries, func(i, j int) (less bool) { return entries[i].Time.After(entries[j].Time) }) } }) } func TestLogEntriesByTime_sort(t *testing.T) { entries := testEntries() sort.Sort(logEntriesByTimeDesc(entries)) for i := range entries[1:] { assert.False(t, entries[i+1].Time.After(entries[i].Time), "%s %s", entries[i+1].Time, entries[i].Time) } }