package dnsserver
import (
	"bytes"
	"expvar"
	"fmt"
	"net/http"
	"sort"
	"sync"
	"time"
	"blitiri.com.ar/go/dnss/internal/trace"
	"blitiri.com.ar/go/log"
	"github.com/miekg/dns"
)
// Resolver is the interface for DNS resolvers that can answer queries.
type Resolver interface {
	// Initialize the resolver.
	Init() error
	// Maintain performs resolver maintenance. It's expected to run
	// indefinitely, but may return early if appropriate.
	Maintain()
	// Query responds to a DNS query.
	Query(r *dns.Msg, tr *trace.Trace) (*dns.Msg, error)
}
///////////////////////////////////////////////////////////////////////////
// Caching resolver.
// cachingResolver implements a caching Resolver.
// It is backed by another Resolver, but will cache results.
type cachingResolver struct {
	// Backing resolver.
	back Resolver
	// The cache where we keep the records.
	answer map[dns.Question][]dns.RR
	// mu protects the answer map.
	mu *sync.RWMutex
}
// NewCachingResolver returns a new resolver which implements a cache on top
// of the given one.
func NewCachingResolver(back Resolver) *cachingResolver {
	return &cachingResolver{
		back:   back,
		answer: map[dns.Question][]dns.RR{},
		mu:     &sync.RWMutex{},
	}
}
// Constants that tune the cache.
// They are declared as variables so we can tweak them for testing.
var (
	// Maximum number of entries we keep in the cache.
	// 2k should be reasonable for a small network.
	// Keep in mind that increasing this too much will interact negatively
	// with Maintain().
	maxCacheSize = 2000
	// Minimum TTL for entries we consider for the cache.
	minTTL = 2 * time.Minute
	// Maximum TTL for our cache. We cap records that exceed this.
	maxTTL = 2 * time.Hour
	// How often to run GC on the cache.
	// Must be < minTTL if we don't want to have entries stale for too long.
	maintenancePeriod = 30 * time.Second
)
// Exported variables for statistics.
// These are global and not per caching resolver, so if we have more than once
// the results will be mixed.
var stats = struct {
	// Total number of queries handled by the cache resolver.
	cacheTotal *expvar.Int
	// Queries that we passed directly through our back resolver.
	cacheBypassed *expvar.Int
	// Cache misses.
	cacheMisses *expvar.Int
	// Cache hits.
	cacheHits *expvar.Int
	// Entries we decided to record in the cache.
	cacheRecorded *expvar.Int
}{}
func init() {
	stats.cacheTotal = expvar.NewInt("cache-total")
	stats.cacheBypassed = expvar.NewInt("cache-bypassed")
	stats.cacheHits = expvar.NewInt("cache-hits")
	stats.cacheMisses = expvar.NewInt("cache-misses")
	stats.cacheRecorded = expvar.NewInt("cache-recorded")
}
func (c *cachingResolver) Init() error {
	return c.back.Init()
}
// RegisterDebugHandlers registers http debug handlers, which can be accessed
// from the monitoring server.
// Note these are global by nature, if you try to register them multiple
// times, you will get a panic.
func (c *cachingResolver) RegisterDebugHandlers() {
	http.HandleFunc("/debug/dnsserver/cache/dump", c.DumpCache)
	http.HandleFunc("/debug/dnsserver/cache/flush", c.FlushCache)
}
func (c *cachingResolver) DumpCache(w http.ResponseWriter, r *http.Request) {
	buf := bytes.NewBuffer(nil)
	c.mu.RLock()
	// Sort output by expiration, so it is somewhat consistent and practical
	// to read.
	qs := []dns.Question{}
	for q := range c.answer {
		qs = append(qs, q)
	}
	sort.Slice(qs, func(i, j int) bool {
		return getTTL(c.answer[qs[i]]) < getTTL(c.answer[qs[j]])
	})
	// Go through the sorted list and dump the entries.
	for _, q := range qs {
		ans := c.answer[q]
		// Only include names and records if we are running verbosily.
		name := "<hidden>"
		if log.V(1) {
			name = q.Name
		}
		fmt.Fprintf(buf, "Q: %s %s %s\n", name, dns.TypeToString[q.Qtype],
			dns.ClassToString[q.Qclass])
		ttl := getTTL(ans)
		fmt.Fprintf(buf, "   expires in %s (%s)\n", ttl, time.Now().Add(ttl))
		if log.V(1) {
			for _, rr := range ans {
				fmt.Fprintf(buf, "   %s\n", rr.String())
			}
		} else {
			fmt.Fprintf(buf, "   %d RRs in answer\n", len(ans))
		}
		fmt.Fprintf(buf, "\n\n")
	}
	c.mu.RUnlock()
	buf.WriteTo(w)
}
func (c *cachingResolver) FlushCache(w http.ResponseWriter, r *http.Request) {
	c.mu.Lock()
	c.answer = map[dns.Question][]dns.RR{}
	c.mu.Unlock()
	w.Write([]byte("cache flush complete"))
}
func (c *cachingResolver) Maintain() {
	go c.back.Maintain()
	for range time.Tick(maintenancePeriod) {
		tr := trace.New("dnsserver.Cache", "GC")
		var total, expired int
		c.mu.Lock()
		total = len(c.answer)
		for q, ans := range c.answer {
			newTTL := getTTL(ans) - maintenancePeriod
			if newTTL > 0 {
				// Don't modify in place, create a copy and override.
				// That way, we avoid races with users that have gotten a
				// cached answer and are returning it.
				newans := copyRRSlice(ans)
				setTTL(newans, newTTL)
				c.answer[q] = newans
				continue
			}
			delete(c.answer, q)
			expired++
		}
		c.mu.Unlock()
		tr.Printf("total: %d   expired: %d", total, expired)
		tr.Finish()
	}
}
func wantToCache(question dns.Question, reply *dns.Msg) error {
	if reply.Rcode != dns.RcodeSuccess {
		return fmt.Errorf("unsuccessful query")
	} else if !reply.Response {
		return fmt.Errorf("response = false")
	} else if reply.Opcode != dns.OpcodeQuery {
		return fmt.Errorf("opcode %d != query", reply.Opcode)
	} else if len(reply.Answer) == 0 {
		return fmt.Errorf("answer is empty")
	} else if len(reply.Question) != 1 {
		return fmt.Errorf("too many/few questions (%d)", len(reply.Question))
	} else if reply.Truncated {
		return fmt.Errorf("truncated reply")
	} else if reply.Question[0] != question {
		return fmt.Errorf(
			"reply question does not match: asked %v, got %v",
			question, reply.Question[0])
	}
	return nil
}
func limitTTL(answer []dns.RR) time.Duration {
	// This assumes all RRs have the same TTL.  That may not be the case in
	// theory, but we are ok not caring for this for now.
	ttl := time.Duration(answer[0].Header().Ttl) * time.Second
	// This helps prevent cache pollution due to unused but long entries, as
	// we don't do usage-based caching yet.
	if ttl > maxTTL {
		ttl = maxTTL
	}
	return ttl
}
func getTTL(answer []dns.RR) time.Duration {
	// This assumes all RRs have the same TTL.  That may not be the case in
	// theory, but we are ok not caring for this for now.
	return time.Duration(answer[0].Header().Ttl) * time.Second
}
func setTTL(answer []dns.RR, newTTL time.Duration) {
	for _, rr := range answer {
		rr.Header().Ttl = uint32(newTTL.Seconds())
	}
}
func copyRRSlice(a []dns.RR) []dns.RR {
	b := make([]dns.RR, 0, len(a))
	for _, rr := range a {
		b = append(b, dns.Copy(rr))
	}
	return b
}
func (c *cachingResolver) Query(r *dns.Msg, tr *trace.Trace) (*dns.Msg, error) {
	stats.cacheTotal.Add(1)
	// To keep it simple we only cache single-question queries.
	if len(r.Question) != 1 {
		tr.Printf("cache bypass: multi-question query")
		stats.cacheBypassed.Add(1)
		return c.back.Query(r, tr)
	}
	question := r.Question[0]
	c.mu.RLock()
	answer, hit := c.answer[question]
	c.mu.RUnlock()
	if hit {
		tr.Printf("cache hit")
		stats.cacheHits.Add(1)
		reply := &dns.Msg{
			MsgHdr: dns.MsgHdr{
				Id:            r.Id,
				Response:      true,
				Authoritative: false,
				Rcode:         dns.RcodeSuccess,
			},
			Question: r.Question,
			Answer:   answer,
		}
		return reply, nil
	}
	tr.Printf("cache miss")
	stats.cacheMisses.Add(1)
	reply, err := c.back.Query(r, tr)
	if err != nil {
		return reply, err
	}
	if err = wantToCache(question, reply); err != nil {
		tr.Printf("cache not recording reply: %v", err)
		return reply, nil
	}
	answer = reply.Answer
	ttl := limitTTL(answer)
	// Only store answers if they're going to stay around for a bit,
	// there's not much point in caching things we have to expire quickly.
	if ttl < minTTL {
		return reply, nil
	}
	// Store the answer in the cache, but don't exceed 2k entries.
	// TODO: Do usage based eviction when we're approaching ~1.5k.
	c.mu.Lock()
	if len(c.answer) < maxCacheSize {
		setTTL(answer, ttl)
		c.answer[question] = answer
		stats.cacheRecorded.Add(1)
	}
	c.mu.Unlock()
	return reply, nil
}
// Compile-time check that the implementation matches the interface.
var _ Resolver = &cachingResolver{}