author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-04-14 19:34:27 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-04-14 19:34:27 UTC |
parent | 2901f5756e642282a8a5c8d48adde7d917452f26 |
dnss_test.go | +3 | -14 |
internal/dnsserver/caching_test.go | +17 | -60 |
internal/testutil/testutil.go | +60 | -2 |
diff --git a/dnss_test.go b/dnss_test.go index e4573e3..eb2e66d 100644 --- a/dnss_test.go +++ b/dnss_test.go @@ -67,8 +67,8 @@ func Setup(tb testing.TB, mode string) string { httpserver.InsecureForTesting = true go htod.ListenAndServe() - // Fake DNS server. - go ServeFakeDNSServer(DNSServerAddr) + // Test DNS server. + go testutil.ServeTestDNSServer(DNSServerAddr, handleTestDNS) // Wait for the servers to start up. err1 := testutil.WaitForDNSServer(DNSToHTTPSAddr) @@ -85,17 +85,6 @@ func Setup(tb testing.TB, mode string) string { return DNSToHTTPSAddr } -// Fake DNS server. -func ServeFakeDNSServer(addr string) { - server := &dns.Server{ - Addr: addr, - Handler: dns.HandlerFunc(handleFakeDNS), - Net: "udp", - } - err := server.ListenAndServe() - panic(err) -} - // DNS answers to give, as a map of "name type" -> []RR. // Tests will modify this according to their needs. var answers map[string][]dns.RR @@ -122,7 +111,7 @@ func addAnswers(tb testing.TB, zone string) { } } -func handleFakeDNS(w dns.ResponseWriter, r *dns.Msg) { +func handleTestDNS(w dns.ResponseWriter, r *dns.Msg) { m := &dns.Msg{} m.SetReply(r) diff --git a/internal/dnsserver/caching_test.go b/internal/dnsserver/caching_test.go index bb0e559..5d0f6e5 100644 --- a/internal/dnsserver/caching_test.go +++ b/internal/dnsserver/caching_test.go @@ -1,8 +1,6 @@ package dnsserver // Tests for the caching resolver. -// Note the other resolvers have more functional tests in the testing/ -// directory. import ( "fmt" @@ -14,61 +12,20 @@ import ( "blitiri.com.ar/go/dnss/internal/testutil" "github.com/miekg/dns" - "golang.org/x/net/trace" ) -// A test resolver that we use as backing for the caching resolver under test. -type TestResolver struct { - // Has this resolver been initialized? - init bool - - // Maintain() sends a value over this channel. - maintain chan bool - - // The last query we've seen. - lastQuery *dns.Msg - - // What we will respond to queries. - response *dns.Msg - respError error -} - -func NewTestResolver() *TestResolver { - return &TestResolver{ - maintain: make(chan bool, 1), - } -} - -func (r *TestResolver) Init() error { - r.init = true - return nil -} - -func (r *TestResolver) Maintain() { - r.maintain <- true -} - -func (r *TestResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { - r.lastQuery = req - if r.response != nil { - r.response.Question = req.Question - r.response.Authoritative = true - } - return r.response, r.respError -} - // // === Tests === // // Test basic functionality. func TestBasic(t *testing.T) { - r := NewTestResolver() + r := testutil.NewTestResolver() c := NewCachingResolver(r) c.Init() - if !r.init { + if !r.Initialized { t.Errorf("caching resolver did not initialize backing") } @@ -94,7 +51,7 @@ func TestBasic(t *testing.T) { // Test TTL handling. func TestTTL(t *testing.T) { - r := NewTestResolver() + r := testutil.NewTestResolver() c := NewCachingResolver(r) c.Init() resetStats() @@ -133,7 +90,7 @@ func TestTTL(t *testing.T) { // Check that the back resolver's Maintain() is called. select { - case <-r.maintain: + case <-r.MaintainC: t.Log("Maintain() called") case <-time.After(1 * time.Second): t.Errorf("back resolver Maintain() was not called") @@ -155,7 +112,7 @@ func TestTTL(t *testing.T) { // Test that we don't cache failed queries. func TestFailedQueries(t *testing.T) { - r := NewTestResolver() + r := testutil.NewTestResolver() c := NewCachingResolver(r) c.Init() resetStats() @@ -177,12 +134,12 @@ func TestFailedQueries(t *testing.T) { // when we're full, which is not ideal and will likely be changed in the // future. func TestCacheFull(t *testing.T) { - r := NewTestResolver() + r := testutil.NewTestResolver() c := NewCachingResolver(r) c.Init() resetStats() - r.response = newReply(mustNewRR(t, "test. A 1.2.3.4")) + r.Response = newReply(mustNewRR(t, "test. A 1.2.3.4")) // Do maxCacheSize+1 different requests. for i := 0; i < maxCacheSize+1; i++ { @@ -212,7 +169,7 @@ func TestCacheFull(t *testing.T) { // Test behaviour when the size of the cache is 0 (so users can disable it // that way). func TestZeroSize(t *testing.T) { - r := NewTestResolver() + r := testutil.NewTestResolver() c := NewCachingResolver(r) c.Init() resetStats() @@ -222,7 +179,7 @@ func TestZeroSize(t *testing.T) { maxCacheSize = 0 defer func() { maxCacheSize = prevMaxCacheSize }() - r.response = newReply(mustNewRR(t, "test. A 1.2.3.4")) + r.Response = newReply(mustNewRR(t, "test. A 1.2.3.4")) // Do 5 different requests. for i := 0; i < 5; i++ { @@ -249,8 +206,8 @@ func TestZeroSize(t *testing.T) { func BenchmarkCacheSimple(b *testing.B) { var err error - r := NewTestResolver() - r.response = newReply(mustNewRR(b, "test. A 1.2.3.4")) + r := testutil.NewTestResolver() + r.Response = newReply(mustNewRR(b, "test. A 1.2.3.4")) c := NewCachingResolver(r) c.Init() @@ -293,8 +250,8 @@ func dumpStats() string { func queryA(t *testing.T, c *cachingResolver, rr, domain, expected string) *dns.Msg { // Set up the response from the given RR (if any). if rr != "" { - back := c.back.(*TestResolver) - back.response = newReply(mustNewRR(t, rr)) + back := c.back.(*testutil.TestResolver) + back.Response = newReply(mustNewRR(t, rr)) } tr := testutil.NewTestTrace(t) @@ -320,10 +277,10 @@ func queryA(t *testing.T, c *cachingResolver, rr, domain, expected string) *dns. } func queryFail(t *testing.T, c *cachingResolver) *dns.Msg { - back := c.back.(*TestResolver) - back.response = &dns.Msg{} - back.response.Response = true - back.response.Rcode = dns.RcodeNameError + back := c.back.(*testutil.TestResolver) + back.Response = &dns.Msg{} + back.Response.Response = true + back.Response.Rcode = dns.RcodeNameError tr := testutil.NewTestTrace(t) defer tr.Finish() diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index fe2dc52..4737911 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -8,6 +8,8 @@ import ( "testing" "time" + "golang.org/x/net/trace" + "github.com/miekg/dns" ) @@ -62,8 +64,8 @@ func WaitForHTTPServer(addr string) error { return fmt.Errorf("timed out") } -// Get a free (TCP) port. This is hacky and not race-free, but it works well -// enough for testing purposes. +// GetFreePort returns a free TCP port. This is hacky and not race-free, but +// it works well enough for testing purposes. func GetFreePort() string { l, _ := net.Listen("tcp", "localhost:0") defer l.Close() @@ -85,6 +87,62 @@ func DNSQuery(srv, addr string, qtype uint16) (*dns.Msg, dns.RR, error) { } } +// TestResolver is a dnsserver.Resolver implementation for testing, so we can +// control its responses during tests. +type TestResolver struct { + // Has this resolver been initialized? + Initialized bool + + // Maintain() sends a value over this channel. + MaintainC chan bool + + // The last query we've seen. + LastQuery *dns.Msg + + // What we will respond to queries. + Response *dns.Msg + RespError error +} + +// NewTestResolver creates a new TestResolver with minimal initialization. +func NewTestResolver() *TestResolver { + return &TestResolver{ + MaintainC: make(chan bool, 1), + } +} + +// Init the resolver. +func (r *TestResolver) Init() error { + r.Initialized = true + return nil +} + +// Maintain the resolver. +func (r *TestResolver) Maintain() { + r.MaintainC <- true +} + +// Query handles the given query, returning the pre-recorded response. +func (r *TestResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { + r.LastQuery = req + if r.Response != nil { + r.Response.Question = req.Question + r.Response.Authoritative = true + } + return r.Response, r.RespError +} + +// ServeTestDNSServer starts the fake DNS server. +func ServeTestDNSServer(addr string, handler func(dns.ResponseWriter, *dns.Msg)) { + server := &dns.Server{ + Addr: addr, + Handler: dns.HandlerFunc(handler), + Net: "udp", + } + err := server.ListenAndServe() + panic(err) +} + // TestTrace implements the tracer.Trace interface, but prints using the test // logging infrastructure. type TestTrace struct {