mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-15 09:58:42 -07:00
Pull request: replace agherr with golibs' errors
Merge in DNS/adguard-home from golibs-errors to master Squashed commit of the following: commit 5aba278a31c5a213bd9e08273ce7277c57713b22 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon May 24 17:05:18 2021 +0300 all: imp code commit f447eb875b81779fa9e391d98c31c1eeba7ef323 Author: Ainar Garipov <A.Garipov@AdGuard.COM> Date: Mon May 24 15:33:45 2021 +0300 replace agherr with golibs' errors
This commit is contained in:
parent
14250821ab
commit
03a828ef51
@ -66,8 +66,7 @@ on GitHub and most other Markdown renderers. -->
|
||||
### <a id="code" href="#code">Code</a>
|
||||
|
||||
* Always `recover` from panics in new goroutines. Preferably in the very
|
||||
first statement. If all you want there is a log message, use
|
||||
`agherr.LogPanic`.
|
||||
first statement. If all you want there is a log message, use `log.OnPanic`.
|
||||
|
||||
* Avoid `fallthrough`. It makes it harder to rearrange `case`s, to reason
|
||||
about the code, and also to switch the code to a handler approach, if that
|
||||
|
2
go.mod
2
go.mod
@ -4,7 +4,7 @@ go 1.16
|
||||
|
||||
require (
|
||||
github.com/AdguardTeam/dnsproxy v0.37.4
|
||||
github.com/AdguardTeam/golibs v0.5.0
|
||||
github.com/AdguardTeam/golibs v0.8.0
|
||||
github.com/AdguardTeam/urlfilter v0.14.5
|
||||
github.com/NYTimes/gziphandler v1.1.1
|
||||
github.com/ameshkov/dnscrypt/v2 v2.1.3
|
||||
|
4
go.sum
4
go.sum
@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.37.4/go.mod h1:xkJWEuTr550gPDmB9azsciKZzSXjf9
|
||||
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/golibs v0.5.0 h1:qwhEKjDrT0UcwDnHrNU2Yg/DLR9b/GsUncnXYW6VzAU=
|
||||
github.com/AdguardTeam/golibs v0.5.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/golibs v0.8.0 h1:rHo+yIgT2fivFG0yW2Cwk/DPc2+t/Aw6QvzPpiIFre0=
|
||||
github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
|
||||
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
|
||||
github.com/AdguardTeam/urlfilter v0.14.5 h1:WyF0hg0MwKevsqNPkoaZFH8f5WRi/yuy/7qePtYt5Ts=
|
||||
github.com/AdguardTeam/urlfilter v0.14.5/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U=
|
||||
|
@ -1,128 +0,0 @@
|
||||
// Package agherr contains AdGuard Home's error handling helpers.
|
||||
package agherr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Error is the constant error type.
|
||||
type Error string
|
||||
|
||||
// Error implements the error interface for Error.
|
||||
func (err Error) Error() (msg string) {
|
||||
return string(err)
|
||||
}
|
||||
|
||||
// manyError is an error containing several wrapped errors. It is created to be
|
||||
// a simpler version of the API provided by github.com/joomcode/errorx.
|
||||
type manyError struct {
|
||||
message string
|
||||
underlying []error
|
||||
}
|
||||
|
||||
// Many wraps several errors and returns a single error.
|
||||
//
|
||||
// TODO(a.garipov): Add formatting to message.
|
||||
func Many(message string, underlying ...error) (err error) {
|
||||
err = &manyError{
|
||||
message: message,
|
||||
underlying: underlying,
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Error implements the error interface for *manyError.
|
||||
func (e *manyError) Error() (msg string) {
|
||||
switch len(e.underlying) {
|
||||
case 0:
|
||||
return e.message
|
||||
case 1:
|
||||
return fmt.Sprintf("%s: %s", e.message, e.underlying[0])
|
||||
default:
|
||||
b := &strings.Builder{}
|
||||
|
||||
// Ignore errors, since strings.(*Buffer).Write never returns
|
||||
// errors. We don't use aghstrings.WriteToBuilder here since
|
||||
// this package should be importable for any other.
|
||||
_, _ = fmt.Fprintf(b, "%s: %s (hidden: %s", e.message, e.underlying[0], e.underlying[1])
|
||||
for _, u := range e.underlying[2:] {
|
||||
// See comment above.
|
||||
_, _ = fmt.Fprintf(b, ", %s", u)
|
||||
}
|
||||
|
||||
// See comment above.
|
||||
_, _ = b.WriteString(")")
|
||||
|
||||
return b.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Unwrap implements the hidden errors.wrapper interface for *manyError.
|
||||
func (e *manyError) Unwrap() (err error) {
|
||||
if len(e.underlying) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return e.underlying[0]
|
||||
}
|
||||
|
||||
// wrapper is a copy of the hidden errors.wrapper interface for tests, linting,
|
||||
// etc.
|
||||
type wrapper interface {
|
||||
Unwrap() error
|
||||
}
|
||||
|
||||
// Annotate annotates the error with the message, unless the error is nil. This
|
||||
// is a helper function to simplify code like this:
|
||||
//
|
||||
// func (f *foo) doStuff(s string) (err error) {
|
||||
// defer func() {
|
||||
// if err != nil {
|
||||
// err = fmt.Errorf("bad foo string %q: %w", s, err)
|
||||
// }
|
||||
// }()
|
||||
//
|
||||
// // …
|
||||
// }
|
||||
//
|
||||
// Instead, write:
|
||||
//
|
||||
// func (f *foo) doStuff(s string) (err error) {
|
||||
// defer agherr.Annotate("bad foo string %q: %w", &err, s)
|
||||
//
|
||||
// // …
|
||||
// }
|
||||
//
|
||||
// msg must contain the final ": %w" verb.
|
||||
//
|
||||
// TODO(a.garipov): Clearify the function usage.
|
||||
func Annotate(msg string, errPtr *error, args ...interface{}) {
|
||||
if errPtr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := *errPtr
|
||||
if err != nil {
|
||||
args = append(args, err)
|
||||
|
||||
*errPtr = fmt.Errorf(msg, args...)
|
||||
}
|
||||
}
|
||||
|
||||
// LogPanic is a convinient helper function to log a panic in a goroutine. It
|
||||
// should not be used where proper error handling is required.
|
||||
func LogPanic(prefix string) {
|
||||
if v := recover(); v != nil {
|
||||
if prefix != "" {
|
||||
log.Error("%s: recovered from panic: %v", prefix, v)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Error("recovered from panic: %v", v)
|
||||
}
|
||||
}
|
@ -1,160 +0,0 @@
|
||||
package agherr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestError_Error(t *testing.T) {
|
||||
testCases := []struct {
|
||||
err error
|
||||
name string
|
||||
want string
|
||||
}{{
|
||||
err: Many("a"),
|
||||
name: "simple",
|
||||
want: "a",
|
||||
}, {
|
||||
err: Many("a", errors.New("b")),
|
||||
name: "wrapping",
|
||||
want: "a: b",
|
||||
}, {
|
||||
err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")),
|
||||
name: "wrapping several",
|
||||
want: "a: b (hidden: c, d)",
|
||||
}, {
|
||||
err: Many("a", Many("b", errors.New("c"), errors.New("d"))),
|
||||
name: "wrapping wrapper",
|
||||
want: "a: b: c (hidden: d)",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
assert.Equal(t, tc.want, tc.err.Error(), tc.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestError_Unwrap(t *testing.T) {
|
||||
var _ wrapper = &manyError{}
|
||||
|
||||
const (
|
||||
errSimple = iota
|
||||
errWrapped
|
||||
errNil
|
||||
)
|
||||
|
||||
errs := []error{
|
||||
errSimple: errors.New("a"),
|
||||
errWrapped: fmt.Errorf("err: %w", errors.New("nested")),
|
||||
errNil: nil,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
want error
|
||||
wrapped error
|
||||
name string
|
||||
}{{
|
||||
want: errs[errSimple],
|
||||
wrapped: Many("a", errs[errSimple]),
|
||||
name: "simple",
|
||||
}, {
|
||||
want: errs[errWrapped],
|
||||
wrapped: Many("b", errs[errWrapped]),
|
||||
name: "nested",
|
||||
}, {
|
||||
want: errs[errNil],
|
||||
wrapped: Many("c", errs[errNil]),
|
||||
name: "nil passed",
|
||||
}, {
|
||||
want: nil,
|
||||
wrapped: Many("d"),
|
||||
name: "nil not passed",
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnnotate(t *testing.T) {
|
||||
const s = "1234"
|
||||
const wantMsg = `bad string "1234": test`
|
||||
|
||||
// Don't use const, because we can't take a pointer of a constant.
|
||||
var errTest error = Error("test")
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
var errPtr *error
|
||||
assert.NotPanics(t, func() {
|
||||
Annotate("bad string %q: %w", errPtr, s)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("non_nil", func(t *testing.T) {
|
||||
errPtr := &errTest
|
||||
assert.NotPanics(t, func() {
|
||||
Annotate("bad string %q: %w", errPtr, s)
|
||||
})
|
||||
|
||||
require.NotNil(t, errPtr)
|
||||
|
||||
err := *errPtr
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, wantMsg, err.Error())
|
||||
})
|
||||
|
||||
t.Run("defer", func(t *testing.T) {
|
||||
f := func() (err error) {
|
||||
defer Annotate("bad string %q: %w", &errTest, s)
|
||||
|
||||
return errTest
|
||||
}
|
||||
|
||||
err := f()
|
||||
require.Error(t, err)
|
||||
|
||||
assert.Equal(t, wantMsg, err.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogPanic(t *testing.T) {
|
||||
buf := &bytes.Buffer{}
|
||||
aghtest.ReplaceLogWriter(t, buf)
|
||||
|
||||
t.Run("prefix", func(t *testing.T) {
|
||||
const (
|
||||
panicMsg = "spooky!"
|
||||
prefix = "packagename"
|
||||
errWithNoPrefix = "[error] recovered from panic: spooky!"
|
||||
errWithPrefix = "[error] packagename: recovered from panic: spooky!"
|
||||
)
|
||||
|
||||
panicFunc := func(prefix string) {
|
||||
defer LogPanic(prefix)
|
||||
|
||||
panic(panicMsg)
|
||||
}
|
||||
|
||||
panicFunc("")
|
||||
assert.Contains(t, buf.String(), errWithNoPrefix)
|
||||
buf.Reset()
|
||||
|
||||
panicFunc(prefix)
|
||||
assert.Contains(t, buf.String(), errWithPrefix)
|
||||
buf.Reset()
|
||||
})
|
||||
|
||||
t.Run("don't_panic", func(t *testing.T) {
|
||||
require.NotPanics(t, func() {
|
||||
defer LogPanic("")
|
||||
})
|
||||
|
||||
assert.Empty(t, buf.String())
|
||||
})
|
||||
}
|
@ -1,59 +0,0 @@
|
||||
// Package aghio contains extensions for io package's types and methods
|
||||
package aghio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// LimitReachedError records the limit and the operation that caused it.
|
||||
type LimitReachedError struct {
|
||||
Limit int64
|
||||
}
|
||||
|
||||
// Error implements error interface for LimitReachedError.
|
||||
// TODO(a.garipov): Think about error string format.
|
||||
func (lre *LimitReachedError) Error() string {
|
||||
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
|
||||
}
|
||||
|
||||
// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and
|
||||
// dealing with agherr package.
|
||||
type limitedReadCloser struct {
|
||||
limit int64
|
||||
n int64
|
||||
rc io.ReadCloser
|
||||
}
|
||||
|
||||
// Read implements Reader interface.
|
||||
func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) {
|
||||
if lrc.n == 0 {
|
||||
return 0, &LimitReachedError{
|
||||
Limit: lrc.limit,
|
||||
}
|
||||
}
|
||||
if int64(len(p)) > lrc.n {
|
||||
p = p[0:lrc.n]
|
||||
}
|
||||
n, err = lrc.rc.Read(p)
|
||||
lrc.n -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Close implements Closer interface.
|
||||
func (lrc *limitedReadCloser) Close() error {
|
||||
return lrc.rc.Close()
|
||||
}
|
||||
|
||||
// LimitReadCloser wraps ReadCloser to make it's Reader stop with
|
||||
// ErrLimitReached after n bytes read.
|
||||
func LimitReadCloser(rc io.ReadCloser, n int64) (limited io.ReadCloser, err error) {
|
||||
if n < 0 {
|
||||
return nil, fmt.Errorf("aghio: invalid n in LimitReadCloser: %d", n)
|
||||
}
|
||||
return &limitedReadCloser{
|
||||
limit: n,
|
||||
n: n,
|
||||
rc: rc,
|
||||
}, nil
|
||||
}
|
59
internal/aghio/limitedreader.go
Normal file
59
internal/aghio/limitedreader.go
Normal file
@ -0,0 +1,59 @@
|
||||
// Package aghio contains extensions for io package's types and methods
|
||||
package aghio
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// LimitReachedError records the limit and the operation that caused it.
|
||||
type LimitReachedError struct {
|
||||
Limit int64
|
||||
}
|
||||
|
||||
// Error implements error interface for LimitReachedError.
|
||||
//
|
||||
// TODO(a.garipov): Think about error string format.
|
||||
func (lre *LimitReachedError) Error() string {
|
||||
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
|
||||
}
|
||||
|
||||
// limitedReader is a wrapper for io.Reader with limited reader and dealing with
|
||||
// errors package.
|
||||
type limitedReader struct {
|
||||
r io.Reader
|
||||
limit int64
|
||||
n int64
|
||||
}
|
||||
|
||||
// Read implements Reader interface.
|
||||
func (lr *limitedReader) Read(p []byte) (n int, err error) {
|
||||
if lr.n == 0 {
|
||||
return 0, &LimitReachedError{
|
||||
Limit: lr.limit,
|
||||
}
|
||||
}
|
||||
|
||||
if int64(len(p)) > lr.n {
|
||||
p = p[0:lr.n]
|
||||
}
|
||||
|
||||
n, err = lr.r.Read(p)
|
||||
lr.n -= int64(n)
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
// LimitReader wraps Reader to make it's Reader stop with ErrLimitReached after
|
||||
// n bytes read.
|
||||
func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) {
|
||||
if n < 0 {
|
||||
return nil, fmt.Errorf("aghio: invalid n in LimitReader: %d", n)
|
||||
}
|
||||
|
||||
return &limitedReader{
|
||||
r: r,
|
||||
limit: n,
|
||||
n: n,
|
||||
}, nil
|
||||
}
|
@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLimitReadCloser(t *testing.T) {
|
||||
func TestLimitReader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
want error
|
||||
name string
|
||||
@ -24,20 +24,20 @@ func TestLimitReadCloser(t *testing.T) {
|
||||
name: "zero",
|
||||
n: 0,
|
||||
}, {
|
||||
want: fmt.Errorf("aghio: invalid n in LimitReadCloser: -1"),
|
||||
want: fmt.Errorf("aghio: invalid n in LimitReader: -1"),
|
||||
name: "negative",
|
||||
n: -1,
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := LimitReadCloser(nil, tc.n)
|
||||
_, err := LimitReader(nil, tc.n)
|
||||
assert.Equal(t, tc.want, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedReadCloser_Read(t *testing.T) {
|
||||
func TestLimitedReader_Read(t *testing.T) {
|
||||
testCases := []struct {
|
||||
err error
|
||||
name string
|
||||
@ -77,7 +77,7 @@ func TestLimitedReadCloser_Read(t *testing.T) {
|
||||
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
|
||||
buf := make([]byte, tc.limit+1)
|
||||
|
||||
lreader, err := LimitReadCloser(readCloser, tc.limit)
|
||||
lreader, err := LimitReader(readCloser, tc.limit)
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err := lreader.Read(buf)
|
||||
@ -87,7 +87,7 @@ func TestLimitedReadCloser_Read(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLimitedReadCloser_LimitReachedError(t *testing.T) {
|
||||
func TestLimitedReader_LimitReachedError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
err error
|
||||
name string
|
@ -6,7 +6,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/net/idna"
|
||||
)
|
||||
|
||||
@ -26,11 +26,11 @@ func isValidHostRune(r rune) (ok bool) {
|
||||
// ValidateHardwareAddress returns an error if hwa is not a valid EUI-48,
|
||||
// EUI-64, or 20-octet InfiniBand link-layer address.
|
||||
func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
|
||||
defer agherr.Annotate("validating hardware address %q: %w", &err, hwa)
|
||||
defer func() { err = errors.Annotate(err, "validating hardware address %q: %w", hwa) }()
|
||||
|
||||
switch l := len(hwa); l {
|
||||
case 0:
|
||||
return agherr.Error("address is empty")
|
||||
return errors.Error("address is empty")
|
||||
case 6, 8, 20:
|
||||
return nil
|
||||
default:
|
||||
@ -51,13 +51,13 @@ const maxDomainNameLen = 253
|
||||
// ValidateDomainNameLabel returns an error if label is not a valid label of
|
||||
// a domain name.
|
||||
func ValidateDomainNameLabel(label string) (err error) {
|
||||
defer agherr.Annotate("validating label %q: %w", &err, label)
|
||||
defer func() { err = errors.Annotate(err, "validating label %q: %w", label) }()
|
||||
|
||||
l := len(label)
|
||||
if l > maxDomainLabelLen {
|
||||
return fmt.Errorf("label is too long, max: %d", maxDomainLabelLen)
|
||||
} else if l == 0 {
|
||||
return agherr.Error("label is empty")
|
||||
return errors.Error("label is empty")
|
||||
}
|
||||
|
||||
if r := label[0]; !IsValidHostOuterRune(rune(r)) {
|
||||
@ -87,7 +87,7 @@ func ValidateDomainNameLabel(label string) (err error) {
|
||||
// TODO(a.garipov): After making sure that this works correctly, port this into
|
||||
// module golibs.
|
||||
func ValidateDomainName(name string) (err error) {
|
||||
defer agherr.Annotate("validating domain name %q: %w", &err, name)
|
||||
defer func() { err = errors.Annotate(err, "validating domain name %q: %w", name) }()
|
||||
|
||||
name, err = idna.ToASCII(name)
|
||||
if err != nil {
|
||||
@ -96,7 +96,7 @@ func ValidateDomainName(name string) (err error) {
|
||||
|
||||
l := len(name)
|
||||
if l == 0 {
|
||||
return agherr.Error("domain name is empty")
|
||||
return errors.Error("domain name is empty")
|
||||
} else if l > maxDomainNameLen {
|
||||
return fmt.Errorf("too long, max: %d", maxDomainNameLen)
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
@ -12,6 +11,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/miekg/dns"
|
||||
@ -239,7 +239,13 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin
|
||||
log.Error("etchostscontainer: %s", err)
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() {
|
||||
derr := f.Close()
|
||||
if derr != nil {
|
||||
log.Error("etchostscontainer: closing file: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
r := bufio.NewReader(f)
|
||||
log.Debug("etchostscontainer: loading hosts from file %s", fn)
|
||||
|
||||
|
@ -3,7 +3,6 @@ package aghnet
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@ -14,14 +13,14 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
|
||||
// the IP being static is available.
|
||||
const ErrNoStaticIPInfo agherr.Error = "no information about static ip"
|
||||
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
|
||||
|
||||
// IfaceHasStaticIP checks if interface is configured to have static IP address.
|
||||
// If it can't give a definitive answer, it returns false and an error for which
|
||||
@ -106,7 +105,7 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
|
||||
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
|
||||
}
|
||||
if len(ifaces) == 0 {
|
||||
return nil, errors.New("couldn't find any legible interface")
|
||||
return nil, errors.Error("couldn't find any legible interface")
|
||||
}
|
||||
|
||||
var netInterfaces []*NetInterface
|
||||
|
@ -5,13 +5,13 @@
|
||||
package aghnet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// hardwarePortInfo - information obtained using MacOS networksetup
|
||||
@ -83,7 +83,7 @@ func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
|
||||
|
||||
match := re.FindStringSubmatch(out)
|
||||
if len(match) == 0 {
|
||||
return h, errors.New("could not find hardware port info")
|
||||
return h, errors.Error("could not find hardware port info")
|
||||
}
|
||||
|
||||
h.name = hardwarePort
|
||||
@ -105,7 +105,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
}
|
||||
|
||||
if portInfo.static {
|
||||
return errors.New("IP address is already static")
|
||||
return errors.Error("IP address is already static")
|
||||
}
|
||||
|
||||
dnsAddrs, err := getEtcResolvConfServers()
|
||||
@ -151,7 +151,7 @@ func getEtcResolvConfServers() ([]string, error) {
|
||||
|
||||
matches := re.FindAllStringSubmatch(string(body), -1)
|
||||
if len(matches) == 0 {
|
||||
return nil, errors.New("found no DNS servers in /etc/resolv.conf")
|
||||
return nil, errors.Error("found no DNS servers in /etc/resolv.conf")
|
||||
}
|
||||
|
||||
addrs := make([]string, 0)
|
||||
|
@ -6,7 +6,6 @@ package aghnet
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/google/renameio/maybe"
|
||||
)
|
||||
|
||||
@ -49,16 +49,15 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
|
||||
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
var fileReadCloser io.ReadCloser
|
||||
fileReadCloser, err = aghio.LimitReadCloser(f, maxConfigFileSize)
|
||||
var fileReader io.Reader
|
||||
fileReader, err = aghio.LimitReader(f, maxConfigFileSize)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer fileReadCloser.Close()
|
||||
|
||||
has, err = check.checker(fileReadCloser, ifaceName)
|
||||
has, err = check.checker(fileReader, ifaceName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -134,7 +133,7 @@ func ifacesStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
|
||||
func ifaceSetStaticIP(ifaceName string) (err error) {
|
||||
ipNet := GetSubnet(ifaceName)
|
||||
if ipNet.IP == nil {
|
||||
return errors.New("can't get IP address")
|
||||
return errors.Error("can't get IP address")
|
||||
}
|
||||
|
||||
gatewayIP := GatewayIP(ifaceName)
|
||||
|
@ -3,7 +3,7 @@ package aghnet
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@ -27,19 +27,19 @@ type SystemResolvers interface {
|
||||
|
||||
const (
|
||||
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
|
||||
errBadAddrPassed agherr.Error = "the passed string is not a valid IP address"
|
||||
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
|
||||
|
||||
// errFakeDial is an error which dialFunc is expected to return.
|
||||
errFakeDial agherr.Error = "this error signals the successful dialFunc work"
|
||||
errFakeDial errors.Error = "this error signals the successful dialFunc work"
|
||||
|
||||
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
|
||||
// more than one percent sign.
|
||||
errUnexpectedHostFormat agherr.Error = "unexpected host format"
|
||||
errUnexpectedHostFormat errors.Error = "unexpected host format"
|
||||
)
|
||||
|
||||
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
|
||||
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
|
||||
defer agherr.LogPanic("systemResolvers")
|
||||
defer log.OnPanic("systemResolvers")
|
||||
|
||||
// TODO(e.burkov): Implement a functionality to stop ticker.
|
||||
for range tickCh {
|
||||
|
@ -6,15 +6,14 @@ package aghnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// defaultHostGen is the default method of generating host for Refresh.
|
||||
@ -34,7 +33,7 @@ type systemResolvers struct {
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) refresh() (err error) {
|
||||
defer agherr.Annotate("systemResolvers: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
||||
|
||||
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
|
||||
dnserr := &net.DNSError{}
|
||||
@ -63,7 +62,7 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
|
||||
|
||||
// validateDialedHost validated the host used by resolvers in dialFunc.
|
||||
func validateDialedHost(host string) (err error) {
|
||||
defer agherr.Annotate("parsing %q: %w", &err, host)
|
||||
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
|
||||
|
||||
var ipStr string
|
||||
parts := strings.Split(host, "%")
|
||||
|
@ -14,9 +14,9 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@ -65,14 +65,15 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
|
||||
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
var stdoutLimited io.ReadCloser
|
||||
stdoutLimited, err = aghio.LimitReadCloser(stdout, aghos.MaxCmdOutputSize)
|
||||
var stdoutLimited io.Reader
|
||||
stdoutLimited, err = aghio.LimitReader(stdout, aghos.MaxCmdOutputSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("limiting stdout reader: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer agherr.LogPanic("systemResolvers")
|
||||
defer log.OnPanic("systemResolvers")
|
||||
|
||||
defer func() {
|
||||
derr := stdin.Close()
|
||||
if derr != nil {
|
||||
@ -141,7 +142,7 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
|
||||
}
|
||||
|
||||
func (sr *systemResolvers) refresh() (err error) {
|
||||
defer agherr.Annotate("systemResolvers: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
|
||||
|
||||
got, err := sr.getAddrs()
|
||||
if err != nil {
|
||||
|
@ -30,7 +30,7 @@ func ReplaceLogWriter(t *testing.T, w io.Writer) {
|
||||
|
||||
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
|
||||
// revert changes.
|
||||
func ReplaceLogLevel(t *testing.T, l int) {
|
||||
func ReplaceLogLevel(t *testing.T, l log.Level) {
|
||||
switch l {
|
||||
case log.INFO, log.DEBUG, log.ERROR:
|
||||
// Go on.
|
||||
|
@ -168,9 +168,6 @@ type TestErrUpstream struct {
|
||||
|
||||
// Exchange always returns nil Msg and non-nil error.
|
||||
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
|
||||
// We don't use an agherr.Error to avoid the import cycle since aghtests
|
||||
// used to provide the utilities for testing which agherr (and any other
|
||||
// testable package) should be able to use.
|
||||
return nil, fmt.Errorf("errupstream: %w", u.Err)
|
||||
}
|
||||
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
|
||||
@ -78,7 +79,7 @@ func CheckIfOtherDHCPServersPresentV4(ifaceName string) (ok bool, err error) {
|
||||
return false, fmt.Errorf("couldn't listen on :68: %w", err)
|
||||
}
|
||||
if c != nil {
|
||||
defer c.Close()
|
||||
defer func() { err = errors.WithDeferred(err, c.Close()) }()
|
||||
}
|
||||
|
||||
// send to 255.255.255.255:67
|
||||
@ -202,7 +203,7 @@ func CheckIfOtherDHCPServersPresentV6(ifaceName string) (ok bool, err error) {
|
||||
return false, fmt.Errorf("dhcpv6: Couldn't listen on :546: %w", err)
|
||||
}
|
||||
if c != nil {
|
||||
defer c.Close()
|
||||
defer func() { err = errors.WithDeferred(err, c.Close()) }()
|
||||
}
|
||||
|
||||
_, err = c.WriteTo(req.ToBytes(), dstAddr)
|
||||
|
@ -4,11 +4,11 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/renameio/maybe"
|
||||
)
|
||||
|
@ -2,7 +2,6 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@ -11,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
"math/big"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// ipRange is an inclusive range of IP addresses. A nil range is a range that
|
||||
@ -28,7 +28,7 @@ const maxRangeLen = math.MaxUint32
|
||||
// newIPRange creates a new IP address range. start must be less than end. The
|
||||
// resulting range must not be greater than maxRangeLen.
|
||||
func newIPRange(start, end net.IP) (r *ipRange, err error) {
|
||||
defer agherr.Annotate("invalid ip range: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()
|
||||
|
||||
// Make sure that both are 16 bytes long to simplify handling in
|
||||
// methods.
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// hexDHCPOptionParserHandler parses a DHCP option as a hex-encoded string.
|
||||
@ -32,7 +32,7 @@ func hexDHCPOptionParserHandler(s string) (data []byte, err error) {
|
||||
func ipDHCPOptionParserHandler(s string) (data []byte, err error) {
|
||||
ip := net.ParseIP(s)
|
||||
if ip == nil {
|
||||
return nil, agherr.Error("invalid ip")
|
||||
return nil, errors.Error("invalid ip")
|
||||
}
|
||||
|
||||
// Most DHCP options require IPv4, so do not put the 16-byte
|
||||
@ -100,12 +100,12 @@ func newDHCPOptionParser() (p *dhcpOptionParser) {
|
||||
|
||||
// parse parses an option. See the handlers' documentation for more info.
|
||||
func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) {
|
||||
defer agherr.Annotate("invalid option string %q: %w", &err, s)
|
||||
defer func() { err = errors.Annotate(err, "invalid option string %q: %w", s) }()
|
||||
|
||||
s = strings.TrimSpace(s)
|
||||
parts := strings.SplitN(s, " ", 3)
|
||||
if len(parts) < 3 {
|
||||
return 0, nil, agherr.Error("need at least three fields")
|
||||
return 0, nil, errors.Error("need at least three fields")
|
||||
}
|
||||
|
||||
codeStr := parts[0]
|
||||
|
@ -1,13 +1,13 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"golang.org/x/net/ipv4"
|
||||
)
|
||||
|
||||
// Create a socket for receiving broadcast packets
|
||||
func newBroadcastPacketConn(bindAddr net.IP, port int, ifname string) (*ipv4.PacketConn, error) {
|
||||
return nil, errors.New("newBroadcastPacketConn(): not supported on Windows")
|
||||
return nil, errors.Error("newBroadcastPacketConn(): not supported on Windows")
|
||||
}
|
||||
|
@ -6,16 +6,15 @@ package dhcpd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/go-ping/ping"
|
||||
"github.com/insomniacslk/dhcp/dhcpv4"
|
||||
@ -55,7 +54,7 @@ func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) {
|
||||
// normalizeHostname normalizes a hostname sent by the client. If err is not
|
||||
// nil, norm is an empty string.
|
||||
func normalizeHostname(hostname string) (norm string, err error) {
|
||||
defer agherr.Annotate("normalizing %q: %w", &err, hostname)
|
||||
defer func() { err = errors.Annotate(err, "normalizing %q: %w", hostname) }()
|
||||
|
||||
if hostname == "" {
|
||||
return "", nil
|
||||
@ -249,7 +248,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
|
||||
if bytes.Equal(l.HWAddr, lease.HWAddr) {
|
||||
if l.IsStatic() {
|
||||
return agherr.Error("static lease already exists")
|
||||
return errors.Error("static lease already exists")
|
||||
}
|
||||
|
||||
s.rmLeaseByIndex(i)
|
||||
@ -262,7 +261,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
|
||||
|
||||
if l.IP.Equal(lease.IP) {
|
||||
if l.IsStatic() {
|
||||
return agherr.Error("static lease already exists")
|
||||
return errors.Error("static lease already exists")
|
||||
}
|
||||
|
||||
s.rmLeaseByIndex(i)
|
||||
@ -322,12 +321,12 @@ func (s *v4Server) rmLease(lease Lease) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
return agherr.Error("lease not found")
|
||||
return errors.Error("lease not found")
|
||||
}
|
||||
|
||||
// AddStaticLease adds a static lease. It is safe for concurrent use.
|
||||
func (s *v4Server) AddStaticLease(l Lease) (err error) {
|
||||
defer agherr.Annotate("dhcpv4: adding static lease: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }()
|
||||
|
||||
if ip4 := l.IP.To4(); ip4 == nil {
|
||||
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
|
||||
@ -397,7 +396,7 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) {
|
||||
|
||||
// RemoveStaticLease removes a static lease. It is safe for concurrent use.
|
||||
func (s *v4Server) RemoveStaticLease(l Lease) (err error) {
|
||||
defer agherr.Annotate("dhcpv4: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
|
||||
if len(l.IP) != 4 {
|
||||
return fmt.Errorf("invalid IP")
|
||||
@ -937,7 +936,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
|
||||
|
||||
// Start starts the IPv4 DHCP server.
|
||||
func (s *v4Server) Start() (err error) {
|
||||
defer agherr.Annotate("dhcpv4: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
|
||||
|
||||
if !s.conf.Enabled {
|
||||
return nil
|
||||
|
@ -1,11 +1,10 @@
|
||||
package dhcpd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -25,7 +24,7 @@ func (iface *fakeIface) Addrs() (addrs []net.Addr, err error) {
|
||||
}
|
||||
|
||||
func TestIfaceIPAddrs(t *testing.T) {
|
||||
const errTest agherr.Error = "test error"
|
||||
const errTest errors.Error = "test error"
|
||||
|
||||
ip4 := net.IP{1, 2, 3, 4}
|
||||
addr4 := &net.IPNet{IP: ip4}
|
||||
@ -108,7 +107,7 @@ func (iface *waitingFakeIface) Addrs() (addrs []net.Addr, err error) {
|
||||
}
|
||||
|
||||
func TestIfaceDNSIPAddrs(t *testing.T) {
|
||||
const errTest agherr.Error = "test error"
|
||||
const errTest errors.Error = "test error"
|
||||
|
||||
ip4 := net.IP{1, 2, 3, 4}
|
||||
addr4 := &net.IPNet{IP: ip4}
|
||||
|
@ -11,8 +11,8 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6"
|
||||
"github.com/insomniacslk/dhcp/dhcpv6/server6"
|
||||
@ -165,7 +165,7 @@ func (s *v6Server) rmDynamicLease(lease Lease) error {
|
||||
|
||||
// AddStaticLease adds a static lease. It is safe for concurrent use.
|
||||
func (s *v6Server) AddStaticLease(l Lease) (err error) {
|
||||
defer agherr.Annotate("dhcpv6: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
if len(l.IP) != 16 {
|
||||
return fmt.Errorf("invalid IP")
|
||||
@ -194,7 +194,7 @@ func (s *v6Server) AddStaticLease(l Lease) (err error) {
|
||||
|
||||
// RemoveStaticLease removes a static lease. It is safe for concurrent use.
|
||||
func (s *v6Server) RemoveStaticLease(l Lease) (err error) {
|
||||
defer agherr.Annotate("dhcpv6: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
if len(l.IP) != 16 {
|
||||
return fmt.Errorf("invalid IP")
|
||||
@ -585,7 +585,7 @@ func (s *v6Server) initRA(iface *net.Interface) error {
|
||||
|
||||
// Start starts the IPv6 DHCP server.
|
||||
func (s *v6Server) Start() (err error) {
|
||||
defer agherr.Annotate("dhcpv6: %w", &err)
|
||||
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
|
||||
|
||||
if !s.conf.Enabled {
|
||||
return nil
|
||||
|
@ -2,13 +2,13 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
)
|
||||
|
||||
|
@ -3,7 +3,6 @@ package dnsforward
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
)
|
||||
@ -220,7 +220,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
|
||||
|
||||
// Validate proxy config
|
||||
if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 {
|
||||
return proxyConfig, errors.New("no default upstream servers configured")
|
||||
return proxyConfig, errors.Error("no default upstream servers configured")
|
||||
}
|
||||
|
||||
return proxyConfig, nil
|
||||
|
@ -2,7 +2,6 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -12,7 +11,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
@ -21,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -226,11 +225,11 @@ type RDNSExchanger interface {
|
||||
const (
|
||||
// rDNSEmptyAnswerErr is returned by Exchange method when the answer
|
||||
// section of respond is empty.
|
||||
rDNSEmptyAnswerErr agherr.Error = "the answer section is empty"
|
||||
rDNSEmptyAnswerErr errors.Error = "the answer section is empty"
|
||||
|
||||
// rDNSNotPTRErr is returned by Exchange method when the response is not
|
||||
// of PTR type.
|
||||
rDNSNotPTRErr agherr.Error = "the response is not a ptr"
|
||||
rDNSNotPTRErr errors.Error = "the response is not a ptr"
|
||||
)
|
||||
|
||||
// Exchange implements the RDNSExchanger interface for *Server.
|
||||
|
@ -17,13 +17,13 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -1198,7 +1198,7 @@ func TestServer_Exchange(t *testing.T) {
|
||||
"2.1.168.192.in-addr.arpa.": {},
|
||||
},
|
||||
}
|
||||
upstreamErr := agherr.Error("upstream error")
|
||||
upstreamErr := errors.Error("upstream error")
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: upstreamErr,
|
||||
}
|
||||
|
@ -4,9 +4,9 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
@ -83,7 +83,7 @@ func (s *Server) filterDNSRewrite(req *dns.Msg, res filtering.Result, d *proxy.D
|
||||
resp := s.makeResponse(req)
|
||||
dnsrr := res.DNSRewriteResult
|
||||
if dnsrr == nil {
|
||||
return agherr.Error("no dns rewrite rule content")
|
||||
return errors.Error("no dns rewrite rule content")
|
||||
}
|
||||
|
||||
resp.Rcode = dnsrr.RCode
|
||||
@ -94,7 +94,7 @@ func (s *Server) filterDNSRewrite(req *dns.Msg, res filtering.Result, d *proxy.D
|
||||
}
|
||||
|
||||
if dnsrr.Response == nil {
|
||||
return agherr.Error("no dns rewrite rule responses")
|
||||
return errors.Error("no dns rewrite rule responses")
|
||||
}
|
||||
|
||||
rr := req.Question[0].Qtype
|
||||
|
@ -8,11 +8,11 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -399,7 +399,7 @@ func validateUpstream(u string) (bool, error) {
|
||||
// separateUpstream returns the upstream without the specified domains.
|
||||
// useDefault is true when a default upstream must be used.
|
||||
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
|
||||
defer agherr.Annotate("bad upstream for domain spec %q: %w", &err, upstreamStr)
|
||||
defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }()
|
||||
|
||||
if !strings.HasPrefix(upstreamStr, "[/") {
|
||||
return upstreamStr, true, nil
|
||||
@ -407,7 +407,7 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err
|
||||
|
||||
parts := strings.Split(upstreamStr[2:], "/]")
|
||||
if len(parts) != 2 {
|
||||
return "", false, agherr.Error("duplicated separator")
|
||||
return "", false, errors.Error("duplicated separator")
|
||||
}
|
||||
|
||||
domains := parts[0]
|
||||
|
@ -10,7 +10,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/digineo/go-ipset/v2"
|
||||
"github.com/mdlayher/netlink"
|
||||
@ -83,7 +83,7 @@ func (c *ipsetCtx) ipsetProps(name string) (set ipsetProps, err error) {
|
||||
}
|
||||
|
||||
if res == nil || res.Family == nil {
|
||||
return set, agherr.Error("empty response or no family data")
|
||||
return set, errors.Error("empty response or no family data")
|
||||
}
|
||||
|
||||
family := netfilter.ProtoFamily(res.Family.Value)
|
||||
@ -193,23 +193,23 @@ func (c *ipsetCtx) init(ipsetConfig []string) (err error) {
|
||||
|
||||
// Close closes the Linux Netfilter connections.
|
||||
func (c *ipsetCtx) Close() (err error) {
|
||||
var errors []error
|
||||
var errs []error
|
||||
if c.ipv4Conn != nil {
|
||||
err = c.ipv4Conn.Close()
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.ipv6Conn != nil {
|
||||
err = c.ipv6Conn.Close()
|
||||
if err != nil {
|
||||
errors = append(errors, err)
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(errors) != 0 {
|
||||
return agherr.Many("closing ipsets", errors...)
|
||||
if len(errs) != 0 {
|
||||
return errors.List("closing ipsets", errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -73,22 +73,27 @@ func glGetTokenDate(file string) uint32 {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
log.Error("os.Open: %s", err)
|
||||
|
||||
return 0
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() {
|
||||
derr := f.Close()
|
||||
if derr != nil {
|
||||
log.Error("glinet: closing file: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize)
|
||||
fileReader, err := aghio.LimitReader(f, MaxFileSize)
|
||||
if err != nil {
|
||||
log.Error("creating limited reader: %s", err)
|
||||
|
||||
return 0
|
||||
}
|
||||
defer fileReadCloser.Close()
|
||||
|
||||
var dateToken uint32
|
||||
|
||||
// This use of ReadAll is now safe, because we limited reader.
|
||||
bs, err := io.ReadAll(fileReadCloser)
|
||||
bs, err := io.ReadAll(fileReader)
|
||||
if err != nil {
|
||||
log.Error("reading token: %s", err)
|
||||
|
||||
|
@ -11,7 +11,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
@ -20,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@ -438,11 +438,11 @@ func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bo
|
||||
func (clients *clientsContainer) check(c *Client) (err error) {
|
||||
switch {
|
||||
case c == nil:
|
||||
return agherr.Error("client is nil")
|
||||
return errors.Error("client is nil")
|
||||
case c.Name == "":
|
||||
return agherr.Error("invalid name")
|
||||
return errors.Error("invalid name")
|
||||
case len(c.IDs) == 0:
|
||||
return agherr.Error("id required")
|
||||
return errors.Error("id required")
|
||||
default:
|
||||
// Go on.
|
||||
}
|
||||
@ -570,14 +570,14 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
|
||||
|
||||
prev, ok := clients.list[name]
|
||||
if !ok {
|
||||
return agherr.Error("client not found")
|
||||
return errors.Error("client not found")
|
||||
}
|
||||
|
||||
// First, check the name index.
|
||||
if prev.Name != c.Name {
|
||||
_, ok = clients.list[c.Name]
|
||||
if ok {
|
||||
return agherr.Error("client already exists")
|
||||
return errors.Error("client already exists")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/renameio/maybe"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
|
@ -15,7 +15,6 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
@ -271,7 +270,7 @@ func copyInstallSettings(dst, src *configuration) {
|
||||
const shutdownTimeout = 5 * time.Second
|
||||
|
||||
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
|
||||
defer agherr.LogPanic("")
|
||||
defer log.OnPanic("")
|
||||
|
||||
if srv == nil {
|
||||
return
|
||||
|
@ -3,7 +3,6 @@ package home
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
|
@ -8,12 +8,12 @@ import (
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/ameshkov/dnscrypt/v2"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
@ -207,14 +207,14 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
|
||||
|
||||
func newDNSCrypt(hosts []net.IP, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
|
||||
if tlsConf.DNSCryptConfigFile == "" {
|
||||
return dnscc, agherr.Error("no dnscrypt_config_file")
|
||||
return dnscc, errors.Error("no dnscrypt_config_file")
|
||||
}
|
||||
|
||||
f, err := os.Open(tlsConf.DNSCryptConfigFile)
|
||||
if err != nil {
|
||||
return dnscc, fmt.Errorf("opening dnscrypt config: %w", err)
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
rc := &dnscrypt.ResolverConfig{}
|
||||
err = yaml.NewDecoder(f).Decode(rc)
|
||||
|
@ -2,7 +2,6 @@ package home
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
@ -17,6 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@ -566,17 +566,18 @@ func (f *Filtering) updateIntl(filter *filter) (updated bool, err error) {
|
||||
if err != nil {
|
||||
return updated, fmt.Errorf("open file: %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
defer f.Close()
|
||||
reader = f
|
||||
} else {
|
||||
var resp *http.Response
|
||||
resp, err = Context.client.Get(filter.URL)
|
||||
if err != nil {
|
||||
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
|
||||
|
||||
return updated, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
|
||||
@ -634,7 +635,7 @@ func (f *Filtering) load(filter *filter) (err error) {
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("opening filter file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
defer func() { err = errors.WithDeferred(err, file.Close()) }()
|
||||
|
||||
st, err := file.Stat()
|
||||
if err != nil {
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
@ -21,7 +20,6 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
|
||||
@ -31,6 +29,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
@ -736,7 +735,7 @@ func customDialContext(ctx context.Context, network, addr string) (conn net.Conn
|
||||
return conn, err
|
||||
}
|
||||
|
||||
return nil, agherr.Many(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
|
||||
return nil, errors.List(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
|
||||
}
|
||||
|
||||
func getHTTPProxy(_ *http.Request) (*url.URL, error) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
@ -58,14 +59,20 @@ func limitRequestBody(h http.Handler) (limited http.Handler) {
|
||||
szLim = largerReqBodySzLim
|
||||
}
|
||||
|
||||
r.Body, err = aghio.LimitReadCloser(r.Body, szLim)
|
||||
var reader io.Reader
|
||||
reader, err = aghio.LimitReader(r.Body, szLim)
|
||||
if err != nil {
|
||||
log.Error("limitRequestBody: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
// HTTP handlers aren't supposed to call r.Body.Close(), so just
|
||||
// replace the body in a clone.
|
||||
rr := r.Clone(r.Context())
|
||||
rr.Body = io.NopCloser(reader)
|
||||
|
||||
h.ServeHTTP(w, rr)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
@ -85,7 +84,7 @@ func (r *RDNS) Begin(ip net.IP) {
|
||||
// workerLoop handles incoming IP addresses from ipChan and adds it into
|
||||
// clients.
|
||||
func (r *RDNS) workerLoop() {
|
||||
defer agherr.LogPanic("rdns")
|
||||
defer log.OnPanic("rdns")
|
||||
|
||||
for ip := range r.ipCh {
|
||||
host, err := r.exchanger.Exchange(ip)
|
||||
|
@ -3,7 +3,6 @@ package home
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
@ -13,6 +12,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -141,7 +141,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
|
||||
},
|
||||
}
|
||||
errUpstream := &aghtest.TestErrUpstream{
|
||||
Err: errors.New("1234"),
|
||||
Err: errors.Error("1234"),
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
@ -1,7 +1,6 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
@ -11,6 +10,7 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -21,6 +20,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
@ -341,14 +341,14 @@ func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error
|
||||
parsed, err := x509.ParseCertificate(cert.Bytes)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
|
||||
return errors.New(data.WarningValidation)
|
||||
return errors.Error(data.WarningValidation)
|
||||
}
|
||||
parsedCerts = append(parsedCerts, parsed)
|
||||
}
|
||||
|
||||
if len(parsedCerts) == 0 {
|
||||
data.WarningValidation = "You have specified an empty certificate"
|
||||
return errors.New(data.WarningValidation)
|
||||
return errors.Error(data.WarningValidation)
|
||||
}
|
||||
|
||||
data.ValidCert = true
|
||||
@ -415,14 +415,14 @@ func validatePkey(data *tlsConfigStatus, pkey string) error {
|
||||
|
||||
if key == nil {
|
||||
data.WarningValidation = "No valid keys were found"
|
||||
return errors.New(data.WarningValidation)
|
||||
return errors.Error(data.WarningValidation)
|
||||
}
|
||||
|
||||
// parse the decoded key
|
||||
_, keytype, err := parsePrivateKey(key.Bytes)
|
||||
if err != nil {
|
||||
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err)
|
||||
return errors.New(data.WarningValidation)
|
||||
return errors.Error(data.WarningValidation)
|
||||
}
|
||||
|
||||
data.ValidKey = true
|
||||
@ -479,7 +479,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
|
||||
case *ecdsa.PrivateKey:
|
||||
return key, "ECDSA", nil
|
||||
default:
|
||||
return nil, "", errors.New("tls: found unknown private key type in PKCS#8 wrapping")
|
||||
return nil, "", errors.Error("tls: found unknown private key type in PKCS#8 wrapping")
|
||||
}
|
||||
}
|
||||
|
||||
@ -487,7 +487,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
|
||||
return key, "ECDSA", nil
|
||||
}
|
||||
|
||||
return nil, "", errors.New("tls: failed to parse private key")
|
||||
return nil, "", errors.Error("tls: failed to parse private key")
|
||||
}
|
||||
|
||||
// unmarshalTLS handles base64-encoded certificates transparently
|
||||
|
@ -1,7 +1,6 @@
|
||||
package home
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -12,6 +11,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/google/renameio/maybe"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/cache"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@ -139,22 +140,22 @@ func whoisParse(data string) (m strmap) {
|
||||
const MaxConnReadSize = 64 * 1024
|
||||
|
||||
// Send request to a server and receive the response
|
||||
func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) {
|
||||
func (w *Whois) query(ctx context.Context, target, serverAddr string) (data string, err error) {
|
||||
addr, _, _ := net.SplitHostPort(serverAddr)
|
||||
if addr == "whois.arin.net" {
|
||||
target = "n + " + target
|
||||
}
|
||||
|
||||
conn, err := w.dialContext(ctx, "tcp", serverAddr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
|
||||
|
||||
connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize)
|
||||
r, err := aghio.LimitReader(conn, MaxConnReadSize)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer connReadCloser.Close()
|
||||
|
||||
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
|
||||
_, err = conn.Write([]byte(target + "\r\n"))
|
||||
@ -163,12 +164,13 @@ func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, e
|
||||
}
|
||||
|
||||
// This use of ReadAll is now safe, because we limited the conn Reader.
|
||||
data, err := io.ReadAll(connReadCloser)
|
||||
var whoisData []byte
|
||||
whoisData, err = io.ReadAll(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(data), nil
|
||||
return string(whoisData), nil
|
||||
}
|
||||
|
||||
// Query WHOIS servers (handle redirects)
|
||||
|
@ -2,7 +2,6 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@ -11,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
@ -8,15 +8,15 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Timestamp not found errors.
|
||||
const (
|
||||
ErrTSNotFound agherr.Error = "ts not found"
|
||||
ErrTSTooLate agherr.Error = "ts too late"
|
||||
ErrTSTooEarly agherr.Error = "ts too early"
|
||||
ErrTSNotFound errors.Error = "ts not found"
|
||||
ErrTSTooLate errors.Error = "ts too late"
|
||||
ErrTSTooEarly errors.Error = "ts too early"
|
||||
)
|
||||
|
||||
// TODO: Find a way to grow buffer instead of relying on this value when reading strings
|
||||
|
@ -2,7 +2,6 @@ package querylog
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
@ -12,6 +11,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -1,12 +1,11 @@
|
||||
package querylog
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@ -154,7 +153,7 @@ func closeQFiles(qFiles []*QLogFile) error {
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return agherr.Many("Error while closing QLogReader", errs...)
|
||||
return errors.List("error while closing QLogReader", errs...)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -6,8 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -78,13 +78,13 @@ type AddParams struct {
|
||||
func (p *AddParams) validate() (err error) {
|
||||
switch {
|
||||
case p.Question == nil:
|
||||
return agherr.Error("question is nil")
|
||||
return errors.Error("question is nil")
|
||||
case len(p.Question.Question) != 1:
|
||||
return agherr.Error("more than one question")
|
||||
return errors.Error("more than one question")
|
||||
case len(p.Question.Question[0].Name) == 0:
|
||||
return agherr.Error("no host in question")
|
||||
return errors.Error("no host in question")
|
||||
case p.ClientIP == nil:
|
||||
return agherr.Error("no client ip")
|
||||
return errors.Error("no client ip")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
@ -3,10 +3,10 @@ package querylog
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@ -39,7 +39,7 @@ func (l *queryLog) flushLogBuffer(fullFlush bool) error {
|
||||
}
|
||||
|
||||
// flushToFile saves the specified log entries to the query log file
|
||||
func (l *queryLog) flushToFile(buffer []*logEntry) error {
|
||||
func (l *queryLog) flushToFile(buffer []*logEntry) (err error) {
|
||||
if len(buffer) == 0 {
|
||||
log.Debug("querylog: there's nothing to write to a file")
|
||||
return nil
|
||||
@ -49,9 +49,10 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
|
||||
var b bytes.Buffer
|
||||
e := json.NewEncoder(&b)
|
||||
for _, entry := range buffer {
|
||||
err := e.Encode(entry)
|
||||
err = e.Encode(entry)
|
||||
if err != nil {
|
||||
log.Error("Failed to marshal entry: %s", err)
|
||||
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -59,7 +60,6 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
|
||||
elapsed := time.Since(start)
|
||||
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer)))
|
||||
|
||||
var err error
|
||||
var zb bytes.Buffer
|
||||
filename := l.logFile
|
||||
zb = b
|
||||
@ -71,7 +71,7 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
|
||||
log.Error("failed to create file \"%s\": %s", filename, err)
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
n, err := f.Write(zb.Bytes())
|
||||
if err != nil {
|
||||
@ -109,7 +109,12 @@ func (l *queryLog) readFileFirstTimeValue() int64 {
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
defer f.Close()
|
||||
defer func() {
|
||||
derr := f.Close()
|
||||
if derr != nil {
|
||||
log.Error("querylog: closing file: %s", derr)
|
||||
}
|
||||
}()
|
||||
|
||||
buf := make([]byte, 500)
|
||||
r, err := f.Read(buf)
|
||||
|
@ -142,7 +142,12 @@ func (l *queryLog) searchFiles(
|
||||
|
||||
return entries, oldest, 0
|
||||
}
|
||||
defer r.Close()
|
||||
defer func() {
|
||||
derr := r.Close()
|
||||
if derr != nil {
|
||||
log.Error("querylog: closing file: %s", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if params.olderThan.IsZero() {
|
||||
err = r.SeekStart()
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
@ -12,7 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
)
|
||||
@ -93,7 +92,7 @@ func createObject(conf Config) (s *statsCtx, err error) {
|
||||
|
||||
// TODO(a.garipov): See if this is actually necessary. Looks
|
||||
// like a rather bizarre solution.
|
||||
errStop := agherr.Error("stop iteration")
|
||||
errStop := errors.Error("stop iteration")
|
||||
forEachBkt := func(name []byte, _ *bolt.Bucket) (cberr error) {
|
||||
nameID := uint32(btoi(name))
|
||||
if nameID < firstID {
|
||||
|
@ -4,10 +4,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Make configurable.
|
||||
@ -27,7 +29,7 @@ const MaxResponseSize = 64 * 1024
|
||||
|
||||
// VersionInfo downloads the latest version information. If forceRecheck is
|
||||
// false and there are cached results, those results are returned.
|
||||
func (u *Updater) VersionInfo(forceRecheck bool) (VersionInfo, error) {
|
||||
func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
|
||||
@ -37,22 +39,23 @@ func (u *Updater) VersionInfo(forceRecheck bool) (VersionInfo, error) {
|
||||
return u.prevCheckResult, u.prevCheckError
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
vcu := u.versionCheckURL
|
||||
resp, err := u.client.Get(vcu)
|
||||
resp, err = u.client.Get(vcu)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize)
|
||||
var r io.Reader
|
||||
r, err = aghio.LimitReader(resp.Body, MaxResponseSize)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// This use of ReadAll is safe, because we just limited the appropriate
|
||||
// ReadCloser.
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -20,6 +19,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
@ -283,22 +283,23 @@ func (u *Updater) clean() {
|
||||
const MaxPackageFileSize = 32 * 1024 * 1024
|
||||
|
||||
// Download package file and save it to disk
|
||||
func (u *Updater) downloadPackageFile(url, filename string) error {
|
||||
resp, err := u.client.Get(url)
|
||||
func (u *Updater) downloadPackageFile(url, filename string) (err error) {
|
||||
var resp *http.Response
|
||||
resp, err = u.client.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
|
||||
|
||||
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize)
|
||||
var r io.Reader
|
||||
r, err = aghio.LimitReader(resp.Body, MaxPackageFileSize)
|
||||
if err != nil {
|
||||
return fmt.Errorf("http request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
log.Debug("updater: reading HTTP body")
|
||||
// This use of ReadAll is now safe, because we limited body's Reader.
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
body, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("io.ReadAll() failed: %w", err)
|
||||
}
|
||||
@ -313,172 +314,178 @@ func (u *Updater) downloadPackageFile(url, filename string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name string, err error) {
|
||||
name = filepath.Base(hdr.Name)
|
||||
if name == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
outputName := filepath.Join(outDir, name)
|
||||
|
||||
if hdr.Typeflag == tar.TypeDir {
|
||||
if name == "AdGuardHome" {
|
||||
// Top-level AdGuardHome/. Skip it.
|
||||
//
|
||||
// TODO(a.garipov): This whole package needs to be
|
||||
// rewritten and covered in more integration tests. It
|
||||
// has weird assumptions and file mode issues.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
err = os.Mkdir(outputName, os.FileMode(hdr.Mode&0o777))
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
|
||||
}
|
||||
|
||||
log.Debug("updater: created directory %q", outputName)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if hdr.Typeflag != tar.TypeReg {
|
||||
log.Debug("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var wc io.WriteCloser
|
||||
wc, err = os.OpenFile(
|
||||
outputName,
|
||||
os.O_WRONLY|os.O_CREATE|os.O_TRUNC,
|
||||
os.FileMode(hdr.Mode&0o777),
|
||||
)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("os.OpenFile(%s): %w", outputName, err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, wc.Close()) }()
|
||||
|
||||
_, err = io.Copy(wc, tr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("updater: created file %s", outputName)
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// Unpack all files from .tar.gz file to the specified directory
|
||||
// Existing files are overwritten
|
||||
// All files are created inside 'outdir', subdirectories are not created
|
||||
// All files are created inside outDir, subdirectories are not created
|
||||
// Return the list of files (not directories) written
|
||||
func tarGzFileUnpack(tarfile, outdir string) ([]string, error) {
|
||||
func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) {
|
||||
f, err := os.Open(tarfile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("os.Open(): %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
defer func() { err = errors.WithDeferred(err, f.Close()) }()
|
||||
|
||||
gzReader, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gzip.NewReader(): %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, gzReader.Close()) }()
|
||||
|
||||
var files []string
|
||||
var err2 error
|
||||
tarReader := tar.NewReader(gzReader)
|
||||
for {
|
||||
var header *tar.Header
|
||||
header, err = tarReader.Next()
|
||||
if err == io.EOF {
|
||||
err2 = nil
|
||||
var hdr *tar.Header
|
||||
hdr, err = tarReader.Next()
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("tarReader.Next(): %w", err)
|
||||
} else if err != nil {
|
||||
err = fmt.Errorf("tarReader.Next(): %w", err)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
_, inputNameOnly := filepath.Split(header.Name)
|
||||
if inputNameOnly == "" {
|
||||
continue
|
||||
var name string
|
||||
name, err = tarGzFileUnpackOne(outDir, tarReader, hdr)
|
||||
|
||||
if name != "" {
|
||||
files = append(files, name)
|
||||
}
|
||||
|
||||
outputName := filepath.Join(outdir, inputNameOnly)
|
||||
|
||||
if header.Typeflag == tar.TypeDir {
|
||||
if inputNameOnly == "AdGuardHome" {
|
||||
// Top-level AdGuardHome/. Skip it.
|
||||
//
|
||||
// TODO(a.garipov): This whole package needs to
|
||||
// be rewritten and covered in more integration
|
||||
// tests. It has weird assumptions and file
|
||||
// mode issues.
|
||||
continue
|
||||
}
|
||||
|
||||
err = os.Mkdir(outputName, os.FileMode(header.Mode&0o777))
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
log.Debug("updater: created directory %q", outputName)
|
||||
|
||||
continue
|
||||
} else if header.Typeflag != tar.TypeReg {
|
||||
log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag)
|
||||
continue
|
||||
}
|
||||
|
||||
var f io.WriteCloser
|
||||
f, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0o777))
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("os.OpenFile(%s): %w", outputName, err)
|
||||
break
|
||||
}
|
||||
_, err = io.Copy(f, tarReader)
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
err2 = fmt.Errorf("io.Copy(): %w", err)
|
||||
break
|
||||
}
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("f.Close(): %w", err)
|
||||
break
|
||||
}
|
||||
|
||||
log.Debug("updater: created file %s", outputName)
|
||||
files = append(files, header.Name)
|
||||
}
|
||||
|
||||
_ = gzReader.Close()
|
||||
return files, err2
|
||||
return files, err
|
||||
}
|
||||
|
||||
func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
|
||||
var rc io.ReadCloser
|
||||
rc, err = zf.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("zip file Open(): %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, rc.Close()) }()
|
||||
|
||||
fi := zf.FileInfo()
|
||||
name = fi.Name()
|
||||
if name == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
outputName := filepath.Join(outDir, name)
|
||||
if fi.IsDir() {
|
||||
if name == "AdGuardHome" {
|
||||
// Top-level AdGuardHome/. Skip it.
|
||||
//
|
||||
// TODO(a.garipov): See the similar todo in
|
||||
// tarGzFileUnpack.
|
||||
return "", nil
|
||||
}
|
||||
|
||||
err = os.Mkdir(outputName, fi.Mode())
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
|
||||
}
|
||||
|
||||
log.Tracef("created directory %q", outputName)
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var wc io.WriteCloser
|
||||
wc, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("os.OpenFile(): %w", err)
|
||||
}
|
||||
defer func() { err = errors.WithDeferred(err, wc.Close()) }()
|
||||
|
||||
_, err = io.Copy(wc, rc)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("io.Copy(): %w", err)
|
||||
}
|
||||
|
||||
log.Tracef("created file %s", outputName)
|
||||
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// Unpack all files from .zip file to the specified directory
|
||||
// Existing files are overwritten
|
||||
// All files are created inside 'outdir', subdirectories are not created
|
||||
// All files are created inside 'outDir', subdirectories are not created
|
||||
// Return the list of files (not directories) written
|
||||
func zipFileUnpack(zipfile, outdir string) ([]string, error) {
|
||||
r, err := zip.OpenReader(zipfile)
|
||||
func zipFileUnpack(zipfile, outDir string) (files []string, err error) {
|
||||
zrc, err := zip.OpenReader(zipfile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("zip.OpenReader(): %w", err)
|
||||
}
|
||||
defer r.Close()
|
||||
defer func() { err = errors.WithDeferred(err, zrc.Close()) }()
|
||||
|
||||
var files []string
|
||||
var err2 error
|
||||
var zr io.ReadCloser
|
||||
for _, zf := range r.File {
|
||||
zr, err = zf.Open()
|
||||
for _, zf := range zrc.File {
|
||||
var name string
|
||||
name, err = zipFileUnpackOne(outDir, zf)
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("zip file Open(): %w", err)
|
||||
break
|
||||
}
|
||||
|
||||
fi := zf.FileInfo()
|
||||
inputNameOnly := fi.Name()
|
||||
if inputNameOnly == "" {
|
||||
continue
|
||||
if name != "" {
|
||||
files = append(files, name)
|
||||
}
|
||||
|
||||
outputName := filepath.Join(outdir, inputNameOnly)
|
||||
|
||||
if fi.IsDir() {
|
||||
if inputNameOnly == "AdGuardHome" {
|
||||
// Top-level AdGuardHome/. Skip it.
|
||||
//
|
||||
// TODO(a.garipov): See the similar todo in
|
||||
// tarGzFileUnpack.
|
||||
continue
|
||||
}
|
||||
|
||||
err = os.Mkdir(outputName, fi.Mode())
|
||||
if err != nil && !errors.Is(err, os.ErrExist) {
|
||||
err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
log.Tracef("created directory %q", outputName)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
var f io.WriteCloser
|
||||
f, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("os.OpenFile(): %w", err)
|
||||
break
|
||||
}
|
||||
_, err = io.Copy(f, zr)
|
||||
if err != nil {
|
||||
_ = f.Close()
|
||||
err2 = fmt.Errorf("io.Copy(): %w", err)
|
||||
break
|
||||
}
|
||||
err = f.Close()
|
||||
if err != nil {
|
||||
err2 = fmt.Errorf("f.Close(): %w", err)
|
||||
break
|
||||
}
|
||||
|
||||
log.Tracef("created file %s", outputName)
|
||||
files = append(files, inputNameOnly)
|
||||
}
|
||||
|
||||
_ = zr.Close()
|
||||
return files, err2
|
||||
return files, err
|
||||
}
|
||||
|
||||
// Copy file on disk
|
||||
|
@ -75,10 +75,10 @@ esac
|
||||
# Simple Analyzers
|
||||
|
||||
# blocklist_imports is a simple check against unwanted packages. Package
|
||||
# io/ioutil is soft-deprecated. Package log is replaced by our own package
|
||||
# github.com/AdguardTeam/golibs/log.
|
||||
# io/ioutil is soft-deprecated. Packages errors and log are replaced by our own
|
||||
# packages in the github.com/AdguardTeam/golibs module.
|
||||
blocklist_imports() {
|
||||
git grep -F -e '"io/ioutil"' -e '"log"' -- '*.go' || exit 0;
|
||||
git grep -F -e '"errors"' -e '"io/ioutil"' -e '"log"' -- '*.go' || exit 0;
|
||||
}
|
||||
|
||||
# method_const is a simple check against the usage of some raw strings and
|
||||
@ -192,13 +192,7 @@ nilness ./...
|
||||
|
||||
exit_on_output shadow --strict ./...
|
||||
|
||||
# TODO(a.garipov): Enable errcheck fully after handling all errors, including
|
||||
# the deferred and generated ones, properly. Also, perhaps, enable --blank.
|
||||
#
|
||||
# errcheck ./...
|
||||
exit_on_output sh -c '
|
||||
errcheck --asserts --ignoregenerated ./... |\
|
||||
{ grep -e "defer" -v || exit 0; }
|
||||
'
|
||||
# TODO(a.garipov): Enable --blank?
|
||||
errcheck --asserts ./...
|
||||
|
||||
staticcheck ./...
|
||||
|
Loading…
Reference in New Issue
Block a user