author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-04-07 08:42:42 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-04-08 08:33:40 UTC |
parent | d4c2d8ed25355c5161dcdcc1305c3830b85471d1 |
dnss.go | +4 | -3 |
dnss_test.go | +2 | -1 |
internal/{dnstohttps => dnsserver}/caching_test.go | +1 | -1 |
internal/dnsserver/resolver.go | +309 | -0 |
internal/{dnstohttps => dnsserver}/server.go | +3 | -3 |
internal/dnstohttps/https_test.go | +2 | -1 |
internal/dnstohttps/resolver.go | +5 | -300 |
diff --git a/dnss.go b/dnss.go index 569a8bb..7c484ef 100644 --- a/dnss.go +++ b/dnss.go @@ -19,6 +19,7 @@ import ( "sync" "time" + "blitiri.com.ar/go/dnss/internal/dnsserver" "blitiri.com.ar/go/dnss/internal/dnstohttps" "blitiri.com.ar/go/dnss/internal/httpstodns" @@ -104,14 +105,14 @@ func main() { // DNS to HTTPS. if *enableDNStoHTTPS { - var resolver dnstohttps.Resolver = dnstohttps.NewHTTPSResolver( + var resolver dnsserver.Resolver = dnstohttps.NewHTTPSResolver( *httpsUpstream, *httpsClientCAFile) if *enableCache { - cr := dnstohttps.NewCachingResolver(resolver) + cr := dnsserver.NewCachingResolver(resolver) cr.RegisterDebugHandlers() resolver = cr } - dth := dnstohttps.New(*dnsListenAddr, resolver, *dnsUnqualifiedUpstream) + dth := dnsserver.New(*dnsListenAddr, resolver, *dnsUnqualifiedUpstream) // If we're using an HTTP proxy, add the name to the fallback domain // so we don't have problems resolving it. diff --git a/dnss_test.go b/dnss_test.go index e1fa284..e1ab69b 100644 --- a/dnss_test.go +++ b/dnss_test.go @@ -10,6 +10,7 @@ import ( "sync" "testing" + "blitiri.com.ar/go/dnss/internal/dnsserver" "blitiri.com.ar/go/dnss/internal/dnstohttps" "blitiri.com.ar/go/dnss/internal/httpstodns" "blitiri.com.ar/go/dnss/internal/testutil" @@ -46,7 +47,7 @@ func realMain(m *testing.M) int { // DNS to HTTPS server. r := dnstohttps.NewHTTPSResolver("http://"+HTTPSToDNSAddr+"/resolve", "") - dtoh := dnstohttps.New(DNSToHTTPSAddr, r, "") + dtoh := dnsserver.New(DNSToHTTPSAddr, r, "") go dtoh.ListenAndServe() // HTTPS to DNS server. diff --git a/internal/dnstohttps/caching_test.go b/internal/dnsserver/caching_test.go similarity index 99% rename from internal/dnstohttps/caching_test.go rename to internal/dnsserver/caching_test.go index 4a59555..bb0e559 100644 --- a/internal/dnstohttps/caching_test.go +++ b/internal/dnsserver/caching_test.go @@ -1,4 +1,4 @@ -package dnstohttps +package dnsserver // Tests for the caching resolver. // Note the other resolvers have more functional tests in the testing/ diff --git a/internal/dnsserver/resolver.go b/internal/dnsserver/resolver.go new file mode 100644 index 0000000..8ca4130 --- /dev/null +++ b/internal/dnsserver/resolver.go @@ -0,0 +1,309 @@ +package dnsserver + +import ( + "bytes" + "expvar" + "fmt" + "net/http" + "sync" + "time" + + "github.com/golang/glog" + "github.com/miekg/dns" + "golang.org/x/net/trace" +) + +// Resolver is the interface for DNS resolvers that can answer queries. +type Resolver interface { + // Initialize the resolver. + Init() error + + // Maintain performs resolver maintenance. It's expected to run + // indefinitely, but may return early if appropriate. + Maintain() + + // Query responds to a DNS query. + Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) +} + +/////////////////////////////////////////////////////////////////////////// +// Caching resolver. + +// cachingResolver implements a caching Resolver. +// It is backed by another Resolver, but will cache results. +type cachingResolver struct { + // Backing resolver. + back Resolver + + // The cache where we keep the records. + answer map[dns.Question][]dns.RR + + // mu protects the answer map. + mu *sync.RWMutex +} + +// NewCachingResolver returns a new resolver which implements a cache on top +// of the given one. +func NewCachingResolver(back Resolver) *cachingResolver { + return &cachingResolver{ + back: back, + answer: map[dns.Question][]dns.RR{}, + mu: &sync.RWMutex{}, + } +} + +// Constants that tune the cache. +// They are declared as variables so we can tweak them for testing. +var ( + // Maximum number of entries we keep in the cache. + // 2k should be reasonable for a small network. + // Keep in mind that increasing this too much will interact negatively + // with Maintain(). + maxCacheSize = 2000 + + // Minimum TTL for entries we consider for the cache. + minTTL = 2 * time.Minute + + // Maximum TTL for our cache. We cap records that exceed this. + maxTTL = 2 * time.Hour + + // How often to run GC on the cache. + // Must be < minTTL if we don't want to have entries stale for too long. + maintenancePeriod = 30 * time.Second +) + +// Exported variables for statistics. +// These are global and not per caching resolver, so if we have more than once +// the results will be mixed. +var stats = struct { + // Total number of queries handled by the cache resolver. + cacheTotal *expvar.Int + + // Queries that we passed directly through our back resolver. + cacheBypassed *expvar.Int + + // Cache misses. + cacheMisses *expvar.Int + + // Cache hits. + cacheHits *expvar.Int + + // Entries we decided to record in the cache. + cacheRecorded *expvar.Int +}{} + +func init() { + stats.cacheTotal = expvar.NewInt("cache-total") + stats.cacheBypassed = expvar.NewInt("cache-bypassed") + stats.cacheHits = expvar.NewInt("cache-hits") + stats.cacheMisses = expvar.NewInt("cache-misses") + stats.cacheRecorded = expvar.NewInt("cache-recorded") +} + +func (c *cachingResolver) Init() error { + return c.back.Init() +} + +// RegisterDebugHandlers registers http debug handlers, which can be accessed +// from the monitoring server. +// Note these are global by nature, if you try to register them multiple +// times, you will get a panic. +func (c *cachingResolver) RegisterDebugHandlers() { + http.HandleFunc("/debug/dnstohttps/cache/dump", c.DumpCache) + http.HandleFunc("/debug/dnstohttps/cache/flush", c.FlushCache) +} + +func (c *cachingResolver) DumpCache(w http.ResponseWriter, r *http.Request) { + buf := bytes.NewBuffer(nil) + + c.mu.RLock() + for q, ans := range c.answer { + // Only include names and records if we are running verbosily. + name := "<hidden>" + if glog.V(3) { + name = q.Name + } + + fmt.Fprintf(buf, "Q: %s %s %s\n", name, dns.TypeToString[q.Qtype], + dns.ClassToString[q.Qclass]) + + ttl := getTTL(ans) + fmt.Fprintf(buf, " expires in %s (%s)\n", ttl, time.Now().Add(ttl)) + + if glog.V(3) { + for _, rr := range ans { + fmt.Fprintf(buf, " %s\n", rr.String()) + } + } else { + fmt.Fprintf(buf, " %d RRs in answer\n", len(ans)) + } + fmt.Fprintf(buf, "\n\n") + } + c.mu.RUnlock() + + buf.WriteTo(w) +} + +func (c *cachingResolver) FlushCache(w http.ResponseWriter, r *http.Request) { + c.mu.Lock() + c.answer = map[dns.Question][]dns.RR{} + c.mu.Unlock() + + w.Write([]byte("cache flush complete")) +} + +func (c *cachingResolver) Maintain() { + go c.back.Maintain() + + for range time.Tick(maintenancePeriod) { + tr := trace.New("dnstohttps.Cache", "GC") + var total, expired int + + c.mu.Lock() + total = len(c.answer) + for q, ans := range c.answer { + newTTL := getTTL(ans) - maintenancePeriod + if newTTL > 0 { + // Don't modify in place, create a copy and override. + // That way, we avoid races with users that have gotten a + // cached answer and are returning it. + newans := copyRRSlice(ans) + setTTL(newans, newTTL) + c.answer[q] = newans + continue + } + + delete(c.answer, q) + expired++ + } + c.mu.Unlock() + tr.LazyPrintf("total: %d expired: %d", total, expired) + tr.Finish() + } +} + +func wantToCache(question dns.Question, reply *dns.Msg) error { + if reply.Rcode != dns.RcodeSuccess { + return fmt.Errorf("unsuccessful query") + } else if !reply.Response { + return fmt.Errorf("response = false") + } else if reply.Opcode != dns.OpcodeQuery { + return fmt.Errorf("opcode %d != query", reply.Opcode) + } else if len(reply.Answer) == 0 { + return fmt.Errorf("answer is empty") + } else if len(reply.Question) != 1 { + return fmt.Errorf("too many/few questions (%d)", len(reply.Question)) + } else if reply.Question[0] != question { + return fmt.Errorf( + "reply question does not match: asked %v, got %v", + question, reply.Question[0]) + } + + return nil +} + +func limitTTL(answer []dns.RR) time.Duration { + // This assumes all RRs have the same TTL. That may not be the case in + // theory, but we are ok not caring for this for now. + ttl := time.Duration(answer[0].Header().Ttl) * time.Second + + // This helps prevent cache pollution due to unused but long entries, as + // we don't do usage-based caching yet. + if ttl > maxTTL { + ttl = maxTTL + } + + return ttl +} + +func getTTL(answer []dns.RR) time.Duration { + // This assumes all RRs have the same TTL. That may not be the case in + // theory, but we are ok not caring for this for now. + return time.Duration(answer[0].Header().Ttl) * time.Second +} + +func setTTL(answer []dns.RR, newTTL time.Duration) { + for _, rr := range answer { + rr.Header().Ttl = uint32(newTTL.Seconds()) + } +} + +func copyRRSlice(a []dns.RR) []dns.RR { + b := make([]dns.RR, 0, len(a)) + for _, rr := range a { + b = append(b, dns.Copy(rr)) + } + return b +} + +func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { + stats.cacheTotal.Add(1) + + // To keep it simple we only cache single-question queries. + if len(r.Question) != 1 { + tr.LazyPrintf("cache bypass: multi-question query") + stats.cacheBypassed.Add(1) + return c.back.Query(r, tr) + } + + question := r.Question[0] + + c.mu.RLock() + answer, hit := c.answer[question] + c.mu.RUnlock() + + if hit { + tr.LazyPrintf("cache hit") + stats.cacheHits.Add(1) + + reply := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: r.Id, + Response: true, + Authoritative: false, + Rcode: dns.RcodeSuccess, + }, + Question: r.Question, + Answer: answer, + } + + return reply, nil + } + + tr.LazyPrintf("cache miss") + stats.cacheMisses.Add(1) + + reply, err := c.back.Query(r, tr) + if err != nil { + return reply, err + } + + if err = wantToCache(question, reply); err != nil { + tr.LazyPrintf("cache not recording reply: %v", err) + return reply, nil + } + + answer = reply.Answer + ttl := limitTTL(answer) + + // Only store answers if they're going to stay around for a bit, + // there's not much point in caching things we have to expire quickly. + if ttl < minTTL { + return reply, nil + } + + // Store the answer in the cache, but don't exceed 2k entries. + // TODO: Do usage based eviction when we're approaching ~1.5k. + c.mu.Lock() + if len(c.answer) < maxCacheSize { + setTTL(answer, ttl) + c.answer[question] = answer + stats.cacheRecorded.Add(1) + } + c.mu.Unlock() + + return reply, nil +} + +// Compile-time check that the implementation matches the interface. +var _ Resolver = &cachingResolver{} diff --git a/internal/dnstohttps/server.go b/internal/dnsserver/server.go similarity index 97% rename from internal/dnstohttps/server.go rename to internal/dnsserver/server.go index 3df711a..4e9fcce 100644 --- a/internal/dnstohttps/server.go +++ b/internal/dnsserver/server.go @@ -1,6 +1,6 @@ -// Package dnstohttps implements a DNS proxy that uses HTTPS to resolve the -// requests. -package dnstohttps +// Package dnsserver implements a DNS server, that uses the given resolvers to +// handle requests. +package dnsserver import ( "crypto/rand" diff --git a/internal/dnstohttps/https_test.go b/internal/dnstohttps/https_test.go index 973ee84..3da1e20 100644 --- a/internal/dnstohttps/https_test.go +++ b/internal/dnstohttps/https_test.go @@ -9,6 +9,7 @@ import ( "os" "testing" + "blitiri.com.ar/go/dnss/internal/dnsserver" "blitiri.com.ar/go/dnss/internal/testutil" "github.com/golang/glog" @@ -130,7 +131,7 @@ func realMain(m *testing.M) int { // DNS to HTTPS server. r := NewHTTPSResolver(httpsrv.URL, "") - dth := New(DNSAddr, r, "") + dth := dnsserver.New(DNSAddr, r, "") go dth.ListenAndServe() // Wait for the servers to start up. diff --git a/internal/dnstohttps/resolver.go b/internal/dnstohttps/resolver.go index b0376d9..e464a8e 100644 --- a/internal/dnstohttps/resolver.go +++ b/internal/dnstohttps/resolver.go @@ -4,41 +4,22 @@ import ( "crypto/tls" "crypto/x509" "encoding/json" - "expvar" "fmt" "io/ioutil" "net/http" "net/url" - "sync" "time" "blitiri.com.ar/go/dnss/internal/dnsjson" + "blitiri.com.ar/go/dnss/internal/dnsserver" "github.com/golang/glog" "github.com/miekg/dns" "golang.org/x/net/trace" - - "bytes" ) -// Resolver is the interface for DNS resolvers that can answer queries. -type Resolver interface { - // Initialize the resolver. - Init() error - - // Maintain performs resolver maintenance. It's expected to run - // indefinitely, but may return early if appropriate. - Maintain() - - // Query responds to a DNS query. - Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) -} - -/////////////////////////////////////////////////////////////////////////// -// HTTPS resolver. - -// httpsResolver implements the Resolver interface by querying a server via -// DNS-over-HTTPS (like https://dns.google.com). +// httpsResolver implements the dnsserver.Resolver interface by querying a +// server via DNS-over-HTTPS (like https://dns.google.com). type httpsResolver struct { Upstream string CAFile string @@ -194,281 +175,5 @@ func (r *httpsResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { return resp, nil } -/////////////////////////////////////////////////////////////////////////// -// Caching resolver. - -// cachingResolver implements a caching Resolver. -// It is backed by another Resolver, but will cache results. -type cachingResolver struct { - // Backing resolver. - back Resolver - - // The cache where we keep the records. - answer map[dns.Question][]dns.RR - - // mu protects the answer map. - mu *sync.RWMutex -} - -// NewCachingResolver returns a new resolver which implements a cache on top -// of the given one. -func NewCachingResolver(back Resolver) *cachingResolver { - return &cachingResolver{ - back: back, - answer: map[dns.Question][]dns.RR{}, - mu: &sync.RWMutex{}, - } -} - -// Constants that tune the cache. -// They are declared as variables so we can tweak them for testing. -var ( - // Maximum number of entries we keep in the cache. - // 2k should be reasonable for a small network. - // Keep in mind that increasing this too much will interact negatively - // with Maintain(). - maxCacheSize = 2000 - - // Minimum TTL for entries we consider for the cache. - minTTL = 2 * time.Minute - - // Maximum TTL for our cache. We cap records that exceed this. - maxTTL = 2 * time.Hour - - // How often to run GC on the cache. - // Must be < minTTL if we don't want to have entries stale for too long. - maintenancePeriod = 30 * time.Second -) - -// Exported variables for statistics. -// These are global and not per caching resolver, so if we have more than once -// the results will be mixed. -var stats = struct { - // Total number of queries handled by the cache resolver. - cacheTotal *expvar.Int - - // Queries that we passed directly through our back resolver. - cacheBypassed *expvar.Int - - // Cache misses. - cacheMisses *expvar.Int - - // Cache hits. - cacheHits *expvar.Int - - // Entries we decided to record in the cache. - cacheRecorded *expvar.Int -}{} - -func init() { - stats.cacheTotal = expvar.NewInt("cache-total") - stats.cacheBypassed = expvar.NewInt("cache-bypassed") - stats.cacheHits = expvar.NewInt("cache-hits") - stats.cacheMisses = expvar.NewInt("cache-misses") - stats.cacheRecorded = expvar.NewInt("cache-recorded") -} - -func (c *cachingResolver) Init() error { - return c.back.Init() -} - -// RegisterDebugHandlers registers http debug handlers, which can be accessed -// from the monitoring server. -// Note these are global by nature, if you try to register them multiple -// times, you will get a panic. -func (c *cachingResolver) RegisterDebugHandlers() { - http.HandleFunc("/debug/dnstohttps/cache/dump", c.DumpCache) - http.HandleFunc("/debug/dnstohttps/cache/flush", c.FlushCache) -} - -func (c *cachingResolver) DumpCache(w http.ResponseWriter, r *http.Request) { - buf := bytes.NewBuffer(nil) - - c.mu.RLock() - for q, ans := range c.answer { - // Only include names and records if we are running verbosily. - name := "<hidden>" - if glog.V(3) { - name = q.Name - } - - fmt.Fprintf(buf, "Q: %s %s %s\n", name, dns.TypeToString[q.Qtype], - dns.ClassToString[q.Qclass]) - - ttl := getTTL(ans) - fmt.Fprintf(buf, " expires in %s (%s)\n", ttl, time.Now().Add(ttl)) - - if glog.V(3) { - for _, rr := range ans { - fmt.Fprintf(buf, " %s\n", rr.String()) - } - } else { - fmt.Fprintf(buf, " %d RRs in answer\n", len(ans)) - } - fmt.Fprintf(buf, "\n\n") - } - c.mu.RUnlock() - - buf.WriteTo(w) -} - -func (c *cachingResolver) FlushCache(w http.ResponseWriter, r *http.Request) { - c.mu.Lock() - c.answer = map[dns.Question][]dns.RR{} - c.mu.Unlock() - - w.Write([]byte("cache flush complete")) -} - -func (c *cachingResolver) Maintain() { - go c.back.Maintain() - - for range time.Tick(maintenancePeriod) { - tr := trace.New("dnstohttps.Cache", "GC") - var total, expired int - - c.mu.Lock() - total = len(c.answer) - for q, ans := range c.answer { - newTTL := getTTL(ans) - maintenancePeriod - if newTTL > 0 { - // Don't modify in place, create a copy and override. - // That way, we avoid races with users that have gotten a - // cached answer and are returning it. - newans := copyRRSlice(ans) - setTTL(newans, newTTL) - c.answer[q] = newans - continue - } - - delete(c.answer, q) - expired++ - } - c.mu.Unlock() - tr.LazyPrintf("total: %d expired: %d", total, expired) - tr.Finish() - } -} - -func wantToCache(question dns.Question, reply *dns.Msg) error { - if reply.Rcode != dns.RcodeSuccess { - return fmt.Errorf("unsuccessful query") - } else if !reply.Response { - return fmt.Errorf("response = false") - } else if reply.Opcode != dns.OpcodeQuery { - return fmt.Errorf("opcode %d != query", reply.Opcode) - } else if len(reply.Answer) == 0 { - return fmt.Errorf("answer is empty") - } else if len(reply.Question) != 1 { - return fmt.Errorf("too many/few questions (%d)", len(reply.Question)) - } else if reply.Question[0] != question { - return fmt.Errorf( - "reply question does not match: asked %v, got %v", - question, reply.Question[0]) - } - - return nil -} - -func limitTTL(answer []dns.RR) time.Duration { - // This assumes all RRs have the same TTL. That may not be the case in - // theory, but we are ok not caring for this for now. - ttl := time.Duration(answer[0].Header().Ttl) * time.Second - - // This helps prevent cache pollution due to unused but long entries, as - // we don't do usage-based caching yet. - if ttl > maxTTL { - ttl = maxTTL - } - - return ttl -} - -func getTTL(answer []dns.RR) time.Duration { - // This assumes all RRs have the same TTL. That may not be the case in - // theory, but we are ok not caring for this for now. - return time.Duration(answer[0].Header().Ttl) * time.Second -} - -func setTTL(answer []dns.RR, newTTL time.Duration) { - for _, rr := range answer { - rr.Header().Ttl = uint32(newTTL.Seconds()) - } -} - -func copyRRSlice(a []dns.RR) []dns.RR { - b := make([]dns.RR, 0, len(a)) - for _, rr := range a { - b = append(b, dns.Copy(rr)) - } - return b -} - -func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { - stats.cacheTotal.Add(1) - - // To keep it simple we only cache single-question queries. - if len(r.Question) != 1 { - tr.LazyPrintf("cache bypass: multi-question query") - stats.cacheBypassed.Add(1) - return c.back.Query(r, tr) - } - - question := r.Question[0] - - c.mu.RLock() - answer, hit := c.answer[question] - c.mu.RUnlock() - - if hit { - tr.LazyPrintf("cache hit") - stats.cacheHits.Add(1) - - reply := &dns.Msg{ - MsgHdr: dns.MsgHdr{ - Id: r.Id, - Response: true, - Authoritative: false, - Rcode: dns.RcodeSuccess, - }, - Question: r.Question, - Answer: answer, - } - - return reply, nil - } - - tr.LazyPrintf("cache miss") - stats.cacheMisses.Add(1) - - reply, err := c.back.Query(r, tr) - if err != nil { - return reply, err - } - - if err = wantToCache(question, reply); err != nil { - tr.LazyPrintf("cache not recording reply: %v", err) - return reply, nil - } - - answer = reply.Answer - ttl := limitTTL(answer) - - // Only store answers if they're going to stay around for a bit, - // there's not much point in caching things we have to expire quickly. - if ttl < minTTL { - return reply, nil - } - - // Store the answer in the cache, but don't exceed 2k entries. - // TODO: Do usage based eviction when we're approaching ~1.5k. - c.mu.Lock() - if len(c.answer) < maxCacheSize { - setTTL(answer, ttl) - c.answer[question] = answer - stats.cacheRecorded.Add(1) - } - c.mu.Unlock() - - return reply, nil -} +// Compile-time check that the implementation matches the interface. +var _ dnsserver.Resolver = &httpsResolver{}