author | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-11-25 00:50:10 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2018-11-30 10:03:48 UTC |
parent | 661f759c0c4b51580bff73002bb4c60509fc84a6 |
test/util/lib.sh | +4 | -0 |
test/util/minidns.go | +305 | -0 |
diff --git a/test/util/lib.sh b/test/util/lib.sh index 3c86ac5..38ed40b 100644 --- a/test/util/lib.sh +++ b/test/util/lib.sh @@ -107,6 +107,10 @@ function conngen() { go run ${UTILDIR}/conngen.go "$@" } +function minidns() { + go run ${UTILDIR}/minidns.go "$@" +} + function success() { echo success } diff --git a/test/util/minidns.go b/test/util/minidns.go new file mode 100644 index 0000000..9c3aeb4 --- /dev/null +++ b/test/util/minidns.go @@ -0,0 +1,305 @@ +// +build ignore + +// minidns is a trivial DNS server used for testing. +// +// It takes an "answers" file which contains lines with the following format: +// +// <domain> <type> <value> +// +// For example: +// +// blah A 1.2.3.4 +// blah MX mx1 +// +// Supported types: A, AAAA, MX, TXT. +// +// It's only meant to be used for testing, so it's not robust, performant, or +// standards compliant. +// +package main + +import ( + "bufio" + "encoding/binary" + "flag" + "fmt" + "net" + "os" + "regexp" + "strings" + "sync" + + "blitiri.com.ar/go/log" + "golang.org/x/net/dns/dnsmessage" +) + +var ( + addr = flag.String("addr", ":53", "address to listen to (UDP)") + zonesPath = flag.String("zones", "", "file with the zones") +) + +func main() { + flag.Parse() + + srv := &miniDNS{ + answers: map[string][]dnsmessage.Resource{}, + } + + if *zonesPath == "" { + log.Fatalf("-zones must be given") + } + var zonesFile *os.File + if *zonesPath == "-" { + zonesFile = os.Stdin + } else { + var err error + zonesFile, err = os.Open(*zonesPath) + if err != nil { + log.Fatalf("error opening %v: %v", *zonesPath, err) + } + } + + srv.loadZones(zonesFile) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + srv.listenAndServeUDP(*addr) + }() + go func() { + defer wg.Done() + srv.listenAndServeTCP(*addr) + }() + wg.Wait() +} + +type miniDNS struct { + // Domain -> Answers. + // We always respond the same regardless of the query. + // Not great, but does the trick. + answers map[string][]dnsmessage.Resource +} + +func (m *miniDNS) listenAndServeUDP(addr string) { + conn, err := net.ListenPacket("udp", addr) + if err != nil { + log.Fatalf("error listening UDP %q: %v", addr, err) + } + + log.Infof("listening on %v", conn.LocalAddr()) + + buf := make([]byte, 64*1024) + for { + n, addr, err := conn.ReadFrom(buf) + if err != nil { + log.Infof("error reading from udp: %v", err) + continue + } + + msg := &dnsmessage.Message{} + err = msg.Unpack(buf[:n]) + if err != nil { + log.Infof("%v error unpacking message: %v", addr, err) + } + + if lq := len(msg.Questions); lq != 1 { + log.Infof("%v/%-5d dropping packet with %d questions", + addr, msg.ID, lq) + continue + } + q := msg.Questions[0] + log.Infof("%v/%-5d Q: %s %s %s", + addr, msg.ID, q.Name, q.Type, q.Class) + + reply := m.handle(msg) + rbuf, err := reply.Pack() + if err != nil { + log.Fatalf("error packing reply: %v", err) + } + + conn.WriteTo(rbuf, addr) + } +} + +func (m *miniDNS) listenAndServeTCP(addr string) { + ls, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("error listening TCP %q: %v", addr, err) + } + + log.Infof("listening on %v", addr) + + for { + conn, err := ls.Accept() + if err != nil { + log.Infof("error accepting: %v", err) + continue + } + + msg, err := readTCPMessage(conn) + if err != nil { + log.Infof("%v error reading message: %v", addr, err) + conn.Close() + continue + } + + if lq := len(msg.Questions); lq != 1 { + log.Infof("%v/%-5d dropping packet with %d questions", + addr, msg.ID, lq) + conn.Close() + continue + } + q := msg.Questions[0] + log.Infof("%v/%-5d Q: %s %s %s", + addr, msg.ID, q.Name, q.Type, q.Class) + + reply := m.handle(msg) + err = writeTCPMessage(conn, reply) + if err != nil { + log.Infof("error writing reply: %v", err) + } + + conn.Close() + } +} + +func readTCPMessage(conn net.Conn) (*dnsmessage.Message, error) { + // Read the 2-byte length first, then the message. + lenHdr := struct{ Len uint16 }{} + err := binary.Read(conn, binary.BigEndian, &lenHdr) + if err != nil { + return nil, err + } + + data := make([]byte, lenHdr.Len) + err = binary.Read(conn, binary.BigEndian, &data) + if err != nil { + return nil, err + } + + msg := &dnsmessage.Message{} + err = msg.Unpack(data) + if err != nil { + return nil, fmt.Errorf("%v error unpacking message: %v", addr, err) + } + + return msg, nil +} + +func writeTCPMessage(conn net.Conn, msg *dnsmessage.Message) error { + rbuf, err := msg.Pack() + if err != nil { + return fmt.Errorf("error packing reply: %v", err) + } + + lenHdr := struct{ Len uint16 }{Len: uint16(len(rbuf))} + err = binary.Write(conn, binary.BigEndian, lenHdr) + if err != nil { + return err + } + + _, err = conn.Write(rbuf) + return err +} + +func (m *miniDNS) handle(msg *dnsmessage.Message) *dnsmessage.Message { + reply := &dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: msg.ID, + Response: true, + RCode: dnsmessage.RCodeSuccess, + }, + Questions: msg.Questions, + } + + q := msg.Questions[0] + if answers, ok := m.answers[q.Name.String()]; ok { + for _, ans := range answers { + if q.Type == ans.Header.Type { + log.Infof("-> %s %v", q.Type, ans.Body) + reply.Answers = append(reply.Answers, ans) + } + } + } else { + log.Infof("-> NXERROR") + reply.Header.RCode = dnsmessage.RCodeNameError + } + + return reply +} + +func (m *miniDNS) loadZones(f *os.File) { + scanner := bufio.NewScanner(f) + lineno := 0 + for scanner.Scan() { + lineno++ + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "#") || line == "" { + continue + } + + vs := regexp.MustCompile("\\s+").Split(line, 3) + if len(vs) != 3 { + log.Fatalf("line %d: invalid format", lineno) + } + domain, t, value := vs[0], vs[1], vs[2] + if !strings.HasSuffix(domain, ".") { + domain += "." + } + + var body dnsmessage.ResourceBody + var qType dnsmessage.Type + switch strings.ToLower(t) { + case "a": + qType = dnsmessage.TypeA + ip := net.ParseIP(value).To4() + if ip == nil { + log.Fatalf("line %d: invalid IP %q", lineno, value) + } + a := &dnsmessage.AResource{} + copy(a.A[:], ip[:4]) + body = a + case "aaaa": + qType = dnsmessage.TypeAAAA + ip := net.ParseIP(value).To16() + if ip == nil { + log.Fatalf("line %d: invalid IP %q", lineno, value) + } + aaaa := &dnsmessage.AAAAResource{} + copy(aaaa.AAAA[:], ip[:16]) + body = aaaa + case "mx": + qType = dnsmessage.TypeMX + if !strings.HasPrefix(value, ".") { + value += "." + } + + body = &dnsmessage.MXResource{ + Pref: 10, + MX: dnsmessage.MustNewName(value), + } + case "txt": + qType = dnsmessage.TypeTXT + body = &dnsmessage.TXTResource{ + TXT: []string{value}, + } + default: + log.Fatalf("line %d: unknown type %q", lineno, t) + } + + answer := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(domain), + Type: qType, + Class: dnsmessage.ClassINET, + }, + Body: body, + } + m.answers[domain] = append(m.answers[domain], answer) + } + + if err := scanner.Err(); err != nil { + log.Fatalf("error reading zones: %v", err) + } +}