git » dnss » commit cbfa0a0

dnstohttps: Take a parsed URL as parameter

author Alberto Bertogli
2018-04-07 12:29:32 UTC
committer Alberto Bertogli
2018-04-08 18:14:05 UTC
parent e47dfd984d428805caf80b1f1122b235a130e19b

dnstohttps: Take a parsed URL as parameter

The current way we construct URL for HTTP queries is very brittle and
prone to issues; this patch makes the URL be parsed early, so we can add
parameters safely.

dnss.go +6 -1
dnss_test.go +7 -1
internal/dnstohttps/https_test.go +8 -2
internal/dnstohttps/resolver.go +9 -8

diff --git a/dnss.go b/dnss.go
index 0b5ed7a..df88362 100644
--- a/dnss.go
+++ b/dnss.go
@@ -15,6 +15,7 @@ import (
 	"flag"
 	"fmt"
 	"net/http"
+	"net/url"
 	"strings"
 	"sync"
 	"time"
@@ -105,8 +106,12 @@ func main() {
 
 	// DNS to HTTPS.
 	if *enableDNStoHTTPS {
+		upstream, err := url.Parse(*httpsUpstream)
+		if err != nil {
+			glog.Fatalf("-https_upstream is not a valid URL: %v", err)
+		}
 		var resolver dnsserver.Resolver = dnstohttps.NewHTTPSResolver(
-			*httpsUpstream, *httpsClientCAFile)
+			upstream, *httpsClientCAFile)
 		if *enableCache {
 			cr := dnsserver.NewCachingResolver(resolver)
 			cr.RegisterDebugHandlers()
diff --git a/dnss_test.go b/dnss_test.go
index 44b97a5..2007cf9 100644
--- a/dnss_test.go
+++ b/dnss_test.go
@@ -5,6 +5,7 @@ import (
 	"flag"
 	"fmt"
 	"net/http"
+	"net/url"
 	"os"
 	"strings"
 	"sync"
@@ -46,7 +47,12 @@ func realMain(m *testing.M) int {
 	ServerAddr = DNSToHTTPSAddr
 
 	// DNS to HTTPS server.
-	r := dnstohttps.NewHTTPSResolver("http://"+HTTPSToDNSAddr+"/resolve", "")
+	HTTPSToDNSURL, err := url.Parse("http://" + HTTPSToDNSAddr + "/resolve")
+	if err != nil {
+		fmt.Printf("invalid URL: %v", err)
+		return 1
+	}
+	r := dnstohttps.NewHTTPSResolver(HTTPSToDNSURL, "")
 	dtoh := dnsserver.New(DNSToHTTPSAddr, r, "")
 	go dtoh.ListenAndServe()
 
diff --git a/internal/dnstohttps/https_test.go b/internal/dnstohttps/https_test.go
index 3da1e20..8ef9cb4 100644
--- a/internal/dnstohttps/https_test.go
+++ b/internal/dnstohttps/https_test.go
@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"net/http"
 	"net/http/httptest"
+	"net/url"
 	"os"
 	"testing"
 
@@ -130,12 +131,17 @@ func realMain(m *testing.M) int {
 	httpsrv := httptest.NewServer(http.HandlerFunc(DNSHandler))
 
 	// DNS to HTTPS server.
-	r := NewHTTPSResolver(httpsrv.URL, "")
+	srvURL, err := url.Parse(httpsrv.URL)
+	if err != nil {
+		fmt.Printf("Failed to parse test http server URL: %v\n", err)
+		return 1
+	}
+	r := NewHTTPSResolver(srvURL, "")
 	dth := dnsserver.New(DNSAddr, r, "")
 	go dth.ListenAndServe()
 
 	// Wait for the servers to start up.
-	err := testutil.WaitForDNSServer(DNSAddr)
+	err = testutil.WaitForDNSServer(DNSAddr)
 	if err != nil {
 		fmt.Printf("Error waiting for the test servers to start: %v\n", err)
 		fmt.Printf("Check the INFO logs for more details\n")
diff --git a/internal/dnstohttps/resolver.go b/internal/dnstohttps/resolver.go
index e464a8e..49463e8 100644
--- a/internal/dnstohttps/resolver.go
+++ b/internal/dnstohttps/resolver.go
@@ -21,7 +21,7 @@ import (
 // httpsResolver implements the dnsserver.Resolver interface by querying a
 // server via DNS-over-HTTPS (like https://dns.google.com).
 type httpsResolver struct {
-	Upstream string
+	Upstream *url.URL
 	CAFile   string
 	client   *http.Client
 }
@@ -42,7 +42,7 @@ func loadCertPool(caFile string) (*x509.CertPool, error) {
 
 // NewHTTPSResolver creates a new resolver which uses the given upstream URL
 // to resolve queries.
-func NewHTTPSResolver(upstream, caFile string) *httpsResolver {
+func NewHTTPSResolver(upstream *url.URL, caFile string) *httpsResolver {
 	return &httpsResolver{
 		Upstream: upstream,
 		CAFile:   caFile,
@@ -99,17 +99,18 @@ func (r *httpsResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) {
 	}
 
 	// Build the query and send the request.
-	v := url.Values{}
-	v.Set("name", question.Name)
-	v.Set("type", dns.TypeToString[question.Qtype])
+	url := *r.Upstream
+	vs := url.Query()
+	vs.Set("name", question.Name)
+	vs.Set("type", dns.TypeToString[question.Qtype])
+	url.RawQuery = vs.Encode()
 	// TODO: add random_padding.
 
-	url := r.Upstream + "?" + v.Encode()
 	if glog.V(3) {
-		tr.LazyPrintf("GET %q", url)
+		tr.LazyPrintf("GET %v", url)
 	}
 
-	hr, err := r.client.Get(url)
+	hr, err := r.client.Get(url.String())
 	if err != nil {
 		return nil, fmt.Errorf("GET failed: %v", err)
 	}