author | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-03-04 01:44:02 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-03-05 20:09:05 UTC |
parent | c98400d6b0e57683c5121f91930aff198a34857c |
dnss.go | +4 | -34 |
dnss_test.go | +5 | -36 |
internal/dnsserver/server.go | +3 | -30 |
internal/dnsserver/server_test.go | +0 | -11 |
internal/httpresolver/resolver.go | +24 | -2 |
tests/external.sh | +24 | -9 |
tests/minidns.go | +305 | -0 |
tests/testzones | +3 | -0 |
diff --git a/dnss.go b/dnss.go index 2c72257..c79aa94 100644 --- a/dnss.go +++ b/dnss.go @@ -16,11 +16,8 @@ import ( "fmt" "net/http" "net/url" - "strings" "sync" - "golang.org/x/net/http/httpproxy" - "blitiri.com.ar/go/dnss/internal/dnsserver" "blitiri.com.ar/go/dnss/internal/httpresolver" "blitiri.com.ar/go/dnss/internal/httpserver" @@ -38,10 +35,8 @@ var ( "DNS server to forward unqualified requests to") fallbackUpstream = flag.String("fallback_upstream", "8.8.8.8:53", - "DNS server to resolve domains in --fallback_domains") - fallbackDomains = flag.String("fallback_domains", "dns.google.", - "Domains we resolve via DNS, using --fallback_upstream"+ - " (space-separated list)") + "DNS server used to resolve domains in -https_upstream"+ + " (including proxy if needed)") enableDNStoHTTPS = flag.Bool("enable_dns_to_https", false, "enable DNS-to-HTTPS proxy") @@ -74,6 +69,7 @@ var ( _ = flag.Duration("log_flush_every", 0, "deprecated, will be removed") _ = flag.Bool("logtostderr", false, "deprecated, will be removed") _ = flag.String("force_mode", "", "deprecated, will be removed") + _ = flag.String("fallback_domains", "", "deprecated, will be removed") ) func main() { @@ -101,7 +97,7 @@ func main() { } var resolver dnsserver.Resolver - resolver = httpresolver.NewDoH(upstream, *httpsClientCAFile) + resolver = httpresolver.NewDoH(upstream, *httpsClientCAFile, *fallbackUpstream) if *enableCache { cr := dnsserver.NewCachingResolver(resolver) @@ -110,15 +106,6 @@ func main() { } dth := dnsserver.New(*dnsListenAddr, resolver, *dnsUnqualifiedUpstream) - // If we're using an HTTP proxy, add the name to the fallback domain - // so we don't have problems resolving it. - fallbackDoms := strings.Split(*fallbackDomains, " ") - if proxyDomain := proxyServerDomain(); proxyDomain != "" { - log.Infof("Adding proxy %q to fallback domains", proxyDomain) - fallbackDoms = append(fallbackDoms, proxyDomain) - } - - dth.SetFallback(*fallbackUpstream, fallbackDoms) wg.Add(1) go func() { defer wg.Done() @@ -146,23 +133,6 @@ func main() { wg.Wait() } -// proxyServerDomain checks if we're using an HTTP proxy server, and if so -// returns its domain. -func proxyServerDomain() string { - url, err := url.Parse(*httpsUpstream) - if err != nil { - return "" - } - - proxyFunc := httpproxy.FromEnvironment().ProxyFunc() - proxyURL, err := proxyFunc(url) - if err != nil || proxyURL == nil { - return "" - } - - return proxyURL.Hostname() -} - func launchMonitoringServer(addr string) { log.Infof("Monitoring HTTP server listening on %s", addr) diff --git a/dnss_test.go b/dnss_test.go index ec8d1b4..6d03939 100644 --- a/dnss_test.go +++ b/dnss_test.go @@ -61,7 +61,11 @@ func Setup(tb testing.TB) string { tb.Fatalf("invalid URL: %v", err) } - r := httpresolver.NewDoH(HTTPSToDNSURL, "") + // Create the DoH resolver and DNS server backed by it. + // Note that we use an invalid address as fallback resolver - since we use + // IP addresses directly in the http requests, the fallback resolver + // should not be needed. + r := httpresolver.NewDoH(HTTPSToDNSURL, "", "0.0.0.0:0") dtoh := dnsserver.New(DNSToHTTPSAddr, r, "") go dtoh.ListenAndServe() @@ -185,41 +189,6 @@ func BenchmarkSimple(b *testing.B) { ///////////////////////////////////////////////////////////////////// // Tests for main-specific helpers -func TestProxyServerDomain(t *testing.T) { - prevProxy, wasSet := os.LookupEnv("HTTPS_PROXY") - - // Valid case, proxy set. - os.Setenv("HTTPS_PROXY", "http://proxy:1234/p") - *httpsUpstream = "https://montoto/xyz" - if got := proxyServerDomain(); got != "proxy" { - t.Errorf("got %q, expected 'proxy'", got) - } - - // Valid case, proxy not set. - os.Unsetenv("HTTPS_PROXY") - *httpsUpstream = "https://montoto/xyz" - if got := proxyServerDomain(); got != "" { - t.Errorf("got %q, expected ''", got) - } - - // Invalid upstream URL. - *httpsUpstream = "in%20valid:url" - if got := proxyServerDomain(); got != "" { - t.Errorf("got %q, expected ''", got) - } - - // Invalid proxy. - os.Setenv("HTTPS_PROXY", "invalid value") - *httpsUpstream = "https://montoto/xyz" - if got := proxyServerDomain(); got != "" { - t.Errorf("got %q, expected ''", got) - } - - if wasSet { - os.Setenv("HTTPS_PROXY", prevProxy) - } -} - func TestMonitoringServer(t *testing.T) { addr := testutil.GetFreePort() launchMonitoringServer(addr) diff --git a/internal/dnsserver/server.go b/internal/dnsserver/server.go index 9532b4b..4c47690 100644 --- a/internal/dnsserver/server.go +++ b/internal/dnsserver/server.go @@ -49,27 +49,15 @@ type Server struct { Addr string unqUpstream string resolver Resolver - - fallbackDomains map[string]struct{} - fallbackUpstream string } // New *Server, which will listen on addr, use resolver as the backend // resolver, and use unqUpstream to resolve unqualified queries. func New(addr string, resolver Resolver, unqUpstream string) *Server { return &Server{ - Addr: addr, - resolver: resolver, - unqUpstream: unqUpstream, - fallbackDomains: map[string]struct{}{}, - } -} - -// SetFallback upstream server for the given domains. -func (s *Server) SetFallback(upstream string, domains []string) { - s.fallbackUpstream = upstream - for _, d := range domains { - s.fallbackDomains[d] = struct{}{} + Addr: addr, + resolver: resolver, + unqUpstream: unqUpstream, } } @@ -109,21 +97,6 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) { return } - // Forward to the fallback server if the domain is on our list. - if _, ok := s.fallbackDomains[r.Question[0].Name]; ok { - u, err := dns.Exchange(r, s.fallbackUpstream) - if err == nil { - tr.LazyPrintf("used fallback upstream (%s)", s.fallbackUpstream) - util.TraceAnswer(tr, u) - w.WriteMsg(u) - } else { - tr.LazyPrintf("fallback upstream error: %v", err) - dns.HandleFailed(w, r) - } - - return - } - // Create our own IDs, in case different users pick the same id and we // pass that upstream. oldid := r.Id diff --git a/internal/dnsserver/server_test.go b/internal/dnsserver/server_test.go index a83e544..28820aa 100644 --- a/internal/dnsserver/server_test.go +++ b/internal/dnsserver/server_test.go @@ -21,19 +21,12 @@ func TestServe(t *testing.T) { go testutil.ServeTestDNSServer(unqUpstreamAddr, testutil.MakeStaticHandler(t, "unq. A 2.2.2.2")) - fallbackAddr := testutil.GetFreePort() - go testutil.ServeTestDNSServer(fallbackAddr, - testutil.MakeStaticHandler(t, "fallback. A 3.3.3.3")) - srv := New(testutil.GetFreePort(), res, unqUpstreamAddr) - srv.SetFallback(fallbackAddr, []string{"one.fallback.", "two.fallback."}) go srv.ListenAndServe() testutil.WaitForDNSServer(srv.Addr) query(t, srv.Addr, "response.test.", "1.1.1.1") query(t, srv.Addr, "unqualified.", "2.2.2.2") - query(t, srv.Addr, "one.fallback.", "3.3.3.3") - query(t, srv.Addr, "two.fallback.", "3.3.3.3") } func query(t *testing.T, srv, domain, expected string) { @@ -56,17 +49,13 @@ func TestBadUpstreams(t *testing.T) { // Get addresses but don't start the servers, so we get an error when // trying to reach them. unqUpstreamAddr := testutil.GetFreePort() - fallbackAddr := testutil.GetFreePort() srv := New(testutil.GetFreePort(), res, unqUpstreamAddr) - srv.SetFallback(fallbackAddr, []string{"one.fallback.", "two.fallback."}) go srv.ListenAndServe() testutil.WaitForDNSServer(srv.Addr) queryFailure(t, srv.Addr, "response.test.") queryFailure(t, srv.Addr, "unqualified.") - queryFailure(t, srv.Addr, "one.fallback.") - queryFailure(t, srv.Addr, "two.fallback.") } func queryFailure(t *testing.T, srv, domain string) { diff --git a/internal/httpresolver/resolver.go b/internal/httpresolver/resolver.go index 1e0426f..a74987d 100644 --- a/internal/httpresolver/resolver.go +++ b/internal/httpresolver/resolver.go @@ -2,6 +2,7 @@ package httpresolver import ( "bytes" + "context" "crypto/tls" "crypto/x509" "fmt" @@ -28,6 +29,10 @@ type httpsResolver struct { CAFile string tlsConfig *tls.Config + // net.Resolver that will contact the server at --fallback_upstream for + // DNS resolutions. + fallbackResolver *net.Resolver + mu sync.Mutex client *http.Client firstErr time.Time @@ -51,11 +56,27 @@ func loadCertPool(caFile string) (*x509.CertPool, error) { // NewDoH creates a new DoH resolver, which uses the given upstream // URL to resolve queries. -func NewDoH(upstream *url.URL, caFile string) *httpsResolver { - return &httpsResolver{ +func NewDoH(upstream *url.URL, caFile, fallback string) *httpsResolver { + r := &httpsResolver{ Upstream: upstream, CAFile: caFile, } + + if fallback != "" { + // Dial function that will always use the fallback address to contact + // DNS. + dialer := net.Dialer{} + dialFallback := func(ctx context.Context, network, address string) (net.Conn, error) { + return dialer.DialContext(ctx, network, fallback) + } + + r.fallbackResolver = &net.Resolver{ + PreferGo: true, // Avoid the system resolver. + Dial: dialFallback, + } + } + + return r } func (r *httpsResolver) Init() error { @@ -101,6 +122,7 @@ func (r *httpsResolver) newClient() (*http.Client, error) { Timeout: 10 * time.Second, KeepAlive: 1 * time.Second, DualStack: true, + Resolver: r.fallbackResolver, }).DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 10, diff --git a/tests/external.sh b/tests/external.sh index 732e3b8..99114cb 100755 --- a/tests/external.sh +++ b/tests/external.sh @@ -44,6 +44,15 @@ function dnss() { PID=$! } +# Run minidns in the background (sets $MINIDNS_PID to its process id). +function minidns() { + go run tests/minidns.go \ + -addr ":1953" \ + -zones tests/testzones \ + > .minidns.log 2>&1 & + MINIDNS_PID=$! +} + # Wait until there's something listening on the given port. function wait_until_ready() { PROTO=$1 @@ -88,11 +97,11 @@ function get() { } function generate_certs() { - mkdir -p .certs/localhost + mkdir -p .certs/$1 ( - cd .certs/localhost + cd .certs/$1 go run ../../tests/generate_cert.go \ - -ca -duration=1h --host=localhost + -ca -duration=1h --host=$1 ) } @@ -105,6 +114,9 @@ if wait $PID; then exit 1 fi +echo "## Launching minidns for testing" +minidns +wait_until_ready tcp 1953 echo "## Launching HTTPS server" dnss -enable_https_to_dns \ @@ -126,7 +138,8 @@ fi echo "## DoH against dnss" dnss -enable_dns_to_https -dns_listen_addr "localhost:1053" \ - -https_upstream "http://localhost:1999/dns-query" + -fallback_upstream "127.0.0.1:1953" \ + -https_upstream "http://upstream:1999/dns-query" # Exercise DoH via GET (dnss always uses POST). get "http://localhost:1999/resolve?&dns=q80BAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB" @@ -148,23 +161,25 @@ kill $HTTP_PID echo "## HTTPS with custom certificates" -generate_certs +generate_certs upstream dnss -enable_https_to_dns \ - -https_key .certs/localhost/privkey.pem \ - -https_cert .certs/localhost/fullchain.pem \ + -https_key .certs/upstream/privkey.pem \ + -https_cert .certs/upstream/fullchain.pem \ -https_server_addr "localhost:1999" HTTP_PID=$PID mv .dnss.log .dnss.http.log wait_until_ready tcp 1999 dnss -enable_dns_to_https -dns_listen_addr "localhost:1053" \ - -https_client_cafile .certs/localhost/fullchain.pem \ - -https_upstream "https://localhost:1999/dns-query" + -fallback_upstream "127.0.0.1:1953" \ + -https_client_cafile .certs/upstream/fullchain.pem \ + -https_upstream "https://upstream:1999/dns-query" resolve kill $PID kill $HTTP_PID +kill $MINIDNS_PID # DoH integration test against some publicly available servers. diff --git a/tests/minidns.go b/tests/minidns.go new file mode 100644 index 0000000..9c3aeb4 --- /dev/null +++ b/tests/minidns.go @@ -0,0 +1,305 @@ +// +build ignore + +// minidns is a trivial DNS server used for testing. +// +// It takes an "answers" file which contains lines with the following format: +// +// <domain> <type> <value> +// +// For example: +// +// blah A 1.2.3.4 +// blah MX mx1 +// +// Supported types: A, AAAA, MX, TXT. +// +// It's only meant to be used for testing, so it's not robust, performant, or +// standards compliant. +// +package main + +import ( + "bufio" + "encoding/binary" + "flag" + "fmt" + "net" + "os" + "regexp" + "strings" + "sync" + + "blitiri.com.ar/go/log" + "golang.org/x/net/dns/dnsmessage" +) + +var ( + addr = flag.String("addr", ":53", "address to listen to (UDP)") + zonesPath = flag.String("zones", "", "file with the zones") +) + +func main() { + flag.Parse() + + srv := &miniDNS{ + answers: map[string][]dnsmessage.Resource{}, + } + + if *zonesPath == "" { + log.Fatalf("-zones must be given") + } + var zonesFile *os.File + if *zonesPath == "-" { + zonesFile = os.Stdin + } else { + var err error + zonesFile, err = os.Open(*zonesPath) + if err != nil { + log.Fatalf("error opening %v: %v", *zonesPath, err) + } + } + + srv.loadZones(zonesFile) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + srv.listenAndServeUDP(*addr) + }() + go func() { + defer wg.Done() + srv.listenAndServeTCP(*addr) + }() + wg.Wait() +} + +type miniDNS struct { + // Domain -> Answers. + // We always respond the same regardless of the query. + // Not great, but does the trick. + answers map[string][]dnsmessage.Resource +} + +func (m *miniDNS) listenAndServeUDP(addr string) { + conn, err := net.ListenPacket("udp", addr) + if err != nil { + log.Fatalf("error listening UDP %q: %v", addr, err) + } + + log.Infof("listening on %v", conn.LocalAddr()) + + buf := make([]byte, 64*1024) + for { + n, addr, err := conn.ReadFrom(buf) + if err != nil { + log.Infof("error reading from udp: %v", err) + continue + } + + msg := &dnsmessage.Message{} + err = msg.Unpack(buf[:n]) + if err != nil { + log.Infof("%v error unpacking message: %v", addr, err) + } + + if lq := len(msg.Questions); lq != 1 { + log.Infof("%v/%-5d dropping packet with %d questions", + addr, msg.ID, lq) + continue + } + q := msg.Questions[0] + log.Infof("%v/%-5d Q: %s %s %s", + addr, msg.ID, q.Name, q.Type, q.Class) + + reply := m.handle(msg) + rbuf, err := reply.Pack() + if err != nil { + log.Fatalf("error packing reply: %v", err) + } + + conn.WriteTo(rbuf, addr) + } +} + +func (m *miniDNS) listenAndServeTCP(addr string) { + ls, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("error listening TCP %q: %v", addr, err) + } + + log.Infof("listening on %v", addr) + + for { + conn, err := ls.Accept() + if err != nil { + log.Infof("error accepting: %v", err) + continue + } + + msg, err := readTCPMessage(conn) + if err != nil { + log.Infof("%v error reading message: %v", addr, err) + conn.Close() + continue + } + + if lq := len(msg.Questions); lq != 1 { + log.Infof("%v/%-5d dropping packet with %d questions", + addr, msg.ID, lq) + conn.Close() + continue + } + q := msg.Questions[0] + log.Infof("%v/%-5d Q: %s %s %s", + addr, msg.ID, q.Name, q.Type, q.Class) + + reply := m.handle(msg) + err = writeTCPMessage(conn, reply) + if err != nil { + log.Infof("error writing reply: %v", err) + } + + conn.Close() + } +} + +func readTCPMessage(conn net.Conn) (*dnsmessage.Message, error) { + // Read the 2-byte length first, then the message. + lenHdr := struct{ Len uint16 }{} + err := binary.Read(conn, binary.BigEndian, &lenHdr) + if err != nil { + return nil, err + } + + data := make([]byte, lenHdr.Len) + err = binary.Read(conn, binary.BigEndian, &data) + if err != nil { + return nil, err + } + + msg := &dnsmessage.Message{} + err = msg.Unpack(data) + if err != nil { + return nil, fmt.Errorf("%v error unpacking message: %v", addr, err) + } + + return msg, nil +} + +func writeTCPMessage(conn net.Conn, msg *dnsmessage.Message) error { + rbuf, err := msg.Pack() + if err != nil { + return fmt.Errorf("error packing reply: %v", err) + } + + lenHdr := struct{ Len uint16 }{Len: uint16(len(rbuf))} + err = binary.Write(conn, binary.BigEndian, lenHdr) + if err != nil { + return err + } + + _, err = conn.Write(rbuf) + return err +} + +func (m *miniDNS) handle(msg *dnsmessage.Message) *dnsmessage.Message { + reply := &dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: msg.ID, + Response: true, + RCode: dnsmessage.RCodeSuccess, + }, + Questions: msg.Questions, + } + + q := msg.Questions[0] + if answers, ok := m.answers[q.Name.String()]; ok { + for _, ans := range answers { + if q.Type == ans.Header.Type { + log.Infof("-> %s %v", q.Type, ans.Body) + reply.Answers = append(reply.Answers, ans) + } + } + } else { + log.Infof("-> NXERROR") + reply.Header.RCode = dnsmessage.RCodeNameError + } + + return reply +} + +func (m *miniDNS) loadZones(f *os.File) { + scanner := bufio.NewScanner(f) + lineno := 0 + for scanner.Scan() { + lineno++ + line := strings.TrimSpace(scanner.Text()) + if strings.HasPrefix(line, "#") || line == "" { + continue + } + + vs := regexp.MustCompile("\\s+").Split(line, 3) + if len(vs) != 3 { + log.Fatalf("line %d: invalid format", lineno) + } + domain, t, value := vs[0], vs[1], vs[2] + if !strings.HasSuffix(domain, ".") { + domain += "." + } + + var body dnsmessage.ResourceBody + var qType dnsmessage.Type + switch strings.ToLower(t) { + case "a": + qType = dnsmessage.TypeA + ip := net.ParseIP(value).To4() + if ip == nil { + log.Fatalf("line %d: invalid IP %q", lineno, value) + } + a := &dnsmessage.AResource{} + copy(a.A[:], ip[:4]) + body = a + case "aaaa": + qType = dnsmessage.TypeAAAA + ip := net.ParseIP(value).To16() + if ip == nil { + log.Fatalf("line %d: invalid IP %q", lineno, value) + } + aaaa := &dnsmessage.AAAAResource{} + copy(aaaa.AAAA[:], ip[:16]) + body = aaaa + case "mx": + qType = dnsmessage.TypeMX + if !strings.HasPrefix(value, ".") { + value += "." + } + + body = &dnsmessage.MXResource{ + Pref: 10, + MX: dnsmessage.MustNewName(value), + } + case "txt": + qType = dnsmessage.TypeTXT + body = &dnsmessage.TXTResource{ + TXT: []string{value}, + } + default: + log.Fatalf("line %d: unknown type %q", lineno, t) + } + + answer := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName(domain), + Type: qType, + Class: dnsmessage.ClassINET, + }, + Body: body, + } + m.answers[domain] = append(m.answers[domain], answer) + } + + if err := scanner.Err(); err != nil { + log.Fatalf("error reading zones: %v", err) + } +} diff --git a/tests/testzones b/tests/testzones new file mode 100644 index 0000000..f8b9479 --- /dev/null +++ b/tests/testzones @@ -0,0 +1,3 @@ +# Zones for minidns to use in the tests. +upstream A 127.0.0.1 +upstream AAAA ::1