git » chasquid » smarthost » tree

[smarthost] / test / util / minidns / minidns.go

//go:build !coverage
// +build !coverage

// minidns is a trivial DNS server used for testing.
//
// It takes an "answers" file which contains lines with the following format:
//
//	<domain> <type> <value>
//
// For example:
//
//	blah A  1.2.3.4
//	blah MX mx1
//
// Supported types: A, AAAA, MX, TXT.
//
// It's only meant to be used for testing, so it's not robust, performant, or
// standards compliant.
package main

import (
	"bufio"
	"encoding/binary"
	"flag"
	"fmt"
	"net"
	"os"
	"regexp"
	"strings"
	"sync"

	"blitiri.com.ar/go/log"
	"golang.org/x/net/dns/dnsmessage"
)

var (
	addr      = flag.String("addr", ":53", "address to listen to (UDP)")
	zonesPath = flag.String("zones", "", "file with the zones")
)

func main() {
	flag.Parse()

	srv := &miniDNS{
		answers: map[string][]dnsmessage.Resource{},
	}

	if *zonesPath == "" {
		log.Fatalf("-zones must be given")
	}
	var zonesFile *os.File
	if *zonesPath == "-" {
		zonesFile = os.Stdin
	} else {
		var err error
		zonesFile, err = os.Open(*zonesPath)
		if err != nil {
			log.Fatalf("error opening %v: %v", *zonesPath, err)
		}
	}

	srv.loadZones(zonesFile)

	var wg sync.WaitGroup
	wg.Add(1)
	go func() {
		defer wg.Done()
		srv.listenAndServeUDP(*addr)
	}()
	go func() {
		defer wg.Done()
		srv.listenAndServeTCP(*addr)
	}()
	wg.Wait()
}

type miniDNS struct {
	// Domain -> Answers.
	// We always respond the same regardless of the query.
	// Not great, but does the trick.
	answers map[string][]dnsmessage.Resource
}

func (m *miniDNS) listenAndServeUDP(addr string) {
	conn, err := net.ListenPacket("udp", addr)
	if err != nil {
		log.Fatalf("error listening UDP %q: %v", addr, err)
	}

	log.Infof("listening on %v", conn.LocalAddr())

	buf := make([]byte, 64*1024)
	for {
		n, addr, err := conn.ReadFrom(buf)
		if err != nil {
			log.Infof("error reading from udp: %v", err)
			continue
		}

		msg := &dnsmessage.Message{}
		err = msg.Unpack(buf[:n])
		if err != nil {
			log.Infof("%v error unpacking message: %v", addr, err)
		}

		if lq := len(msg.Questions); lq != 1 {
			log.Infof("%v/%-5d  dropping packet with %d questions",
				addr, msg.ID, lq)
			continue
		}
		q := msg.Questions[0]
		log.Infof("%v/%-5d   Q: %s %s %s",
			addr, msg.ID, q.Name, q.Type, q.Class)

		reply := m.handle(msg)
		rbuf, err := reply.Pack()
		if err != nil {
			log.Fatalf("error packing reply: %v", err)
		}

		_, err = conn.WriteTo(rbuf, addr)
		if err != nil {
			log.Infof("%v/%-5d  error writing: %v",
				addr, msg.ID, err)
		}
	}
}

func (m *miniDNS) listenAndServeTCP(addr string) {
	ls, err := net.Listen("tcp", addr)
	if err != nil {
		log.Fatalf("error listening TCP %q: %v", addr, err)
	}

	log.Infof("listening on %v", addr)

	for {
		conn, err := ls.Accept()
		if err != nil {
			log.Infof("error accepting: %v", err)
			continue
		}

		msg, err := readTCPMessage(conn)
		if err != nil {
			log.Infof("%v error reading message: %v", addr, err)
			conn.Close()
			continue
		}

		if lq := len(msg.Questions); lq != 1 {
			log.Infof("%v/%-5d  dropping packet with %d questions",
				addr, msg.ID, lq)
			conn.Close()
			continue
		}
		q := msg.Questions[0]
		log.Infof("%v/%-5d   Q: %s %s %s",
			addr, msg.ID, q.Name, q.Type, q.Class)

		reply := m.handle(msg)
		err = writeTCPMessage(conn, reply)
		if err != nil {
			log.Infof("error writing reply: %v", err)
		}

		conn.Close()
	}
}

func readTCPMessage(conn net.Conn) (*dnsmessage.Message, error) {
	// Read the 2-byte length first, then the message.
	lenHdr := struct{ Len uint16 }{}
	err := binary.Read(conn, binary.BigEndian, &lenHdr)
	if err != nil {
		return nil, err
	}

	data := make([]byte, lenHdr.Len)
	err = binary.Read(conn, binary.BigEndian, &data)
	if err != nil {
		return nil, err
	}

	msg := &dnsmessage.Message{}
	err = msg.Unpack(data)
	if err != nil {
		return nil, fmt.Errorf("%v error unpacking message: %v", addr, err)
	}

	return msg, nil
}

func writeTCPMessage(conn net.Conn, msg *dnsmessage.Message) error {
	rbuf, err := msg.Pack()
	if err != nil {
		return fmt.Errorf("error packing reply: %v", err)
	}

	lenHdr := struct{ Len uint16 }{Len: uint16(len(rbuf))}
	err = binary.Write(conn, binary.BigEndian, lenHdr)
	if err != nil {
		return err
	}

	_, err = conn.Write(rbuf)
	return err
}

func (m *miniDNS) handle(msg *dnsmessage.Message) *dnsmessage.Message {
	reply := &dnsmessage.Message{
		Header: dnsmessage.Header{
			ID:       msg.ID,
			Response: true,
			RCode:    dnsmessage.RCodeSuccess,

			// We're authoritative for the zones we're serving.
			// We should either set this, or RecursionAvailable, otherwise
			// some client libraries will complain.
			Authoritative: true,
		},
		Questions: msg.Questions,
	}

	q := msg.Questions[0]
	if answers, ok := m.answers[q.Name.String()]; ok {
		for _, ans := range answers {
			if q.Type == ans.Header.Type {
				log.Infof("-> %s %v", q.Type, ans.Body)
				reply.Answers = append(reply.Answers, ans)
			}
		}
	} else {
		log.Infof("-> NXERROR")
		reply.Header.RCode = dnsmessage.RCodeNameError
	}

	return reply
}

func (m *miniDNS) loadZones(f *os.File) {
	scanner := bufio.NewScanner(f)
	lineno := 0
	for scanner.Scan() {
		lineno++
		line := strings.TrimSpace(scanner.Text())
		if strings.HasPrefix(line, "#") || line == "" {
			continue
		}

		vs := regexp.MustCompile(`\s+`).Split(line, 3)
		if len(vs) != 3 {
			log.Fatalf("line %d: invalid format", lineno)
		}
		domain, t, value := vs[0], vs[1], vs[2]
		if !strings.HasSuffix(domain, ".") {
			domain += "."
		}

		var body dnsmessage.ResourceBody
		var qType dnsmessage.Type
		switch strings.ToLower(t) {
		case "a":
			qType = dnsmessage.TypeA
			ip := net.ParseIP(value).To4()
			if ip == nil {
				log.Fatalf("line %d: invalid IP %q", lineno, value)
			}
			a := &dnsmessage.AResource{}
			copy(a.A[:], ip[:4])
			body = a
		case "aaaa":
			qType = dnsmessage.TypeAAAA
			ip := net.ParseIP(value).To16()
			if ip == nil {
				log.Fatalf("line %d: invalid IP %q", lineno, value)
			}
			aaaa := &dnsmessage.AAAAResource{}
			copy(aaaa.AAAA[:], ip[:16])
			body = aaaa
		case "mx":
			qType = dnsmessage.TypeMX
			if !strings.HasPrefix(value, ".") {
				value += "."
			}

			body = &dnsmessage.MXResource{
				Pref: 10,
				MX:   dnsmessage.MustNewName(value),
			}
		case "txt":
			qType = dnsmessage.TypeTXT
			body = &dnsmessage.TXTResource{
				TXT: []string{value},
			}
		default:
			log.Fatalf("line %d: unknown type %q", lineno, t)
		}

		answer := dnsmessage.Resource{
			Header: dnsmessage.ResourceHeader{
				Name:  dnsmessage.MustNewName(domain),
				Type:  qType,
				Class: dnsmessage.ClassINET,
			},
			Body: body,
		}
		m.answers[domain] = append(m.answers[domain], answer)
	}

	if err := scanner.Err(); err != nil {
		log.Fatalf("error reading zones: %v", err)
	}
}