author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-07-21 10:23:26 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-07-21 11:13:39 UTC |
parent | 138d7e3c41c7884985ef1ec085fc0374051637e4 |
dnss.go | +7 | -26 |
dnss_test.go | +25 | -26 |
diff --git a/dnss.go b/dnss.go index f182ab1..25d0e53 100644 --- a/dnss.go +++ b/dnss.go @@ -19,6 +19,8 @@ import ( "strings" "sync" + "golang.org/x/net/http/httpproxy" + "blitiri.com.ar/go/dnss/internal/dnsserver" "blitiri.com.ar/go/dnss/internal/httpresolver" "blitiri.com.ar/go/dnss/internal/httpserver" @@ -156,39 +158,18 @@ func main() { // proxyServerDomain checks if we're using an HTTP proxy server, and if so // returns its domain. func proxyServerDomain() string { - req, err := http.NewRequest("GET", *httpsUpstream, nil) + url, err := url.Parse(*httpsUpstream) if err != nil { return "" } - url, err := http.ProxyFromEnvironment(req) - if err != nil || url == nil { + proxyFunc := httpproxy.FromEnvironment().ProxyFunc() + proxyURL, err := proxyFunc(url) + if err != nil || proxyURL == nil { return "" } - return extractHostname(url.Host) -} - -// extractHostname from an URL host, which can be in the form "host" or -// "host:port". -// TODO: Use url.Hostname() instead of this, once we drop support for Go 1.7 -// (the function was added in 1.8). -func extractHostname(host string) string { - // IPv6 URLs have the address between brackets. - // http://www.ietf.org/rfc/rfc2732.txt - if i := strings.Index(host, "]"); i != -1 { - return strings.TrimPrefix(host[:i], "[") - } - - // IPv4 or host URL, we can just drop everything after the ":" (if - // present). - if i := strings.Index(host, ":"); i != -1 { - return host[:i] - } - - // Port is not specified. - return host - + return proxyURL.Hostname() } func launchMonitoringServer(addr string) { diff --git a/dnss_test.go b/dnss_test.go index f6ad15f..d0cde94 100644 --- a/dnss_test.go +++ b/dnss_test.go @@ -26,8 +26,6 @@ func TestMain(m *testing.M) { log.Init() log.Default.Level = log.Error - // We need to do this early, see TestProxyServerDomain. - os.Setenv("HTTPS_PROXY", "http://proxy:1234/p") os.Exit(m.Run()) } @@ -200,37 +198,38 @@ func BenchmarkSimple(b *testing.B) { ///////////////////////////////////////////////////////////////////// // Tests for main-specific helpers -// Test proxyServerDomain(). Unfortunately, this function can only be called -// once, as the results of http.ProxyFromEnvironment are cached, so we test it -// for a single case. func TestProxyServerDomain(t *testing.T) { + prevProxy, wasSet := os.LookupEnv("HTTPS_PROXY") + + // Valid case, proxy set. + os.Setenv("HTTPS_PROXY", "http://proxy:1234/p") *httpsUpstream = "https://montoto/xyz" - // In TestMain we set: HTTPS_PROXY=http://proxy:1234/p - // We have to do that earlier to prevent other tests from (indirectly) - // calling http.ProxyFromEnvironment and have it cache a nil result. if got := proxyServerDomain(); got != "proxy" { t.Errorf("got %q, expected 'proxy'", got) } -} -func TestExtractHostname(t *testing.T) { - cases := []struct{ host, expected string }{ - {"host", "host"}, - {"host:1234", "host"}, - {"[host]", "host"}, - {"[host]:1234", "host"}, - {"1.2.3.4", "1.2.3.4"}, - {"1.2.3.4:1234", "1.2.3.4"}, - {"[::192.9.5.5]", "::192.9.5.5"}, - {"[::192.9.5.5]:1234", "::192.9.5.5"}, - {"[3ffe:2a00:100:7031::1]", "3ffe:2a00:100:7031::1"}, - {"[3ffe:2a00:100:7031::1]:1234", "3ffe:2a00:100:7031::1"}, + // Valid case, proxy not set. + os.Unsetenv("HTTPS_PROXY") + *httpsUpstream = "https://montoto/xyz" + if got := proxyServerDomain(); got != "" { + t.Errorf("got %q, expected ''", got) } - for _, c := range cases { - if got := extractHostname(c.host); got != c.expected { - t.Errorf("extractHostname(%q) = %q ; expected %q", - c.host, got, c.expected) - } + + // Invalid upstream URL. + *httpsUpstream = "in%20valid:url" + if got := proxyServerDomain(); got != "" { + t.Errorf("got %q, expected ''", got) + } + + // Invalid proxy. + os.Setenv("HTTPS_PROXY", "invalid value") + *httpsUpstream = "https://montoto/xyz" + if got := proxyServerDomain(); got != "" { + t.Errorf("got %q, expected ''", got) + } + + if wasSet { + os.Setenv("HTTPS_PROXY", prevProxy) } }