git » dnss » commit 5d88590

dnstogrpc: Add a caching resolver

author Alberto Bertogli
2015-10-24 15:14:44 UTC
committer Alberto Bertogli
2015-10-25 00:31:20 UTC
parent a440db5cf5d32b9b305b9e117a0883ecfbae85ee

dnstogrpc: Add a caching resolver

This patch adds a caching resolver to dnstogrpc.

It will keep a cache of recent requests, and avoid the query to GRPC if
possible. The implementation assumes dnss will be used as the local server for
a small network, and is likely to be unsuitable for more generic use.

dnss.go +2 -1
dnss_test.go +3 -2
dnstogrpc/dnstogrpc.go +3 -1
dnstogrpc/resolver.go +273 -2

diff --git a/dnss.go b/dnss.go
index c6527d9..ea6b659 100644
--- a/dnss.go
+++ b/dnss.go
@@ -81,7 +81,8 @@ func main() {
 	// DNS to GRPC.
 	if *enableDNStoGRPC {
 		r := dnstogrpc.NewGRPCResolver(*grpcUpstream, *grpcClientCAFile)
-		dtg := dnstogrpc.New(*dnsListenAddr, r, *dnsUnqualifiedUpstream)
+		cr := dnstogrpc.NewCachingResolver(r)
+		dtg := dnstogrpc.New(*dnsListenAddr, cr, *dnsUnqualifiedUpstream)
 		wg.Add(1)
 		go func() {
 			defer wg.Done()
diff --git a/dnss_test.go b/dnss_test.go
index f0aab8a..391294e 100644
--- a/dnss_test.go
+++ b/dnss_test.go
@@ -230,8 +230,9 @@ func realMain(m *testing.M) int {
 	}
 
 	// DNS to GRPC server.
-	r := dnstogrpc.NewGRPCResolver(grpcToDnsAddr, tmpDir+"/cert.pem")
-	dtg := dnstogrpc.New(dnsToGrpcAddr, r, "")
+	gr := dnstogrpc.NewGRPCResolver(grpcToDnsAddr, tmpDir+"/cert.pem")
+	cr := dnstogrpc.NewCachingResolver(gr)
+	dtg := dnstogrpc.New(dnsToGrpcAddr, cr, "")
 	go dtg.ListenAndServe()
 
 	// GRPC to DNS server.
diff --git a/dnstogrpc/dnstogrpc.go b/dnstogrpc/dnstogrpc.go
index 8e1a490..11a7c65 100644
--- a/dnstogrpc/dnstogrpc.go
+++ b/dnstogrpc/dnstogrpc.go
@@ -90,7 +90,7 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
 	oldid := r.Id
 	r.Id = <-newId
 
-	from_up, err := s.resolver.Query(r)
+	from_up, err := s.resolver.Query(r, tr)
 	if err != nil {
 		glog.Infof(err.Error())
 		tr.LazyPrintf(err.Error())
@@ -113,6 +113,8 @@ func (s *Server) ListenAndServe() {
 		return
 	}
 
+	go s.resolver.Maintain()
+
 	glog.Infof("DNS listening on %s", s.Addr)
 
 	var wg sync.WaitGroup
diff --git a/dnstogrpc/resolver.go b/dnstogrpc/resolver.go
index c7f7d4f..656a1f6 100644
--- a/dnstogrpc/resolver.go
+++ b/dnstogrpc/resolver.go
@@ -1,13 +1,21 @@
 package dnstogrpc
 
 import (
+	"expvar"
+	"fmt"
+	"net/http"
+	"sync"
 	"time"
 
+	"github.com/golang/glog"
 	"github.com/miekg/dns"
 	"golang.org/x/net/context"
+	"golang.org/x/net/trace"
 	"google.golang.org/grpc"
 	"google.golang.org/grpc/credentials"
 
+	"bytes"
+
 	pb "blitiri.com.ar/go/dnss/internal/proto"
 )
 
@@ -21,7 +29,7 @@ type Resolver interface {
 	Maintain()
 
 	// Query responds to a DNS query.
-	Query(r *dns.Msg) (*dns.Msg, error)
+	Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error)
 }
 
 // grpcResolver implements the Resolver interface by querying a server via
@@ -63,7 +71,7 @@ func (g *grpcResolver) Init() error {
 func (g *grpcResolver) Maintain() {
 }
 
-func (g *grpcResolver) Query(r *dns.Msg) (*dns.Msg, error) {
+func (g *grpcResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) {
 	buf, err := r.Pack()
 	if err != nil {
 		return nil, err
@@ -83,3 +91,266 @@ func (g *grpcResolver) Query(r *dns.Msg) (*dns.Msg, error) {
 	err = m.Unpack(reply.Data)
 	return m, err
 }
+
+// cachingResolver implements a caching Resolver.
+// It is backed by another Resolver, but will cache results.
+type cachingResolver struct {
+	// Backing resolver.
+	back Resolver
+
+	// The cache where we keep the records.
+	answer  map[dns.Question][]dns.RR
+	expires map[dns.Question]time.Time
+
+	// mu protects both answer and expires.
+	mu *sync.RWMutex
+}
+
+func NewCachingResolver(back Resolver) *cachingResolver {
+	return &cachingResolver{
+		back:    back,
+		answer:  map[dns.Question][]dns.RR{},
+		expires: map[dns.Question]time.Time{},
+		mu:      &sync.RWMutex{},
+	}
+}
+
+const (
+	// Maximum number of entries we keep in the cache.
+	// 2k should be reasonable for a small network.
+	// Keep in mind that increasing this too much will interact negatively
+	// with Maintain().
+	maxCacheSize = 2000
+
+	// Minimum TTL for entries we consider for the cache.
+	minTTL = 2 * time.Minute
+
+	// Maximum TTL for our cache. We cap records that exceed this.
+	maxTTL = 2 * time.Hour
+
+	// How often to run GC on the cache.
+	// Must be < minTTL if we don't want to have entries stale for too long.
+	maintenancePeriod = 30 * time.Second
+)
+
+// Exported variables for statistics.
+// These are global and not per caching resolver, so if we have more than once
+// the results will be mixed.
+var stats = struct {
+	// Total number of queries handled by the cache resolver.
+	cacheTotal *expvar.Int
+
+	// Queries that we passed directly through our back resolver.
+	cacheBypassed *expvar.Int
+
+	// Cache misses.
+	cacheMisses *expvar.Int
+
+	// Cache hits.
+	cacheHits *expvar.Int
+
+	// Entries we decided to record in the cache.
+	cacheRecorded *expvar.Int
+}{}
+
+func init() {
+	stats.cacheTotal = expvar.NewInt("cache-total")
+	stats.cacheBypassed = expvar.NewInt("cache-bypassed")
+	stats.cacheHits = expvar.NewInt("cache-hits")
+	stats.cacheMisses = expvar.NewInt("cache-misses")
+	stats.cacheRecorded = expvar.NewInt("cache-recorded")
+}
+
+func (c *cachingResolver) Init() error {
+	if err := c.back.Init(); err != nil {
+		return err
+	}
+
+	// We register the debug handlers.
+	// Note these are global by nature, if you create more than once resolver,
+	// the last one will prevail.
+	http.HandleFunc("/debug/dnstogrpc/cache/dump", c.DumpCache)
+	http.HandleFunc("/debug/dnstogrpc/cache/flush", c.FlushCache)
+	return nil
+}
+
+func (c *cachingResolver) DumpCache(w http.ResponseWriter, r *http.Request) {
+	buf := bytes.NewBuffer(nil)
+	now := time.Now().Truncate(time.Second)
+	var expires time.Time
+
+	c.mu.RLock()
+	for q, ans := range c.answer {
+		expires = c.expires[q].Truncate(time.Second)
+
+		// Only include names and records if we are running verbosily.
+		name := "<hidden>"
+		if glog.V(3) {
+			name = q.Name
+		}
+
+		fmt.Fprintf(buf, "Q: %s %s %s\n", name, dns.TypeToString[q.Qtype],
+			dns.ClassToString[q.Qclass])
+
+		fmt.Fprintf(buf, "   expires in %s (%s)\n", expires.Sub(now),
+			expires)
+
+		if glog.V(3) {
+			for _, rr := range ans {
+				fmt.Fprintf(buf, "   %s\n", rr.String())
+			}
+		} else {
+			fmt.Fprintf(buf, "   %d RRs in answer\n", len(ans))
+		}
+		fmt.Fprintf(buf, "\n\n")
+	}
+	c.mu.RUnlock()
+
+	buf.WriteTo(w)
+}
+
+func (c *cachingResolver) FlushCache(w http.ResponseWriter, r *http.Request) {
+
+	c.mu.Lock()
+	c.answer = map[dns.Question][]dns.RR{}
+	c.expires = map[dns.Question]time.Time{}
+	c.mu.Unlock()
+
+	w.Write([]byte("cache flush complete"))
+}
+
+func (c *cachingResolver) Maintain() {
+	go c.back.Maintain()
+
+	for now := range time.Tick(maintenancePeriod) {
+		tr := trace.New("dnstogrpc.Cache", "GC")
+		var total, expired int
+
+		c.mu.Lock()
+		total = len(c.expires)
+		for q, exp := range c.expires {
+			if now.Before(exp) {
+				continue
+			}
+
+			delete(c.answer, q)
+			delete(c.expires, q)
+			expired++
+		}
+		c.mu.Unlock()
+		tr.LazyPrintf("total: %d   expired: %d", total, expired)
+		tr.Finish()
+	}
+}
+
+func wantToCache(question dns.Question, reply *dns.Msg) error {
+	if reply.Rcode != dns.RcodeSuccess {
+		return fmt.Errorf("unsuccessful query")
+	} else if !reply.Response {
+		return fmt.Errorf("response = false")
+	} else if reply.Opcode != dns.OpcodeQuery {
+		return fmt.Errorf("opcode %d != query", reply.Opcode)
+	} else if len(reply.Answer) == 0 {
+		return fmt.Errorf("answer is empty")
+	} else if len(reply.Question) != 1 {
+		return fmt.Errorf("too many/few questions (%d)", len(reply.Question))
+	} else if reply.Question[0] != question {
+		return fmt.Errorf(
+			"reply question does not match: asked %v, got %v",
+			question, reply.Question[0])
+	}
+
+	return nil
+}
+
+func calculateTTL(answer []dns.RR) time.Duration {
+	// This assumes all RRs have the same TTL.  That may not be the case in
+	// theory, but we are ok not caring for this for now.
+	ttl := time.Duration(answer[0].Header().Ttl) * time.Second
+
+	// This helps prevent cache pollution due to unused but long entries, as
+	// we don't do usage-based caching yet.
+	if ttl > maxTTL {
+		ttl = maxTTL
+	}
+
+	return ttl
+}
+
+func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) {
+	stats.cacheTotal.Add(1)
+
+	// To keep it simple we only cache single-question queries.
+	if len(r.Question) != 1 {
+		tr.LazyPrintf("cache bypass: multi-question query")
+		stats.cacheBypassed.Add(1)
+		return c.back.Query(r, tr)
+	}
+
+	question := r.Question[0]
+
+	c.mu.RLock()
+	answer, hit := c.answer[question]
+	c.mu.RUnlock()
+
+	if hit {
+		tr.LazyPrintf("cache hit")
+		stats.cacheHits.Add(1)
+
+		reply := &dns.Msg{
+			MsgHdr: dns.MsgHdr{
+				Id:            r.Id,
+				Response:      true,
+				Authoritative: false,
+				Rcode:         dns.RcodeSuccess,
+			},
+			Question: r.Question,
+			Answer:   answer,
+		}
+
+		return reply, nil
+	}
+
+	tr.LazyPrintf("cache miss")
+	stats.cacheMisses.Add(1)
+
+	reply, err := c.back.Query(r, tr)
+	if err != nil {
+		return reply, err
+	}
+
+	if err = wantToCache(question, reply); err != nil {
+		tr.LazyPrintf("cache not recording reply: %v", err)
+		return reply, nil
+	}
+
+	answer = reply.Answer
+	ttl := calculateTTL(answer)
+	expires := time.Now().Add(ttl)
+
+	// Only store answers if they're going to stay around for a bit,
+	// there's not much point in caching things we have to expire quickly.
+	if ttl > minTTL {
+		// Override the answer TTL to our minimum.
+		// Otherwise we'd be telling the clients high TTLs for as long as the
+		// entry is in our cache.
+		// This makes us very unsuitable as a proper DNS server, but it's
+		// useful when we're the last ones and in a small network where
+		// clients are unlikely to cache up to TTL anyway.
+		for _, rr := range answer {
+			rr.Header().Ttl = uint32(minTTL.Seconds())
+		}
+
+		// Store the answer in the cache, but don't exceed 2k entries.
+		// TODO: Do usage based eviction when we're approaching ~1.5k.
+		c.mu.Lock()
+		if len(c.answer) < maxCacheSize {
+			c.answer[question] = answer
+			c.expires[question] = expires
+			stats.cacheRecorded.Add(1)
+		}
+		c.mu.Unlock()
+	}
+
+	return reply, nil
+}