git » dnss » main » tree

[main] / internal / testutil / testutil.go

// Package testutil implements common testing utilities.
package testutil

import (
	"fmt"
	"net"
	"net/http"
	"testing"
	"time"

	"blitiri.com.ar/go/dnss/internal/trace"

	"github.com/miekg/dns"
)

// WaitForDNSServer waits 5 seconds for a DNS server to start, and returns an
// error if it fails to do so.
// It does this by repeatedly querying the DNS server until it either replies
// or times out. Note we do not do any validation of the reply.
func WaitForDNSServer(addr string) error {
	conn, err := dns.DialTimeout("udp", addr, 1*time.Second)
	if err != nil {
		return fmt.Errorf("dns.Dial error: %v", err)
	}
	defer conn.Close()

	m := &dns.Msg{}
	m.SetQuestion("unused.", dns.TypeA)

	deadline := time.Now().Add(5 * time.Second)
	tick := time.Tick(100 * time.Millisecond)

	for (<-tick).Before(deadline) {
		conn.SetDeadline(time.Now().Add(1 * time.Second))
		conn.WriteMsg(m)
		_, err := conn.ReadMsg()
		if err == nil {
			return nil
		}
	}

	return fmt.Errorf("timed out")
}

// WaitForHTTPServer waits 5 seconds for an HTTP server to start, and returns
// an error if it fails to do so.
// It does this by repeatedly querying the server until it either replies or
// times out.
func WaitForHTTPServer(addr string) error {
	c := http.Client{
		Timeout: 100 * time.Millisecond,
	}

	deadline := time.Now().Add(5 * time.Second)
	tick := time.Tick(100 * time.Millisecond)

	for (<-tick).Before(deadline) {
		_, err := c.Get("http://" + addr + "/testpoke")
		if err == nil {
			return nil
		}
	}

	return fmt.Errorf("timed out")
}

// GetFreePort returns a free TCP port. This is hacky and not race-free, but
// it works well enough for testing purposes.
func GetFreePort() string {
	l, _ := net.Listen("tcp", "localhost:0")
	defer l.Close()
	return l.Addr().String()
}

// DNSQuery is a convenient wrapper to issue simple DNS queries.
func DNSQuery(srv, addr string, qtype uint16) (*dns.Msg, dns.RR, error) {
	m := new(dns.Msg)
	m.SetQuestion(addr, qtype)
	in, err := dns.Exchange(m, srv)

	if err != nil {
		return nil, nil, err
	} else if len(in.Answer) > 0 {
		return in, in.Answer[0], nil
	} else {
		return in, nil, nil
	}
}

// TestResolver is a dnsserver.Resolver implementation for testing, so we can
// control its responses during tests.
type TestResolver struct {
	// Has this resolver been initialized?
	Initialized bool

	// Maintain() sends a value over this channel.
	MaintainC chan bool

	// The last query we've seen.
	LastQuery *dns.Msg

	// What we will respond to queries.
	Response  *dns.Msg
	RespError error
}

// NewTestResolver creates a new TestResolver with minimal initialization.
func NewTestResolver() *TestResolver {
	return &TestResolver{
		MaintainC: make(chan bool, 1),
	}
}

// Init the resolver.
func (r *TestResolver) Init() error {
	r.Initialized = true
	return nil
}

// Maintain the resolver.
func (r *TestResolver) Maintain() {
	r.MaintainC <- true
}

// Query handles the given query, returning the pre-recorded response.
func (r *TestResolver) Query(req *dns.Msg, tr *trace.Trace) (*dns.Msg, error) {
	r.LastQuery = req
	if r.Response != nil {
		r.Response.Question = req.Question
		r.Response.Authoritative = true
	}
	return r.Response, r.RespError
}

// ServeTestDNSServer starts the fake DNS server.
func ServeTestDNSServer(addr string, handler func(dns.ResponseWriter, *dns.Msg)) {
	server := &dns.Server{
		Addr:    addr,
		Handler: dns.HandlerFunc(handler),
		Net:     "udp",
	}
	err := server.ListenAndServe()
	panic(err)
}

// MakeStaticHandler for the DNS server. The given answer must be a valid
// zone.
func MakeStaticHandler(tb testing.TB, answer string) func(dns.ResponseWriter, *dns.Msg) {
	rr := NewRR(tb, answer)

	return func(w dns.ResponseWriter, r *dns.Msg) {
		m := &dns.Msg{}
		m.SetReply(r)
		m.Answer = append(m.Answer, rr)
		w.WriteMsg(m)
	}
}

func NewRR(tb testing.TB, s string) dns.RR {
	rr, err := dns.NewRR(s)
	if err != nil {
		tb.Fatalf("Error parsing RR for testing: %v", err)
	}
	return rr
}