// Package dnsserver implements a DNS server, that uses the given resolvers to
// handle requests.
package dnsserver
import (
"crypto/rand"
"encoding/binary"
"fmt"
"net"
"strings"
"sync"
"github.com/miekg/dns"
"golang.org/x/net/trace"
"blitiri.com.ar/go/dnss/internal/util"
"blitiri.com.ar/go/log"
"blitiri.com.ar/go/systemd"
)
// newID is a channel used to generate new request IDs.
// There is a goroutine created at init() time that will get IDs randomly, to
// help prevent guesses.
var newID chan uint16
func init() {
// Buffer 100 numbers to avoid blocking on crypto rand.
newID = make(chan uint16, 100)
go func() {
var id uint16
var err error
for {
err = binary.Read(rand.Reader, binary.LittleEndian, &id)
if err != nil {
panic(fmt.Sprintf("error creating id: %v", err))
}
newID <- id
}
}()
}
// Server implements a DNS proxy, which will (mostly) use the given resolver
// to resolve queries.
type Server struct {
Addr string
unqUpstream string
resolver Resolver
fallbackDomains map[string]struct{}
fallbackUpstream string
}
// New *Server, which will listen on addr, use resolver as the backend
// resolver, and use unqUpstream to resolve unqualified queries.
func New(addr string, resolver Resolver, unqUpstream string) *Server {
return &Server{
Addr: addr,
resolver: resolver,
unqUpstream: unqUpstream,
fallbackDomains: map[string]struct{}{},
}
}
// SetFallback upstream server for the given domains.
func (s *Server) SetFallback(upstream string, domains []string) {
s.fallbackUpstream = upstream
for _, d := range domains {
s.fallbackDomains[d] = struct{}{}
}
}
// Handler for the incoming DNS queries.
func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
tr := trace.New("dnsserver", "Handler")
defer tr.Finish()
tr.LazyPrintf("from:%v id:%v", w.RemoteAddr(), r.Id)
util.TraceQuestion(tr, r.Question)
// We only support single-question queries.
if len(r.Question) != 1 {
tr.LazyPrintf("len(Q) != 1, failing")
dns.HandleFailed(w, r)
return
}
// Forward to the unqualified upstream server if:
// - We have one configured.
// - There's only one question in the request, to keep things simple.
// - The question is unqualified (only one '.' in the name).
useUnqUpstream := s.unqUpstream != "" &&
strings.Count(r.Question[0].Name, ".") <= 1
if useUnqUpstream {
u, err := dns.Exchange(r, s.unqUpstream)
if err == nil {
tr.LazyPrintf("used unqualified upstream")
util.TraceAnswer(tr, u)
w.WriteMsg(u)
} else {
tr.LazyPrintf("unqualified upstream error: %v", err)
dns.HandleFailed(w, r)
}
return
}
// Forward to the fallback server if the domain is on our list.
if _, ok := s.fallbackDomains[r.Question[0].Name]; ok {
u, err := dns.Exchange(r, s.fallbackUpstream)
if err == nil {
tr.LazyPrintf("used fallback upstream (%s)", s.fallbackUpstream)
util.TraceAnswer(tr, u)
w.WriteMsg(u)
} else {
tr.LazyPrintf("fallback upstream error: %v", err)
dns.HandleFailed(w, r)
}
return
}
// Create our own IDs, in case different users pick the same id and we
// pass that upstream.
oldid := r.Id
r.Id = <-newID
fromUp, err := s.resolver.Query(r, tr)
if err != nil {
log.Infof("resolver query error: %v", err)
tr.LazyPrintf(err.Error())
tr.SetError()
r.Id = oldid
dns.HandleFailed(w, r)
return
}
util.TraceAnswer(tr, fromUp)
fromUp.Id = oldid
w.WriteMsg(fromUp)
}
// ListenAndServe launches the DNS proxy.
func (s *Server) ListenAndServe() {
err := s.resolver.Init()
if err != nil {
log.Fatalf("Error initializing: %v", err)
}
go s.resolver.Maintain()
if s.Addr == "systemd" {
s.systemdServe()
} else {
s.classicServe()
}
}
func (s *Server) classicServe() {
log.Infof("DNS listening on %s", s.Addr)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err := dns.ListenAndServe(s.Addr, "udp", dns.HandlerFunc(s.Handler))
log.Fatalf("Exiting UDP: %v", err)
}()
wg.Add(1)
go func() {
defer wg.Done()
err := dns.ListenAndServe(s.Addr, "tcp", dns.HandlerFunc(s.Handler))
log.Fatalf("Exiting TCP: %v", err)
}()
wg.Wait()
}
func (s *Server) systemdServe() {
fsMap, err := systemd.Files()
if err != nil {
log.Fatalf("Error getting systemd listeners: %v", err)
}
// We will usually have at least one TCP socket and one UDP socket.
// PacketConns are UDP sockets, Listeners are TCP sockets.
pconns := []net.PacketConn{}
listeners := []net.Listener{}
for _, fs := range fsMap {
for _, f := range fs {
if lis, err := net.FileListener(f); err == nil {
listeners = append(listeners, lis)
f.Close()
} else if pc, err := net.FilePacketConn(f); err == nil {
pconns = append(pconns, pc)
f.Close()
}
}
}
var wg sync.WaitGroup
for _, pconn := range pconns {
if pconn == nil {
continue
}
wg.Add(1)
go func(c net.PacketConn) {
defer wg.Done()
log.Infof("Activate on packet connection (UDP): %v", c.LocalAddr())
err := dns.ActivateAndServe(nil, c, dns.HandlerFunc(s.Handler))
log.Fatalf("Exiting UDP listener: %v", err)
}(pconn)
}
for _, lis := range listeners {
if lis == nil {
continue
}
wg.Add(1)
go func(l net.Listener) {
defer wg.Done()
log.Infof("Activate on listening socket (TCP): %v", l.Addr())
err := dns.ActivateAndServe(l, nil, dns.HandlerFunc(s.Handler))
log.Fatalf("Exiting TCP listener: %v", err)
}(lis)
}
wg.Wait()
// We should only get here if there were no useful sockets.
log.Fatalf("No systemd sockets, did you forget the .socket?")
}