author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-07-30 11:51:49 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-07-30 16:21:46 UTC |
parent | 8cdf69a054b39dbe955c3e7ca68deb5f0b7b8a0e |
dnss_test.go | +194 | -0 |
internal/httpstodns/server.go | +8 | -1 |
testing/https/https_test.go | +4 | -17 |
testing/util/util.go | +38 | -0 |
diff --git a/dnss_test.go b/dnss_test.go new file mode 100644 index 0000000..4ad62a0 --- /dev/null +++ b/dnss_test.go @@ -0,0 +1,194 @@ +// End to end tests. +package main + +import ( + "flag" + "fmt" + "os" + "strings" + "sync" + "testing" + + "blitiri.com.ar/go/dnss/internal/dnstohttps" + "blitiri.com.ar/go/dnss/internal/httpstodns" + "blitiri.com.ar/go/dnss/testing/util" + "github.com/golang/glog" + "github.com/miekg/dns" +) + +// Setup: +// DNS client -> DNS-to-HTTPS -> HTTPS-to-DNS -> DNS server +// +// The DNS client will be created on each test. +// The DNS server will be created below, and the tests can adjust its +// responses as needed. + +// Address of the DNS-to-HTTPS server, for the tests to use. +var ServerAddr string + +// realMain is the real main function, which returns the value to pass to +// os.Exit(). We have to do this so we can use defer. +func realMain(m *testing.M) int { + flag.Parse() + defer glog.Flush() + + DNSToHTTPSAddr := util.GetFreePort() + HTTPSToDNSAddr := util.GetFreePort() + DNSServerAddr := util.GetFreePort() + + // We want tests talking to the DNS-to-HTTPS server, the first in the + // chain. + ServerAddr = DNSToHTTPSAddr + + // DNS to HTTPS server. + r := dnstohttps.NewHTTPSResolver("http://"+HTTPSToDNSAddr+"/resolve", "") + dtoh := dnstohttps.New(DNSToHTTPSAddr, r, "") + go dtoh.ListenAndServe() + + // HTTPS to DNS server. + htod := httpstodns.Server{ + Addr: HTTPSToDNSAddr, + Upstream: DNSServerAddr, + } + httpstodns.InsecureForTesting = true + go htod.ListenAndServe() + + // Fake DNS server. + go ServeFakeDNSServer(DNSServerAddr) + + // Wait for the servers to start up. + err1 := util.WaitForDNSServer(DNSToHTTPSAddr) + err2 := util.WaitForHTTPServer(HTTPSToDNSAddr) + err3 := util.WaitForDNSServer(DNSServerAddr) + if err1 != nil || err2 != nil || err3 != nil { + fmt.Printf("Error waiting for the test servers to start:\n") + fmt.Printf(" DNS to HTTPS: %v\n", err1) + fmt.Printf(" HTTPS to DNS: %v\n", err2) + fmt.Printf(" DNS server: %v\n", err3) + fmt.Printf("Check the INFO logs for more details\n") + return 1 + } + return m.Run() +} + +func TestMain(m *testing.M) { + os.Exit(realMain(m)) +} + +// Fake DNS server. +func ServeFakeDNSServer(addr string) { + server := &dns.Server{ + Addr: addr, + Handler: dns.HandlerFunc(handleFakeDNS), + Net: "udp", + } + err := server.ListenAndServe() + panic(err) +} + +// DNS answers to give, as a map of "name type" -> []RR. +// Tests will modify this according to their needs. +var answers map[string][]dns.RR +var answersMu sync.Mutex + +func resetAnswers() { + answersMu.Lock() + answers = map[string][]dns.RR{} + answersMu.Unlock() +} + +func addAnswers(tb testing.TB, zone string) { + for x := range dns.ParseZone(strings.NewReader(zone), "", "") { + if x.Error != nil { + tb.Fatalf("error parsing zone: %v\n", x.Error) + return + } + + hdr := x.RR.Header() + key := fmt.Sprintf("%s %d", hdr.Name, hdr.Rrtype) + answersMu.Lock() + answers[key] = append(answers[key], x.RR) + answersMu.Unlock() + } +} + +func handleFakeDNS(w dns.ResponseWriter, r *dns.Msg) { + m := &dns.Msg{} + m.SetReply(r) + + if len(r.Question) != 1 { + w.WriteMsg(m) + return + } + + q := r.Question[0] + if testing.Verbose() { + fmt.Printf("fake dns <- %v\n", q) + } + + key := fmt.Sprintf("%s %d", q.Name, q.Qtype) + answersMu.Lock() + if rrs, ok := answers[key]; ok { + m.Answer = rrs + } else { + m.Rcode = dns.RcodeNameError + } + answersMu.Unlock() + + if testing.Verbose() { + fmt.Printf("fake dns -> %v | %v\n", + dns.RcodeToString[m.Rcode], m.Answer) + } + w.WriteMsg(m) +} + +// +// Tests +// + +func TestSimple(t *testing.T) { + resetAnswers() + addAnswers(t, "test.blah. A 1.2.3.4") + _, ans, err := util.DNSQuery(ServerAddr, "test.blah.", dns.TypeA) + if err != nil { + t.Errorf("dns query returned error: %v", err) + } + if ans.(*dns.A).A.String() != "1.2.3.4" { + t.Errorf("unexpected result: %q", ans) + } + + addAnswers(t, "test.blah. MX 10 mail.test.blah.") + _, ans, err = util.DNSQuery(ServerAddr, "test.blah.", dns.TypeMX) + if err != nil { + t.Errorf("dns query returned error: %v", err) + } + if ans.(*dns.MX).Mx != "mail.test.blah." { + t.Errorf("unexpected result: %q", ans.(*dns.MX).Mx) + } + + in, _, err := util.DNSQuery(ServerAddr, "unknown.", dns.TypeA) + if err != nil { + t.Errorf("dns query returned error: %v", err) + } + if in.Rcode != dns.RcodeNameError { + t.Errorf("unexpected result: %q", in) + } +} + +// +// Benchmarks +// + +func BenchmarkSimple(b *testing.B) { + resetAnswers() + addAnswers(b, "test.blah. A 1.2.3.4") + b.ResetTimer() + + var err error + for i := 0; i < b.N; i++ { + _, _, err = util.DNSQuery(ServerAddr, "test.blah.", dns.TypeA) + if err != nil { + b.Errorf("dns query returned error: %v", err) + } + } +} diff --git a/internal/httpstodns/server.go b/internal/httpstodns/server.go index be0150e..c66d661 100644 --- a/internal/httpstodns/server.go +++ b/internal/httpstodns/server.go @@ -25,6 +25,8 @@ type Server struct { KeyFile string } +var InsecureForTesting = false + func (s *Server) ListenAndServe() { mux := http.NewServeMux() mux.HandleFunc("/resolve", s.Resolve) @@ -34,7 +36,12 @@ func (s *Server) ListenAndServe() { } glog.Infof("HTTPS listening on %s", s.Addr) - err := srv.ListenAndServeTLS(s.CertFile, s.KeyFile) + var err error + if InsecureForTesting { + err = srv.ListenAndServe() + } else { + err = srv.ListenAndServeTLS(s.CertFile, s.KeyFile) + } glog.Fatalf("HTTPS exiting: %s", err) } diff --git a/testing/https/https_test.go b/testing/https/https_test.go index 3b5d31c..a29ced9 100644 --- a/testing/https/https_test.go +++ b/testing/https/https_test.go @@ -19,22 +19,9 @@ import ( // // === Tests === // -func dnsQuery(addr string, qtype uint16) (*dns.Msg, dns.RR, error) { - m := new(dns.Msg) - m.SetQuestion(addr, qtype) - in, err := dns.Exchange(m, DNSAddr) - - if err != nil { - return nil, nil, err - } else if len(in.Answer) > 0 { - return in, in.Answer[0], nil - } else { - return in, nil, nil - } -} func TestSimple(t *testing.T) { - _, ans, err := dnsQuery("test.blah.", dns.TypeA) + _, ans, err := util.DNSQuery(DNSAddr, "test.blah.", dns.TypeA) if err != nil { t.Errorf("dns query returned error: %v", err) } @@ -42,7 +29,7 @@ func TestSimple(t *testing.T) { t.Errorf("unexpected result: %q", ans) } - _, ans, err = dnsQuery("test.blah.", dns.TypeMX) + _, ans, err = util.DNSQuery(DNSAddr, "test.blah.", dns.TypeMX) if err != nil { t.Errorf("dns query returned error: %v", err) } @@ -50,7 +37,7 @@ func TestSimple(t *testing.T) { t.Errorf("unexpected result: %q", ans.(*dns.MX).Mx) } - in, _, err := dnsQuery("unknown.", dns.TypeA) + in, _, err := util.DNSQuery(DNSAddr, "unknown.", dns.TypeA) if err != nil { t.Errorf("dns query returned error: %v", err) } @@ -66,7 +53,7 @@ func TestSimple(t *testing.T) { func BenchmarkHTTPSimple(b *testing.B) { var err error for i := 0; i < b.N; i++ { - _, _, err = dnsQuery("test.blah.", dns.TypeA) + _, _, err = util.DNSQuery(DNSAddr, "test.blah.", dns.TypeA) if err != nil { b.Errorf("dns query returned error: %v", err) } diff --git a/testing/util/util.go b/testing/util/util.go index 52a080f..748fe5b 100644 --- a/testing/util/util.go +++ b/testing/util/util.go @@ -4,6 +4,7 @@ package util import ( "fmt" "net" + "net/http" "testing" "time" @@ -39,6 +40,28 @@ func WaitForDNSServer(addr string) error { return fmt.Errorf("timed out") } +// WaitForHTTPServer waits 5 seconds for an HTTP server to start, and returns +// an error if it fails to do so. +// It does this by repeatedly querying the server until it either replies or +// times out. +func WaitForHTTPServer(addr string) error { + c := http.Client{ + Timeout: 100 * time.Millisecond, + } + + deadline := time.Now().Add(5 * time.Second) + tick := time.Tick(100 * time.Millisecond) + + for (<-tick).Before(deadline) { + _, err := c.Get("http://" + addr + "/testpoke") + if err == nil { + return nil + } + } + + return fmt.Errorf("timed out") +} + // Get a free (TCP) port. This is hacky and not race-free, but it works well // enough for testing purposes. func GetFreePort() string { @@ -47,6 +70,21 @@ func GetFreePort() string { return l.Addr().String() } +// DNSQuery is a convenient wrapper to issue simple DNS queries. +func DNSQuery(srv, addr string, qtype uint16) (*dns.Msg, dns.RR, error) { + m := new(dns.Msg) + m.SetQuestion(addr, qtype) + in, err := dns.Exchange(m, srv) + + if err != nil { + return nil, nil, err + } else if len(in.Answer) > 0 { + return in, in.Answer[0], nil + } else { + return in, nil, nil + } +} + // TestTrace implements the tracer.Trace interface, but prints using the test // logging infrastructure. type TestTrace struct {