Refactoring for set upstream and bootstrap DNS

This commit is contained in:
Aleksey Dmitrevskiy 2019-02-27 12:58:42 +03:00
parent dc05556c5a
commit bf893d488a

View File

@ -323,44 +323,94 @@ func sortByValue(m map[string]int) []string {
// -----------------------
func handleSetUpstreamDNS(w http.ResponseWriter, r *http.Request) {
setDNSServers(&w, r, true)
}
func handleSetBootstrapDNS(w http.ResponseWriter, r *http.Request) {
setDNSServers(&w, r, false)
}
// setDNSServers sets upstream and bootstrap DNS servers
func setDNSServers(w *http.ResponseWriter, r *http.Request, upstreams bool) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
http.Error(*w, errorText, http.StatusBadRequest)
return
}
// if empty body -- user is asking for default servers
hosts := strings.Fields(string(body))
if len(hosts) == 0 {
config.DNS.UpstreamDNS = defaultDNS
// bootstrap servers are plain DNS only. We should remove tls:// https:// and sdns:// hosts from slice
bootstraps := []string{}
if !upstreams && len(hosts) > 0 {
for _, host := range hosts {
err = checkBootstrapDNS(host)
if err != nil {
log.Tracef("%s can not be used as bootstrap DNS cause: %s", host, err)
continue
}
hosts = append(bootstraps, host)
}
}
// count of upstream or bootstrap servers
var count int
if upstreams {
count = len(hosts)
} else {
config.DNS.UpstreamDNS = hosts
count = len(bootstraps)
}
if upstreams {
if count == 0 {
config.DNS.UpstreamDNS = defaultDNS
} else {
config.DNS.UpstreamDNS = hosts
}
} else {
if count == 0 {
config.DNS.BootstrapDNS = defaultBootstrap
} else {
config.DNS.BootstrapDNS = bootstraps
}
}
err = writeAllConfigs()
if err != nil {
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
http.Error(*w, errorText, http.StatusInternalServerError)
return
}
err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
http.Error(*w, errorText, http.StatusInternalServerError)
return
}
_, err = fmt.Fprintf(w, "OK %d servers\n", len(hosts))
_, err = fmt.Fprintf(*w, "OK %d servers\n", count)
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
http.Error(*w, errorText, http.StatusInternalServerError)
}
}
func checkBootstrapDNS(host string) error {
// Check if host is ip without port
if net.ParseIP(host) != nil {
return nil
}
// Check if host is ip with port
_, _, err := net.SplitHostPort(host)
return err
}
func handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
@ -437,45 +487,6 @@ func checkDNS(input string) error {
return nil
}
func handleSetBootstrapDNS(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
errorText := fmt.Sprintf("Failed to read request body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusBadRequest)
return
}
// if empty body -- user is asking for default servers
hosts := strings.Fields(string(body))
if len(hosts) == 0 {
config.DNS.BootstrapDNS = defaultBootstrap
} else {
config.DNS.BootstrapDNS = hosts
}
err = writeAllConfigs()
if err != nil {
errorText := fmt.Sprintf("Couldn't write config file: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
err = reconfigureDNSServer()
if err != nil {
errorText := fmt.Sprintf("Couldn't reconfigure the DNS server: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
return
}
_, err = fmt.Fprintf(w, "OK %d bootsrap servers\n", len(hosts))
if err != nil {
errorText := fmt.Sprintf("Couldn't write body: %s", err)
log.Println(errorText)
http.Error(w, errorText, http.StatusInternalServerError)
}
}
func handleGetVersionJSON(w http.ResponseWriter, r *http.Request) {
now := time.Now()
if now.Sub(versionCheckLastTime) <= versionCheckPeriod && len(versionCheckJSON) != 0 {