diff --git a/control.go b/control.go index 480f9fdc..726bf2b7 100644 --- a/control.go +++ b/control.go @@ -323,44 +323,94 @@ func sortByValue(m map[string]int) []string { // ----------------------- func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) { + setDNSServers(&w, r, true) +} + +func handleSetBootstrapDNS(w http.ResponseWriter, r *http.Request) { + setDNSServers(&w, r, false) +} + +// setDNSServers sets upstream and bootstrap DNS servers +func setDNSServers(w *http.ResponseWriter, r *http.Request, upstreams bool) { body, err := ioutil.ReadAll(r.Body) if err != nil { errorText := fmt.Sprintf("Failed to read request body: %s", err) log.Println(errorText) - http.Error(w, errorText, http.StatusBadRequest) + http.Error(*w, errorText, http.StatusBadRequest) return } // if empty body -- user is asking for default servers hosts := strings.Fields(string(body)) - if len(hosts) == 0 { - config.DNS.UpstreamDNS = defaultDNS + // bootstrap servers are plain DNS only. We should remove tls:// https:// and sdns:// hosts from slice + bootstraps := []string{} + if !upstreams && len(hosts) > 0 { + for _, host := range hosts { + err = checkBootstrapDNS(host) + if err != nil { + log.Tracef("%s can not be used as bootstrap DNS cause: %s", host, err) + continue + } + hosts = append(bootstraps, host) + } + } + + // count of upstream or bootstrap servers + var count int + if upstreams { + count = len(hosts) } else { - config.DNS.UpstreamDNS = hosts + count = len(bootstraps) + } + + if upstreams { + if count == 0 { + config.DNS.UpstreamDNS = defaultDNS + } else { + config.DNS.UpstreamDNS = hosts + } + } else { + if count == 0 { + config.DNS.BootstrapDNS = defaultBootstrap + } else { + config.DNS.BootstrapDNS = bootstraps + } } err = writeAllConfigs() if err != nil { errorText := fmt.Sprintf("Couldn't write config file: %s", err) log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) + http.Error(*w, errorText, http.StatusInternalServerError) return } err = reconfigureDNSServer() if err != nil { errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err) log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) + http.Error(*w, errorText, http.StatusInternalServerError) return } - _, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts)) + + _, err = fmt.Fprintf(*w, "OK %d servers\n", count) if err != nil { errorText := fmt.Sprintf("Couldn't write body: %s", err) log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) + http.Error(*w, errorText, http.StatusInternalServerError) } } +func checkBootstrapDNS(host string) error { + // Check if host is ip without port + if net.ParseIP(host) != nil { + return nil + } + + // Check if host is ip with port + _, _, err := net.SplitHostPort(host) + return err +} + func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) { body, err := ioutil.ReadAll(r.Body) if err != nil { @@ -437,45 +487,6 @@ func checkDNS(input string) error { return nil } -func handleSetBootstrapDNS(w http.ResponseWriter, r *http.Request) { - body, err := ioutil.ReadAll(r.Body) - if err != nil { - errorText := fmt.Sprintf("Failed to read request body: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusBadRequest) - return - } - // if empty body -- user is asking for default servers - hosts := strings.Fields(string(body)) - - if len(hosts) == 0 { - config.DNS.BootstrapDNS = defaultBootstrap - } else { - config.DNS.BootstrapDNS = hosts - } - - err = writeAllConfigs() - if err != nil { - errorText := fmt.Sprintf("Couldn't write config file: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - return - } - err = reconfigureDNSServer() - if err != nil { - errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - return - } - _, err = fmt.Fprintf(w, "OK %d bootsrap servers\n", len(hosts)) - if err != nil { - errorText := fmt.Sprintf("Couldn't write body: %s", err) - log.Println(errorText) - http.Error(w, errorText, http.StatusInternalServerError) - } -} - func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) { now := time.Now() if now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 {