git » dnss » commit 7c06003

Add the http proxy host (if any) to the fallback domain list

author Alberto Bertogli
2017-07-30 15:30:21 UTC
committer Alberto Bertogli
2017-07-30 16:21:46 UTC
parent 1c82f297e6798fbebdf225a091c6ff02f39fb149

Add the http proxy host (if any) to the fallback domain list

If we're using an HTTP proxy, we should add the name to the fallback
domain so we don't have problems resolving it.

dnss.go +48 -2
dnss_test.go +50 -2

diff --git a/dnss.go b/dnss.go
index 8de022c..62ac145 100644
--- a/dnss.go
+++ b/dnss.go
@@ -89,8 +89,16 @@ func main() {
 		cr := dnstohttps.NewCachingResolver(r)
 		cr.RegisterDebugHandlers()
 		dth := dnstohttps.New(*dnsListenAddr, cr, *dnsUnqualifiedUpstream)
-		dth.SetFallback(
-			*fallbackUpstream, strings.Split(*fallbackDomains, " "))
+
+		// If we're using an HTTP proxy, add the name to the fallback domain
+		// so we don't have problems resolving it.
+		fallbackDoms := strings.Split(*fallbackDomains, " ")
+		if proxyDomain := proxyServerDomain(); proxyDomain != "" {
+			glog.Infof("Adding proxy %q to fallback domains", proxyDomain)
+			fallbackDoms = append(fallbackDoms, proxyDomain)
+		}
+
+		dth.SetFallback(*fallbackUpstream, fallbackDoms)
 		wg.Add(1)
 		go func() {
 			defer wg.Done()
@@ -116,6 +124,44 @@ func main() {
 	wg.Wait()
 }
 
+// proxyServerDomain checks if we're using an HTTP proxy server, and if so
+// returns its domain.
+func proxyServerDomain() string {
+	req, err := http.NewRequest("GET", *httpsUpstream, nil)
+	if err != nil {
+		return ""
+	}
+
+	url, err := http.ProxyFromEnvironment(req)
+	if err != nil || url == nil {
+		return ""
+	}
+
+	return extractHostname(url.Host)
+}
+
+// extractHostname from an URL host, which can be in the form "host" or
+// "host:port".
+// TODO: Use url.Hostname() instead of this, once we drop support for Go 1.7
+// (the function was added in 1.8).
+func extractHostname(host string) string {
+	// IPv6 URLs have the address between brackets.
+	// http://www.ietf.org/rfc/rfc2732.txt
+	if i := strings.Index(host, "]"); i != -1 {
+		return strings.TrimPrefix(host[:i], "[")
+	}
+
+	// IPv4 or host URL, we can just drop everything after the ":" (if
+	// present).
+	if i := strings.Index(host, ":"); i != -1 {
+		return host[:i]
+	}
+
+	// Port is not specified.
+	return host
+
+}
+
 func launchMonitoringServer(addr string) {
 	glog.Infof("Monitoring HTTP server listening on %s", addr)
 
diff --git a/dnss_test.go b/dnss_test.go
index 2b5383b..353f08b 100644
--- a/dnss_test.go
+++ b/dnss_test.go
@@ -1,4 +1,4 @@
-// End to end tests.
+// Tests for dnss (end to end tests, and main-specific helpers).
 package main
 
 import (
@@ -16,6 +16,9 @@ import (
 	"github.com/miekg/dns"
 )
 
+/////////////////////////////////////////////////////////////////////
+// End to end tests
+
 // Setup:
 // DNS client -> DNS-to-HTTPS -> HTTPS-to-DNS -> DNS server
 //
@@ -146,7 +149,7 @@ func handleFakeDNS(w dns.ResponseWriter, r *dns.Msg) {
 // Tests
 //
 
-func TestSimple(t *testing.T) {
+func TestEndToEnd(t *testing.T) {
 	resetAnswers()
 	addAnswers(t, "test.blah. A 1.2.3.4")
 	_, ans, err := testutil.DNSQuery(ServerAddr, "test.blah.", dns.TypeA)
@@ -192,3 +195,48 @@ func BenchmarkSimple(b *testing.B) {
 		}
 	}
 }
+
+/////////////////////////////////////////////////////////////////////
+// Tests for main-specific helpers
+
+// Test proxyServerDomain(). Unfortunately, this function can only be called
+// once, as the results of http.ProxyFromEnvironment are cached, so we test it
+// for a single case.
+func TestProxyServerDomain(t *testing.T) {
+	*httpsUpstream = "https://montoto/xyz"
+	os.Setenv("HTTPS_PROXY", "http://proxy:1234/p")
+	if got := proxyServerDomain(); got != "proxy" {
+		t.Errorf("got %q, expected 'proxy'", got)
+	}
+}
+
+func TestExtractHostname(t *testing.T) {
+	cases := []struct{ host, expected string }{
+		{"host", "host"},
+		{"host:1234", "host"},
+		{"[host]", "host"},
+		{"[host]:1234", "host"},
+		{"1.2.3.4", "1.2.3.4"},
+		{"1.2.3.4:1234", "1.2.3.4"},
+		{"[::192.9.5.5]", "::192.9.5.5"},
+		{"[::192.9.5.5]:1234", "::192.9.5.5"},
+		{"[3ffe:2a00:100:7031::1]", "3ffe:2a00:100:7031::1"},
+		{"[3ffe:2a00:100:7031::1]:1234", "3ffe:2a00:100:7031::1"},
+	}
+	for _, c := range cases {
+		if got := extractHostname(c.host); got != c.expected {
+			t.Errorf("extractHostname(%q) = %q ; expected %q",
+				c.host, got, c.expected)
+		}
+	}
+}
+
+func TestDumpFlags(t *testing.T) {
+	flag.Parse()
+	flag.Set("https_upstream", "https://montoto/xyz")
+
+	f := dumpFlags()
+	if !strings.Contains(f, "-https_upstream=https://montoto/xyz\n") {
+		t.Errorf("Flags string missing canary value: %v", f)
+	}
+}