Pull request: replace agherr with golibs' errors

Merge in DNS/adguard-home from golibs-errors to master

Squashed commit of the following:

commit 5aba278a31c5a213bd9e08273ce7277c57713b22
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon May 24 17:05:18 2021 +0300

    all: imp code

commit f447eb875b81779fa9e391d98c31c1eeba7ef323
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Mon May 24 15:33:45 2021 +0300

    replace agherr with golibs' errors
This commit is contained in:
Ainar Garipov 2021-05-24 17:28:11 +03:00
parent 14250821ab
commit 03a828ef51
60 changed files with 406 additions and 672 deletions

View File

@ -66,8 +66,7 @@ on GitHub and most other Markdown renderers. -->
### <a id="code" href="#code">Code</a>
* Always `recover` from panics in new goroutines. Preferably in the very
first statement. If all you want there is a log message, use
`agherr.LogPanic`.
first statement. If all you want there is a log message, use `log.OnPanic`.
* Avoid `fallthrough`. It makes it harder to rearrange `case`s, to reason
about the code, and also to switch the code to a handler approach, if that

2
go.mod
View File

@ -4,7 +4,7 @@ go 1.16
require (
github.com/AdguardTeam/dnsproxy v0.37.4
github.com/AdguardTeam/golibs v0.5.0
github.com/AdguardTeam/golibs v0.8.0
github.com/AdguardTeam/urlfilter v0.14.5
github.com/NYTimes/gziphandler v1.1.1
github.com/ameshkov/dnscrypt/v2 v2.1.3

4
go.sum
View File

@ -14,8 +14,8 @@ github.com/AdguardTeam/dnsproxy v0.37.4/go.mod h1:xkJWEuTr550gPDmB9azsciKZzSXjf9
github.com/AdguardTeam/golibs v0.4.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.2/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.4.4/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.5.0 h1:qwhEKjDrT0UcwDnHrNU2Yg/DLR9b/GsUncnXYW6VzAU=
github.com/AdguardTeam/golibs v0.5.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/golibs v0.8.0 h1:rHo+yIgT2fivFG0yW2Cwk/DPc2+t/Aw6QvzPpiIFre0=
github.com/AdguardTeam/golibs v0.8.0/go.mod h1:skKsDKIBB7kkFflLJBpfGX+G8QFTx0WKUzB6TIgtUj4=
github.com/AdguardTeam/gomitmproxy v0.2.0/go.mod h1:Qdv0Mktnzer5zpdpi5rAwixNJzW2FN91LjKJCkVbYGU=
github.com/AdguardTeam/urlfilter v0.14.5 h1:WyF0hg0MwKevsqNPkoaZFH8f5WRi/yuy/7qePtYt5Ts=
github.com/AdguardTeam/urlfilter v0.14.5/go.mod h1:klx4JbOfc4EaNb5lWLqOwfg+pVcyRukmoJRvO55lL5U=

View File

@ -1,128 +0,0 @@
// Package agherr contains AdGuard Home's error handling helpers.
package agherr
import (
"fmt"
"strings"
"github.com/AdguardTeam/golibs/log"
)
// Error is the constant error type.
type Error string
// Error implements the error interface for Error.
func (err Error) Error() (msg string) {
return string(err)
}
// manyError is an error containing several wrapped errors. It is created to be
// a simpler version of the API provided by github.com/joomcode/errorx.
type manyError struct {
message string
underlying []error
}
// Many wraps several errors and returns a single error.
//
// TODO(a.garipov): Add formatting to message.
func Many(message string, underlying ...error) (err error) {
err = &manyError{
message: message,
underlying: underlying,
}
return err
}
// Error implements the error interface for *manyError.
func (e *manyError) Error() (msg string) {
switch len(e.underlying) {
case 0:
return e.message
case 1:
return fmt.Sprintf("%s: %s", e.message, e.underlying[0])
default:
b := &strings.Builder{}
// Ignore errors, since strings.(*Buffer).Write never returns
// errors. We don't use aghstrings.WriteToBuilder here since
// this package should be importable for any other.
_, _ = fmt.Fprintf(b, "%s: %s (hidden: %s", e.message, e.underlying[0], e.underlying[1])
for _, u := range e.underlying[2:] {
// See comment above.
_, _ = fmt.Fprintf(b, ", %s", u)
}
// See comment above.
_, _ = b.WriteString(")")
return b.String()
}
}
// Unwrap implements the hidden errors.wrapper interface for *manyError.
func (e *manyError) Unwrap() (err error) {
if len(e.underlying) == 0 {
return nil
}
return e.underlying[0]
}
// wrapper is a copy of the hidden errors.wrapper interface for tests, linting,
// etc.
type wrapper interface {
Unwrap() error
}
// Annotate annotates the error with the message, unless the error is nil. This
// is a helper function to simplify code like this:
//
// func (f *foo) doStuff(s string) (err error) {
// defer func() {
// if err != nil {
// err = fmt.Errorf("bad foo string %q: %w", s, err)
// }
// }()
//
// // …
// }
//
// Instead, write:
//
// func (f *foo) doStuff(s string) (err error) {
// defer agherr.Annotate("bad foo string %q: %w", &err, s)
//
// // …
// }
//
// msg must contain the final ": %w" verb.
//
// TODO(a.garipov): Clearify the function usage.
func Annotate(msg string, errPtr *error, args ...interface{}) {
if errPtr == nil {
return
}
err := *errPtr
if err != nil {
args = append(args, err)
*errPtr = fmt.Errorf(msg, args...)
}
}
// LogPanic is a convinient helper function to log a panic in a goroutine. It
// should not be used where proper error handling is required.
func LogPanic(prefix string) {
if v := recover(); v != nil {
if prefix != "" {
log.Error("%s: recovered from panic: %v", prefix, v)
return
}
log.Error("recovered from panic: %v", v)
}
}

View File

@ -1,160 +0,0 @@
package agherr
import (
"bytes"
"errors"
"fmt"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestError_Error(t *testing.T) {
testCases := []struct {
err error
name string
want string
}{{
err: Many("a"),
name: "simple",
want: "a",
}, {
err: Many("a", errors.New("b")),
name: "wrapping",
want: "a: b",
}, {
err: Many("a", errors.New("b"), errors.New("c"), errors.New("d")),
name: "wrapping several",
want: "a: b (hidden: c, d)",
}, {
err: Many("a", Many("b", errors.New("c"), errors.New("d"))),
name: "wrapping wrapper",
want: "a: b: c (hidden: d)",
}}
for _, tc := range testCases {
assert.Equal(t, tc.want, tc.err.Error(), tc.name)
}
}
func TestError_Unwrap(t *testing.T) {
var _ wrapper = &manyError{}
const (
errSimple = iota
errWrapped
errNil
)
errs := []error{
errSimple: errors.New("a"),
errWrapped: fmt.Errorf("err: %w", errors.New("nested")),
errNil: nil,
}
testCases := []struct {
want error
wrapped error
name string
}{{
want: errs[errSimple],
wrapped: Many("a", errs[errSimple]),
name: "simple",
}, {
want: errs[errWrapped],
wrapped: Many("b", errs[errWrapped]),
name: "nested",
}, {
want: errs[errNil],
wrapped: Many("c", errs[errNil]),
name: "nil passed",
}, {
want: nil,
wrapped: Many("d"),
name: "nil not passed",
}}
for _, tc := range testCases {
assert.Equal(t, tc.want, errors.Unwrap(tc.wrapped), tc.name)
}
}
func TestAnnotate(t *testing.T) {
const s = "1234"
const wantMsg = `bad string "1234": test`
// Don't use const, because we can't take a pointer of a constant.
var errTest error = Error("test")
t.Run("nil", func(t *testing.T) {
var errPtr *error
assert.NotPanics(t, func() {
Annotate("bad string %q: %w", errPtr, s)
})
})
t.Run("non_nil", func(t *testing.T) {
errPtr := &errTest
assert.NotPanics(t, func() {
Annotate("bad string %q: %w", errPtr, s)
})
require.NotNil(t, errPtr)
err := *errPtr
require.Error(t, err)
assert.Equal(t, wantMsg, err.Error())
})
t.Run("defer", func(t *testing.T) {
f := func() (err error) {
defer Annotate("bad string %q: %w", &errTest, s)
return errTest
}
err := f()
require.Error(t, err)
assert.Equal(t, wantMsg, err.Error())
})
}
func TestLogPanic(t *testing.T) {
buf := &bytes.Buffer{}
aghtest.ReplaceLogWriter(t, buf)
t.Run("prefix", func(t *testing.T) {
const (
panicMsg = "spooky!"
prefix = "packagename"
errWithNoPrefix = "[error] recovered from panic: spooky!"
errWithPrefix = "[error] packagename: recovered from panic: spooky!"
)
panicFunc := func(prefix string) {
defer LogPanic(prefix)
panic(panicMsg)
}
panicFunc("")
assert.Contains(t, buf.String(), errWithNoPrefix)
buf.Reset()
panicFunc(prefix)
assert.Contains(t, buf.String(), errWithPrefix)
buf.Reset()
})
t.Run("don't_panic", func(t *testing.T) {
require.NotPanics(t, func() {
defer LogPanic("")
})
assert.Empty(t, buf.String())
})
}

View File

@ -1,59 +0,0 @@
// Package aghio contains extensions for io package's types and methods
package aghio
import (
"fmt"
"io"
)
// LimitReachedError records the limit and the operation that caused it.
type LimitReachedError struct {
Limit int64
}
// Error implements error interface for LimitReachedError.
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
}
// limitedReadCloser is a wrapper for io.ReadCloser with limited reader and
// dealing with agherr package.
type limitedReadCloser struct {
limit int64
n int64
rc io.ReadCloser
}
// Read implements Reader interface.
func (lrc *limitedReadCloser) Read(p []byte) (n int, err error) {
if lrc.n == 0 {
return 0, &LimitReachedError{
Limit: lrc.limit,
}
}
if int64(len(p)) > lrc.n {
p = p[0:lrc.n]
}
n, err = lrc.rc.Read(p)
lrc.n -= int64(n)
return n, err
}
// Close implements Closer interface.
func (lrc *limitedReadCloser) Close() error {
return lrc.rc.Close()
}
// LimitReadCloser wraps ReadCloser to make it's Reader stop with
// ErrLimitReached after n bytes read.
func LimitReadCloser(rc io.ReadCloser, n int64) (limited io.ReadCloser, err error) {
if n < 0 {
return nil, fmt.Errorf("aghio: invalid n in LimitReadCloser: %d", n)
}
return &limitedReadCloser{
limit: n,
n: n,
rc: rc,
}, nil
}

View File

@ -0,0 +1,59 @@
// Package aghio contains extensions for io package's types and methods
package aghio
import (
"fmt"
"io"
)
// LimitReachedError records the limit and the operation that caused it.
type LimitReachedError struct {
Limit int64
}
// Error implements error interface for LimitReachedError.
//
// TODO(a.garipov): Think about error string format.
func (lre *LimitReachedError) Error() string {
return fmt.Sprintf("attempted to read more than %d bytes", lre.Limit)
}
// limitedReader is a wrapper for io.Reader with limited reader and dealing with
// errors package.
type limitedReader struct {
r io.Reader
limit int64
n int64
}
// Read implements Reader interface.
func (lr *limitedReader) Read(p []byte) (n int, err error) {
if lr.n == 0 {
return 0, &LimitReachedError{
Limit: lr.limit,
}
}
if int64(len(p)) > lr.n {
p = p[0:lr.n]
}
n, err = lr.r.Read(p)
lr.n -= int64(n)
return n, err
}
// LimitReader wraps Reader to make it's Reader stop with ErrLimitReached after
// n bytes read.
func LimitReader(r io.Reader, n int64) (limited io.Reader, err error) {
if n < 0 {
return nil, fmt.Errorf("aghio: invalid n in LimitReader: %d", n)
}
return &limitedReader{
r: r,
limit: n,
n: n,
}, nil
}

View File

@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestLimitReadCloser(t *testing.T) {
func TestLimitReader(t *testing.T) {
testCases := []struct {
want error
name string
@ -24,20 +24,20 @@ func TestLimitReadCloser(t *testing.T) {
name: "zero",
n: 0,
}, {
want: fmt.Errorf("aghio: invalid n in LimitReadCloser: -1"),
want: fmt.Errorf("aghio: invalid n in LimitReader: -1"),
name: "negative",
n: -1,
}}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := LimitReadCloser(nil, tc.n)
_, err := LimitReader(nil, tc.n)
assert.Equal(t, tc.want, err)
})
}
}
func TestLimitedReadCloser_Read(t *testing.T) {
func TestLimitedReader_Read(t *testing.T) {
testCases := []struct {
err error
name string
@ -77,7 +77,7 @@ func TestLimitedReadCloser_Read(t *testing.T) {
readCloser := io.NopCloser(strings.NewReader(tc.rStr))
buf := make([]byte, tc.limit+1)
lreader, err := LimitReadCloser(readCloser, tc.limit)
lreader, err := LimitReader(readCloser, tc.limit)
require.NoError(t, err)
n, err := lreader.Read(buf)
@ -87,7 +87,7 @@ func TestLimitedReadCloser_Read(t *testing.T) {
}
}
func TestLimitedReadCloser_LimitReachedError(t *testing.T) {
func TestLimitedReader_LimitReachedError(t *testing.T) {
testCases := []struct {
err error
name string

View File

@ -6,7 +6,7 @@ import (
"strconv"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/net/idna"
)
@ -26,11 +26,11 @@ func isValidHostRune(r rune) (ok bool) {
// ValidateHardwareAddress returns an error if hwa is not a valid EUI-48,
// EUI-64, or 20-octet InfiniBand link-layer address.
func ValidateHardwareAddress(hwa net.HardwareAddr) (err error) {
defer agherr.Annotate("validating hardware address %q: %w", &err, hwa)
defer func() { err = errors.Annotate(err, "validating hardware address %q: %w", hwa) }()
switch l := len(hwa); l {
case 0:
return agherr.Error("address is empty")
return errors.Error("address is empty")
case 6, 8, 20:
return nil
default:
@ -51,13 +51,13 @@ const maxDomainNameLen = 253
// ValidateDomainNameLabel returns an error if label is not a valid label of
// a domain name.
func ValidateDomainNameLabel(label string) (err error) {
defer agherr.Annotate("validating label %q: %w", &err, label)
defer func() { err = errors.Annotate(err, "validating label %q: %w", label) }()
l := len(label)
if l > maxDomainLabelLen {
return fmt.Errorf("label is too long, max: %d", maxDomainLabelLen)
} else if l == 0 {
return agherr.Error("label is empty")
return errors.Error("label is empty")
}
if r := label[0]; !IsValidHostOuterRune(rune(r)) {
@ -87,7 +87,7 @@ func ValidateDomainNameLabel(label string) (err error) {
// TODO(a.garipov): After making sure that this works correctly, port this into
// module golibs.
func ValidateDomainName(name string) (err error) {
defer agherr.Annotate("validating domain name %q: %w", &err, name)
defer func() { err = errors.Annotate(err, "validating domain name %q: %w", name) }()
name, err = idna.ToASCII(name)
if err != nil {
@ -96,7 +96,7 @@ func ValidateDomainName(name string) (err error) {
l := len(name)
if l == 0 {
return agherr.Error("domain name is empty")
return errors.Error("domain name is empty")
} else if l > maxDomainNameLen {
return fmt.Errorf("too long, max: %d", maxDomainNameLen)
}

View File

@ -2,7 +2,6 @@ package aghnet
import (
"bufio"
"errors"
"io"
"net"
"os"
@ -12,6 +11,7 @@ import (
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/fsnotify/fsnotify"
"github.com/miekg/dns"
@ -239,7 +239,13 @@ func (ehc *EtcHostsContainer) load(table map[string][]net.IP, tableRev map[strin
log.Error("etchostscontainer: %s", err)
return
}
defer f.Close()
defer func() {
derr := f.Close()
if derr != nil {
log.Error("etchostscontainer: closing file: %s", err)
}
}()
r := bufio.NewReader(f)
log.Debug("etchostscontainer: loading hosts from file %s", fn)

View File

@ -3,7 +3,6 @@ package aghnet
import (
"encoding/json"
"errors"
"fmt"
"net"
"os"
@ -14,14 +13,14 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// ErrNoStaticIPInfo is returned by IfaceHasStaticIP when no information about
// the IP being static is available.
const ErrNoStaticIPInfo agherr.Error = "no information about static ip"
const ErrNoStaticIPInfo errors.Error = "no information about static ip"
// IfaceHasStaticIP checks if interface is configured to have static IP address.
// If it can't give a definitive answer, it returns false and an error for which
@ -106,7 +105,7 @@ func GetValidNetInterfacesForWeb() ([]*NetInterface, error) {
return nil, fmt.Errorf("couldn't get interfaces: %w", err)
}
if len(ifaces) == 0 {
return nil, errors.New("couldn't find any legible interface")
return nil, errors.Error("couldn't find any legible interface")
}
var netInterfaces []*NetInterface

View File

@ -5,13 +5,13 @@
package aghnet
import (
"errors"
"fmt"
"os"
"regexp"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
)
// hardwarePortInfo - information obtained using MacOS networksetup
@ -83,7 +83,7 @@ func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
match := re.FindStringSubmatch(out)
if len(match) == 0 {
return h, errors.New("could not find hardware port info")
return h, errors.Error("could not find hardware port info")
}
h.name = hardwarePort
@ -105,7 +105,7 @@ func ifaceSetStaticIP(ifaceName string) (err error) {
}
if portInfo.static {
return errors.New("IP address is already static")
return errors.Error("IP address is already static")
}
dnsAddrs, err := getEtcResolvConfServers()
@ -151,7 +151,7 @@ func getEtcResolvConfServers() ([]string, error) {
matches := re.FindAllStringSubmatch(string(body), -1)
if len(matches) == 0 {
return nil, errors.New("found no DNS servers in /etc/resolv.conf")
return nil, errors.Error("found no DNS servers in /etc/resolv.conf")
}
addrs := make([]string, 0)

View File

@ -6,7 +6,6 @@ package aghnet
import (
"bufio"
"errors"
"fmt"
"io"
"net"
@ -14,6 +13,7 @@ import (
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors"
"github.com/google/renameio/maybe"
)
@ -49,16 +49,15 @@ func ifaceHasStaticIP(ifaceName string) (has bool, err error) {
return false, err
}
defer f.Close()
defer func() { err = errors.WithDeferred(err, f.Close()) }()
var fileReadCloser io.ReadCloser
fileReadCloser, err = aghio.LimitReadCloser(f, maxConfigFileSize)
var fileReader io.Reader
fileReader, err = aghio.LimitReader(f, maxConfigFileSize)
if err != nil {
return false, err
}
defer fileReadCloser.Close()
has, err = check.checker(fileReadCloser, ifaceName)
has, err = check.checker(fileReader, ifaceName)
if err != nil {
return false, err
}
@ -134,7 +133,7 @@ func ifacesStaticConfig(r io.Reader, ifaceName string) (has bool, err error) {
func ifaceSetStaticIP(ifaceName string) (err error) {
ipNet := GetSubnet(ifaceName)
if ipNet.IP == nil {
return errors.New("can't get IP address")
return errors.Error("can't get IP address")
}
gatewayIP := GatewayIP(ifaceName)

View File

@ -3,7 +3,7 @@ package aghnet
import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
@ -27,19 +27,19 @@ type SystemResolvers interface {
const (
// errBadAddrPassed is returned when dialFunc can't parse an IP address.
errBadAddrPassed agherr.Error = "the passed string is not a valid IP address"
errBadAddrPassed errors.Error = "the passed string is not a valid IP address"
// errFakeDial is an error which dialFunc is expected to return.
errFakeDial agherr.Error = "this error signals the successful dialFunc work"
errFakeDial errors.Error = "this error signals the successful dialFunc work"
// errUnexpectedHostFormat is returned by validateDialedHost when the host has
// more than one percent sign.
errUnexpectedHostFormat agherr.Error = "unexpected host format"
errUnexpectedHostFormat errors.Error = "unexpected host format"
)
// refreshWithTicker refreshes the cache of sr after each tick form tickCh.
func refreshWithTicker(sr SystemResolvers, tickCh <-chan time.Time) {
defer agherr.LogPanic("systemResolvers")
defer log.OnPanic("systemResolvers")
// TODO(e.burkov): Implement a functionality to stop ticker.
for range tickCh {

View File

@ -6,15 +6,14 @@ package aghnet
import (
"context"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/errors"
)
// defaultHostGen is the default method of generating host for Refresh.
@ -34,7 +33,7 @@ type systemResolvers struct {
}
func (sr *systemResolvers) refresh() (err error) {
defer agherr.Annotate("systemResolvers: %w", &err)
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
_, err = sr.resolver.LookupHost(context.Background(), sr.hostGenFunc())
dnserr := &net.DNSError{}
@ -63,7 +62,7 @@ func newSystemResolvers(refreshIvl time.Duration, hostGenFunc HostGenFunc) (sr S
// validateDialedHost validated the host used by resolvers in dialFunc.
func validateDialedHost(host string) (err error) {
defer agherr.Annotate("parsing %q: %w", &err, host)
defer func() { err = errors.Annotate(err, "parsing %q: %w", host) }()
var ipStr string
parts := strings.Split(host, "%")

View File

@ -14,9 +14,9 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
@ -65,14 +65,15 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
return nil, fmt.Errorf("getting the command's stdout pipe: %w", err)
}
var stdoutLimited io.ReadCloser
stdoutLimited, err = aghio.LimitReadCloser(stdout, aghos.MaxCmdOutputSize)
var stdoutLimited io.Reader
stdoutLimited, err = aghio.LimitReader(stdout, aghos.MaxCmdOutputSize)
if err != nil {
return nil, fmt.Errorf("limiting stdout reader: %w", err)
}
go func() {
defer agherr.LogPanic("systemResolvers")
defer log.OnPanic("systemResolvers")
defer func() {
derr := stdin.Close()
if derr != nil {
@ -141,7 +142,7 @@ func (sr *systemResolvers) getAddrs() (addrs []string, err error) {
}
func (sr *systemResolvers) refresh() (err error) {
defer agherr.Annotate("systemResolvers: %w", &err)
defer func() { err = errors.Annotate(err, "systemResolvers: %w") }()
got, err := sr.getAddrs()
if err != nil {

View File

@ -30,7 +30,7 @@ func ReplaceLogWriter(t *testing.T, w io.Writer) {
// ReplaceLogLevel sets logging level to l and uses Cleanup method of t to
// revert changes.
func ReplaceLogLevel(t *testing.T, l int) {
func ReplaceLogLevel(t *testing.T, l log.Level) {
switch l {
case log.INFO, log.DEBUG, log.ERROR:
// Go on.

View File

@ -168,9 +168,6 @@ type TestErrUpstream struct {
// Exchange always returns nil Msg and non-nil error.
func (u *TestErrUpstream) Exchange(*dns.Msg) (*dns.Msg, error) {
// We don't use an agherr.Error to avoid the import cycle since aghtests
// used to provide the utilities for testing which agherr (and any other
// testable package) should be able to use.
return nil, fmt.Errorf("errupstream: %w", u.Err)
}

View File

@ -12,6 +12,7 @@ import (
"runtime"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/nclient4"
@ -78,7 +79,7 @@ func CheckIfOtherDHCPServersPresentV4(ifaceName string) (ok bool, err error) {
return false, fmt.Errorf("couldn't listen on :68: %w", err)
}
if c != nil {
defer c.Close()
defer func() { err = errors.WithDeferred(err, c.Close()) }()
}
// send to 255.255.255.255:67
@ -202,7 +203,7 @@ func CheckIfOtherDHCPServersPresentV6(ifaceName string) (ok bool, err error) {
return false, fmt.Errorf("dhcpv6: Couldn't listen on :546: %w", err)
}
if c != nil {
defer c.Close()
defer func() { err = errors.WithDeferred(err, c.Close()) }()
}
_, err = c.WriteTo(req.ToBytes(), dstAddr)

View File

@ -4,11 +4,11 @@ package dhcpd
import (
"encoding/json"
"errors"
"net"
"os"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/maybe"
)

View File

@ -2,7 +2,6 @@ package dhcpd
import (
"encoding/json"
"errors"
"fmt"
"io"
"net"
@ -11,6 +10,7 @@ import (
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)

View File

@ -6,7 +6,7 @@ import (
"math/big"
"net"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
)
// ipRange is an inclusive range of IP addresses. A nil range is a range that
@ -28,7 +28,7 @@ const maxRangeLen = math.MaxUint32
// newIPRange creates a new IP address range. start must be less than end. The
// resulting range must not be greater than maxRangeLen.
func newIPRange(start, end net.IP) (r *ipRange, err error) {
defer agherr.Annotate("invalid ip range: %w", &err)
defer func() { err = errors.Annotate(err, "invalid ip range: %w") }()
// Make sure that both are 16 bytes long to simplify handling in
// methods.

View File

@ -7,7 +7,7 @@ import (
"strconv"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
)
// hexDHCPOptionParserHandler parses a DHCP option as a hex-encoded string.
@ -32,7 +32,7 @@ func hexDHCPOptionParserHandler(s string) (data []byte, err error) {
func ipDHCPOptionParserHandler(s string) (data []byte, err error) {
ip := net.ParseIP(s)
if ip == nil {
return nil, agherr.Error("invalid ip")
return nil, errors.Error("invalid ip")
}
// Most DHCP options require IPv4, so do not put the 16-byte
@ -100,12 +100,12 @@ func newDHCPOptionParser() (p *dhcpOptionParser) {
// parse parses an option. See the handlers' documentation for more info.
func (p *dhcpOptionParser) parse(s string) (code uint8, data []byte, err error) {
defer agherr.Annotate("invalid option string %q: %w", &err, s)
defer func() { err = errors.Annotate(err, "invalid option string %q: %w", s) }()
s = strings.TrimSpace(s)
parts := strings.SplitN(s, " ", 3)
if len(parts) < 3 {
return 0, nil, agherr.Error("need at least three fields")
return 0, nil, errors.Error("need at least three fields")
}
codeStr := parts[0]

View File

@ -1,13 +1,13 @@
package dhcpd
import (
"errors"
"net"
"github.com/AdguardTeam/golibs/errors"
"golang.org/x/net/ipv4"
)
// Create a socket for receiving broadcast packets
func newBroadcastPacketConn(bindAddr net.IP, port int, ifname string) (*ipv4.PacketConn, error) {
return nil, errors.New("newBroadcastPacketConn(): not supported on Windows")
return nil, errors.Error("newBroadcastPacketConn(): not supported on Windows")
}

View File

@ -6,16 +6,15 @@ package dhcpd
import (
"bytes"
"errors"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/go-ping/ping"
"github.com/insomniacslk/dhcp/dhcpv4"
@ -55,7 +54,7 @@ func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) {
// normalizeHostname normalizes a hostname sent by the client. If err is not
// nil, norm is an empty string.
func normalizeHostname(hostname string) (norm string, err error) {
defer agherr.Annotate("normalizing %q: %w", &err, hostname)
defer func() { err = errors.Annotate(err, "normalizing %q: %w", hostname) }()
if hostname == "" {
return "", nil
@ -249,7 +248,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
if bytes.Equal(l.HWAddr, lease.HWAddr) {
if l.IsStatic() {
return agherr.Error("static lease already exists")
return errors.Error("static lease already exists")
}
s.rmLeaseByIndex(i)
@ -262,7 +261,7 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) {
if l.IP.Equal(lease.IP) {
if l.IsStatic() {
return agherr.Error("static lease already exists")
return errors.Error("static lease already exists")
}
s.rmLeaseByIndex(i)
@ -322,12 +321,12 @@ func (s *v4Server) rmLease(lease Lease) (err error) {
}
}
return agherr.Error("lease not found")
return errors.Error("lease not found")
}
// AddStaticLease adds a static lease. It is safe for concurrent use.
func (s *v4Server) AddStaticLease(l Lease) (err error) {
defer agherr.Annotate("dhcpv4: adding static lease: %w", &err)
defer func() { err = errors.Annotate(err, "dhcpv4: adding static lease: %w") }()
if ip4 := l.IP.To4(); ip4 == nil {
return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP)
@ -397,7 +396,7 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) {
// RemoveStaticLease removes a static lease. It is safe for concurrent use.
func (s *v4Server) RemoveStaticLease(l Lease) (err error) {
defer agherr.Annotate("dhcpv4: %w", &err)
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if len(l.IP) != 4 {
return fmt.Errorf("invalid IP")
@ -937,7 +936,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
// Start starts the IPv4 DHCP server.
func (s *v4Server) Start() (err error) {
defer agherr.Annotate("dhcpv4: %w", &err)
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if !s.conf.Enabled {
return nil

View File

@ -1,11 +1,10 @@
package dhcpd
import (
"errors"
"net"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -25,7 +24,7 @@ func (iface *fakeIface) Addrs() (addrs []net.Addr, err error) {
}
func TestIfaceIPAddrs(t *testing.T) {
const errTest agherr.Error = "test error"
const errTest errors.Error = "test error"
ip4 := net.IP{1, 2, 3, 4}
addr4 := &net.IPNet{IP: ip4}
@ -108,7 +107,7 @@ func (iface *waitingFakeIface) Addrs() (addrs []net.Addr, err error) {
}
func TestIfaceDNSIPAddrs(t *testing.T) {
const errTest agherr.Error = "test error"
const errTest errors.Error = "test error"
ip4 := net.IP{1, 2, 3, 4}
addr4 := &net.IPNet{IP: ip4}

View File

@ -11,8 +11,8 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/insomniacslk/dhcp/dhcpv6"
"github.com/insomniacslk/dhcp/dhcpv6/server6"
@ -165,7 +165,7 @@ func (s *v6Server) rmDynamicLease(lease Lease) error {
// AddStaticLease adds a static lease. It is safe for concurrent use.
func (s *v6Server) AddStaticLease(l Lease) (err error) {
defer agherr.Annotate("dhcpv6: %w", &err)
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if len(l.IP) != 16 {
return fmt.Errorf("invalid IP")
@ -194,7 +194,7 @@ func (s *v6Server) AddStaticLease(l Lease) (err error) {
// RemoveStaticLease removes a static lease. It is safe for concurrent use.
func (s *v6Server) RemoveStaticLease(l Lease) (err error) {
defer agherr.Annotate("dhcpv6: %w", &err)
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if len(l.IP) != 16 {
return fmt.Errorf("invalid IP")
@ -585,7 +585,7 @@ func (s *v6Server) initRA(iface *net.Interface) error {
// Start starts the IPv6 DHCP server.
func (s *v6Server) Start() (err error) {
defer agherr.Annotate("dhcpv6: %w", &err)
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if !s.conf.Enabled {
return nil

View File

@ -2,13 +2,13 @@ package dnsforward
import (
"crypto/tls"
"errors"
"fmt"
"path"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/lucas-clemente/quic-go"
)

View File

@ -3,7 +3,6 @@ package dnsforward
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/http"
@ -14,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnscrypt/v2"
)
@ -220,7 +220,7 @@ func (s *Server) createProxyConfig() (proxy.Config, error) {
// Validate proxy config
if proxyConfig.UpstreamConfig == nil || len(proxyConfig.UpstreamConfig.Upstreams) == 0 {
return proxyConfig, errors.New("no default upstream servers configured")
return proxyConfig, errors.Error("no default upstream servers configured")
}
return proxyConfig, nil

View File

@ -2,7 +2,6 @@
package dnsforward
import (
"errors"
"fmt"
"net"
"net/http"
@ -12,7 +11,6 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
@ -21,6 +19,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
@ -226,11 +225,11 @@ type RDNSExchanger interface {
const (
// rDNSEmptyAnswerErr is returned by Exchange method when the answer
// section of respond is empty.
rDNSEmptyAnswerErr agherr.Error = "the answer section is empty"
rDNSEmptyAnswerErr errors.Error = "the answer section is empty"
// rDNSNotPTRErr is returned by Exchange method when the response is not
// of PTR type.
rDNSNotPTRErr agherr.Error = "the response is not a ptr"
rDNSNotPTRErr errors.Error = "the response is not a ptr"
)
// Exchange implements the RDNSExchanger interface for *Server.

View File

@ -17,13 +17,13 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -1198,7 +1198,7 @@ func TestServer_Exchange(t *testing.T) {
"2.1.168.192.in-addr.arpa.": {},
},
}
upstreamErr := agherr.Error("upstream error")
upstreamErr := errors.Error("upstream error")
errUpstream := &aghtest.TestErrUpstream{
Err: upstreamErr,
}

View File

@ -4,9 +4,9 @@ import (
"fmt"
"net"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
@ -83,7 +83,7 @@ func (s *Server) filterDNSRewrite(req *dns.Msg, res filtering.Result, d *proxy.D
resp := s.makeResponse(req)
dnsrr := res.DNSRewriteResult
if dnsrr == nil {
return agherr.Error("no dns rewrite rule content")
return errors.Error("no dns rewrite rule content")
}
resp.Rcode = dnsrr.RCode
@ -94,7 +94,7 @@ func (s *Server) filterDNSRewrite(req *dns.Msg, res filtering.Result, d *proxy.D
}
if dnsrr.Response == nil {
return agherr.Error("no dns rewrite rule responses")
return errors.Error("no dns rewrite rule responses")
}
rr := req.Question[0].Qtype

View File

@ -8,11 +8,11 @@ import (
"strconv"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
@ -399,7 +399,7 @@ func validateUpstream(u string) (bool, error) {
// separateUpstream returns the upstream without the specified domains.
// useDefault is true when a default upstream must be used.
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
defer agherr.Annotate("bad upstream for domain spec %q: %w", &err, upstreamStr)
defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }()
if !strings.HasPrefix(upstreamStr, "[/") {
return upstreamStr, true, nil
@ -407,7 +407,7 @@ func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err
parts := strings.Split(upstreamStr[2:], "/]")
if len(parts) != 2 {
return "", false, agherr.Error("duplicated separator")
return "", false, errors.Error("duplicated separator")
}
domains := parts[0]

View File

@ -10,7 +10,7 @@ import (
"strings"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/digineo/go-ipset/v2"
"github.com/mdlayher/netlink"
@ -83,7 +83,7 @@ func (c *ipsetCtx) ipsetProps(name string) (set ipsetProps, err error) {
}
if res == nil || res.Family == nil {
return set, agherr.Error("empty response or no family data")
return set, errors.Error("empty response or no family data")
}
family := netfilter.ProtoFamily(res.Family.Value)
@ -193,23 +193,23 @@ func (c *ipsetCtx) init(ipsetConfig []string) (err error) {
// Close closes the Linux Netfilter connections.
func (c *ipsetCtx) Close() (err error) {
var errors []error
var errs []error
if c.ipv4Conn != nil {
err = c.ipv4Conn.Close()
if err != nil {
errors = append(errors, err)
errs = append(errs, err)
}
}
if c.ipv6Conn != nil {
err = c.ipv6Conn.Close()
if err != nil {
errors = append(errors, err)
errs = append(errs, err)
}
}
if len(errors) != 0 {
return agherr.Many("closing ipsets", errors...)
if len(errs) != 0 {
return errors.List("closing ipsets", errs...)
}
return nil

View File

@ -73,22 +73,27 @@ func glGetTokenDate(file string) uint32 {
f, err := os.Open(file)
if err != nil {
log.Error("os.Open: %s", err)
return 0
}
defer f.Close()
defer func() {
derr := f.Close()
if derr != nil {
log.Error("glinet: closing file: %s", err)
}
}()
fileReadCloser, err := aghio.LimitReadCloser(f, MaxFileSize)
fileReader, err := aghio.LimitReader(f, MaxFileSize)
if err != nil {
log.Error("creating limited reader: %s", err)
return 0
}
defer fileReadCloser.Close()
var dateToken uint32
// This use of ReadAll is now safe, because we limited reader.
bs, err := io.ReadAll(fileReadCloser)
bs, err := io.ReadAll(fileReader)
if err != nil {
log.Error("reading token: %s", err)

View File

@ -11,7 +11,6 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghstrings"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
@ -20,6 +19,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
@ -438,11 +438,11 @@ func (clients *clientsContainer) FindRuntimeClient(ip string) (RuntimeClient, bo
func (clients *clientsContainer) check(c *Client) (err error) {
switch {
case c == nil:
return agherr.Error("client is nil")
return errors.Error("client is nil")
case c.Name == "":
return agherr.Error("invalid name")
return errors.Error("invalid name")
case len(c.IDs) == 0:
return agherr.Error("id required")
return errors.Error("id required")
default:
// Go on.
}
@ -570,14 +570,14 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
prev, ok := clients.list[name]
if !ok {
return agherr.Error("client not found")
return errors.Error("client not found")
}
// First, check the name index.
if prev.Name != c.Name {
_, ok = clients.list[c.Name]
if ok {
return agherr.Error("client already exists")
return errors.Error("client already exists")
}
}

View File

@ -1,7 +1,6 @@
package home
import (
"errors"
"fmt"
"net"
"os"
@ -14,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/maybe"
yaml "gopkg.in/yaml.v2"

View File

@ -15,7 +15,6 @@ import (
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/log"
)
@ -271,7 +270,7 @@ func copyInstallSettings(dst, src *configuration) {
const shutdownTimeout = 5 * time.Second
func shutdownSrv(ctx context.Context, cancel context.CancelFunc, srv *http.Server) {
defer agherr.LogPanic("")
defer log.OnPanic("")
if srv == nil {
return

View File

@ -3,7 +3,6 @@ package home
import (
"context"
"encoding/json"
"errors"
"net/http"
"os"
"os/exec"
@ -14,6 +13,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)

View File

@ -8,12 +8,12 @@ import (
"path/filepath"
"strconv"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/ameshkov/dnscrypt/v2"
yaml "gopkg.in/yaml.v2"
@ -207,14 +207,14 @@ func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
func newDNSCrypt(hosts []net.IP, tlsConf tlsConfigSettings) (dnscc dnsforward.DNSCryptConfig, err error) {
if tlsConf.DNSCryptConfigFile == "" {
return dnscc, agherr.Error("no dnscrypt_config_file")
return dnscc, errors.Error("no dnscrypt_config_file")
}
f, err := os.Open(tlsConf.DNSCryptConfigFile)
if err != nil {
return dnscc, fmt.Errorf("opening dnscrypt config: %w", err)
}
defer f.Close()
defer func() { err = errors.WithDeferred(err, f.Close()) }()
rc := &dnscrypt.ResolverConfig{}
err = yaml.NewDecoder(f).Decode(rc)

View File

@ -2,7 +2,6 @@ package home
import (
"bufio"
"errors"
"fmt"
"hash/crc32"
"io"
@ -17,6 +16,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
@ -566,17 +566,18 @@ func (f *Filtering) updateIntl(filter *filter) (updated bool, err error) {
if err != nil {
return updated, fmt.Errorf("open file: %w", err)
}
defer func() { err = errors.WithDeferred(err, f.Close()) }()
defer f.Close()
reader = f
} else {
var resp *http.Response
resp, err = Context.client.Get(filter.URL)
if err != nil {
log.Printf("Couldn't request filter from URL %s, skipping: %s", filter.URL, err)
return updated, err
}
defer resp.Body.Close()
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
if resp.StatusCode != http.StatusOK {
log.Printf("Got status code %d from URL %s, skipping", resp.StatusCode, filter.URL)
@ -634,7 +635,7 @@ func (f *Filtering) load(filter *filter) (err error) {
} else if err != nil {
return fmt.Errorf("opening filter file: %w", err)
}
defer file.Close()
defer func() { err = errors.WithDeferred(err, file.Close()) }()
st, err := file.Stat()
if err != nil {

View File

@ -5,7 +5,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/fs"
"net"
@ -21,7 +20,6 @@ import (
"syscall"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpd"
@ -31,6 +29,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"gopkg.in/natefinch/lumberjack.v2"
)
@ -736,7 +735,7 @@ func customDialContext(ctx context.Context, network, addr string) (conn net.Conn
return conn, err
}
return nil, agherr.Many(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
return nil, errors.List(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
}
func getHTTPProxy(_ *http.Request) (*url.URL, error) {

View File

@ -1,6 +1,7 @@
package home
import (
"io"
"net/http"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
@ -58,14 +59,20 @@ func limitRequestBody(h http.Handler) (limited http.Handler) {
szLim = largerReqBodySzLim
}
r.Body, err = aghio.LimitReadCloser(r.Body, szLim)
var reader io.Reader
reader, err = aghio.LimitReader(r.Body, szLim)
if err != nil {
log.Error("limitRequestBody: %s", err)
return
}
h.ServeHTTP(w, r)
// HTTP handlers aren't supposed to call r.Body.Close(), so just
// replace the body in a clone.
rr := r.Clone(r.Context())
rr.Body = io.NopCloser(reader)
h.ServeHTTP(w, rr)
})
}

View File

@ -5,7 +5,6 @@ import (
"net"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/log"
@ -85,7 +84,7 @@ func (r *RDNS) Begin(ip net.IP) {
// workerLoop handles incoming IP addresses from ipChan and adds it into
// clients.
func (r *RDNS) workerLoop() {
defer agherr.LogPanic("rdns")
defer log.OnPanic("rdns")
for ip := range r.ipCh {
host, err := r.exchanger.Exchange(ip)

View File

@ -3,7 +3,6 @@ package home
import (
"bytes"
"encoding/binary"
"errors"
"net"
"sync"
"testing"
@ -13,6 +12,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@ -141,7 +141,7 @@ func TestRDNS_WorkerLoop(t *testing.T) {
},
}
errUpstream := &aghtest.TestErrUpstream{
Err: errors.New("1234"),
Err: errors.Error("1234"),
}
testCases := []struct {

View File

@ -1,7 +1,6 @@
package home
import (
"errors"
"fmt"
"io/fs"
"os"
@ -11,6 +10,7 @@ import (
"syscall"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/kardianos/service"
)

View File

@ -10,7 +10,6 @@ import (
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"net/http"
"os"
@ -21,6 +20,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"golang.org/x/sys/cpu"
)
@ -341,14 +341,14 @@ func verifyCertChain(data *tlsConfigStatus, certChain, serverName string) error
parsed, err := x509.ParseCertificate(cert.Bytes)
if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse certificate: %s", err)
return errors.New(data.WarningValidation)
return errors.Error(data.WarningValidation)
}
parsedCerts = append(parsedCerts, parsed)
}
if len(parsedCerts) == 0 {
data.WarningValidation = "You have specified an empty certificate"
return errors.New(data.WarningValidation)
return errors.Error(data.WarningValidation)
}
data.ValidCert = true
@ -415,14 +415,14 @@ func validatePkey(data *tlsConfigStatus, pkey string) error {
if key == nil {
data.WarningValidation = "No valid keys were found"
return errors.New(data.WarningValidation)
return errors.Error(data.WarningValidation)
}
// parse the decoded key
_, keytype, err := parsePrivateKey(key.Bytes)
if err != nil {
data.WarningValidation = fmt.Sprintf("Failed to parse private key: %s", err)
return errors.New(data.WarningValidation)
return errors.Error(data.WarningValidation)
}
data.ValidKey = true
@ -479,7 +479,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
case *ecdsa.PrivateKey:
return key, "ECDSA", nil
default:
return nil, "", errors.New("tls: found unknown private key type in PKCS#8 wrapping")
return nil, "", errors.Error("tls: found unknown private key type in PKCS#8 wrapping")
}
}
@ -487,7 +487,7 @@ func parsePrivateKey(der []byte) (crypto.PrivateKey, string, error) {
return key, "ECDSA", nil
}
return nil, "", errors.New("tls: failed to parse private key")
return nil, "", errors.Error("tls: failed to parse private key")
}
// unmarshalTLS handles base64-encoded certificates transparently

View File

@ -1,7 +1,6 @@
package home
import (
"errors"
"fmt"
"net/url"
"os"
@ -12,6 +11,7 @@ import (
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/google/renameio/maybe"
"golang.org/x/crypto/bcrypt"

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/cache"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
@ -139,22 +140,22 @@ func whoisParse(data string) (m strmap) {
const MaxConnReadSize = 64 * 1024
// Send request to a server and receive the response
func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, error) {
func (w *Whois) query(ctx context.Context, target, serverAddr string) (data string, err error) {
addr, _, _ := net.SplitHostPort(serverAddr)
if addr == "whois.arin.net" {
target = "n + " + target
}
conn, err := w.dialContext(ctx, "tcp", serverAddr)
if err != nil {
return "", err
}
defer conn.Close()
defer func() { err = errors.WithDeferred(err, conn.Close()) }()
connReadCloser, err := aghio.LimitReadCloser(conn, MaxConnReadSize)
r, err := aghio.LimitReader(conn, MaxConnReadSize)
if err != nil {
return "", err
}
defer connReadCloser.Close()
_ = conn.SetReadDeadline(time.Now().Add(time.Duration(w.timeoutMsec) * time.Millisecond))
_, err = conn.Write([]byte(target + "\r\n"))
@ -163,12 +164,13 @@ func (w *Whois) query(ctx context.Context, target, serverAddr string) (string, e
}
// This use of ReadAll is now safe, because we limited the conn Reader.
data, err := io.ReadAll(connReadCloser)
var whoisData []byte
whoisData, err = io.ReadAll(r)
if err != nil {
return "", err
}
return string(data), nil
return string(whoisData), nil
}
// Query WHOIS servers (handle redirects)

View File

@ -2,7 +2,6 @@
package querylog
import (
"errors"
"fmt"
"net"
"os"
@ -11,6 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)

View File

@ -8,15 +8,15 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
// Timestamp not found errors.
const (
ErrTSNotFound agherr.Error = "ts not found"
ErrTSTooLate agherr.Error = "ts too late"
ErrTSTooEarly agherr.Error = "ts too early"
ErrTSNotFound errors.Error = "ts not found"
ErrTSTooLate errors.Error = "ts too late"
ErrTSTooEarly errors.Error = "ts too early"
)
// TODO: Find a way to grow buffer instead of relying on this value when reading strings

View File

@ -2,7 +2,6 @@ package querylog
import (
"encoding/binary"
"errors"
"fmt"
"io"
"math"
@ -12,6 +11,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

View File

@ -1,12 +1,11 @@
package querylog
import (
"errors"
"fmt"
"io"
"os"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
@ -154,7 +153,7 @@ func closeQFiles(qFiles []*QLogFile) error {
}
if len(errs) > 0 {
return agherr.Many("Error while closing QLogReader", errs...)
return errors.List("error while closing QLogReader", errs...)
}
return nil

View File

@ -6,8 +6,8 @@ import (
"path/filepath"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/miekg/dns"
)
@ -78,13 +78,13 @@ type AddParams struct {
func (p *AddParams) validate() (err error) {
switch {
case p.Question == nil:
return agherr.Error("question is nil")
return errors.Error("question is nil")
case len(p.Question.Question) != 1:
return agherr.Error("more than one question")
return errors.Error("more than one question")
case len(p.Question.Question[0].Name) == 0:
return agherr.Error("no host in question")
return errors.Error("no host in question")
case p.ClientIP == nil:
return agherr.Error("no client ip")
return errors.Error("no client ip")
default:
return nil
}

View File

@ -3,10 +3,10 @@ package querylog
import (
"bytes"
"encoding/json"
"errors"
"os"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
@ -39,7 +39,7 @@ func (l *queryLog) flushLogBuffer(fullFlush bool) error {
}
// flushToFile saves the specified log entries to the query log file
func (l *queryLog) flushToFile(buffer []*logEntry) error {
func (l *queryLog) flushToFile(buffer []*logEntry) (err error) {
if len(buffer) == 0 {
log.Debug("querylog: there's nothing to write to a file")
return nil
@ -49,9 +49,10 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
var b bytes.Buffer
e := json.NewEncoder(&b)
for _, entry := range buffer {
err := e.Encode(entry)
err = e.Encode(entry)
if err != nil {
log.Error("Failed to marshal entry: %s", err)
return err
}
}
@ -59,7 +60,6 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
elapsed := time.Since(start)
log.Debug("%d elements serialized via json in %v: %d kB, %v/entry, %v/entry", len(buffer), elapsed, b.Len()/1024, float64(b.Len())/float64(len(buffer)), elapsed/time.Duration(len(buffer)))
var err error
var zb bytes.Buffer
filename := l.logFile
zb = b
@ -71,7 +71,7 @@ func (l *queryLog) flushToFile(buffer []*logEntry) error {
log.Error("failed to create file \"%s\": %s", filename, err)
return err
}
defer f.Close()
defer func() { err = errors.WithDeferred(err, f.Close()) }()
n, err := f.Write(zb.Bytes())
if err != nil {
@ -109,7 +109,12 @@ func (l *queryLog) readFileFirstTimeValue() int64 {
if err != nil {
return -1
}
defer f.Close()
defer func() {
derr := f.Close()
if derr != nil {
log.Error("querylog: closing file: %s", derr)
}
}()
buf := make([]byte, 500)
r, err := f.Read(buf)

View File

@ -142,7 +142,12 @@ func (l *queryLog) searchFiles(
return entries, oldest, 0
}
defer r.Close()
defer func() {
derr := r.Close()
if derr != nil {
log.Error("querylog: closing file: %s", err)
}
}()
if params.olderThan.IsZero() {
err = r.SeekStart()

View File

@ -4,7 +4,6 @@ import (
"bytes"
"encoding/binary"
"encoding/gob"
"errors"
"fmt"
"net"
"os"
@ -12,7 +11,7 @@ import (
"sync"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
bolt "go.etcd.io/bbolt"
)
@ -93,7 +92,7 @@ func createObject(conf Config) (s *statsCtx, err error) {
// TODO(a.garipov): See if this is actually necessary. Looks
// like a rather bizarre solution.
errStop := agherr.Error("stop iteration")
errStop := errors.Error("stop iteration")
forEachBkt := func(name []byte, _ *bolt.Bucket) (cberr error) {
nameID := uint32(btoi(name))
if nameID < firstID {

View File

@ -4,10 +4,12 @@ import (
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/golibs/errors"
)
// TODO(a.garipov): Make configurable.
@ -27,7 +29,7 @@ const MaxResponseSize = 64 * 1024
// VersionInfo downloads the latest version information. If forceRecheck is
// false and there are cached results, those results are returned.
func (u *Updater) VersionInfo(forceRecheck bool) (VersionInfo, error) {
func (u *Updater) VersionInfo(forceRecheck bool) (vi VersionInfo, err error) {
u.mu.Lock()
defer u.mu.Unlock()
@ -37,22 +39,23 @@ func (u *Updater) VersionInfo(forceRecheck bool) (VersionInfo, error) {
return u.prevCheckResult, u.prevCheckError
}
var resp *http.Response
vcu := u.versionCheckURL
resp, err := u.client.Get(vcu)
resp, err = u.client.Get(vcu)
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
}
defer resp.Body.Close()
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxResponseSize)
var r io.Reader
r, err = aghio.LimitReader(resp.Body, MaxResponseSize)
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: LimitReadCloser: %w", err)
}
defer resp.Body.Close()
// This use of ReadAll is safe, because we just limited the appropriate
// ReadCloser.
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(r)
if err != nil {
return VersionInfo{}, fmt.Errorf("updater: HTTP GET %s: %w", vcu, err)
}

View File

@ -5,7 +5,6 @@ import (
"archive/tar"
"archive/zip"
"compress/gzip"
"errors"
"fmt"
"io"
"net/http"
@ -20,6 +19,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghio"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
)
@ -283,22 +283,23 @@ func (u *Updater) clean() {
const MaxPackageFileSize = 32 * 1024 * 1024
// Download package file and save it to disk
func (u *Updater) downloadPackageFile(url, filename string) error {
resp, err := u.client.Get(url)
func (u *Updater) downloadPackageFile(url, filename string) (err error) {
var resp *http.Response
resp, err = u.client.Get(url)
if err != nil {
return fmt.Errorf("http request failed: %w", err)
}
defer resp.Body.Close()
defer func() { err = errors.WithDeferred(err, resp.Body.Close()) }()
resp.Body, err = aghio.LimitReadCloser(resp.Body, MaxPackageFileSize)
var r io.Reader
r, err = aghio.LimitReader(resp.Body, MaxPackageFileSize)
if err != nil {
return fmt.Errorf("http request failed: %w", err)
}
defer resp.Body.Close()
log.Debug("updater: reading HTTP body")
// This use of ReadAll is now safe, because we limited body's Reader.
body, err := io.ReadAll(resp.Body)
body, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("io.ReadAll() failed: %w", err)
}
@ -313,172 +314,178 @@ func (u *Updater) downloadPackageFile(url, filename string) error {
return nil
}
func tarGzFileUnpackOne(outDir string, tr *tar.Reader, hdr *tar.Header) (name string, err error) {
name = filepath.Base(hdr.Name)
if name == "" {
return "", nil
}
outputName := filepath.Join(outDir, name)
if hdr.Typeflag == tar.TypeDir {
if name == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it.
//
// TODO(a.garipov): This whole package needs to be
// rewritten and covered in more integration tests. It
// has weird assumptions and file mode issues.
return "", nil
}
err = os.Mkdir(outputName, os.FileMode(hdr.Mode&0o777))
if err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
}
log.Debug("updater: created directory %q", outputName)
return "", nil
}
if hdr.Typeflag != tar.TypeReg {
log.Debug("updater: %s: unknown file type %d, skipping", name, hdr.Typeflag)
return "", nil
}
var wc io.WriteCloser
wc, err = os.OpenFile(
outputName,
os.O_WRONLY|os.O_CREATE|os.O_TRUNC,
os.FileMode(hdr.Mode&0o777),
)
if err != nil {
return "", fmt.Errorf("os.OpenFile(%s): %w", outputName, err)
}
defer func() { err = errors.WithDeferred(err, wc.Close()) }()
_, err = io.Copy(wc, tr)
if err != nil {
return "", fmt.Errorf("io.Copy(): %w", err)
}
log.Tracef("updater: created file %s", outputName)
return name, nil
}
// Unpack all files from .tar.gz file to the specified directory
// Existing files are overwritten
// All files are created inside 'outdir', subdirectories are not created
// All files are created inside outDir, subdirectories are not created
// Return the list of files (not directories) written
func tarGzFileUnpack(tarfile, outdir string) ([]string, error) {
func tarGzFileUnpack(tarfile, outDir string) (files []string, err error) {
f, err := os.Open(tarfile)
if err != nil {
return nil, fmt.Errorf("os.Open(): %w", err)
}
defer func() {
_ = f.Close()
}()
defer func() { err = errors.WithDeferred(err, f.Close()) }()
gzReader, err := gzip.NewReader(f)
if err != nil {
return nil, fmt.Errorf("gzip.NewReader(): %w", err)
}
defer func() { err = errors.WithDeferred(err, gzReader.Close()) }()
var files []string
var err2 error
tarReader := tar.NewReader(gzReader)
for {
var header *tar.Header
header, err = tarReader.Next()
if err == io.EOF {
err2 = nil
var hdr *tar.Header
hdr, err = tarReader.Next()
if errors.Is(err, io.EOF) {
err = nil
break
}
if err != nil {
err2 = fmt.Errorf("tarReader.Next(): %w", err)
} else if err != nil {
err = fmt.Errorf("tarReader.Next(): %w", err)
break
}
_, inputNameOnly := filepath.Split(header.Name)
if inputNameOnly == "" {
continue
var name string
name, err = tarGzFileUnpackOne(outDir, tarReader, hdr)
if name != "" {
files = append(files, name)
}
outputName := filepath.Join(outdir, inputNameOnly)
if header.Typeflag == tar.TypeDir {
if inputNameOnly == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it.
//
// TODO(a.garipov): This whole package needs to
// be rewritten and covered in more integration
// tests. It has weird assumptions and file
// mode issues.
continue
}
err = os.Mkdir(outputName, os.FileMode(header.Mode&0o777))
if err != nil && !errors.Is(err, os.ErrExist) {
err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
break
}
log.Debug("updater: created directory %q", outputName)
continue
} else if header.Typeflag != tar.TypeReg {
log.Debug("updater: %s: unknown file type %d, skipping", inputNameOnly, header.Typeflag)
continue
}
var f io.WriteCloser
f, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode&0o777))
if err != nil {
err2 = fmt.Errorf("os.OpenFile(%s): %w", outputName, err)
break
}
_, err = io.Copy(f, tarReader)
if err != nil {
_ = f.Close()
err2 = fmt.Errorf("io.Copy(): %w", err)
break
}
err = f.Close()
if err != nil {
err2 = fmt.Errorf("f.Close(): %w", err)
break
}
log.Debug("updater: created file %s", outputName)
files = append(files, header.Name)
}
_ = gzReader.Close()
return files, err2
return files, err
}
func zipFileUnpackOne(outDir string, zf *zip.File) (name string, err error) {
var rc io.ReadCloser
rc, err = zf.Open()
if err != nil {
return "", fmt.Errorf("zip file Open(): %w", err)
}
defer func() { err = errors.WithDeferred(err, rc.Close()) }()
fi := zf.FileInfo()
name = fi.Name()
if name == "" {
return "", nil
}
outputName := filepath.Join(outDir, name)
if fi.IsDir() {
if name == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it.
//
// TODO(a.garipov): See the similar todo in
// tarGzFileUnpack.
return "", nil
}
err = os.Mkdir(outputName, fi.Mode())
if err != nil && !errors.Is(err, os.ErrExist) {
return "", fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
}
log.Tracef("created directory %q", outputName)
return "", nil
}
var wc io.WriteCloser
wc, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
if err != nil {
return "", fmt.Errorf("os.OpenFile(): %w", err)
}
defer func() { err = errors.WithDeferred(err, wc.Close()) }()
_, err = io.Copy(wc, rc)
if err != nil {
return "", fmt.Errorf("io.Copy(): %w", err)
}
log.Tracef("created file %s", outputName)
return name, nil
}
// Unpack all files from .zip file to the specified directory
// Existing files are overwritten
// All files are created inside 'outdir', subdirectories are not created
// All files are created inside 'outDir', subdirectories are not created
// Return the list of files (not directories) written
func zipFileUnpack(zipfile, outdir string) ([]string, error) {
r, err := zip.OpenReader(zipfile)
func zipFileUnpack(zipfile, outDir string) (files []string, err error) {
zrc, err := zip.OpenReader(zipfile)
if err != nil {
return nil, fmt.Errorf("zip.OpenReader(): %w", err)
}
defer r.Close()
defer func() { err = errors.WithDeferred(err, zrc.Close()) }()
var files []string
var err2 error
var zr io.ReadCloser
for _, zf := range r.File {
zr, err = zf.Open()
for _, zf := range zrc.File {
var name string
name, err = zipFileUnpackOne(outDir, zf)
if err != nil {
err2 = fmt.Errorf("zip file Open(): %w", err)
break
}
fi := zf.FileInfo()
inputNameOnly := fi.Name()
if inputNameOnly == "" {
continue
if name != "" {
files = append(files, name)
}
outputName := filepath.Join(outdir, inputNameOnly)
if fi.IsDir() {
if inputNameOnly == "AdGuardHome" {
// Top-level AdGuardHome/. Skip it.
//
// TODO(a.garipov): See the similar todo in
// tarGzFileUnpack.
continue
}
err = os.Mkdir(outputName, fi.Mode())
if err != nil && !errors.Is(err, os.ErrExist) {
err2 = fmt.Errorf("os.Mkdir(%q): %w", outputName, err)
break
}
log.Tracef("created directory %q", outputName)
continue
}
var f io.WriteCloser
f, err = os.OpenFile(outputName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, fi.Mode())
if err != nil {
err2 = fmt.Errorf("os.OpenFile(): %w", err)
break
}
_, err = io.Copy(f, zr)
if err != nil {
_ = f.Close()
err2 = fmt.Errorf("io.Copy(): %w", err)
break
}
err = f.Close()
if err != nil {
err2 = fmt.Errorf("f.Close(): %w", err)
break
}
log.Tracef("created file %s", outputName)
files = append(files, inputNameOnly)
}
_ = zr.Close()
return files, err2
return files, err
}
// Copy file on disk

View File

@ -75,10 +75,10 @@ esac
# Simple Analyzers
# blocklist_imports is a simple check against unwanted packages. Package
# io/ioutil is soft-deprecated. Package log is replaced by our own package
# github.com/AdguardTeam/golibs/log.
# io/ioutil is soft-deprecated. Packages errors and log are replaced by our own
# packages in the github.com/AdguardTeam/golibs module.
blocklist_imports() {
git grep -F -e '"io/ioutil"' -e '"log"' -- '*.go' || exit 0;
git grep -F -e '"errors"' -e '"io/ioutil"' -e '"log"' -- '*.go' || exit 0;
}
# method_const is a simple check against the usage of some raw strings and
@ -192,13 +192,7 @@ nilness ./...
exit_on_output shadow --strict ./...
# TODO(a.garipov): Enable errcheck fully after handling all errors, including
# the deferred and generated ones, properly. Also, perhaps, enable --blank.
#
# errcheck ./...
exit_on_output sh -c '
errcheck --asserts --ignoregenerated ./... |\
{ grep -e "defer" -v || exit 0; }
'
# TODO(a.garipov): Enable --blank?
errcheck --asserts ./...
staticcheck ./...