git » dnss » main » tree

[main] / internal / dnsserver / server.go

// Package dnsserver implements a DNS server, that uses the given resolvers to
// handle requests.
package dnsserver

import (
	"crypto/rand"
	"encoding/binary"
	"fmt"
	"net"
	"sync"

	"blitiri.com.ar/go/dnss/internal/trace"

	"blitiri.com.ar/go/log"
	"blitiri.com.ar/go/systemd"
	"github.com/miekg/dns"
)

// newID is a channel used to generate new request IDs.
// There is a goroutine created at init() time that will get IDs randomly, to
// help prevent guesses.
var newID chan uint16

func init() {
	// Buffer 100 numbers to avoid blocking on crypto rand.
	newID = make(chan uint16, 100)

	go func() {
		var id uint16
		var err error

		for {
			err = binary.Read(rand.Reader, binary.LittleEndian, &id)
			if err != nil {
				panic(fmt.Sprintf("error creating id: %v", err))
			}

			newID <- id
		}

	}()
}

// Server implements a DNS proxy, which will (mostly) use the given resolver
// to resolve queries.
type Server struct {
	Addr            string
	unqUpstream     string
	serverOverrides DomainMap
	resolver        Resolver
}

// New *Server, which will listen on addr, use resolver as the backend
// resolver, and use unqUpstream to resolve unqualified queries.
func New(addr string, resolver Resolver, unqUpstream string, serverOverrides DomainMap) *Server {
	return &Server{
		Addr:            addr,
		resolver:        resolver,
		unqUpstream:     unqUpstream,
		serverOverrides: serverOverrides,
	}
}

// Handler for the incoming DNS queries.
func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
	tr := trace.New("dnsserver.Handler",
		w.RemoteAddr().Network()+" "+w.RemoteAddr().String())
	defer tr.Finish()

	tr.Printf("id:%v", r.Id)
	tr.Question(r.Question)

	// We only support single-question queries.
	if len(r.Question) != 1 {
		tr.Printf("len(Q) != 1, failing")
		dns.HandleFailed(w, r)
		return
	}

	// If the domain has a server override, forward to it instead.
	override, ok := s.serverOverrides.GetMostSpecific(r.Question[0].Name)
	if ok {
		tr.Printf("override found: %q", override)
		u, err := dns.Exchange(r, override)
		if err == nil {
			tr.Answer(u)
			s.writeReply(tr, w, r, u)
		} else {
			tr.Printf("override server returned error: %v", err)
			dns.HandleFailed(w, r)
		}

		return
	}

	// Forward to the unqualified upstream server if:
	//  - We have one configured.
	//  - There's only one question in the request, to keep things simple.
	//  - The question is unqualified (only one '.' in the name).
	useUnqUpstream := s.unqUpstream != "" &&
		dns.CountLabel(r.Question[0].Name) <= 1
	if useUnqUpstream {
		u, err := dns.Exchange(r, s.unqUpstream)
		if err == nil {
			tr.Printf("used unqualified upstream")
			tr.Answer(u)
			s.writeReply(tr, w, r, u)
		} else {
			tr.Printf("unqualified upstream error: %v", err)
			dns.HandleFailed(w, r)
		}

		return
	}

	// Create our own IDs, in case different users pick the same id and we
	// pass that upstream.
	oldid := r.Id
	r.Id = <-newID

	fromUp, err := s.resolver.Query(r, tr)
	if err != nil {
		log.Infof("resolver query error: %v", err)
		tr.Error(err)

		r.Id = oldid
		dns.HandleFailed(w, r)
		return
	}

	tr.Answer(fromUp)

	fromUp.Id = oldid
	s.writeReply(tr, w, r, fromUp)
}

func (s *Server) writeReply(tr *trace.Trace, w dns.ResponseWriter, r, reply *dns.Msg) {
	if w.RemoteAddr().Network() == "udp" {
		// We need to check if the response fits.
		// UDP by default has a maximum of 512 bytes. This can be extended via
		// the client in the EDNS0 record.
		max := 512
		ednsOPT := r.IsEdns0()
		if ednsOPT != nil {
			max = int(ednsOPT.UDPSize())
		}
		reply.Truncate(max)
		tr.Printf("UDP max:%d truncated:%v", max, reply.Truncated)
	}

	w.WriteMsg(reply)
}

// ListenAndServe launches the DNS proxy.
func (s *Server) ListenAndServe() {
	err := s.resolver.Init()
	if err != nil {
		log.Fatalf("Error initializing: %v", err)
	}

	go s.resolver.Maintain()

	if s.Addr == "systemd" {
		s.systemdServe()
	} else {
		s.classicServe()
	}
}

func (s *Server) classicServe() {
	log.Infof("DNS listening on %s", s.Addr)

	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		err := dns.ListenAndServe(s.Addr, "udp", dns.HandlerFunc(s.Handler))
		log.Fatalf("Exiting UDP: %v", err)
	}()

	wg.Add(1)
	go func() {
		defer wg.Done()
		err := dns.ListenAndServe(s.Addr, "tcp", dns.HandlerFunc(s.Handler))
		log.Fatalf("Exiting TCP: %v", err)
	}()

	wg.Wait()
}

func (s *Server) systemdServe() {
	fsMap, err := systemd.Files()
	if err != nil {
		log.Fatalf("Error getting systemd listeners: %v", err)
	}

	// We will usually have at least one TCP socket and one UDP socket.
	// PacketConns are UDP sockets, Listeners are TCP sockets.
	pconns := []net.PacketConn{}
	listeners := []net.Listener{}
	for _, fs := range fsMap {
		for _, f := range fs {
			if lis, err := net.FileListener(f); err == nil {
				listeners = append(listeners, lis)
				f.Close()
			} else if pc, err := net.FilePacketConn(f); err == nil {
				pconns = append(pconns, pc)
				f.Close()
			}
		}
	}

	var wg sync.WaitGroup

	for _, pconn := range pconns {
		if pconn == nil {
			continue
		}

		wg.Add(1)
		go func(c net.PacketConn) {
			defer wg.Done()
			log.Infof("Activate on packet connection (UDP): %v", c.LocalAddr())
			err := dns.ActivateAndServe(nil, c, dns.HandlerFunc(s.Handler))
			log.Fatalf("Exiting UDP listener: %v", err)
		}(pconn)
	}

	for _, lis := range listeners {
		if lis == nil {
			continue
		}

		wg.Add(1)
		go func(l net.Listener) {
			defer wg.Done()
			log.Infof("Activate on listening socket (TCP): %v", l.Addr())
			err := dns.ActivateAndServe(l, nil, dns.HandlerFunc(s.Handler))
			log.Fatalf("Exiting TCP listener: %v", err)
		}(lis)
	}

	wg.Wait()

	// We should only get here if there were no useful sockets.
	log.Fatalf("No systemd sockets, did you forget the .socket?")
}