author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-07-30 15:30:21 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-07-30 16:21:46 UTC |
parent | 1c82f297e6798fbebdf225a091c6ff02f39fb149 |
dnss.go | +48 | -2 |
dnss_test.go | +50 | -2 |
diff --git a/dnss.go b/dnss.go index 8de022c..62ac145 100644 --- a/dnss.go +++ b/dnss.go @@ -89,8 +89,16 @@ func main() { cr := dnstohttps.NewCachingResolver(r) cr.RegisterDebugHandlers() dth := dnstohttps.New(*dnsListenAddr, cr, *dnsUnqualifiedUpstream) - dth.SetFallback( - *fallbackUpstream, strings.Split(*fallbackDomains, " ")) + + // If we're using an HTTP proxy, add the name to the fallback domain + // so we don't have problems resolving it. + fallbackDoms := strings.Split(*fallbackDomains, " ") + if proxyDomain := proxyServerDomain(); proxyDomain != "" { + glog.Infof("Adding proxy %q to fallback domains", proxyDomain) + fallbackDoms = append(fallbackDoms, proxyDomain) + } + + dth.SetFallback(*fallbackUpstream, fallbackDoms) wg.Add(1) go func() { defer wg.Done() @@ -116,6 +124,44 @@ func main() { wg.Wait() } +// proxyServerDomain checks if we're using an HTTP proxy server, and if so +// returns its domain. +func proxyServerDomain() string { + req, err := http.NewRequest("GET", *httpsUpstream, nil) + if err != nil { + return "" + } + + url, err := http.ProxyFromEnvironment(req) + if err != nil || url == nil { + return "" + } + + return extractHostname(url.Host) +} + +// extractHostname from an URL host, which can be in the form "host" or +// "host:port". +// TODO: Use url.Hostname() instead of this, once we drop support for Go 1.7 +// (the function was added in 1.8). +func extractHostname(host string) string { + // IPv6 URLs have the address between brackets. + // http://www.ietf.org/rfc/rfc2732.txt + if i := strings.Index(host, "]"); i != -1 { + return strings.TrimPrefix(host[:i], "[") + } + + // IPv4 or host URL, we can just drop everything after the ":" (if + // present). + if i := strings.Index(host, ":"); i != -1 { + return host[:i] + } + + // Port is not specified. + return host + +} + func launchMonitoringServer(addr string) { glog.Infof("Monitoring HTTP server listening on %s", addr) diff --git a/dnss_test.go b/dnss_test.go index 2b5383b..353f08b 100644 --- a/dnss_test.go +++ b/dnss_test.go @@ -1,4 +1,4 @@ -// End to end tests. +// Tests for dnss (end to end tests, and main-specific helpers). package main import ( @@ -16,6 +16,9 @@ import ( "github.com/miekg/dns" ) +///////////////////////////////////////////////////////////////////// +// End to end tests + // Setup: // DNS client -> DNS-to-HTTPS -> HTTPS-to-DNS -> DNS server // @@ -146,7 +149,7 @@ func handleFakeDNS(w dns.ResponseWriter, r *dns.Msg) { // Tests // -func TestSimple(t *testing.T) { +func TestEndToEnd(t *testing.T) { resetAnswers() addAnswers(t, "test.blah. A 1.2.3.4") _, ans, err := testutil.DNSQuery(ServerAddr, "test.blah.", dns.TypeA) @@ -192,3 +195,48 @@ func BenchmarkSimple(b *testing.B) { } } } + +///////////////////////////////////////////////////////////////////// +// Tests for main-specific helpers + +// Test proxyServerDomain(). Unfortunately, this function can only be called +// once, as the results of http.ProxyFromEnvironment are cached, so we test it +// for a single case. +func TestProxyServerDomain(t *testing.T) { + *httpsUpstream = "https://montoto/xyz" + os.Setenv("HTTPS_PROXY", "http://proxy:1234/p") + if got := proxyServerDomain(); got != "proxy" { + t.Errorf("got %q, expected 'proxy'", got) + } +} + +func TestExtractHostname(t *testing.T) { + cases := []struct{ host, expected string }{ + {"host", "host"}, + {"host:1234", "host"}, + {"[host]", "host"}, + {"[host]:1234", "host"}, + {"1.2.3.4", "1.2.3.4"}, + {"1.2.3.4:1234", "1.2.3.4"}, + {"[::192.9.5.5]", "::192.9.5.5"}, + {"[::192.9.5.5]:1234", "::192.9.5.5"}, + {"[3ffe:2a00:100:7031::1]", "3ffe:2a00:100:7031::1"}, + {"[3ffe:2a00:100:7031::1]:1234", "3ffe:2a00:100:7031::1"}, + } + for _, c := range cases { + if got := extractHostname(c.host); got != c.expected { + t.Errorf("extractHostname(%q) = %q ; expected %q", + c.host, got, c.expected) + } + } +} + +func TestDumpFlags(t *testing.T) { + flag.Parse() + flag.Set("https_upstream", "https://montoto/xyz") + + f := dumpFlags() + if !strings.Contains(f, "-https_upstream=https://montoto/xyz\n") { + t.Errorf("Flags string missing canary value: %v", f) + } +}