diff --git a/app.go b/app.go index dacf1034..acb463de 100644 --- a/app.go +++ b/app.go @@ -1,16 +1,20 @@ package main import ( + "bufio" "crypto/tls" "fmt" + "io" "io/ioutil" "net" "net/http" "os" + "os/exec" "os/signal" "path/filepath" "runtime" "strconv" + "strings" "sync" "syscall" @@ -45,15 +49,6 @@ func main() { return } - signalChannel := make(chan os.Signal) - signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) - go func() { - <-signalChannel - cleanup() - cleanupAlways() - os.Exit(0) - }() - // run the protection run(args) } @@ -83,6 +78,18 @@ func run(args options) { } config.firstRun = detectFirstRun() + if config.firstRun { + requireAdminRights() + } + + signalChannel := make(chan os.Signal) + signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP, syscall.SIGQUIT) + go func() { + <-signalChannel + cleanup() + cleanupAlways() + os.Exit(0) + }() // Do the upgrade if necessary err := upgradeConfig() @@ -228,6 +235,37 @@ func run(args options) { } } +// Check if the current user has root (administrator) rights +// and if not, ask and try to run as root +func requireAdminRights() { + admin, _ := haveAdminRights() + if admin { + return + } + + if runtime.GOOS == "windows" { + log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.") + + } else { + log.Error("This is the first launch of AdGuard Home. You must run it as root.") + + _, _ = io.WriteString(os.Stdout, "Do you want to start AdGuard Home as root user? [y/n] ") + stdin := bufio.NewReader(os.Stdin) + buf, _ := stdin.ReadString('\n') + buf = strings.TrimSpace(buf) + if buf != "y" { + os.Exit(1) + } + + cmd := exec.Command("sudo", os.Args...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + _ = cmd.Run() + os.Exit(1) + } +} + // Write PID to a file func writePIDFile(fn string) bool { data := fmt.Sprintf("%d", os.Getpid()) diff --git a/os_unix.go b/os_unix.go index 12a918c8..9baa357d 100644 --- a/os_unix.go +++ b/os_unix.go @@ -3,6 +3,7 @@ package main import ( + "os" "syscall" "github.com/AdguardTeam/golibs/log" @@ -19,3 +20,8 @@ func setRlimit(val uint) { log.Error("Setrlimit() failed: %v", err) } } + +// Check if the current user has root (administrator) rights +func haveAdminRights() (bool, error) { + return os.Getuid() == 0, nil +} diff --git a/os_windows.go b/os_windows.go index 1155e04b..e847ccce 100644 --- a/os_windows.go +++ b/os_windows.go @@ -1,5 +1,28 @@ package main +import "golang.org/x/sys/windows" + // Set user-specified limit of how many fd's we can use func setRlimit(val uint) { } + +func haveAdminRights() (bool, error) { + var token windows.Token + h, _ := windows.GetCurrentProcess() + err := windows.OpenProcessToken(h, windows.TOKEN_QUERY, &token) + if err != nil { + return false, err + } + + info := make([]byte, 4) + var returnedLen uint32 + err = windows.GetTokenInformation(token, windows.TokenElevation, &info[0], uint32(len(info)), &returnedLen) + token.Close() + if err != nil { + return false, err + } + if info[0] == 0 { + return false, nil + } + return true, nil +}