author | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-05-28 12:58:19 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-05-28 16:55:03 UTC |
parent | 4e21a7186431b64a0dbe3359c967da3678952104 |
internal/dnstox/caching_test.go | +339 | -0 |
internal/dnstox/resolver.go | +3 | -1 |
testing/util/util.go | +26 | -0 |
diff --git a/internal/dnstox/caching_test.go b/internal/dnstox/caching_test.go new file mode 100644 index 0000000..f798f2b --- /dev/null +++ b/internal/dnstox/caching_test.go @@ -0,0 +1,339 @@ +package dnstox + +// Tests for the caching resolver. +// Note the other resolvers have more functional tests in the testing/ +// directory. + +import ( + "fmt" + "reflect" + "strconv" + "testing" + "time" + + "blitiri.com.ar/go/dnss/testing/util" + + "github.com/miekg/dns" + "golang.org/x/net/trace" +) + +// A test resolver that we use as backing for the caching resolver under test. +type TestResolver struct { + // Has this resolver been initialized? + init bool + + // Maintain() sends a value over this channel. + maintain chan bool + + // The last query we've seen. + lastQuery *dns.Msg + + // What we will respond to queries. + response *dns.Msg + respError error +} + +func NewTestResolver() *TestResolver { + return &TestResolver{ + maintain: make(chan bool, 1), + } +} + +func (r *TestResolver) Init() error { + r.init = true + return nil +} + +func (r *TestResolver) Maintain() { + r.maintain <- true +} + +func (r *TestResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { + r.lastQuery = req + if r.response != nil { + r.response.Question = req.Question + r.response.Authoritative = true + } + return r.response, r.respError +} + +// +// === Tests === +// + +// Test basic functionality. +func TestBasic(t *testing.T) { + r := NewTestResolver() + + c := NewCachingResolver(r) + + c.Init() + if !r.init { + t.Errorf("caching resolver did not initialize backing") + } + go c.Maintain() + + // Check that the back resolver's Maintain() is called. + select { + case <-r.maintain: + t.Log("Maintain() called") + case <-time.After(1 * time.Second): + t.Errorf("back resolver Maintain() was not called") + } + + resetStats() + + resp := queryA(t, c, "test. A 1.2.3.4", "test.", "1.2.3.4") + if !statsEquals(1, 0, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + if !resp.Authoritative { + t.Errorf("cache miss was not authoritative") + } + + // Same query, should be cached. + resp = queryA(t, c, "", "test.", "1.2.3.4") + if !statsEquals(2, 1, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + if resp.Authoritative { + t.Errorf("cache hit was authoritative") + } +} + +// Test TTL handling. +func TestTTL(t *testing.T) { + r := NewTestResolver() + c := NewCachingResolver(r) + c.Init() + resetStats() + + // Note we don't start c.Maintain() yet, as we don't want the background + // TTL updater until later. + + // Test a record with a larger-than-max TTL (1 day). + // The TTL of the response should be capped. + resp := queryA(t, c, "test. 86400 A 1.2.3.4", "test.", "1.2.3.4") + if !statsEquals(1, 0, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + if ttl := getTTL(resp.Answer); ttl != maxTTL { + t.Errorf("expected max TTL (%v), got %v", maxTTL, ttl) + } + + // Same query, should be cached, and TTL also capped. + // As we've not enabled cache maintenance, we can be sure TTL == maxTTL. + resp = queryA(t, c, "", "test.", "1.2.3.4") + if !statsEquals(2, 1, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + if ttl := getTTL(resp.Answer); ttl != maxTTL { + t.Errorf("expected max TTL (%v), got %v", maxTTL, ttl) + } + + // To test that the TTL is reduced appropriately, set a small maintenance + // period, and then repeatedly query the record. We should see its TTL + // shrinking down within 1s. + // Even though the TTL resolution in the protocol is in seconds, we don't + // need to wait that much "thanks" to rounding artifacts. + maintenancePeriod = 50 * time.Millisecond + go c.Maintain() + resetStats() + + start := time.Now() + for time.Since(start) < 1*time.Second { + resp = queryA(t, c, "", "test.", "1.2.3.4") + t.Logf("TTL %v", getTTL(resp.Answer)) + if ttl := getTTL(resp.Answer); ttl <= (maxTTL - 1*time.Second) { + break + } + time.Sleep(maintenancePeriod) + } + if ttl := getTTL(resp.Answer); ttl > (maxTTL - 1*time.Second) { + t.Errorf("expected maxTTL-1s, got %v", ttl) + } +} + +// Test that we don't cache failed queries. +func TestFailedQueries(t *testing.T) { + r := NewTestResolver() + c := NewCachingResolver(r) + c.Init() + resetStats() + + // Do two failed identical queries, check that both are cache misses. + queryFail(t, c) + if !statsEquals(1, 0, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + + queryFail(t, c) + if !statsEquals(2, 0, 2) { + t.Errorf("bad stats: %v", dumpStats()) + } +} + +// Test that we handle the cache filling up. +// Note this test is tied to the current behaviour of not doing any eviction +// when we're full, which is not ideal and will likely be changed in the +// future. +func TestCacheFull(t *testing.T) { + r := NewTestResolver() + c := NewCachingResolver(r) + c.Init() + resetStats() + + r.response = newReply(mustNewRR(t, "test. A 1.2.3.4")) + + // Do maxCacheSize+1 different requests. + for i := 0; i < maxCacheSize+1; i++ { + queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4") + if !statsEquals(i+1, 0, i+1) { + t.Errorf("bad stats: %v", dumpStats()) + } + } + + // Query up to maxCacheSize, they should all be hits. + resetStats() + for i := 0; i < maxCacheSize; i++ { + queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4") + if !statsEquals(i+1, i+1, 0) { + t.Errorf("bad stats: %v", dumpStats()) + } + } + + // Querying maxCacheSize+1 should be a miss, because the cache was full. + resetStats() + queryA(t, c, "", fmt.Sprintf("test%d.", maxCacheSize), "1.2.3.4") + if !statsEquals(1, 0, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } +} + +// Test behaviour when the size of the cache is 0 (so users can disable it +// that way). +func TestZeroSize(t *testing.T) { + r := NewTestResolver() + c := NewCachingResolver(r) + c.Init() + resetStats() + + // Override the max cache size to 0. + prevMaxCacheSize := maxCacheSize + maxCacheSize = 0 + defer func() { maxCacheSize = prevMaxCacheSize }() + + r.response = newReply(mustNewRR(t, "test. A 1.2.3.4")) + + // Do 5 different requests. + for i := 0; i < 5; i++ { + queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4") + if !statsEquals(i+1, 0, i+1) { + t.Errorf("bad stats: %v", dumpStats()) + } + } + + // Query them back, they should all be misses. + resetStats() + for i := 0; i < 5; i++ { + queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4") + if !statsEquals(i+1, 0, i+1) { + t.Errorf("bad stats: %v", dumpStats()) + } + } +} + +// +// === Helpers === +// + +func resetStats() { + stats.cacheTotal.Set(0) + stats.cacheBypassed.Set(0) + stats.cacheHits.Set(0) + stats.cacheMisses.Set(0) + stats.cacheRecorded.Set(0) +} + +func statsEquals(total, hits, misses int) bool { + return (stats.cacheTotal.String() == strconv.Itoa(total) && + stats.cacheHits.String() == strconv.Itoa(hits) && + stats.cacheMisses.String() == strconv.Itoa(misses)) +} + +func dumpStats() string { + return fmt.Sprintf("(t:%v h:%s m:%v)", + stats.cacheTotal, stats.cacheHits, stats.cacheMisses) +} + +func queryA(t *testing.T, c *cachingResolver, rr, domain, expected string) *dns.Msg { + // Set up the response from the given RR (if any). + if rr != "" { + back := c.back.(*TestResolver) + back.response = newReply(mustNewRR(t, rr)) + } + + tr := util.NewTestTrace(t) + defer tr.Finish() + + req := newQuery(domain, dns.TypeA) + resp, err := c.Query(req, tr) + if err != nil { + t.Fatalf("query failed: %v", err) + } + + a := resp.Answer[0].(*dns.A) + if a.A.String() != expected { + t.Errorf("expected %s, got %v", expected, a.A) + } + + if !reflect.DeepEqual(req.Question, resp.Question) { + t.Errorf("question mis-match: request %v, response %v", + req.Question, resp.Question) + } + + return resp +} + +func queryFail(t *testing.T, c *cachingResolver) *dns.Msg { + back := c.back.(*TestResolver) + back.response = &dns.Msg{} + back.response.Response = true + back.response.Rcode = dns.RcodeNameError + + tr := util.NewTestTrace(t) + defer tr.Finish() + + req := newQuery("doesnotexist.", dns.TypeA) + resp, err := c.Query(req, tr) + if err != nil { + t.Fatalf("query failed: %v", err) + } + + return resp +} + +func mustNewRR(t *testing.T, s string) dns.RR { + rr, err := dns.NewRR(s) + if err != nil { + t.Fatalf("invalid RR %q: %v", s, err) + } + return rr +} + +func newQuery(domain string, t uint16) *dns.Msg { + m := &dns.Msg{} + m.SetQuestion(domain, t) + return m +} + +func newReply(answer dns.RR) *dns.Msg { + return &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Authoritative: false, + Rcode: dns.RcodeSuccess, + }, + Answer: []dns.RR{answer}, + } +} diff --git a/internal/dnstox/resolver.go b/internal/dnstox/resolver.go index 831d8b4..a90326b 100644 --- a/internal/dnstox/resolver.go +++ b/internal/dnstox/resolver.go @@ -289,7 +289,9 @@ func NewCachingResolver(back Resolver) *cachingResolver { } } -const ( +// Constants that tune the cache. +// They are declared as variables so we can tweak them for testing. +var ( // Maximum number of entries we keep in the cache. // 2k should be reasonable for a small network. // Keep in mind that increasing this too much will interact negatively diff --git a/testing/util/util.go b/testing/util/util.go index 9964691..1f046a2 100644 --- a/testing/util/util.go +++ b/testing/util/util.go @@ -3,6 +3,7 @@ package util import ( "fmt" + "testing" "time" "github.com/miekg/dns" @@ -38,3 +39,28 @@ func WaitForDNSServer(addr string) error { return fmt.Errorf("not reachable") } + +// TestTrace implements the tracer.Trace interface, but prints using the test +// logging infrastructure. +type TestTrace struct { + T *testing.T +} + +func NewTestTrace(t *testing.T) *TestTrace { + return &TestTrace{t} +} + +func (t *TestTrace) LazyLog(x fmt.Stringer, sensitive bool) { + t.T.Logf("trace %p (%b): %s", t, sensitive, x) +} + +func (t *TestTrace) LazyPrintf(format string, a ...interface{}) { + prefix := fmt.Sprintf("trace %p: ", t) + t.T.Logf(prefix+format, a...) +} + +func (t *TestTrace) SetError() {} +func (t *TestTrace) SetRecycler(f func(interface{})) {} +func (t *TestTrace) SetTraceInfo(traceID, spanID uint64) {} +func (t *TestTrace) SetMaxEvents(m int) {} +func (t *TestTrace) Finish() {}