git » dnss » commit 9677b91

Add an end-to-end test

author Alberto Bertogli
2017-07-30 11:51:49 UTC
committer Alberto Bertogli
2017-07-30 16:21:46 UTC
parent 8cdf69a054b39dbe955c3e7ca68deb5f0b7b8a0e

Add an end-to-end test

This patch adds an end to end test, which sets up a chain of

  DNS to HTTPS proxy -> HTTPS to DNS proxy -> DNS server

and then issues requests to the first proxy, and validates that the
responses have made it correctly all the way through.

dnss_test.go +194 -0
internal/httpstodns/server.go +8 -1
testing/https/https_test.go +4 -17
testing/util/util.go +38 -0

diff --git a/dnss_test.go b/dnss_test.go
new file mode 100644
index 0000000..4ad62a0
--- /dev/null
+++ b/dnss_test.go
@@ -0,0 +1,194 @@
+// End to end tests.
+package main
+
+import (
+	"flag"
+	"fmt"
+	"os"
+	"strings"
+	"sync"
+	"testing"
+
+	"blitiri.com.ar/go/dnss/internal/dnstohttps"
+	"blitiri.com.ar/go/dnss/internal/httpstodns"
+	"blitiri.com.ar/go/dnss/testing/util"
+	"github.com/golang/glog"
+	"github.com/miekg/dns"
+)
+
+// Setup:
+// DNS client -> DNS-to-HTTPS -> HTTPS-to-DNS -> DNS server
+//
+// The DNS client will be created on each test.
+// The DNS server will be created below, and the tests can adjust its
+// responses as needed.
+
+// Address of the DNS-to-HTTPS server, for the tests to use.
+var ServerAddr string
+
+// realMain is the real main function, which returns the value to pass to
+// os.Exit(). We have to do this so we can use defer.
+func realMain(m *testing.M) int {
+	flag.Parse()
+	defer glog.Flush()
+
+	DNSToHTTPSAddr := util.GetFreePort()
+	HTTPSToDNSAddr := util.GetFreePort()
+	DNSServerAddr := util.GetFreePort()
+
+	// We want tests talking to the DNS-to-HTTPS server, the first in the
+	// chain.
+	ServerAddr = DNSToHTTPSAddr
+
+	// DNS to HTTPS server.
+	r := dnstohttps.NewHTTPSResolver("http://"+HTTPSToDNSAddr+"/resolve", "")
+	dtoh := dnstohttps.New(DNSToHTTPSAddr, r, "")
+	go dtoh.ListenAndServe()
+
+	// HTTPS to DNS server.
+	htod := httpstodns.Server{
+		Addr:     HTTPSToDNSAddr,
+		Upstream: DNSServerAddr,
+	}
+	httpstodns.InsecureForTesting = true
+	go htod.ListenAndServe()
+
+	// Fake DNS server.
+	go ServeFakeDNSServer(DNSServerAddr)
+
+	// Wait for the servers to start up.
+	err1 := util.WaitForDNSServer(DNSToHTTPSAddr)
+	err2 := util.WaitForHTTPServer(HTTPSToDNSAddr)
+	err3 := util.WaitForDNSServer(DNSServerAddr)
+	if err1 != nil || err2 != nil || err3 != nil {
+		fmt.Printf("Error waiting for the test servers to start:\n")
+		fmt.Printf("  DNS to HTTPS: %v\n", err1)
+		fmt.Printf("  HTTPS to DNS: %v\n", err2)
+		fmt.Printf("  DNS server:   %v\n", err3)
+		fmt.Printf("Check the INFO logs for more details\n")
+		return 1
+	}
+	return m.Run()
+}
+
+func TestMain(m *testing.M) {
+	os.Exit(realMain(m))
+}
+
+// Fake DNS server.
+func ServeFakeDNSServer(addr string) {
+	server := &dns.Server{
+		Addr:    addr,
+		Handler: dns.HandlerFunc(handleFakeDNS),
+		Net:     "udp",
+	}
+	err := server.ListenAndServe()
+	panic(err)
+}
+
+// DNS answers to give, as a map of "name type" -> []RR.
+// Tests will modify this according to their needs.
+var answers map[string][]dns.RR
+var answersMu sync.Mutex
+
+func resetAnswers() {
+	answersMu.Lock()
+	answers = map[string][]dns.RR{}
+	answersMu.Unlock()
+}
+
+func addAnswers(tb testing.TB, zone string) {
+	for x := range dns.ParseZone(strings.NewReader(zone), "", "") {
+		if x.Error != nil {
+			tb.Fatalf("error parsing zone: %v\n", x.Error)
+			return
+		}
+
+		hdr := x.RR.Header()
+		key := fmt.Sprintf("%s %d", hdr.Name, hdr.Rrtype)
+		answersMu.Lock()
+		answers[key] = append(answers[key], x.RR)
+		answersMu.Unlock()
+	}
+}
+
+func handleFakeDNS(w dns.ResponseWriter, r *dns.Msg) {
+	m := &dns.Msg{}
+	m.SetReply(r)
+
+	if len(r.Question) != 1 {
+		w.WriteMsg(m)
+		return
+	}
+
+	q := r.Question[0]
+	if testing.Verbose() {
+		fmt.Printf("fake dns <- %v\n", q)
+	}
+
+	key := fmt.Sprintf("%s %d", q.Name, q.Qtype)
+	answersMu.Lock()
+	if rrs, ok := answers[key]; ok {
+		m.Answer = rrs
+	} else {
+		m.Rcode = dns.RcodeNameError
+	}
+	answersMu.Unlock()
+
+	if testing.Verbose() {
+		fmt.Printf("fake dns -> %v | %v\n",
+			dns.RcodeToString[m.Rcode], m.Answer)
+	}
+	w.WriteMsg(m)
+}
+
+//
+// Tests
+//
+
+func TestSimple(t *testing.T) {
+	resetAnswers()
+	addAnswers(t, "test.blah. A 1.2.3.4")
+	_, ans, err := util.DNSQuery(ServerAddr, "test.blah.", dns.TypeA)
+	if err != nil {
+		t.Errorf("dns query returned error: %v", err)
+	}
+	if ans.(*dns.A).A.String() != "1.2.3.4" {
+		t.Errorf("unexpected result: %q", ans)
+	}
+
+	addAnswers(t, "test.blah. MX 10 mail.test.blah.")
+	_, ans, err = util.DNSQuery(ServerAddr, "test.blah.", dns.TypeMX)
+	if err != nil {
+		t.Errorf("dns query returned error: %v", err)
+	}
+	if ans.(*dns.MX).Mx != "mail.test.blah." {
+		t.Errorf("unexpected result: %q", ans.(*dns.MX).Mx)
+	}
+
+	in, _, err := util.DNSQuery(ServerAddr, "unknown.", dns.TypeA)
+	if err != nil {
+		t.Errorf("dns query returned error: %v", err)
+	}
+	if in.Rcode != dns.RcodeNameError {
+		t.Errorf("unexpected result: %q", in)
+	}
+}
+
+//
+// Benchmarks
+//
+
+func BenchmarkSimple(b *testing.B) {
+	resetAnswers()
+	addAnswers(b, "test.blah. A 1.2.3.4")
+	b.ResetTimer()
+
+	var err error
+	for i := 0; i < b.N; i++ {
+		_, _, err = util.DNSQuery(ServerAddr, "test.blah.", dns.TypeA)
+		if err != nil {
+			b.Errorf("dns query returned error: %v", err)
+		}
+	}
+}
diff --git a/internal/httpstodns/server.go b/internal/httpstodns/server.go
index be0150e..c66d661 100644
--- a/internal/httpstodns/server.go
+++ b/internal/httpstodns/server.go
@@ -25,6 +25,8 @@ type Server struct {
 	KeyFile  string
 }
 
+var InsecureForTesting = false
+
 func (s *Server) ListenAndServe() {
 	mux := http.NewServeMux()
 	mux.HandleFunc("/resolve", s.Resolve)
@@ -34,7 +36,12 @@ func (s *Server) ListenAndServe() {
 	}
 
 	glog.Infof("HTTPS listening on %s", s.Addr)
-	err := srv.ListenAndServeTLS(s.CertFile, s.KeyFile)
+	var err error
+	if InsecureForTesting {
+		err = srv.ListenAndServe()
+	} else {
+		err = srv.ListenAndServeTLS(s.CertFile, s.KeyFile)
+	}
 	glog.Fatalf("HTTPS exiting: %s", err)
 }
 
diff --git a/testing/https/https_test.go b/testing/https/https_test.go
index 3b5d31c..a29ced9 100644
--- a/testing/https/https_test.go
+++ b/testing/https/https_test.go
@@ -19,22 +19,9 @@ import (
 //
 // === Tests ===
 //
-func dnsQuery(addr string, qtype uint16) (*dns.Msg, dns.RR, error) {
-	m := new(dns.Msg)
-	m.SetQuestion(addr, qtype)
-	in, err := dns.Exchange(m, DNSAddr)
-
-	if err != nil {
-		return nil, nil, err
-	} else if len(in.Answer) > 0 {
-		return in, in.Answer[0], nil
-	} else {
-		return in, nil, nil
-	}
-}
 
 func TestSimple(t *testing.T) {
-	_, ans, err := dnsQuery("test.blah.", dns.TypeA)
+	_, ans, err := util.DNSQuery(DNSAddr, "test.blah.", dns.TypeA)
 	if err != nil {
 		t.Errorf("dns query returned error: %v", err)
 	}
@@ -42,7 +29,7 @@ func TestSimple(t *testing.T) {
 		t.Errorf("unexpected result: %q", ans)
 	}
 
-	_, ans, err = dnsQuery("test.blah.", dns.TypeMX)
+	_, ans, err = util.DNSQuery(DNSAddr, "test.blah.", dns.TypeMX)
 	if err != nil {
 		t.Errorf("dns query returned error: %v", err)
 	}
@@ -50,7 +37,7 @@ func TestSimple(t *testing.T) {
 		t.Errorf("unexpected result: %q", ans.(*dns.MX).Mx)
 	}
 
-	in, _, err := dnsQuery("unknown.", dns.TypeA)
+	in, _, err := util.DNSQuery(DNSAddr, "unknown.", dns.TypeA)
 	if err != nil {
 		t.Errorf("dns query returned error: %v", err)
 	}
@@ -66,7 +53,7 @@ func TestSimple(t *testing.T) {
 func BenchmarkHTTPSimple(b *testing.B) {
 	var err error
 	for i := 0; i < b.N; i++ {
-		_, _, err = dnsQuery("test.blah.", dns.TypeA)
+		_, _, err = util.DNSQuery(DNSAddr, "test.blah.", dns.TypeA)
 		if err != nil {
 			b.Errorf("dns query returned error: %v", err)
 		}
diff --git a/testing/util/util.go b/testing/util/util.go
index 52a080f..748fe5b 100644
--- a/testing/util/util.go
+++ b/testing/util/util.go
@@ -4,6 +4,7 @@ package util
 import (
 	"fmt"
 	"net"
+	"net/http"
 	"testing"
 	"time"
 
@@ -39,6 +40,28 @@ func WaitForDNSServer(addr string) error {
 	return fmt.Errorf("timed out")
 }
 
+// WaitForHTTPServer waits 5 seconds for an HTTP server to start, and returns
+// an error if it fails to do so.
+// It does this by repeatedly querying the server until it either replies or
+// times out.
+func WaitForHTTPServer(addr string) error {
+	c := http.Client{
+		Timeout: 100 * time.Millisecond,
+	}
+
+	deadline := time.Now().Add(5 * time.Second)
+	tick := time.Tick(100 * time.Millisecond)
+
+	for (<-tick).Before(deadline) {
+		_, err := c.Get("http://" + addr + "/testpoke")
+		if err == nil {
+			return nil
+		}
+	}
+
+	return fmt.Errorf("timed out")
+}
+
 // Get a free (TCP) port. This is hacky and not race-free, but it works well
 // enough for testing purposes.
 func GetFreePort() string {
@@ -47,6 +70,21 @@ func GetFreePort() string {
 	return l.Addr().String()
 }
 
+// DNSQuery is a convenient wrapper to issue simple DNS queries.
+func DNSQuery(srv, addr string, qtype uint16) (*dns.Msg, dns.RR, error) {
+	m := new(dns.Msg)
+	m.SetQuestion(addr, qtype)
+	in, err := dns.Exchange(m, srv)
+
+	if err != nil {
+		return nil, nil, err
+	} else if len(in.Answer) > 0 {
+		return in, in.Answer[0], nil
+	} else {
+		return in, nil, nil
+	}
+}
+
 // TestTrace implements the tracer.Trace interface, but prints using the test
 // logging infrastructure.
 type TestTrace struct {