mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-11-15 18:08:30 -07:00
+(dhcpd): added static IP for MacOS
This commit is contained in:
parent
7afa16fbe7
commit
c27852537d
@ -467,6 +467,6 @@
|
|||||||
"static_ip_desc": "AdGuard Home is a server so it needs a static IP address to function properly. Otherwise, at some point, your router may assign a different IP address to this device.",
|
"static_ip_desc": "AdGuard Home is a server so it needs a static IP address to function properly. Otherwise, at some point, your router may assign a different IP address to this device.",
|
||||||
"set_static_ip": "Set a static IP address",
|
"set_static_ip": "Set a static IP address",
|
||||||
"install_static_ok": "Good news! The static IP address is already configured",
|
"install_static_ok": "Good news! The static IP address is already configured",
|
||||||
"install_static_error": "AdGuard Home cannot configure it automatically for your OS. Please look for an instruction on how to do this manually",
|
"install_static_error": "AdGuard Home cannot configure it automatically for this network interface. Please look for an instruction on how to do this manually.",
|
||||||
"install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}</0>. Do you want to use it as your static address?"
|
"install_static_configure": "We have detected that a dynamic IP address is used — <0>{{ip}}</0>. Do you want to use it as your static address?"
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -121,7 +123,7 @@ type netInterfaceJSON struct {
|
|||||||
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||||
response := map[string]interface{}{}
|
response := map[string]interface{}{}
|
||||||
|
|
||||||
ifaces, err := GetValidNetInterfaces()
|
ifaces, err := util.GetValidNetInterfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
httpError(r, w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||||
return
|
return
|
||||||
@ -219,7 +221,7 @@ func (s *Server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
|||||||
staticIP["error"] = err.Error()
|
staticIP["error"] = err.Error()
|
||||||
} else if !isStaticIP {
|
} else if !isStaticIP {
|
||||||
staticIPStatus = "no"
|
staticIPStatus = "no"
|
||||||
staticIP["ip"] = GetFullIP(interfaceName)
|
staticIP["ip"] = util.GetSubnet(interfaceName)
|
||||||
}
|
}
|
||||||
staticIP["static"] = staticIPStatus
|
staticIP["static"] = staticIPStatus
|
||||||
|
|
||||||
|
@ -6,37 +6,17 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/file"
|
"github.com/AdguardTeam/golibs/file"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP
|
|
||||||
// invalid interface is a ppp interface or the one that doesn't allow broadcasts
|
|
||||||
func GetValidNetInterfaces() ([]net.Interface, error) {
|
|
||||||
ifaces, err := net.Interfaces()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
netIfaces := []net.Interface{}
|
|
||||||
|
|
||||||
for i := range ifaces {
|
|
||||||
if ifaces[i].Flags&net.FlagPointToPoint != 0 {
|
|
||||||
// this interface is ppp, we're not interested in this one
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
iface := ifaces[i]
|
|
||||||
netIfaces = append(netIfaces, iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
return netIfaces, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if network interface has a static IP configured
|
// Check if network interface has a static IP configured
|
||||||
// Supports: Raspbian.
|
// Supports: Raspbian.
|
||||||
func HasStaticIP(ifaceName string) (bool, error) {
|
func HasStaticIP(ifaceName string) (bool, error) {
|
||||||
@ -56,54 +36,18 @@ func HasStaticIP(ifaceName string) (bool, error) {
|
|||||||
return false, fmt.Errorf("Cannot check if IP is static: not supported on %s", runtime.GOOS)
|
return false, fmt.Errorf("Cannot check if IP is static: not supported on %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get IP address with netmask
|
// Set a static IP for the specified network interface
|
||||||
func GetFullIP(ifaceName string) string {
|
|
||||||
cmd := exec.Command("ip", "-oneline", "-family", "inet", "address", "show", ifaceName)
|
|
||||||
log.Tracef("executing %s %v", cmd.Path, cmd.Args)
|
|
||||||
d, err := cmd.Output()
|
|
||||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := strings.Fields(string(d))
|
|
||||||
if len(fields) < 4 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
_, _, err = net.ParseCIDR(fields[3])
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return fields[3]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set a static IP for network interface
|
|
||||||
// Supports: Raspbian.
|
|
||||||
func SetStaticIP(ifaceName string) error {
|
func SetStaticIP(ifaceName string) error {
|
||||||
ip := GetFullIP(ifaceName)
|
if runtime.GOOS == "linux" {
|
||||||
if len(ip) == 0 {
|
return setStaticIPDhcpdConf(ifaceName)
|
||||||
return errors.New("Can't get IP address")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ip4, _, err := net.ParseCIDR(ip)
|
if runtime.GOOS == "darwin" {
|
||||||
if err != nil {
|
return fmt.Errorf("cannot do that")
|
||||||
return err
|
// return setStaticIPDarwin(ifaceName)
|
||||||
}
|
|
||||||
gatewayIP := getGatewayIP(ifaceName)
|
|
||||||
add := setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String())
|
|
||||||
|
|
||||||
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
body = append(body, []byte(add)...)
|
return fmt.Errorf("Cannot set static IP on %s", runtime.GOOS)
|
||||||
err = file.SafeWrite("/etc/dhcpcd.conf", body)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// for dhcpcd.conf
|
// for dhcpcd.conf
|
||||||
@ -167,8 +111,37 @@ func getGatewayIP(ifaceName string) string {
|
|||||||
return fields[2]
|
return fields[2]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setStaticIPDhcpdConf - updates /etc/dhcpd.conf and sets the current IP address to be static
|
||||||
|
func setStaticIPDhcpdConf(ifaceName string) error {
|
||||||
|
ip := util.GetSubnet(ifaceName)
|
||||||
|
if len(ip) == 0 {
|
||||||
|
return errors.New("Can't get IP address")
|
||||||
|
}
|
||||||
|
|
||||||
|
ip4, _, err := net.ParseCIDR(ip)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
gatewayIP := getGatewayIP(ifaceName)
|
||||||
|
add := updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, ip4.String())
|
||||||
|
|
||||||
|
body, err := ioutil.ReadFile("/etc/dhcpcd.conf")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = append(body, []byte(add)...)
|
||||||
|
err = file.SafeWrite("/etc/dhcpcd.conf", body)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updates dhcpd.conf content -- sets static IP address there
|
||||||
// for dhcpcd.conf
|
// for dhcpcd.conf
|
||||||
func setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string {
|
func updateStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string {
|
||||||
var body []byte
|
var body []byte
|
||||||
|
|
||||||
add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n",
|
add := fmt.Sprintf("\ninterface %s\nstatic ip_address=%s\n",
|
||||||
@ -187,3 +160,154 @@ func setStaticIPDhcpcdConf(ifaceName, ip, gatewayIP, dnsIP string) string {
|
|||||||
|
|
||||||
return string(body)
|
return string(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if network interface has a static IP configured
|
||||||
|
// Supports: MacOS.
|
||||||
|
func hasStaticIPDarwin(ifaceName string) (bool, error) {
|
||||||
|
portInfo, err := getCurrentHardwarePortInfo(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return portInfo.static, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setStaticIPDarwin - uses networksetup util to set the current IP address to be static
|
||||||
|
// Additionally it configures the current DNS servers as well
|
||||||
|
func setStaticIPDarwin(ifaceName string) error {
|
||||||
|
portInfo, err := getCurrentHardwarePortInfo(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInfo.static {
|
||||||
|
return errors.New("IP address is already static")
|
||||||
|
}
|
||||||
|
|
||||||
|
dnsAddrs, err := getEtcResolvConfServers()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
args := make([]string, 0)
|
||||||
|
args = append(args, "-setdnsservers")
|
||||||
|
args = append(args, dnsAddrs...)
|
||||||
|
|
||||||
|
// Setting DNS servers is necessary when configuring a static IP
|
||||||
|
code, _, err := util.RunCommand("networksetup", args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if code != 0 {
|
||||||
|
return fmt.Errorf("Failed to set DNS servers, code=%d", code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Actually configures hardware port to have static IP
|
||||||
|
code, _, err = util.RunCommand("networksetup", "-setmanual",
|
||||||
|
portInfo.name, portInfo.ip, portInfo.subnet, portInfo.gatewayIP)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if code != 0 {
|
||||||
|
return fmt.Errorf("Failed to set DNS servers, code=%d", code)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCurrentHardwarePortInfo gets information the specified network interface
|
||||||
|
func getCurrentHardwarePortInfo(ifaceName string) (hardwarePortInfo, error) {
|
||||||
|
// First of all we should find hardware port name
|
||||||
|
m := getNetworkSetupHardwareReports()
|
||||||
|
hardwarePort, ok := m[ifaceName]
|
||||||
|
if !ok {
|
||||||
|
return hardwarePortInfo{}, fmt.Errorf("Could not find hardware port for %s", ifaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return getHardwarePortInfo(hardwarePort)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNetworkSetupHardwareReports parses the output of the `networksetup -listallhardwareports` command
|
||||||
|
// it returns a map where the key is the interface name, and the value is the "hardware port"
|
||||||
|
// returns nil if it fails to parse the output
|
||||||
|
func getNetworkSetupHardwareReports() map[string]string {
|
||||||
|
_, out, err := util.RunCommand("networksetup", "-listallhardwareports")
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
re, err := regexp.Compile("Hardware Port: (.*?)\nDevice: (.*?)\n")
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m := make(map[string]string, 0)
|
||||||
|
|
||||||
|
matches := re.FindAllStringSubmatch(out, -1)
|
||||||
|
for i := range matches {
|
||||||
|
port := matches[i][1]
|
||||||
|
device := matches[i][2]
|
||||||
|
m[device] = port
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// hardwarePortInfo - information obtained using MacOS networksetup
|
||||||
|
// about the current state of the internet connection
|
||||||
|
type hardwarePortInfo struct {
|
||||||
|
name string
|
||||||
|
ip string
|
||||||
|
subnet string
|
||||||
|
gatewayIP string
|
||||||
|
static bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func getHardwarePortInfo(hardwarePort string) (hardwarePortInfo, error) {
|
||||||
|
h := hardwarePortInfo{}
|
||||||
|
|
||||||
|
_, out, err := util.RunCommand("networksetup", "-getinfo", hardwarePort)
|
||||||
|
if err != nil {
|
||||||
|
return h, err
|
||||||
|
}
|
||||||
|
|
||||||
|
re := regexp.MustCompile("IP address: (.*?)\nSubnet mask: (.*?)\nRouter: (.*?)\n")
|
||||||
|
|
||||||
|
match := re.FindStringSubmatch(out)
|
||||||
|
if len(match) == 0 {
|
||||||
|
return h, errors.New("Could not find hardware port info")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.name = hardwarePort
|
||||||
|
h.ip = match[1]
|
||||||
|
h.subnet = match[2]
|
||||||
|
h.gatewayIP = match[3]
|
||||||
|
|
||||||
|
if strings.Index(out, "Manual Configuration") == 0 {
|
||||||
|
h.static = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gets a list of nameservers currently configured in the /etc/resolv.conf
|
||||||
|
func getEtcResolvConfServers() ([]string, error) {
|
||||||
|
body, err := ioutil.ReadFile("/etc/resolv.conf")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
re := regexp.MustCompile("nameserver ([a-zA-Z0-9.:]+)")
|
||||||
|
|
||||||
|
matches := re.FindAllStringSubmatch(string(body), -1)
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return nil, errors.New("Found no DNS servers in /etc/resolv.conf")
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs := make([]string, 0)
|
||||||
|
for i := range matches {
|
||||||
|
addrs = append(addrs, matches[i][1])
|
||||||
|
}
|
||||||
|
|
||||||
|
return addrs, nil
|
||||||
|
}
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
package dhcpd
|
|
||||||
|
|
||||||
// Check if network interface has a static IP configured
|
|
||||||
// Supports: Raspbian.
|
|
||||||
func hasStaticIPDarwin(ifaceName string) (bool, error) {
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
@ -46,7 +46,7 @@ static routers=192.168.0.1
|
|||||||
static domain_name_servers=192.168.0.2
|
static domain_name_servers=192.168.0.2
|
||||||
|
|
||||||
`
|
`
|
||||||
s := setStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2")
|
s := updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "192.168.0.1", "192.168.0.2")
|
||||||
assert.Equal(t, dhcpcdConf, s)
|
assert.Equal(t, dhcpcdConf, s)
|
||||||
|
|
||||||
// without gateway
|
// without gateway
|
||||||
@ -56,6 +56,6 @@ static ip_address=192.168.0.2/24
|
|||||||
static domain_name_servers=192.168.0.2
|
static domain_name_servers=192.168.0.2
|
||||||
|
|
||||||
`
|
`
|
||||||
s = setStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2")
|
s = updateStaticIPDhcpcdConf("wlan0", "192.168.0.2/24", "", "192.168.0.2")
|
||||||
assert.Equal(t, dhcpcdConf, s)
|
assert.Equal(t, dhcpcdConf, s)
|
||||||
}
|
}
|
||||||
|
125
home/control.go
125
home/control.go
@ -3,7 +3,13 @@ package home
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
"github.com/AdguardTeam/AdGuardHome/dnsforward"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
@ -54,8 +60,7 @@ func getDNSAddresses() []string {
|
|||||||
dnsAddresses := []string{}
|
dnsAddresses := []string{}
|
||||||
|
|
||||||
if config.DNS.BindHost == "0.0.0.0" {
|
if config.DNS.BindHost == "0.0.0.0" {
|
||||||
|
ifaces, e := util.GetValidNetInterfacesForWeb()
|
||||||
ifaces, e := getValidNetInterfacesForWeb()
|
|
||||||
if e != nil {
|
if e != nil {
|
||||||
log.Error("Couldn't get network interfaces: %v", e)
|
log.Error("Couldn't get network interfaces: %v", e)
|
||||||
return []string{}
|
return []string{}
|
||||||
@ -66,7 +71,6 @@ func getDNSAddresses() []string {
|
|||||||
addDNSAddress(&dnsAddresses, addr)
|
addDNSAddress(&dnsAddresses, addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
addDNSAddress(&dnsAddresses, config.DNS.BindHost)
|
||||||
}
|
}
|
||||||
@ -180,3 +184,118 @@ func registerControlHandlers() {
|
|||||||
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
func httpRegister(method string, url string, handler func(http.ResponseWriter, *http.Request)) {
|
||||||
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
http.Handle(url, postInstallHandler(optionalAuthHandler(gziphandler.GzipHandler(ensureHandler(method, handler)))))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ----------------------------------
|
||||||
|
// helper functions for HTTP handlers
|
||||||
|
// ----------------------------------
|
||||||
|
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Debug("%s %v", r.Method, r.URL)
|
||||||
|
|
||||||
|
if r.Method != method {
|
||||||
|
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if method == "POST" || method == "PUT" || method == "DELETE" {
|
||||||
|
Context.controlLock.Lock()
|
||||||
|
defer Context.controlLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
handler(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||||
|
return ensure("POST", handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||||
|
return ensure("GET", handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bridge between http.Handler object and Go function
|
||||||
|
type httpHandler struct {
|
||||||
|
handler func(http.ResponseWriter, *http.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
h.handler(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler {
|
||||||
|
h := httpHandler{}
|
||||||
|
h.handler = ensure(method, handler)
|
||||||
|
return &h
|
||||||
|
}
|
||||||
|
|
||||||
|
// preInstall lets the handler run only if firstRun is true, no redirects
|
||||||
|
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !Context.firstRun {
|
||||||
|
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
|
||||||
|
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
|
||||||
|
type preInstallHandlerStruct struct {
|
||||||
|
handler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
preInstall(p.handler.ServeHTTP)(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// preInstallHandler returns http.Handler interface for preInstall wrapper
|
||||||
|
func preInstallHandler(handler http.Handler) http.Handler {
|
||||||
|
return &preInstallHandlerStruct{handler}
|
||||||
|
}
|
||||||
|
|
||||||
|
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
|
||||||
|
// it also enforces HTTPS if it is enabled and configured
|
||||||
|
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if Context.firstRun &&
|
||||||
|
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
||||||
|
r.URL.Path != "/favicon.png" {
|
||||||
|
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// enforce https?
|
||||||
|
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
|
||||||
|
// yes, and we want host from host:port
|
||||||
|
host, _, err := net.SplitHostPort(r.Host)
|
||||||
|
if err != nil {
|
||||||
|
// no port in host
|
||||||
|
host = r.Host
|
||||||
|
}
|
||||||
|
// construct new URL to redirect to
|
||||||
|
newURL := url.URL{
|
||||||
|
Scheme: "https",
|
||||||
|
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)),
|
||||||
|
Path: r.URL.Path,
|
||||||
|
RawQuery: r.URL.RawQuery,
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
handler(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type postInstallHandlerStruct struct {
|
||||||
|
handler http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
postInstall(p.handler.ServeHTTP)(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func postInstallHandler(handler http.Handler) http.Handler {
|
||||||
|
return &postInstallHandlerStruct{handler}
|
||||||
|
}
|
||||||
|
@ -13,6 +13,8 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
@ -38,7 +40,7 @@ func handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
|||||||
data.WebPort = 80
|
data.WebPort = 80
|
||||||
data.DNSPort = 53
|
data.DNSPort = 53
|
||||||
|
|
||||||
ifaces, err := getValidNetInterfacesForWeb()
|
ifaces, err := util.GetValidNetInterfacesForWeb()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
httpError(w, http.StatusInternalServerError, "Couldn't get interfaces: %s", err)
|
||||||
return
|
return
|
||||||
@ -101,16 +103,16 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
|
if reqData.Web.Port != 0 && reqData.Web.Port != config.BindPort {
|
||||||
err = checkPortAvailable(reqData.Web.IP, reqData.Web.Port)
|
err = util.CheckPortAvailable(reqData.Web.IP, reqData.Web.Port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respData.Web.Status = fmt.Sprintf("%v", err)
|
respData.Web.Status = fmt.Sprintf("%v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if reqData.DNS.Port != 0 {
|
if reqData.DNS.Port != 0 {
|
||||||
err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||||
|
|
||||||
if errorIsAddrInUse(err) {
|
if util.ErrorIsAddrInUse(err) {
|
||||||
canAutofix := checkDNSStubListener()
|
canAutofix := checkDNSStubListener()
|
||||||
if canAutofix && reqData.DNS.Autofix {
|
if canAutofix && reqData.DNS.Autofix {
|
||||||
|
|
||||||
@ -119,7 +121,7 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
log.Error("Couldn't disable DNSStubListener: %s", err)
|
log.Error("Couldn't disable DNSStubListener: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = checkPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
err = util.CheckPacketPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||||
canAutofix = false
|
canAutofix = false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,26 +129,22 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = checkPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
err = util.CheckPortAvailable(reqData.DNS.IP, reqData.DNS.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respData.DNS.Status = fmt.Sprintf("%v", err)
|
respData.DNS.Status = fmt.Sprintf("%v", err)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
interfaceName := util.GetInterfaceByIP(reqData.DNS.IP)
|
||||||
interfaceName := getInterfaceByIP(reqData.DNS.IP)
|
|
||||||
staticIPStatus := "yes"
|
staticIPStatus := "yes"
|
||||||
|
|
||||||
if len(interfaceName) == 0 {
|
if len(interfaceName) == 0 {
|
||||||
staticIPStatus = "error"
|
staticIPStatus = "error"
|
||||||
respData.StaticIP.Error = fmt.Sprintf("Couldn't find network interface by IP %s", reqData.DNS.IP)
|
respData.StaticIP.Error = fmt.Sprintf("Couldn't find network interface by IP %s", reqData.DNS.IP)
|
||||||
|
|
||||||
} else if reqData.DNS.SetStaticIP {
|
} else if reqData.DNS.SetStaticIP {
|
||||||
err = dhcpd.SetStaticIP(interfaceName)
|
err = dhcpd.SetStaticIP(interfaceName)
|
||||||
staticIPStatus = "error"
|
staticIPStatus = "error"
|
||||||
respData.StaticIP.Error = err.Error()
|
respData.StaticIP.Error = err.Error()
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// check if we have a static IP
|
// check if we have a static IP
|
||||||
isStaticIP, err := dhcpd.HasStaticIP(interfaceName)
|
isStaticIP, err := dhcpd.HasStaticIP(interfaceName)
|
||||||
@ -155,7 +153,7 @@ func handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
respData.StaticIP.Error = err.Error()
|
respData.StaticIP.Error = err.Error()
|
||||||
} else if !isStaticIP {
|
} else if !isStaticIP {
|
||||||
staticIPStatus = "no"
|
staticIPStatus = "no"
|
||||||
respData.StaticIP.IP = dhcpd.GetFullIP(interfaceName)
|
respData.StaticIP.IP = util.GetSubnet(interfaceName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
respData.StaticIP.Static = staticIPStatus
|
respData.StaticIP.Static = staticIPStatus
|
||||||
@ -279,7 +277,7 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// validate that hosts and ports are bindable
|
// validate that hosts and ports are bindable
|
||||||
if restartHTTP {
|
if restartHTTP {
|
||||||
err = checkPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
|
err = util.CheckPortAvailable(newSettings.Web.IP, newSettings.Web.Port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
|
httpError(w, http.StatusBadRequest, "Impossible to listen on IP:port %s due to %s",
|
||||||
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
|
net.JoinHostPort(newSettings.Web.IP, strconv.Itoa(newSettings.Web.Port)), err)
|
||||||
@ -287,13 +285,13 @@ func handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = checkPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
err = util.CheckPacketPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
httpError(w, http.StatusBadRequest, "%s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = checkPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
err = util.CheckPortAvailable(newSettings.DNS.IP, newSettings.DNS.Port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "%s", err)
|
httpError(w, http.StatusBadRequest, "%s", err)
|
||||||
return
|
return
|
||||||
|
@ -20,6 +20,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
)
|
)
|
||||||
@ -84,7 +86,7 @@ func handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
|||||||
alreadyRunning = true
|
alreadyRunning = true
|
||||||
}
|
}
|
||||||
if !alreadyRunning {
|
if !alreadyRunning {
|
||||||
err = checkPortAvailable(config.BindHost, data.PortHTTPS)
|
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
||||||
return
|
return
|
||||||
@ -114,7 +116,7 @@ func handleTLSConfigure(w http.ResponseWriter, r *http.Request) {
|
|||||||
alreadyRunning = true
|
alreadyRunning = true
|
||||||
}
|
}
|
||||||
if !alreadyRunning {
|
if !alreadyRunning {
|
||||||
err = checkPortAvailable(config.BindHost, data.PortHTTPS)
|
err = util.CheckPortAvailable(config.BindHost, data.PortHTTPS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
httpError(w, http.StatusBadRequest, "port %d is not available, cannot enable HTTPS on it", data.PortHTTPS)
|
||||||
return
|
return
|
||||||
|
@ -17,6 +17,8 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -196,7 +198,7 @@ func getUpdateInfo(jsonData []byte) (*updateInfo, error) {
|
|||||||
binName = "AdGuardHome.exe"
|
binName = "AdGuardHome.exe"
|
||||||
}
|
}
|
||||||
u.curBinName = filepath.Join(workDir, binName)
|
u.curBinName = filepath.Join(workDir, binName)
|
||||||
if !fileExists(u.curBinName) {
|
if !util.FileExists(u.curBinName) {
|
||||||
return nil, fmt.Errorf("Executable file %s doesn't exist", u.curBinName)
|
return nil, fmt.Errorf("Executable file %s doesn't exist", u.curBinName)
|
||||||
}
|
}
|
||||||
u.bkpBinName = filepath.Join(u.backupDir, binName)
|
u.bkpBinName = filepath.Join(u.backupDir, binName)
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestDoUpdate(t *testing.T) {
|
func TestDoUpdate(t *testing.T) {
|
||||||
|
|
||||||
config.DNS.Port = 0
|
config.DNS.Port = 0
|
||||||
Context.workDir = "..." // set absolute path
|
Context.workDir = "..." // set absolute path
|
||||||
newver := "v0.96"
|
newver := "v0.96"
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
"github.com/AdguardTeam/AdGuardHome/dnsfilter"
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
"github.com/AdguardTeam/golibs/file"
|
"github.com/AdguardTeam/golibs/file"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
@ -401,7 +402,7 @@ func parseFilterContents(contents []byte) (int, string) {
|
|||||||
|
|
||||||
// Count lines in the filter
|
// Count lines in the filter
|
||||||
for len(data) != 0 {
|
for len(data) != 0 {
|
||||||
line := SplitNext(&data, '\n')
|
line := util.SplitNext(&data, '\n')
|
||||||
if len(line) == 0 {
|
if len(line) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestFilters(t *testing.T) {
|
func TestFilters(t *testing.T) {
|
||||||
|
dir := prepareTestDir()
|
||||||
|
defer func() { _ = os.RemoveAll(dir) }()
|
||||||
|
|
||||||
|
Context = homeContext{}
|
||||||
|
Context.workDir = dir
|
||||||
Context.client = &http.Client{
|
Context.client = &http.Client{
|
||||||
Timeout: time.Minute * 5,
|
Timeout: time.Minute * 5,
|
||||||
}
|
}
|
||||||
@ -33,5 +38,5 @@ func TestFilters(t *testing.T) {
|
|||||||
assert.True(t, err == nil)
|
assert.True(t, err == nil)
|
||||||
|
|
||||||
f.unload()
|
f.unload()
|
||||||
os.Remove(f.Path())
|
_ = os.Remove(f.Path())
|
||||||
}
|
}
|
||||||
|
241
home/helpers.go
241
home/helpers.go
@ -1,241 +0,0 @@
|
|||||||
package home
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"path"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
"github.com/joomcode/errorx"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ----------------------------------
|
|
||||||
// helper functions for HTTP handlers
|
|
||||||
// ----------------------------------
|
|
||||||
func ensure(method string, handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
log.Debug("%s %v", r.Method, r.URL)
|
|
||||||
|
|
||||||
if r.Method != method {
|
|
||||||
http.Error(w, "This request must be "+method, http.StatusMethodNotAllowed)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if method == "POST" || method == "PUT" || method == "DELETE" {
|
|
||||||
Context.controlLock.Lock()
|
|
||||||
defer Context.controlLock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
handler(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensurePOST(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
|
||||||
return ensure("POST", handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensureGET(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
|
||||||
return ensure("GET", handler)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bridge between http.Handler object and Go function
|
|
||||||
type httpHandler struct {
|
|
||||||
handler func(http.ResponseWriter, *http.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
h.handler(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensureHandler(method string, handler func(http.ResponseWriter, *http.Request)) http.Handler {
|
|
||||||
h := httpHandler{}
|
|
||||||
h.handler = ensure(method, handler)
|
|
||||||
return &h
|
|
||||||
}
|
|
||||||
|
|
||||||
// -------------------
|
|
||||||
// first run / install
|
|
||||||
// -------------------
|
|
||||||
func detectFirstRun() bool {
|
|
||||||
configfile := Context.configFilename
|
|
||||||
if !filepath.IsAbs(configfile) {
|
|
||||||
configfile = filepath.Join(Context.workDir, Context.configFilename)
|
|
||||||
}
|
|
||||||
_, err := os.Stat(configfile)
|
|
||||||
if !os.IsNotExist(err) {
|
|
||||||
// do nothing, file exists
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// preInstall lets the handler run only if firstRun is true, no redirects
|
|
||||||
func preInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if !Context.firstRun {
|
|
||||||
// if it's not first run, don't let users access it (for example /install.html when configuration is done)
|
|
||||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
handler(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// preInstallStruct wraps preInstall into a struct that can be returned as an interface where necessary
|
|
||||||
type preInstallHandlerStruct struct {
|
|
||||||
handler http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *preInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
preInstall(p.handler.ServeHTTP)(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// preInstallHandler returns http.Handler interface for preInstall wrapper
|
|
||||||
func preInstallHandler(handler http.Handler) http.Handler {
|
|
||||||
return &preInstallHandlerStruct{handler}
|
|
||||||
}
|
|
||||||
|
|
||||||
// postInstall lets the handler run only if firstRun is false, and redirects to /install.html otherwise
|
|
||||||
// it also enforces HTTPS if it is enabled and configured
|
|
||||||
func postInstall(handler func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if Context.firstRun &&
|
|
||||||
!strings.HasPrefix(r.URL.Path, "/install.") &&
|
|
||||||
r.URL.Path != "/favicon.png" {
|
|
||||||
http.Redirect(w, r, "/install.html", http.StatusSeeOther) // should not be cacheable
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// enforce https?
|
|
||||||
if config.TLS.ForceHTTPS && r.TLS == nil && config.TLS.Enabled && config.TLS.PortHTTPS != 0 && Context.httpsServer.server != nil {
|
|
||||||
// yes, and we want host from host:port
|
|
||||||
host, _, err := net.SplitHostPort(r.Host)
|
|
||||||
if err != nil {
|
|
||||||
// no port in host
|
|
||||||
host = r.Host
|
|
||||||
}
|
|
||||||
// construct new URL to redirect to
|
|
||||||
newURL := url.URL{
|
|
||||||
Scheme: "https",
|
|
||||||
Host: net.JoinHostPort(host, strconv.Itoa(config.TLS.PortHTTPS)),
|
|
||||||
Path: r.URL.Path,
|
|
||||||
RawQuery: r.URL.RawQuery,
|
|
||||||
}
|
|
||||||
http.Redirect(w, r, newURL.String(), http.StatusTemporaryRedirect)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
||||||
handler(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type postInstallHandlerStruct struct {
|
|
||||||
handler http.Handler
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *postInstallHandlerStruct) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
||||||
postInstall(p.handler.ServeHTTP)(w, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func postInstallHandler(handler http.Handler) http.Handler {
|
|
||||||
return &postInstallHandlerStruct{handler}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect to a remote server resolving hostname using our own DNS server
|
|
||||||
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
||||||
log.Tracef("network:%v addr:%v", network, addr)
|
|
||||||
|
|
||||||
host, port, err := net.SplitHostPort(addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dialer := &net.Dialer{
|
|
||||||
Timeout: time.Minute * 5,
|
|
||||||
}
|
|
||||||
|
|
||||||
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
|
|
||||||
con, err := dialer.DialContext(ctx, network, addr)
|
|
||||||
return con, err
|
|
||||||
}
|
|
||||||
|
|
||||||
addrs, e := Context.dnsServer.Resolve(host)
|
|
||||||
log.Debug("dnsServer.Resolve: %s: %v", host, addrs)
|
|
||||||
if e != nil {
|
|
||||||
return nil, e
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(addrs) == 0 {
|
|
||||||
return nil, fmt.Errorf("couldn't lookup host: %s", host)
|
|
||||||
}
|
|
||||||
|
|
||||||
var dialErrs []error
|
|
||||||
for _, a := range addrs {
|
|
||||||
addr = net.JoinHostPort(a.String(), port)
|
|
||||||
con, err := dialer.DialContext(ctx, network, addr)
|
|
||||||
if err != nil {
|
|
||||||
dialErrs = append(dialErrs, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return con, err
|
|
||||||
}
|
|
||||||
return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------
|
|
||||||
// general helpers
|
|
||||||
// ---------------------
|
|
||||||
|
|
||||||
// fileExists returns TRUE if file exists
|
|
||||||
func fileExists(fn string) bool {
|
|
||||||
_, err := os.Stat(fn)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// runCommand runs shell command
|
|
||||||
func runCommand(command string, arguments ...string) (int, string, error) {
|
|
||||||
cmd := exec.Command(command, arguments...)
|
|
||||||
out, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
return 1, "", fmt.Errorf("exec.Command(%s) failed: %s", command, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return cmd.ProcessState.ExitCode(), string(out), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------
|
|
||||||
// debug logging helpers
|
|
||||||
// ---------------------
|
|
||||||
func _Func() string {
|
|
||||||
pc := make([]uintptr, 10) // at least 1 entry needed
|
|
||||||
runtime.Callers(2, pc)
|
|
||||||
f := runtime.FuncForPC(pc[0])
|
|
||||||
return path.Base(f.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
// SplitNext - split string by a byte and return the first chunk
|
|
||||||
// Whitespace is trimmed
|
|
||||||
func SplitNext(str *string, splitBy byte) string {
|
|
||||||
i := strings.IndexByte(*str, splitBy)
|
|
||||||
s := ""
|
|
||||||
if i != -1 {
|
|
||||||
s = (*str)[0:i]
|
|
||||||
*str = (*str)[i+1:]
|
|
||||||
} else {
|
|
||||||
s = *str
|
|
||||||
*str = ""
|
|
||||||
}
|
|
||||||
return strings.TrimSpace(s)
|
|
||||||
}
|
|
73
home/home.go
73
home/home.go
@ -20,6 +20,10 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
|
"github.com/joomcode/errorx"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/isdelve"
|
"github.com/AdguardTeam/AdGuardHome/isdelve"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
||||||
@ -193,7 +197,7 @@ func run(args options) {
|
|||||||
|
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") &&
|
||||||
config.RlimitNoFile != 0 {
|
config.RlimitNoFile != 0 {
|
||||||
setRlimit(config.RlimitNoFile)
|
util.SetRlimit(config.RlimitNoFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// override bind host/port from the console
|
// override bind host/port from the console
|
||||||
@ -327,7 +331,7 @@ func httpServerLoop() {
|
|||||||
// Check if the current user has root (administrator) rights
|
// Check if the current user has root (administrator) rights
|
||||||
// and if not, ask and try to run as root
|
// and if not, ask and try to run as root
|
||||||
func requireAdminRights() {
|
func requireAdminRights() {
|
||||||
admin, _ := haveAdminRights()
|
admin, _ := util.HaveAdminRights()
|
||||||
if //noinspection ALL
|
if //noinspection ALL
|
||||||
admin || isdelve.Enabled {
|
admin || isdelve.Enabled {
|
||||||
return
|
return
|
||||||
@ -412,7 +416,7 @@ func configureLogger(args options) {
|
|||||||
|
|
||||||
if ls.LogFile == configSyslog {
|
if ls.LogFile == configSyslog {
|
||||||
// Use syslog where it is possible and eventlog on Windows
|
// Use syslog where it is possible and eventlog on Windows
|
||||||
err := configureSyslog()
|
err := util.ConfigureSyslog(serviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("cannot initialize syslog: %s", err)
|
log.Fatalf("cannot initialize syslog: %s", err)
|
||||||
}
|
}
|
||||||
@ -448,9 +452,9 @@ func stopHTTPServer() {
|
|||||||
log.Info("Stopping HTTP server...")
|
log.Info("Stopping HTTP server...")
|
||||||
Context.httpsServer.shutdown = true
|
Context.httpsServer.shutdown = true
|
||||||
if Context.httpsServer.server != nil {
|
if Context.httpsServer.server != nil {
|
||||||
Context.httpsServer.server.Shutdown(context.TODO())
|
_ = Context.httpsServer.server.Shutdown(context.TODO())
|
||||||
}
|
}
|
||||||
Context.httpServer.Shutdown(context.TODO())
|
_ = Context.httpServer.Shutdown(context.TODO())
|
||||||
log.Info("Stopped HTTP server")
|
log.Info("Stopped HTTP server")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -580,7 +584,7 @@ func printHTTPAddresses(proto string) {
|
|||||||
}
|
}
|
||||||
} else if config.BindHost == "0.0.0.0" {
|
} else if config.BindHost == "0.0.0.0" {
|
||||||
log.Println("AdGuard Home is available on the following addresses:")
|
log.Println("AdGuard Home is available on the following addresses:")
|
||||||
ifaces, err := getValidNetInterfacesForWeb()
|
ifaces, err := util.GetValidNetInterfacesForWeb()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// That's weird, but we'll ignore it
|
// That's weird, but we'll ignore it
|
||||||
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
address = net.JoinHostPort(config.BindHost, strconv.Itoa(config.BindPort))
|
||||||
@ -597,3 +601,60 @@ func printHTTPAddresses(proto string) {
|
|||||||
log.Printf("Go to %s://%s", proto, address)
|
log.Printf("Go to %s://%s", proto, address)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -------------------
|
||||||
|
// first run / install
|
||||||
|
// -------------------
|
||||||
|
func detectFirstRun() bool {
|
||||||
|
configfile := Context.configFilename
|
||||||
|
if !filepath.IsAbs(configfile) {
|
||||||
|
configfile = filepath.Join(Context.workDir, Context.configFilename)
|
||||||
|
}
|
||||||
|
_, err := os.Stat(configfile)
|
||||||
|
if !os.IsNotExist(err) {
|
||||||
|
// do nothing, file exists
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to a remote server resolving hostname using our own DNS server
|
||||||
|
func customDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
|
log.Tracef("network:%v addr:%v", network, addr)
|
||||||
|
|
||||||
|
host, port, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: time.Minute * 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
if net.ParseIP(host) != nil || config.DNS.Port == 0 {
|
||||||
|
con, err := dialer.DialContext(ctx, network, addr)
|
||||||
|
return con, err
|
||||||
|
}
|
||||||
|
|
||||||
|
addrs, e := Context.dnsServer.Resolve(host)
|
||||||
|
log.Debug("dnsServer.Resolve: %s: %v", host, addrs)
|
||||||
|
if e != nil {
|
||||||
|
return nil, e
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(addrs) == 0 {
|
||||||
|
return nil, fmt.Errorf("couldn't lookup host: %s", host)
|
||||||
|
}
|
||||||
|
|
||||||
|
var dialErrs []error
|
||||||
|
for _, a := range addrs {
|
||||||
|
addr = net.JoinHostPort(a.String(), port)
|
||||||
|
con, err := dialer.DialContext(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
dialErrs = append(dialErrs, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return con, err
|
||||||
|
}
|
||||||
|
return nil, errorx.DecorateMany(fmt.Sprintf("couldn't dial to %s", addr), dialErrs...)
|
||||||
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
)
|
)
|
||||||
@ -229,7 +230,7 @@ func configureService(c *service.Config) {
|
|||||||
// returns command code or error if any
|
// returns command code or error if any
|
||||||
func runInitdCommand(action string) (int, error) {
|
func runInitdCommand(action string) (int, error) {
|
||||||
confPath := "/etc/init.d/" + serviceName
|
confPath := "/etc/init.d/" + serviceName
|
||||||
code, _, err := runCommand("sh", "-c", confPath+" "+action)
|
code, _, err := util.RunCommand("sh", "-c", confPath+" "+action)
|
||||||
return code, err
|
return code, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/file"
|
"github.com/AdguardTeam/golibs/file"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
@ -114,7 +116,7 @@ func upgradeConfigSchema(oldVersion int, diskConfig *map[string]interface{}) err
|
|||||||
// The first schema upgrade:
|
// The first schema upgrade:
|
||||||
// No more "dnsfilter.txt", filters are now kept in data/filters/
|
// No more "dnsfilter.txt", filters are now kept in data/filters/
|
||||||
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
|
func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
|
||||||
log.Printf("%s(): called", _Func())
|
log.Printf("%s(): called", util.FuncName())
|
||||||
|
|
||||||
dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt")
|
dnsFilterPath := filepath.Join(Context.workDir, "dnsfilter.txt")
|
||||||
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
|
if _, err := os.Stat(dnsFilterPath); !os.IsNotExist(err) {
|
||||||
@ -135,7 +137,7 @@ func upgradeSchema0to1(diskConfig *map[string]interface{}) error {
|
|||||||
// coredns is now dns in config
|
// coredns is now dns in config
|
||||||
// delete 'Corefile', since we don't use that anymore
|
// delete 'Corefile', since we don't use that anymore
|
||||||
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
|
func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
|
||||||
log.Printf("%s(): called", _Func())
|
log.Printf("%s(): called", util.FuncName())
|
||||||
|
|
||||||
coreFilePath := filepath.Join(Context.workDir, "Corefile")
|
coreFilePath := filepath.Join(Context.workDir, "Corefile")
|
||||||
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
|
if _, err := os.Stat(coreFilePath); !os.IsNotExist(err) {
|
||||||
@ -159,7 +161,7 @@ func upgradeSchema1to2(diskConfig *map[string]interface{}) error {
|
|||||||
// Third schema upgrade:
|
// Third schema upgrade:
|
||||||
// Bootstrap DNS becomes an array
|
// Bootstrap DNS becomes an array
|
||||||
func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
|
func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
|
||||||
log.Printf("%s(): called", _Func())
|
log.Printf("%s(): called", util.FuncName())
|
||||||
|
|
||||||
// Let's read dns configuration from diskConfig
|
// Let's read dns configuration from diskConfig
|
||||||
dnsConfig, ok := (*diskConfig)["dns"]
|
dnsConfig, ok := (*diskConfig)["dns"]
|
||||||
@ -196,7 +198,7 @@ func upgradeSchema2to3(diskConfig *map[string]interface{}) error {
|
|||||||
|
|
||||||
// Add use_global_blocked_services=true setting for existing "clients" array
|
// Add use_global_blocked_services=true setting for existing "clients" array
|
||||||
func upgradeSchema3to4(diskConfig *map[string]interface{}) error {
|
func upgradeSchema3to4(diskConfig *map[string]interface{}) error {
|
||||||
log.Printf("%s(): called", _Func())
|
log.Printf("%s(): called", util.FuncName())
|
||||||
|
|
||||||
(*diskConfig)["schema_version"] = 4
|
(*diskConfig)["schema_version"] = 4
|
||||||
|
|
||||||
@ -233,7 +235,7 @@ func upgradeSchema3to4(diskConfig *map[string]interface{}) error {
|
|||||||
// password: "..."
|
// password: "..."
|
||||||
// ...
|
// ...
|
||||||
func upgradeSchema4to5(diskConfig *map[string]interface{}) error {
|
func upgradeSchema4to5(diskConfig *map[string]interface{}) error {
|
||||||
log.Printf("%s(): called", _Func())
|
log.Printf("%s(): called", util.FuncName())
|
||||||
|
|
||||||
(*diskConfig)["schema_version"] = 5
|
(*diskConfig)["schema_version"] = 5
|
||||||
|
|
||||||
@ -288,7 +290,7 @@ func upgradeSchema4to5(diskConfig *map[string]interface{}) error {
|
|||||||
// - 127.0.0.1
|
// - 127.0.0.1
|
||||||
// - ...
|
// - ...
|
||||||
func upgradeSchema5to6(diskConfig *map[string]interface{}) error {
|
func upgradeSchema5to6(diskConfig *map[string]interface{}) error {
|
||||||
log.Printf("%s(): called", _Func())
|
log.Printf("%s(): called", util.FuncName())
|
||||||
|
|
||||||
(*diskConfig)["schema_version"] = 6
|
(*diskConfig)["schema_version"] = 6
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/AdguardTeam/AdGuardHome/util"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/cache"
|
"github.com/AdguardTeam/golibs/cache"
|
||||||
"github.com/AdguardTeam/golibs/log"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
)
|
)
|
||||||
@ -61,7 +63,7 @@ func whoisParse(data string) map[string]string {
|
|||||||
descr := ""
|
descr := ""
|
||||||
netname := ""
|
netname := ""
|
||||||
for len(data) != 0 {
|
for len(data) != 0 {
|
||||||
ln := SplitNext(&data, '\n')
|
ln := util.SplitNext(&data, '\n')
|
||||||
if len(ln) == 0 || ln[0] == '#' || ln[0] == '%' {
|
if len(ln) == 0 || ln[0] == '#' || ln[0] == '%' {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
59
util/helpers.go
Normal file
59
util/helpers.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------------------
|
||||||
|
// general helpers
|
||||||
|
// ---------------------
|
||||||
|
|
||||||
|
// fileExists returns TRUE if file exists
|
||||||
|
func FileExists(fn string) bool {
|
||||||
|
_, err := os.Stat(fn)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// runCommand runs shell command
|
||||||
|
func RunCommand(command string, arguments ...string) (int, string, error) {
|
||||||
|
cmd := exec.Command(command, arguments...)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return 1, "", fmt.Errorf("exec.Command(%s) failed: %s", command, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cmd.ProcessState.ExitCode(), string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------
|
||||||
|
// debug logging helpers
|
||||||
|
// ---------------------
|
||||||
|
func FuncName() string {
|
||||||
|
pc := make([]uintptr, 10) // at least 1 entry needed
|
||||||
|
runtime.Callers(2, pc)
|
||||||
|
f := runtime.FuncForPC(pc[0])
|
||||||
|
return path.Base(f.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
// SplitNext - split string by a byte and return the first chunk
|
||||||
|
// Whitespace is trimmed
|
||||||
|
func SplitNext(str *string, splitBy byte) string {
|
||||||
|
i := strings.IndexByte(*str, splitBy)
|
||||||
|
s := ""
|
||||||
|
if i != -1 {
|
||||||
|
s = (*str)[0:i]
|
||||||
|
*str = (*str)[i+1:]
|
||||||
|
} else {
|
||||||
|
s = *str
|
||||||
|
*str = ""
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(s)
|
||||||
|
}
|
14
util/helpers_test.go
Normal file
14
util/helpers_test.go
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSplitNext(t *testing.T) {
|
||||||
|
s := " a,b , c "
|
||||||
|
assert.True(t, SplitNext(&s, ',') == "a")
|
||||||
|
assert.True(t, SplitNext(&s, ',') == "b")
|
||||||
|
assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0)
|
||||||
|
}
|
@ -1,4 +1,4 @@
|
|||||||
package home
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
@ -10,23 +10,48 @@ import (
|
|||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/AdguardTeam/AdGuardHome/dhcpd"
|
"github.com/AdguardTeam/golibs/log"
|
||||||
|
|
||||||
"github.com/joomcode/errorx"
|
"github.com/joomcode/errorx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type netInterface struct {
|
// NetInterface represents a list of network interfaces
|
||||||
Name string
|
type NetInterface struct {
|
||||||
MTU int
|
Name string // Network interface name
|
||||||
HardwareAddr string
|
MTU int // MTU
|
||||||
Addresses []string
|
HardwareAddr string // Hardware address
|
||||||
Flags string
|
Addresses []string // Array with the network interface addresses
|
||||||
|
Subnets []string // Array with CIDR addresses of this network interface
|
||||||
|
Flags string // Network interface flags (up, broadcast, etc)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetValidNetInterfaces returns interfaces that are eligible for DNS and/or DHCP
|
||||||
|
// invalid interface is a ppp interface or the one that doesn't allow broadcasts
|
||||||
|
func GetValidNetInterfaces() ([]net.Interface, error) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Couldn't get list of interfaces: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
netIfaces := []net.Interface{}
|
||||||
|
|
||||||
|
for i := range ifaces {
|
||||||
|
if ifaces[i].Flags&net.FlagPointToPoint != 0 {
|
||||||
|
// this interface is ppp, we're not interested in this one
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
iface := ifaces[i]
|
||||||
|
netIfaces = append(netIfaces, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
return netIfaces, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only
|
// getValidNetInterfacesMap returns interfaces that are eligible for DNS and WEB only
|
||||||
// we do not return link-local addresses here
|
// we do not return link-local addresses here
|
||||||
func getValidNetInterfacesForWeb() ([]netInterface, error) {
|
func GetValidNetInterfacesForWeb() ([]NetInterface, error) {
|
||||||
ifaces, err := dhcpd.GetValidNetInterfaces()
|
ifaces, err := GetValidNetInterfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errorx.Decorate(err, "Couldn't get interfaces")
|
return nil, errorx.Decorate(err, "Couldn't get interfaces")
|
||||||
}
|
}
|
||||||
@ -34,7 +59,7 @@ func getValidNetInterfacesForWeb() ([]netInterface, error) {
|
|||||||
return nil, errors.New("couldn't find any legible interface")
|
return nil, errors.New("couldn't find any legible interface")
|
||||||
}
|
}
|
||||||
|
|
||||||
var netInterfaces []netInterface
|
var netInterfaces []NetInterface
|
||||||
|
|
||||||
for _, iface := range ifaces {
|
for _, iface := range ifaces {
|
||||||
addrs, e := iface.Addrs()
|
addrs, e := iface.Addrs()
|
||||||
@ -42,7 +67,7 @@ func getValidNetInterfacesForWeb() ([]netInterface, error) {
|
|||||||
return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name)
|
return nil, errorx.Decorate(e, "Failed to get addresses for interface %s", iface.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
netIface := netInterface{
|
netIface := NetInterface{
|
||||||
Name: iface.Name,
|
Name: iface.Name,
|
||||||
MTU: iface.MTU,
|
MTU: iface.MTU,
|
||||||
HardwareAddr: iface.HardwareAddr.String(),
|
HardwareAddr: iface.HardwareAddr.String(),
|
||||||
@ -52,19 +77,26 @@ func getValidNetInterfacesForWeb() ([]netInterface, error) {
|
|||||||
netIface.Flags = iface.Flags.String()
|
netIface.Flags = iface.Flags.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// we don't want link-local addresses in json, so skip them
|
// Collect network interface addresses
|
||||||
for _, addr := range addrs {
|
for _, addr := range addrs {
|
||||||
ipnet, ok := addr.(*net.IPNet)
|
ipNet, ok := addr.(*net.IPNet)
|
||||||
if !ok {
|
if !ok {
|
||||||
// not an IPNet, should not happen
|
// not an IPNet, should not happen
|
||||||
return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
|
return nil, fmt.Errorf("got iface.Addrs() element %s that is not net.IPNet, it is %T", addr, addr)
|
||||||
}
|
}
|
||||||
// ignore link-local
|
// ignore link-local
|
||||||
if ipnet.IP.IsLinkLocalUnicast() {
|
if ipNet.IP.IsLinkLocalUnicast() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
netIface.Addresses = append(netIface.Addresses, ipnet.IP.String())
|
// ignore IPv6
|
||||||
|
if ipNet.IP.To4() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
netIface.Addresses = append(netIface.Addresses, ipNet.IP.String())
|
||||||
|
netIface.Subnets = append(netIface.Subnets, ipNet.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Discard interfaces with no addresses
|
||||||
if len(netIface.Addresses) != 0 {
|
if len(netIface.Addresses) != 0 {
|
||||||
netInterfaces = append(netInterfaces, netIface)
|
netInterfaces = append(netInterfaces, netIface)
|
||||||
}
|
}
|
||||||
@ -74,8 +106,8 @@ func getValidNetInterfacesForWeb() ([]netInterface, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get interface name by its IP address.
|
// Get interface name by its IP address.
|
||||||
func getInterfaceByIP(ip string) string {
|
func GetInterfaceByIP(ip string) string {
|
||||||
ifaces, err := getValidNetInterfacesForWeb()
|
ifaces, err := GetValidNetInterfacesForWeb()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@ -91,8 +123,26 @@ func getInterfaceByIP(ip string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get IP address with netmask for the specified interface
|
||||||
|
// Returns an empty string if it fails to find it
|
||||||
|
func GetSubnet(ifaceName string) string {
|
||||||
|
netIfaces, err := GetValidNetInterfacesForWeb()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Could not get network interfaces info: %v", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, netIface := range netIfaces {
|
||||||
|
if netIface.Name == ifaceName && len(netIface.Subnets) > 0 {
|
||||||
|
return netIface.Subnets[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily
|
// checkPortAvailable is not a cheap test to see if the port is bindable, because it's actually doing the bind momentarily
|
||||||
func checkPortAvailable(host string, port int) error {
|
func CheckPortAvailable(host string, port int) error {
|
||||||
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port)))
|
ln, err := net.Listen("tcp", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -105,7 +155,7 @@ func checkPortAvailable(host string, port int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkPacketPortAvailable(host string, port int) error {
|
func CheckPacketPortAvailable(host string, port int) error {
|
||||||
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port)))
|
ln, err := net.ListenPacket("udp", net.JoinHostPort(host, strconv.Itoa(port)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -119,7 +169,7 @@ func checkPacketPortAvailable(host string, port int) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if error is "address already in use"
|
// check if error is "address already in use"
|
||||||
func errorIsAddrInUse(err error) bool {
|
func ErrorIsAddrInUse(err error) bool {
|
||||||
errOpError, ok := err.(*net.OpError)
|
errOpError, ok := err.(*net.OpError)
|
||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
@ -1,14 +1,12 @@
|
|||||||
package home
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/AdguardTeam/golibs/log"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
||||||
ifaces, err := getValidNetInterfacesForWeb()
|
ifaces, err := GetValidNetInterfacesForWeb()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Cannot get net interfaces: %s", err)
|
t.Fatalf("Cannot get net interfaces: %s", err)
|
||||||
}
|
}
|
||||||
@ -24,10 +22,3 @@ func TestGetValidNetInterfacesForWeb(t *testing.T) {
|
|||||||
log.Printf("%v", iface)
|
log.Printf("%v", iface)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSplitNext(t *testing.T) {
|
|
||||||
s := " a,b , c "
|
|
||||||
assert.True(t, SplitNext(&s, ',') == "a")
|
|
||||||
assert.True(t, SplitNext(&s, ',') == "b")
|
|
||||||
assert.True(t, SplitNext(&s, ',') == "c" && len(s) == 0)
|
|
||||||
}
|
|
@ -1,6 +1,6 @@
|
|||||||
// +build freebsd
|
// +build freebsd
|
||||||
|
|
||||||
package home
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
// Set user-specified limit of how many fd's we can use
|
// Set user-specified limit of how many fd's we can use
|
||||||
// https://github.com/AdguardTeam/AdGuardHome/issues/659
|
// https://github.com/AdguardTeam/AdGuardHome/issues/659
|
||||||
func setRlimit(val uint) {
|
func SetRlimit(val uint) {
|
||||||
var rlim syscall.Rlimit
|
var rlim syscall.Rlimit
|
||||||
rlim.Max = int64(val)
|
rlim.Max = int64(val)
|
||||||
rlim.Cur = int64(val)
|
rlim.Cur = int64(val)
|
||||||
@ -22,6 +22,6 @@ func setRlimit(val uint) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the current user has root (administrator) rights
|
// Check if the current user has root (administrator) rights
|
||||||
func haveAdminRights() (bool, error) {
|
func HaveAdminRights() (bool, error) {
|
||||||
return os.Getuid() == 0, nil
|
return os.Getuid() == 0, nil
|
||||||
}
|
}
|
@ -1,6 +1,6 @@
|
|||||||
// +build aix darwin dragonfly linux netbsd openbsd solaris
|
// +build aix darwin dragonfly linux netbsd openbsd solaris
|
||||||
|
|
||||||
package home
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os"
|
"os"
|
||||||
@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
// Set user-specified limit of how many fd's we can use
|
// Set user-specified limit of how many fd's we can use
|
||||||
// https://github.com/AdguardTeam/AdGuardHome/issues/659
|
// https://github.com/AdguardTeam/AdGuardHome/issues/659
|
||||||
func setRlimit(val uint) {
|
func SetRlimit(val uint) {
|
||||||
var rlim syscall.Rlimit
|
var rlim syscall.Rlimit
|
||||||
rlim.Max = uint64(val)
|
rlim.Max = uint64(val)
|
||||||
rlim.Cur = uint64(val)
|
rlim.Cur = uint64(val)
|
||||||
@ -22,6 +22,6 @@ func setRlimit(val uint) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the current user has root (administrator) rights
|
// Check if the current user has root (administrator) rights
|
||||||
func haveAdminRights() (bool, error) {
|
func HaveAdminRights() (bool, error) {
|
||||||
return os.Getuid() == 0, nil
|
return os.Getuid() == 0, nil
|
||||||
}
|
}
|
@ -1,12 +1,12 @@
|
|||||||
package home
|
package util
|
||||||
|
|
||||||
import "golang.org/x/sys/windows"
|
import "golang.org/x/sys/windows"
|
||||||
|
|
||||||
// Set user-specified limit of how many fd's we can use
|
// Set user-specified limit of how many fd's we can use
|
||||||
func setRlimit(val uint) {
|
func SetRlimit(val uint) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func haveAdminRights() (bool, error) {
|
func HaveAdminRights() (bool, error) {
|
||||||
var token windows.Token
|
var token windows.Token
|
||||||
h, _ := windows.GetCurrentProcess()
|
h, _ := windows.GetCurrentProcess()
|
||||||
err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token)
|
err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token)
|
@ -1,14 +1,14 @@
|
|||||||
// +build !windows,!nacl,!plan9
|
// +build !windows,!nacl,!plan9
|
||||||
|
|
||||||
package home
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
"log/syslog"
|
"log/syslog"
|
||||||
)
|
)
|
||||||
|
|
||||||
// configureSyslog reroutes standard logger output to syslog
|
// ConfigureSyslog reroutes standard logger output to syslog
|
||||||
func configureSyslog() error {
|
func ConfigureSyslog(serviceName string) error {
|
||||||
w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName)
|
w, err := syslog.New(syslog.LOG_NOTICE|syslog.LOG_USER, serviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
@ -1,4 +1,4 @@
|
|||||||
package home
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
@ -17,7 +17,7 @@ func (w *eventLogWriter) Write(b []byte) (int, error) {
|
|||||||
return len(b), w.el.Info(1, string(b))
|
return len(b), w.el.Info(1, string(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
func configureSyslog() error {
|
func ConfigureSyslog(serviceName string) error {
|
||||||
// Note that the eventlog src is the same as the service name
|
// Note that the eventlog src is the same as the service name
|
||||||
// Otherwise, we will get "the description for event id cannot be found" warning in every log record
|
// Otherwise, we will get "the description for event id cannot be found" warning in every log record
|
||||||
|
|
Loading…
Reference in New Issue
Block a user