git » debian:dnss » commit 9aee35b

dnstox: Return more accurate TTLs

author Alberto Bertogli
2016-05-23 01:41:18 UTC
committer Alberto Bertogli
2016-05-23 01:48:21 UTC
parent 448f59f862a74c2e6c9ba8e88c1100e3e0dab925

dnstox: Return more accurate TTLs

This patch makes the cache resolver keep and return more accurate TTLs.

Before, we used to always return the minimum TTL to clients, which was not too
bad but caused a bit of waste, and introduced noise if dnss was not the last
server in the chain.

With this change, dnss will return the real TTL (accurate to within our
maintenance period = 30s), capped to prevent really long entries for
monopolizing the cache.

internal/dnstox/resolver.go +38 -42

diff --git a/internal/dnstox/resolver.go b/internal/dnstox/resolver.go
index 3db5e0f..a25a289 100644
--- a/internal/dnstox/resolver.go
+++ b/internal/dnstox/resolver.go
@@ -275,19 +275,17 @@ type cachingResolver struct {
 	back Resolver
 
 	// The cache where we keep the records.
-	answer  map[dns.Question][]dns.RR
-	expires map[dns.Question]time.Time
+	answer map[dns.Question][]dns.RR
 
-	// mu protects both answer and expires.
+	// mu protects the answer map.
 	mu *sync.RWMutex
 }
 
 func NewCachingResolver(back Resolver) *cachingResolver {
 	return &cachingResolver{
-		back:    back,
-		answer:  map[dns.Question][]dns.RR{},
-		expires: map[dns.Question]time.Time{},
-		mu:      &sync.RWMutex{},
+		back:   back,
+		answer: map[dns.Question][]dns.RR{},
+		mu:     &sync.RWMutex{},
 	}
 }
 
@@ -352,13 +350,9 @@ func (c *cachingResolver) Init() error {
 
 func (c *cachingResolver) DumpCache(w http.ResponseWriter, r *http.Request) {
 	buf := bytes.NewBuffer(nil)
-	now := time.Now().Truncate(time.Second)
-	var expires time.Time
 
 	c.mu.RLock()
 	for q, ans := range c.answer {
-		expires = c.expires[q].Truncate(time.Second)
-
 		// Only include names and records if we are running verbosily.
 		name := "<hidden>"
 		if glog.V(3) {
@@ -368,8 +362,8 @@ func (c *cachingResolver) DumpCache(w http.ResponseWriter, r *http.Request) {
 		fmt.Fprintf(buf, "Q: %s %s %s\n", name, dns.TypeToString[q.Qtype],
 			dns.ClassToString[q.Qclass])
 
-		fmt.Fprintf(buf, "   expires in %s (%s)\n", expires.Sub(now),
-			expires)
+		ttl := getTTL(ans)
+		fmt.Fprintf(buf, "   expires in %s (%s)\n", ttl, time.Now().Add(ttl))
 
 		if glog.V(3) {
 			for _, rr := range ans {
@@ -386,10 +380,8 @@ func (c *cachingResolver) DumpCache(w http.ResponseWriter, r *http.Request) {
 }
 
 func (c *cachingResolver) FlushCache(w http.ResponseWriter, r *http.Request) {
-
 	c.mu.Lock()
 	c.answer = map[dns.Question][]dns.RR{}
-	c.expires = map[dns.Question]time.Time{}
 	c.mu.Unlock()
 
 	w.Write([]byte("cache flush complete"))
@@ -398,19 +390,20 @@ func (c *cachingResolver) FlushCache(w http.ResponseWriter, r *http.Request) {
 func (c *cachingResolver) Maintain() {
 	go c.back.Maintain()
 
-	for now := range time.Tick(maintenancePeriod) {
+	for range time.Tick(maintenancePeriod) {
 		tr := trace.New("dnstox.Cache", "GC")
 		var total, expired int
 
 		c.mu.Lock()
-		total = len(c.expires)
-		for q, exp := range c.expires {
-			if now.Before(exp) {
+		total = len(c.answer)
+		for q, ans := range c.answer {
+			newTTL := getTTL(ans) - maintenancePeriod
+			if newTTL > 0 {
+				setTTL(ans, newTTL)
 				continue
 			}
 
 			delete(c.answer, q)
-			delete(c.expires, q)
 			expired++
 		}
 		c.mu.Unlock()
@@ -439,7 +432,7 @@ func wantToCache(question dns.Question, reply *dns.Msg) error {
 	return nil
 }
 
-func calculateTTL(answer []dns.RR) time.Duration {
+func limitTTL(answer []dns.RR) time.Duration {
 	// This assumes all RRs have the same TTL.  That may not be the case in
 	// theory, but we are ok not caring for this for now.
 	ttl := time.Duration(answer[0].Header().Ttl) * time.Second
@@ -453,6 +446,18 @@ func calculateTTL(answer []dns.RR) time.Duration {
 	return ttl
 }
 
+func getTTL(answer []dns.RR) time.Duration {
+	// This assumes all RRs have the same TTL.  That may not be the case in
+	// theory, but we are ok not caring for this for now.
+	return time.Duration(answer[0].Header().Ttl) * time.Second
+}
+
+func setTTL(answer []dns.RR, newTTL time.Duration) {
+	for _, rr := range answer {
+		rr.Header().Ttl = uint32(newTTL.Seconds())
+	}
+}
+
 func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) {
 	stats.cacheTotal.Add(1)
 
@@ -501,32 +506,23 @@ func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) {
 	}
 
 	answer = reply.Answer
-	ttl := calculateTTL(answer)
-	expires := time.Now().Add(ttl)
+	ttl := limitTTL(answer)
 
 	// Only store answers if they're going to stay around for a bit,
 	// there's not much point in caching things we have to expire quickly.
-	if ttl > minTTL {
-		// Override the answer TTL to our minimum.
-		// Otherwise we'd be telling the clients high TTLs for as long as the
-		// entry is in our cache.
-		// This makes us very unsuitable as a proper DNS server, but it's
-		// useful when we're the last ones and in a small network where
-		// clients are unlikely to cache up to TTL anyway.
-		for _, rr := range answer {
-			rr.Header().Ttl = uint32(minTTL.Seconds())
-		}
+	if ttl < minTTL {
+		return reply, nil
+	}
 
-		// Store the answer in the cache, but don't exceed 2k entries.
-		// TODO: Do usage based eviction when we're approaching ~1.5k.
-		c.mu.Lock()
-		if len(c.answer) < maxCacheSize {
-			c.answer[question] = answer
-			c.expires[question] = expires
-			stats.cacheRecorded.Add(1)
-		}
-		c.mu.Unlock()
+	// Store the answer in the cache, but don't exceed 2k entries.
+	// TODO: Do usage based eviction when we're approaching ~1.5k.
+	c.mu.Lock()
+	if len(c.answer) < maxCacheSize {
+		setTTL(answer, ttl)
+		c.answer[question] = answer
+		stats.cacheRecorded.Add(1)
 	}
+	c.mu.Unlock()
 
 	return reply, nil
 }