git » dnss » commit 3f473eb

dnss: Use httpproxy to find the proxy server domain

author Alberto Bertogli
2018-07-21 10:23:26 UTC
committer Alberto Bertogli
2018-07-21 11:13:39 UTC
parent 138d7e3c41c7884985ef1ec085fc0374051637e4

dnss: Use httpproxy to find the proxy server domain

To extract the proxy server domain, this patch switches from using
the built-in http to the external httpproxy package.

The former's (new) built-in caching makes testing more difficult, and
the httpproxy package is part of golang.org/x/net so it is not a
new dependency.

As a part of this change we get rid of extractHostname, which was
a temporary hack and is no longer needed (as we don't support Go < 1.9
anymore).

dnss.go +7 -26
dnss_test.go +25 -26

diff --git a/dnss.go b/dnss.go
index f182ab1..25d0e53 100644
--- a/dnss.go
+++ b/dnss.go
@@ -19,6 +19,8 @@ import (
 	"strings"
 	"sync"
 
+	"golang.org/x/net/http/httpproxy"
+
 	"blitiri.com.ar/go/dnss/internal/dnsserver"
 	"blitiri.com.ar/go/dnss/internal/httpresolver"
 	"blitiri.com.ar/go/dnss/internal/httpserver"
@@ -156,39 +158,18 @@ func main() {
 // proxyServerDomain checks if we're using an HTTP proxy server, and if so
 // returns its domain.
 func proxyServerDomain() string {
-	req, err := http.NewRequest("GET", *httpsUpstream, nil)
+	url, err := url.Parse(*httpsUpstream)
 	if err != nil {
 		return ""
 	}
 
-	url, err := http.ProxyFromEnvironment(req)
-	if err != nil || url == nil {
+	proxyFunc := httpproxy.FromEnvironment().ProxyFunc()
+	proxyURL, err := proxyFunc(url)
+	if err != nil || proxyURL == nil {
 		return ""
 	}
 
-	return extractHostname(url.Host)
-}
-
-// extractHostname from an URL host, which can be in the form "host" or
-// "host:port".
-// TODO: Use url.Hostname() instead of this, once we drop support for Go 1.7
-// (the function was added in 1.8).
-func extractHostname(host string) string {
-	// IPv6 URLs have the address between brackets.
-	// http://www.ietf.org/rfc/rfc2732.txt
-	if i := strings.Index(host, "]"); i != -1 {
-		return strings.TrimPrefix(host[:i], "[")
-	}
-
-	// IPv4 or host URL, we can just drop everything after the ":" (if
-	// present).
-	if i := strings.Index(host, ":"); i != -1 {
-		return host[:i]
-	}
-
-	// Port is not specified.
-	return host
-
+	return proxyURL.Hostname()
 }
 
 func launchMonitoringServer(addr string) {
diff --git a/dnss_test.go b/dnss_test.go
index f6ad15f..d0cde94 100644
--- a/dnss_test.go
+++ b/dnss_test.go
@@ -26,8 +26,6 @@ func TestMain(m *testing.M) {
 	log.Init()
 	log.Default.Level = log.Error
 
-	// We need to do this early, see TestProxyServerDomain.
-	os.Setenv("HTTPS_PROXY", "http://proxy:1234/p")
 	os.Exit(m.Run())
 }
 
@@ -200,37 +198,38 @@ func BenchmarkSimple(b *testing.B) {
 /////////////////////////////////////////////////////////////////////
 // Tests for main-specific helpers
 
-// Test proxyServerDomain(). Unfortunately, this function can only be called
-// once, as the results of http.ProxyFromEnvironment are cached, so we test it
-// for a single case.
 func TestProxyServerDomain(t *testing.T) {
+	prevProxy, wasSet := os.LookupEnv("HTTPS_PROXY")
+
+	// Valid case, proxy set.
+	os.Setenv("HTTPS_PROXY", "http://proxy:1234/p")
 	*httpsUpstream = "https://montoto/xyz"
-	// In TestMain we set: HTTPS_PROXY=http://proxy:1234/p
-	// We have to do that earlier to prevent other tests from (indirectly)
-	// calling http.ProxyFromEnvironment and have it cache a nil result.
 	if got := proxyServerDomain(); got != "proxy" {
 		t.Errorf("got %q, expected 'proxy'", got)
 	}
-}
 
-func TestExtractHostname(t *testing.T) {
-	cases := []struct{ host, expected string }{
-		{"host", "host"},
-		{"host:1234", "host"},
-		{"[host]", "host"},
-		{"[host]:1234", "host"},
-		{"1.2.3.4", "1.2.3.4"},
-		{"1.2.3.4:1234", "1.2.3.4"},
-		{"[::192.9.5.5]", "::192.9.5.5"},
-		{"[::192.9.5.5]:1234", "::192.9.5.5"},
-		{"[3ffe:2a00:100:7031::1]", "3ffe:2a00:100:7031::1"},
-		{"[3ffe:2a00:100:7031::1]:1234", "3ffe:2a00:100:7031::1"},
+	// Valid case, proxy not set.
+	os.Unsetenv("HTTPS_PROXY")
+	*httpsUpstream = "https://montoto/xyz"
+	if got := proxyServerDomain(); got != "" {
+		t.Errorf("got %q, expected ''", got)
 	}
-	for _, c := range cases {
-		if got := extractHostname(c.host); got != c.expected {
-			t.Errorf("extractHostname(%q) = %q ; expected %q",
-				c.host, got, c.expected)
-		}
+
+	// Invalid upstream URL.
+	*httpsUpstream = "in%20valid:url"
+	if got := proxyServerDomain(); got != "" {
+		t.Errorf("got %q, expected ''", got)
+	}
+
+	// Invalid proxy.
+	os.Setenv("HTTPS_PROXY", "invalid value")
+	*httpsUpstream = "https://montoto/xyz"
+	if got := proxyServerDomain(); got != "" {
+		t.Errorf("got %q, expected ''", got)
+	}
+
+	if wasSet {
+		os.Setenv("HTTPS_PROXY", prevProxy)
 	}
 }