author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-04-07 12:29:32 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-04-08 18:14:05 UTC |
parent | e47dfd984d428805caf80b1f1122b235a130e19b |
dnss.go | +6 | -1 |
dnss_test.go | +7 | -1 |
internal/dnstohttps/https_test.go | +8 | -2 |
internal/dnstohttps/resolver.go | +9 | -8 |
diff --git a/dnss.go b/dnss.go index 0b5ed7a..df88362 100644 --- a/dnss.go +++ b/dnss.go @@ -15,6 +15,7 @@ import ( "flag" "fmt" "net/http" + "net/url" "strings" "sync" "time" @@ -105,8 +106,12 @@ func main() { // DNS to HTTPS. if *enableDNStoHTTPS { + upstream, err := url.Parse(*httpsUpstream) + if err != nil { + glog.Fatalf("-https_upstream is not a valid URL: %v", err) + } var resolver dnsserver.Resolver = dnstohttps.NewHTTPSResolver( - *httpsUpstream, *httpsClientCAFile) + upstream, *httpsClientCAFile) if *enableCache { cr := dnsserver.NewCachingResolver(resolver) cr.RegisterDebugHandlers() diff --git a/dnss_test.go b/dnss_test.go index 44b97a5..2007cf9 100644 --- a/dnss_test.go +++ b/dnss_test.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "net/http" + "net/url" "os" "strings" "sync" @@ -46,7 +47,12 @@ func realMain(m *testing.M) int { ServerAddr = DNSToHTTPSAddr // DNS to HTTPS server. - r := dnstohttps.NewHTTPSResolver("http://"+HTTPSToDNSAddr+"/resolve", "") + HTTPSToDNSURL, err := url.Parse("http://" + HTTPSToDNSAddr + "/resolve") + if err != nil { + fmt.Printf("invalid URL: %v", err) + return 1 + } + r := dnstohttps.NewHTTPSResolver(HTTPSToDNSURL, "") dtoh := dnsserver.New(DNSToHTTPSAddr, r, "") go dtoh.ListenAndServe() diff --git a/internal/dnstohttps/https_test.go b/internal/dnstohttps/https_test.go index 3da1e20..8ef9cb4 100644 --- a/internal/dnstohttps/https_test.go +++ b/internal/dnstohttps/https_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "net/url" "os" "testing" @@ -130,12 +131,17 @@ func realMain(m *testing.M) int { httpsrv := httptest.NewServer(http.HandlerFunc(DNSHandler)) // DNS to HTTPS server. - r := NewHTTPSResolver(httpsrv.URL, "") + srvURL, err := url.Parse(httpsrv.URL) + if err != nil { + fmt.Printf("Failed to parse test http server URL: %v\n", err) + return 1 + } + r := NewHTTPSResolver(srvURL, "") dth := dnsserver.New(DNSAddr, r, "") go dth.ListenAndServe() // Wait for the servers to start up. - err := testutil.WaitForDNSServer(DNSAddr) + err = testutil.WaitForDNSServer(DNSAddr) if err != nil { fmt.Printf("Error waiting for the test servers to start: %v\n", err) fmt.Printf("Check the INFO logs for more details\n") diff --git a/internal/dnstohttps/resolver.go b/internal/dnstohttps/resolver.go index e464a8e..49463e8 100644 --- a/internal/dnstohttps/resolver.go +++ b/internal/dnstohttps/resolver.go @@ -21,7 +21,7 @@ import ( // httpsResolver implements the dnsserver.Resolver interface by querying a // server via DNS-over-HTTPS (like https://dns.google.com). type httpsResolver struct { - Upstream string + Upstream *url.URL CAFile string client *http.Client } @@ -42,7 +42,7 @@ func loadCertPool(caFile string) (*x509.CertPool, error) { // NewHTTPSResolver creates a new resolver which uses the given upstream URL // to resolve queries. -func NewHTTPSResolver(upstream, caFile string) *httpsResolver { +func NewHTTPSResolver(upstream *url.URL, caFile string) *httpsResolver { return &httpsResolver{ Upstream: upstream, CAFile: caFile, @@ -99,17 +99,18 @@ func (r *httpsResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { } // Build the query and send the request. - v := url.Values{} - v.Set("name", question.Name) - v.Set("type", dns.TypeToString[question.Qtype]) + url := *r.Upstream + vs := url.Query() + vs.Set("name", question.Name) + vs.Set("type", dns.TypeToString[question.Qtype]) + url.RawQuery = vs.Encode() // TODO: add random_padding. - url := r.Upstream + "?" + v.Encode() if glog.V(3) { - tr.LazyPrintf("GET %q", url) + tr.LazyPrintf("GET %v", url) } - hr, err := r.client.Get(url) + hr, err := r.client.Get(url.String()) if err != nil { return nil, fmt.Errorf("GET failed: %v", err) }