git » dnss » commit 207b26b

test: Move the test resolver and DNS server to testutil

author Alberto Bertogli
2018-04-14 19:34:27 UTC
committer Alberto Bertogli
2018-04-14 19:34:27 UTC
parent 2901f5756e642282a8a5c8d48adde7d917452f26

test: Move the test resolver and DNS server to testutil

In the future we will want to use the test resolver and DNS server from
other tests, so this patch moves them to testutil.

dnss_test.go +3 -14
internal/dnsserver/caching_test.go +17 -60
internal/testutil/testutil.go +60 -2

diff --git a/dnss_test.go b/dnss_test.go
index e4573e3..eb2e66d 100644
--- a/dnss_test.go
+++ b/dnss_test.go
@@ -67,8 +67,8 @@ func Setup(tb testing.TB, mode string) string {
 	httpserver.InsecureForTesting = true
 	go htod.ListenAndServe()
 
-	// Fake DNS server.
-	go ServeFakeDNSServer(DNSServerAddr)
+	// Test DNS server.
+	go testutil.ServeTestDNSServer(DNSServerAddr, handleTestDNS)
 
 	// Wait for the servers to start up.
 	err1 := testutil.WaitForDNSServer(DNSToHTTPSAddr)
@@ -85,17 +85,6 @@ func Setup(tb testing.TB, mode string) string {
 	return DNSToHTTPSAddr
 }
 
-// Fake DNS server.
-func ServeFakeDNSServer(addr string) {
-	server := &dns.Server{
-		Addr:    addr,
-		Handler: dns.HandlerFunc(handleFakeDNS),
-		Net:     "udp",
-	}
-	err := server.ListenAndServe()
-	panic(err)
-}
-
 // DNS answers to give, as a map of "name type" -> []RR.
 // Tests will modify this according to their needs.
 var answers map[string][]dns.RR
@@ -122,7 +111,7 @@ func addAnswers(tb testing.TB, zone string) {
 	}
 }
 
-func handleFakeDNS(w dns.ResponseWriter, r *dns.Msg) {
+func handleTestDNS(w dns.ResponseWriter, r *dns.Msg) {
 	m := &dns.Msg{}
 	m.SetReply(r)
 
diff --git a/internal/dnsserver/caching_test.go b/internal/dnsserver/caching_test.go
index bb0e559..5d0f6e5 100644
--- a/internal/dnsserver/caching_test.go
+++ b/internal/dnsserver/caching_test.go
@@ -1,8 +1,6 @@
 package dnsserver
 
 // Tests for the caching resolver.
-// Note the other resolvers have more functional tests in the testing/
-// directory.
 
 import (
 	"fmt"
@@ -14,61 +12,20 @@ import (
 	"blitiri.com.ar/go/dnss/internal/testutil"
 
 	"github.com/miekg/dns"
-	"golang.org/x/net/trace"
 )
 
-// A test resolver that we use as backing for the caching resolver under test.
-type TestResolver struct {
-	// Has this resolver been initialized?
-	init bool
-
-	// Maintain() sends a value over this channel.
-	maintain chan bool
-
-	// The last query we've seen.
-	lastQuery *dns.Msg
-
-	// What we will respond to queries.
-	response  *dns.Msg
-	respError error
-}
-
-func NewTestResolver() *TestResolver {
-	return &TestResolver{
-		maintain: make(chan bool, 1),
-	}
-}
-
-func (r *TestResolver) Init() error {
-	r.init = true
-	return nil
-}
-
-func (r *TestResolver) Maintain() {
-	r.maintain <- true
-}
-
-func (r *TestResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) {
-	r.lastQuery = req
-	if r.response != nil {
-		r.response.Question = req.Question
-		r.response.Authoritative = true
-	}
-	return r.response, r.respError
-}
-
 //
 // === Tests ===
 //
 
 // Test basic functionality.
 func TestBasic(t *testing.T) {
-	r := NewTestResolver()
+	r := testutil.NewTestResolver()
 
 	c := NewCachingResolver(r)
 
 	c.Init()
-	if !r.init {
+	if !r.Initialized {
 		t.Errorf("caching resolver did not initialize backing")
 	}
 
@@ -94,7 +51,7 @@ func TestBasic(t *testing.T) {
 
 // Test TTL handling.
 func TestTTL(t *testing.T) {
-	r := NewTestResolver()
+	r := testutil.NewTestResolver()
 	c := NewCachingResolver(r)
 	c.Init()
 	resetStats()
@@ -133,7 +90,7 @@ func TestTTL(t *testing.T) {
 
 	// Check that the back resolver's Maintain() is called.
 	select {
-	case <-r.maintain:
+	case <-r.MaintainC:
 		t.Log("Maintain() called")
 	case <-time.After(1 * time.Second):
 		t.Errorf("back resolver Maintain() was not called")
@@ -155,7 +112,7 @@ func TestTTL(t *testing.T) {
 
 // Test that we don't cache failed queries.
 func TestFailedQueries(t *testing.T) {
-	r := NewTestResolver()
+	r := testutil.NewTestResolver()
 	c := NewCachingResolver(r)
 	c.Init()
 	resetStats()
@@ -177,12 +134,12 @@ func TestFailedQueries(t *testing.T) {
 // when we're full, which is not ideal and will likely be changed in the
 // future.
 func TestCacheFull(t *testing.T) {
-	r := NewTestResolver()
+	r := testutil.NewTestResolver()
 	c := NewCachingResolver(r)
 	c.Init()
 	resetStats()
 
-	r.response = newReply(mustNewRR(t, "test. A 1.2.3.4"))
+	r.Response = newReply(mustNewRR(t, "test. A 1.2.3.4"))
 
 	// Do maxCacheSize+1 different requests.
 	for i := 0; i < maxCacheSize+1; i++ {
@@ -212,7 +169,7 @@ func TestCacheFull(t *testing.T) {
 // Test behaviour when the size of the cache is 0 (so users can disable it
 // that way).
 func TestZeroSize(t *testing.T) {
-	r := NewTestResolver()
+	r := testutil.NewTestResolver()
 	c := NewCachingResolver(r)
 	c.Init()
 	resetStats()
@@ -222,7 +179,7 @@ func TestZeroSize(t *testing.T) {
 	maxCacheSize = 0
 	defer func() { maxCacheSize = prevMaxCacheSize }()
 
-	r.response = newReply(mustNewRR(t, "test. A 1.2.3.4"))
+	r.Response = newReply(mustNewRR(t, "test. A 1.2.3.4"))
 
 	// Do 5 different requests.
 	for i := 0; i < 5; i++ {
@@ -249,8 +206,8 @@ func TestZeroSize(t *testing.T) {
 func BenchmarkCacheSimple(b *testing.B) {
 	var err error
 
-	r := NewTestResolver()
-	r.response = newReply(mustNewRR(b, "test. A 1.2.3.4"))
+	r := testutil.NewTestResolver()
+	r.Response = newReply(mustNewRR(b, "test. A 1.2.3.4"))
 
 	c := NewCachingResolver(r)
 	c.Init()
@@ -293,8 +250,8 @@ func dumpStats() string {
 func queryA(t *testing.T, c *cachingResolver, rr, domain, expected string) *dns.Msg {
 	// Set up the response from the given RR (if any).
 	if rr != "" {
-		back := c.back.(*TestResolver)
-		back.response = newReply(mustNewRR(t, rr))
+		back := c.back.(*testutil.TestResolver)
+		back.Response = newReply(mustNewRR(t, rr))
 	}
 
 	tr := testutil.NewTestTrace(t)
@@ -320,10 +277,10 @@ func queryA(t *testing.T, c *cachingResolver, rr, domain, expected string) *dns.
 }
 
 func queryFail(t *testing.T, c *cachingResolver) *dns.Msg {
-	back := c.back.(*TestResolver)
-	back.response = &dns.Msg{}
-	back.response.Response = true
-	back.response.Rcode = dns.RcodeNameError
+	back := c.back.(*testutil.TestResolver)
+	back.Response = &dns.Msg{}
+	back.Response.Response = true
+	back.Response.Rcode = dns.RcodeNameError
 
 	tr := testutil.NewTestTrace(t)
 	defer tr.Finish()
diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go
index fe2dc52..4737911 100644
--- a/internal/testutil/testutil.go
+++ b/internal/testutil/testutil.go
@@ -8,6 +8,8 @@ import (
 	"testing"
 	"time"
 
+	"golang.org/x/net/trace"
+
 	"github.com/miekg/dns"
 )
 
@@ -62,8 +64,8 @@ func WaitForHTTPServer(addr string) error {
 	return fmt.Errorf("timed out")
 }
 
-// Get a free (TCP) port. This is hacky and not race-free, but it works well
-// enough for testing purposes.
+// GetFreePort returns a free TCP port. This is hacky and not race-free, but
+// it works well enough for testing purposes.
 func GetFreePort() string {
 	l, _ := net.Listen("tcp", "localhost:0")
 	defer l.Close()
@@ -85,6 +87,62 @@ func DNSQuery(srv, addr string, qtype uint16) (*dns.Msg, dns.RR, error) {
 	}
 }
 
+// TestResolver is a dnsserver.Resolver implementation for testing, so we can
+// control its responses during tests.
+type TestResolver struct {
+	// Has this resolver been initialized?
+	Initialized bool
+
+	// Maintain() sends a value over this channel.
+	MaintainC chan bool
+
+	// The last query we've seen.
+	LastQuery *dns.Msg
+
+	// What we will respond to queries.
+	Response  *dns.Msg
+	RespError error
+}
+
+// NewTestResolver creates a new TestResolver with minimal initialization.
+func NewTestResolver() *TestResolver {
+	return &TestResolver{
+		MaintainC: make(chan bool, 1),
+	}
+}
+
+// Init the resolver.
+func (r *TestResolver) Init() error {
+	r.Initialized = true
+	return nil
+}
+
+// Maintain the resolver.
+func (r *TestResolver) Maintain() {
+	r.MaintainC <- true
+}
+
+// Query handles the given query, returning the pre-recorded response.
+func (r *TestResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) {
+	r.LastQuery = req
+	if r.Response != nil {
+		r.Response.Question = req.Question
+		r.Response.Authoritative = true
+	}
+	return r.Response, r.RespError
+}
+
+// ServeTestDNSServer starts the fake DNS server.
+func ServeTestDNSServer(addr string, handler func(dns.ResponseWriter, *dns.Msg)) {
+	server := &dns.Server{
+		Addr:    addr,
+		Handler: dns.HandlerFunc(handler),
+		Net:     "udp",
+	}
+	err := server.ListenAndServe()
+	panic(err)
+}
+
 // TestTrace implements the tracer.Trace interface, but prints using the test
 // logging infrastructure.
 type TestTrace struct {