author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-04-15 13:40:13 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-04-15 13:42:27 UTC |
parent | 432b60afdeb687da5c3c8a9f9d14775ab7494f5f |
internal/dnsserver/server.go | +3 | -0 |
internal/dnsserver/server_test.go | +32 | -0 |
diff --git a/internal/dnsserver/server.go b/internal/dnsserver/server.go index 2874f7d..d8f13de 100644 --- a/internal/dnsserver/server.go +++ b/internal/dnsserver/server.go @@ -134,6 +134,9 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) { log.Infof("resolver query error: %v", err) tr.LazyPrintf(err.Error()) tr.SetError() + + r.Id = oldid + dns.HandleFailed(w, r) return } diff --git a/internal/dnsserver/server_test.go b/internal/dnsserver/server_test.go index 8612c9c..a83e544 100644 --- a/internal/dnsserver/server_test.go +++ b/internal/dnsserver/server_test.go @@ -1,6 +1,7 @@ package dnsserver import ( + "fmt" "testing" "github.com/miekg/dns" @@ -47,3 +48,34 @@ func query(t *testing.T, srv, domain, expected string) { t.Errorf("query %q: expected %q but got %q", domain, expected, result) } } + +func TestBadUpstreams(t *testing.T) { + res := testutil.NewTestResolver() + res.RespError = fmt.Errorf("response error for testing") + + // Get addresses but don't start the servers, so we get an error when + // trying to reach them. + unqUpstreamAddr := testutil.GetFreePort() + fallbackAddr := testutil.GetFreePort() + + srv := New(testutil.GetFreePort(), res, unqUpstreamAddr) + srv.SetFallback(fallbackAddr, []string{"one.fallback.", "two.fallback."}) + go srv.ListenAndServe() + testutil.WaitForDNSServer(srv.Addr) + + queryFailure(t, srv.Addr, "response.test.") + queryFailure(t, srv.Addr, "unqualified.") + queryFailure(t, srv.Addr, "one.fallback.") + queryFailure(t, srv.Addr, "two.fallback.") +} + +func queryFailure(t *testing.T, srv, domain string) { + m, _, err := testutil.DNSQuery(srv, domain, dns.TypeA) + if err != nil { + t.Errorf("error querying %q: %v", domain, err) + } + + if m.Rcode != dns.RcodeServerFailure { + t.Errorf("query %q: expected SERVFAIL, got message: %v", domain, m) + } +}