package dnsserver
// Tests for the caching resolver.
import (
"fmt"
"reflect"
"strconv"
"strings"
"testing"
"time"
"blitiri.com.ar/go/dnss/internal/testutil"
"blitiri.com.ar/go/dnss/internal/trace"
"github.com/miekg/dns"
)
//
// === Tests ===
//
// Test basic functionality.
func TestBasic(t *testing.T) {
r := testutil.NewTestResolver()
c := NewCachingResolver(r)
c.Init()
if !r.Initialized {
t.Errorf("caching resolver did not initialize backing")
}
resetStats()
resp := queryA(t, c, "test. A 1.2.3.4", "test.", "1.2.3.4")
if !statsEquals(1, 0, 1) {
t.Errorf("bad stats: %v", dumpStats())
}
if !resp.Authoritative {
t.Errorf("cache miss was not authoritative")
}
// Same query, should be cached.
resp = queryA(t, c, "", "test.", "1.2.3.4")
if !statsEquals(2, 1, 1) {
t.Errorf("bad stats: %v", dumpStats())
}
if resp.Authoritative {
t.Errorf("cache hit was authoritative")
}
}
// Test TTL handling.
func TestTTL(t *testing.T) {
r := testutil.NewTestResolver()
c := NewCachingResolver(r)
c.Init()
resetStats()
// Note we don't start c.Maintain() yet, as we don't want the background
// TTL updater until later.
// Test a record with a larger-than-max TTL (1 day).
// The TTL of the response should be capped.
resp := queryA(t, c, "test. 86400 A 1.2.3.4", "test.", "1.2.3.4")
if !statsEquals(1, 0, 1) {
t.Errorf("bad stats: %v", dumpStats())
}
if ttl := getTTL(resp.Answer); ttl != maxTTL {
t.Errorf("expected max TTL (%v), got %v", maxTTL, ttl)
}
// Same query, should be cached, and TTL also capped.
// As we've not enabled cache maintenance, we can be sure TTL == maxTTL.
resp = queryA(t, c, "", "test.", "1.2.3.4")
if !statsEquals(2, 1, 1) {
t.Errorf("bad stats: %v", dumpStats())
}
if ttl := getTTL(resp.Answer); ttl != maxTTL {
t.Errorf("expected max TTL (%v), got %v", maxTTL, ttl)
}
// To test that the TTL is reduced appropriately, set a small maintenance
// period, and then repeatedly query the record. We should see its TTL
// shrinking down within 1s.
// Even though the TTL resolution in the protocol is in seconds, we don't
// need to wait that much "thanks" to rounding artifacts.
maintenancePeriod = 50 * time.Millisecond
go c.Maintain()
resetStats()
// Check that the back resolver's Maintain() is called.
select {
case <-r.MaintainC:
t.Log("Maintain() called")
case <-time.After(1 * time.Second):
t.Errorf("back resolver Maintain() was not called")
}
start := time.Now()
for time.Since(start) < 1*time.Second {
resp = queryA(t, c, "", "test.", "1.2.3.4")
t.Logf("TTL %v", getTTL(resp.Answer))
if ttl := getTTL(resp.Answer); ttl <= (maxTTL - 1*time.Second) {
break
}
time.Sleep(maintenancePeriod)
}
if ttl := getTTL(resp.Answer); ttl > (maxTTL - 1*time.Second) {
t.Errorf("expected maxTTL-1s, got %v", ttl)
}
}
// Test that we don't cache failed queries.
func TestFailedQueries(t *testing.T) {
r := testutil.NewTestResolver()
c := NewCachingResolver(r)
c.Init()
resetStats()
// Do two failed identical queries, check that both are cache misses.
queryFail(t, c)
if !statsEquals(1, 0, 1) {
t.Errorf("bad stats: %v", dumpStats())
}
queryFail(t, c)
if !statsEquals(2, 0, 2) {
t.Errorf("bad stats: %v", dumpStats())
}
}
func TestWantToCache(t *testing.T) {
query := newQuery("test.", dns.TypeA)
q := query.Question[0]
reply := newReply(mustNewRR(t, "test. A 1.2.3.4"))
reply.Question = []dns.Question{q}
if err := wantToCache(q, reply); err != nil {
t.Errorf("wantToCache failed on cacheable request: %v", err)
}
r := reply.Copy()
r.Rcode = dns.RcodeBadName
checkWantToCache(t, q, r, "unsuccessful query")
r = reply.Copy()
r.Response = false
checkWantToCache(t, q, r, "response = false")
r = reply.Copy()
r.Opcode = dns.OpcodeUpdate
checkWantToCache(t, q, r, "opcode")
r = reply.Copy()
r.Answer = []dns.RR{}
checkWantToCache(t, q, r, "answer is empty")
r = reply.Copy()
r.Truncated = true
checkWantToCache(t, q, r, "truncated reply")
r = reply.Copy()
r.Question = []dns.Question{q, q}
checkWantToCache(t, q, r, "too many/few questions (2)")
r = reply.Copy()
r.Question = []dns.Question{}
checkWantToCache(t, q, r, "too many/few questions (0)")
r = reply.Copy()
r.Question = []dns.Question{
{"other.", dns.TypeMX, dns.ClassINET}}
checkWantToCache(t, q, r, "reply question does not match")
}
// Test that we handle the cache filling up.
// Note this test is tied to the current behaviour of not doing any eviction
// when we're full, which is not ideal and will likely be changed in the
// future.
func TestCacheFull(t *testing.T) {
r := testutil.NewTestResolver()
c := NewCachingResolver(r)
c.Init()
resetStats()
r.Response = newReply(mustNewRR(t, "test. A 1.2.3.4"))
// Do maxCacheSize+1 different requests.
for i := 0; i < maxCacheSize+1; i++ {
queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4")
if !statsEquals(i+1, 0, i+1) {
t.Errorf("bad stats: %v", dumpStats())
}
}
// Query up to maxCacheSize, they should all be hits.
resetStats()
for i := 0; i < maxCacheSize; i++ {
queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4")
if !statsEquals(i+1, i+1, 0) {
t.Errorf("bad stats: %v", dumpStats())
}
}
// Querying maxCacheSize+1 should be a miss, because the cache was full.
resetStats()
queryA(t, c, "", fmt.Sprintf("test%d.", maxCacheSize), "1.2.3.4")
if !statsEquals(1, 0, 1) {
t.Errorf("bad stats: %v", dumpStats())
}
}
// Test behaviour when the size of the cache is 0 (so users can disable it
// that way).
func TestZeroSize(t *testing.T) {
r := testutil.NewTestResolver()
c := NewCachingResolver(r)
c.Init()
resetStats()
// Override the max cache size to 0.
prevMaxCacheSize := maxCacheSize
maxCacheSize = 0
defer func() { maxCacheSize = prevMaxCacheSize }()
r.Response = newReply(mustNewRR(t, "test. A 1.2.3.4"))
// Do 5 different requests.
for i := 0; i < 5; i++ {
queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4")
if !statsEquals(i+1, 0, i+1) {
t.Errorf("bad stats: %v", dumpStats())
}
}
// Query them back, they should all be misses.
resetStats()
for i := 0; i < 5; i++ {
queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4")
if !statsEquals(i+1, 0, i+1) {
t.Errorf("bad stats: %v", dumpStats())
}
}
}
//
// === Benchmarks ===
//
func BenchmarkCacheSimple(b *testing.B) {
var err error
r := testutil.NewTestResolver()
r.Response = newReply(mustNewRR(b, "test. A 1.2.3.4"))
c := NewCachingResolver(r)
c.Init()
tr := trace.New("test", "Benchmark")
defer tr.Finish()
req := newQuery("test.", dns.TypeA)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err = c.Query(req, tr)
if err != nil {
b.Errorf("query failed: %v", err)
}
}
}
//
// === Helpers ===
//
func resetStats() {
stats.cacheTotal.Set(0)
stats.cacheBypassed.Set(0)
stats.cacheHits.Set(0)
stats.cacheMisses.Set(0)
stats.cacheRecorded.Set(0)
}
func statsEquals(total, hits, misses int) bool {
return (stats.cacheTotal.String() == strconv.Itoa(total) &&
stats.cacheHits.String() == strconv.Itoa(hits) &&
stats.cacheMisses.String() == strconv.Itoa(misses))
}
func dumpStats() string {
return fmt.Sprintf("(t:%v h:%s m:%v)",
stats.cacheTotal, stats.cacheHits, stats.cacheMisses)
}
func queryA(t *testing.T, c *cachingResolver, rr, domain, expected string) *dns.Msg {
// Set up the response from the given RR (if any).
if rr != "" {
back := c.back.(*testutil.TestResolver)
back.Response = newReply(mustNewRR(t, rr))
}
tr := trace.New("test", "queryA")
defer tr.Finish()
req := newQuery(domain, dns.TypeA)
resp, err := c.Query(req, tr)
if err != nil {
t.Fatalf("query failed: %v", err)
}
a := resp.Answer[0].(*dns.A)
if a.A.String() != expected {
t.Errorf("expected %s, got %v", expected, a.A)
}
if !reflect.DeepEqual(req.Question, resp.Question) {
t.Errorf("question mis-match: request %v, response %v",
req.Question, resp.Question)
}
return resp
}
func queryFail(t *testing.T, c *cachingResolver) *dns.Msg {
back := c.back.(*testutil.TestResolver)
back.Response = &dns.Msg{}
back.Response.Response = true
back.Response.Rcode = dns.RcodeNameError
tr := trace.New("test", "queryFail")
defer tr.Finish()
req := newQuery("doesnotexist.", dns.TypeA)
resp, err := c.Query(req, tr)
if err != nil {
t.Fatalf("query failed: %v", err)
}
return resp
}
func checkWantToCache(t *testing.T, q dns.Question, r *dns.Msg, exp string) {
t.Helper()
err := wantToCache(q, r)
if !strings.Contains(err.Error(), exp) {
t.Errorf("q:%v r:%v expected:%q got:%v", q, r, exp, err)
}
}
func mustNewRR(tb testing.TB, s string) dns.RR {
rr, err := dns.NewRR(s)
if err != nil {
tb.Fatalf("invalid RR %q: %v", s, err)
}
return rr
}
func newQuery(domain string, t uint16) *dns.Msg {
m := &dns.Msg{}
m.SetQuestion(domain, t)
return m
}
func newReply(answer dns.RR) *dns.Msg {
return &dns.Msg{
MsgHdr: dns.MsgHdr{
Response: true,
Authoritative: false,
Rcode: dns.RcodeSuccess,
},
Answer: []dns.RR{answer},
}
}