Pull request 1814: AG-21291-web-races

Merge in DNS/adguard-home from AG-21291-web-races to master

Squashed commit of the following:

commit 1134013f928aa5e186db3b6d0450e425cb053e9c
Author: Ainar Garipov <A.Garipov@AdGuard.COM>
Date:   Tue Apr 11 16:52:52 2023 +0300

    home: fix web api races
This commit is contained in:
Ainar Garipov 2023-04-11 17:22:51 +03:00
parent 230d7b8c17
commit 0376afb38e
5 changed files with 45 additions and 38 deletions

View File

@ -332,13 +332,17 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
return false return false
} }
var serveHTTP3 bool var (
var portHTTPS int forceHTTPS bool
serveHTTP3 bool
portHTTPS int
)
func() { func() {
config.RLock() config.RLock()
defer config.RUnlock() defer config.RUnlock()
serveHTTP3, portHTTPS = config.DNS.ServeHTTP3, config.TLS.PortHTTPS serveHTTP3, portHTTPS = config.DNS.ServeHTTP3, config.TLS.PortHTTPS
forceHTTPS = config.TLS.ForceHTTPS && config.TLS.Enabled && config.TLS.PortHTTPS != 0
}() }()
respHdr := w.Header() respHdr := w.Header()
@ -354,10 +358,10 @@ func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (ok bool) {
respHdr.Set(httphdr.AltSvc, altSvc) respHdr.Set(httphdr.AltSvc, altSvc)
} }
if r.TLS == nil && web.forceHTTPS { if r.TLS == nil && forceHTTPS {
hostPort := host hostPort := host
if port := web.conf.PortHTTPS; port != defaultPortHTTPS { if portHTTPS != defaultPortHTTPS {
hostPort = netutil.JoinHostPort(host, port) hostPort = netutil.JoinHostPort(host, portHTTPS)
} }
httpsURL := &url.URL{ httpsURL := &url.URL{

View File

@ -39,7 +39,7 @@ type getAddrsResponse struct {
} }
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint. // handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
func (web *Web) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) { func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
data := getAddrsResponse{ data := getAddrsResponse{
Version: version.Version(), Version: version.Version(),
@ -167,7 +167,7 @@ func (req *checkConfReq) validateDNS(
} }
// handleInstallCheckConfig handles the /check_config endpoint. // handleInstallCheckConfig handles the /check_config endpoint.
func (web *Web) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) { func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
req := &checkConfReq{} req := &checkConfReq{}
err := json.NewDecoder(r.Body).Decode(req) err := json.NewDecoder(r.Body).Decode(req)
@ -375,7 +375,7 @@ func shutdownSrv3(srv *http3.Server) {
const PasswordMinRunes = 8 const PasswordMinRunes = 8
// Apply new configuration, start DNS server, restart Web server // Apply new configuration, start DNS server, restart Web server
func (web *Web) handleInstallConfigure(w http.ResponseWriter, r *http.Request) { func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
req, restartHTTP, err := decodeApplyConfigReq(r.Body) req, restartHTTP, err := decodeApplyConfigReq(r.Body)
if err != nil { if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
@ -503,7 +503,7 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
return req, restartHTTP, err return req, restartHTTP, err
} }
func (web *Web) registerInstallHandlers() { func (web *webAPI) registerInstallHandlers() {
Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses))) Context.mux.HandleFunc("/control/install/get_addresses", preInstall(ensureGET(web.handleInstallGetAddresses)))
Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig))) Context.mux.HandleFunc("/control/install/check_config", preInstall(ensurePOST(web.handleInstallCheckConfig)))
Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure))) Context.mux.HandleFunc("/control/install/configure", preInstall(ensurePOST(web.handleInstallConfigure)))

View File

@ -58,7 +58,7 @@ type homeContext struct {
dhcpServer dhcpd.Interface // DHCP module dhcpServer dhcpd.Interface // DHCP module
auth *Auth // HTTP authentication module auth *Auth // HTTP authentication module
filters *filtering.DNSFilter // DNS filtering module filters *filtering.DNSFilter // DNS filtering module
web *Web // Web (HTTP, HTTPS) module web *webAPI // Web (HTTP, HTTPS) module
tls *tlsManager // TLS module tls *tlsManager // TLS module
// etcHosts contains IP-hostname mappings taken from the OS-specific hosts // etcHosts contains IP-hostname mappings taken from the OS-specific hosts
@ -387,7 +387,7 @@ func checkPorts() (err error) {
return nil return nil
} }
func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) { func initWeb(opts options, clientBuildFS fs.FS) (web *webAPI, err error) {
var clientFS fs.FS var clientFS fs.FS
if opts.localFrontend { if opts.localFrontend {
log.Info("warning: using local frontend files") log.Info("warning: using local frontend files")
@ -414,7 +414,7 @@ func initWeb(opts options, clientBuildFS fs.FS) (web *Web, err error) {
serveHTTP3: config.DNS.ServeHTTP3, serveHTTP3: config.DNS.ServeHTTP3,
} }
web = newWeb(&webConf) web = newWebAPI(&webConf)
if web == nil { if web == nil {
return nil, fmt.Errorf("initializing web: %w", err) return nil, fmt.Errorf("initializing web: %w", err)
} }
@ -533,7 +533,7 @@ func run(opts options, clientBuildFS fs.FS) {
} }
} }
Context.web.Start() Context.web.start()
// wait indefinitely for other go-routines to complete their job // wait indefinitely for other go-routines to complete their job
select {} select {}
@ -713,7 +713,7 @@ func cleanup(ctx context.Context) {
log.Info("stopping AdGuard Home") log.Info("stopping AdGuard Home")
if Context.web != nil { if Context.web != nil {
Context.web.Close(ctx) Context.web.close(ctx)
Context.web = nil Context.web = nil
} }
if Context.auth != nil { if Context.auth != nil {

View File

@ -108,7 +108,7 @@ func (m *tlsManager) start() {
// The background context is used because the TLSConfigChanged wraps context // The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current // with timeout on its own and shuts down the server, which handles current
// request. // request.
Context.web.TLSConfigChanged(context.Background(), tlsConf) Context.web.tlsConfigChanged(context.Background(), tlsConf)
} }
// reload updates the configuration and restarts t. // reload updates the configuration and restarts t.
@ -156,7 +156,7 @@ func (m *tlsManager) reload() {
// The background context is used because the TLSConfigChanged wraps context // The background context is used because the TLSConfigChanged wraps context
// with timeout on its own and shuts down the server, which handles current // with timeout on its own and shuts down the server, which handles current
// request. // request.
Context.web.TLSConfigChanged(context.Background(), tlsConf) Context.web.tlsConfigChanged(context.Background(), tlsConf)
} }
// loadTLSConf loads and validates the TLS configuration. The returned error is // loadTLSConf loads and validates the TLS configuration. The returned error is
@ -454,7 +454,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
// same reason. // same reason.
if restartHTTPS { if restartHTTPS {
go func() { go func() {
Context.web.TLSConfigChanged(context.Background(), req.tlsConfigSettings) Context.web.tlsConfigChanged(context.Background(), req.tlsConfigSettings)
}() }()
} }
} }

View File

@ -37,7 +37,6 @@ type webConfig struct {
BindHost netip.Addr BindHost netip.Addr
BindPort int BindPort int
PortHTTPS int
// ReadTimeout is an option to pass to http.Server for setting an // ReadTimeout is an option to pass to http.Server for setting an
// appropriate field. // appropriate field.
@ -72,8 +71,8 @@ type httpsServer struct {
enabled bool enabled bool
} }
// Web is the web UI and API server. // webAPI is the web UI and API server.
type Web struct { type webAPI struct {
conf *webConfig conf *webConfig
// TODO(a.garipov): Refactor all these servers. // TODO(a.garipov): Refactor all these servers.
@ -82,15 +81,13 @@ type Web struct {
// httpsServer is the server that handles HTTPS traffic. If it is not nil, // httpsServer is the server that handles HTTPS traffic. If it is not nil,
// [Web.http3Server] must also not be nil. // [Web.http3Server] must also not be nil.
httpsServer httpsServer httpsServer httpsServer
forceHTTPS bool
} }
// newWeb creates a new instance of the web UI and API server. // newWebAPI creates a new instance of the web UI and API server.
func newWeb(conf *webConfig) (w *Web) { func newWebAPI(conf *webConfig) (w *webAPI) {
log.Info("web: initializing") log.Info("web: initializing")
w = &Web{ w = &webAPI{
conf: conf, conf: conf,
} }
@ -125,12 +122,10 @@ func webCheckPortAvailable(port int) (ok bool) {
return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil return aghnet.CheckPort("tcp", netip.AddrPortFrom(config.BindHost, uint16(port))) == nil
} }
// TLSConfigChanged updates the TLS configuration and restarts the HTTPS server // tlsConfigChanged updates the TLS configuration and restarts the HTTPS server
// if necessary. // if necessary.
func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) { func (web *webAPI) tlsConfigChanged(ctx context.Context, tlsConf tlsConfigSettings) {
log.Debug("web: applying new tls configuration") log.Debug("web: applying new tls configuration")
web.conf.PortHTTPS = tlsConf.PortHTTPS
web.forceHTTPS = (tlsConf.ForceHTTPS && tlsConf.Enabled && tlsConf.PortHTTPS != 0)
enabled := tlsConf.Enabled && enabled := tlsConf.Enabled &&
tlsConf.PortHTTPS != 0 && tlsConf.PortHTTPS != 0 &&
@ -161,8 +156,8 @@ func (web *Web) TLSConfigChanged(ctx context.Context, tlsConf tlsConfigSettings)
web.httpsServer.cond.L.Unlock() web.httpsServer.cond.L.Unlock()
} }
// Start - start serving HTTP requests // start - start serving HTTP requests
func (web *Web) Start() { func (web *webAPI) start() {
log.Println("AdGuard Home is available at the following addresses:") log.Println("AdGuard Home is available at the following addresses:")
// for https, we have a separate goroutine loop // for https, we have a separate goroutine loop
@ -203,8 +198,8 @@ func (web *Web) Start() {
} }
} }
// Close gracefully shuts down the HTTP servers. // close gracefully shuts down the HTTP servers.
func (web *Web) Close(ctx context.Context) { func (web *webAPI) close(ctx context.Context) {
log.Info("stopping http server...") log.Info("stopping http server...")
web.httpsServer.cond.L.Lock() web.httpsServer.cond.L.Lock()
@ -222,7 +217,7 @@ func (web *Web) Close(ctx context.Context) {
log.Info("stopped http server") log.Info("stopped http server")
} }
func (web *Web) tlsServerLoop() { func (web *webAPI) tlsServerLoop() {
for { for {
web.httpsServer.cond.L.Lock() web.httpsServer.cond.L.Lock()
if web.httpsServer.inShutdown { if web.httpsServer.inShutdown {
@ -241,7 +236,15 @@ func (web *Web) tlsServerLoop() {
web.httpsServer.cond.L.Unlock() web.httpsServer.cond.L.Unlock()
addr := netutil.JoinHostPort(web.conf.BindHost.String(), web.conf.PortHTTPS) var portHTTPS int
func() {
config.RLock()
defer config.RUnlock()
portHTTPS = config.TLS.PortHTTPS
}()
addr := netutil.JoinHostPort(web.conf.BindHost.String(), portHTTPS)
web.httpsServer.server = &http.Server{ web.httpsServer.server = &http.Server{
ErrorLog: log.StdLog("web: https", log.DEBUG), ErrorLog: log.StdLog("web: https", log.DEBUG),
Addr: addr, Addr: addr,
@ -272,7 +275,7 @@ func (web *Web) tlsServerLoop() {
} }
} }
func (web *Web) mustStartHTTP3(address string) { func (web *webAPI) mustStartHTTP3(address string) {
defer log.OnPanic("web: http3") defer log.OnPanic("web: http3")
web.httpsServer.server3 = &http3.Server{ web.httpsServer.server3 = &http3.Server{