git » dnss » commit 8a32217

dnsserver: Add support for overriding DNS server for specific domains

author Alberto Bertogli
2023-01-26 11:17:04 UTC
committer Alberto Bertogli
2023-01-27 18:10:09 UTC
parent cda7f0a1c825f26cf35cc69a8f816338f49ea518

dnsserver: Add support for overriding DNS server for specific domains

On small networks, it can be useful to say "for these specific domains,
instead of DoH, use this local DNS server instead".

This patch implements that feature, via a command-line flag.

README.md +4 -0
dnss.go +11 -1
dnss_test.go +1 -1
go.mod +1 -0
go.sum +2 -0
internal/dnsserver/domainmap.go +69 -0
internal/dnsserver/domainmap_test.go +89 -0
internal/dnsserver/server.go +25 -7
internal/dnsserver/server_test.go +28 -2

diff --git a/README.md b/README.md
index 7d3c2fa..c57102f 100644
--- a/README.md
+++ b/README.md
@@ -75,6 +75,10 @@ dnss -enable_dns_to_https -https_upstream="https://1.1.1.1/dns-query"
 
 # Use Google's dns.google:
 dnss -enable_dns_to_https -https_upstream="https://dns.google/dns-query"
+
+# Use the default HTTPS URL for all resolutions, except for domain "myhome"
+# which is resolved via a local DNS server.
+dnss -enable_dns_to_https -dns_server_for_domain="myhome:10.0.1.1:53"
 ```
 
 ### HTTPS server
diff --git a/dnss.go b/dnss.go
index c79aa94..8f5f685 100644
--- a/dnss.go
+++ b/dnss.go
@@ -33,6 +33,9 @@ var (
 
 	dnsUnqualifiedUpstream = flag.String("dns_unqualified_upstream", "",
 		"DNS server to forward unqualified requests to")
+	dnsServerForDomain = flag.String("dns_server_for_domain", "",
+		"DNS server to use for a specific domain, "+
+			`in the form of "domain1:addr1, domain2:addr, ..."`)
 
 	fallbackUpstream = flag.String("fallback_upstream", "8.8.8.8:53",
 		"DNS server used to resolve domains in -https_upstream"+
@@ -104,7 +107,14 @@ func main() {
 			cr.RegisterDebugHandlers()
 			resolver = cr
 		}
-		dth := dnsserver.New(*dnsListenAddr, resolver, *dnsUnqualifiedUpstream)
+
+		overrides, err := dnsserver.DomainMapFromString(*dnsServerForDomain)
+		if err != nil {
+			log.Fatalf("-dns_server_for_domain is not valid: %v", err)
+		}
+
+		dth := dnsserver.New(*dnsListenAddr, resolver,
+			*dnsUnqualifiedUpstream, overrides)
 
 		wg.Add(1)
 		go func() {
diff --git a/dnss_test.go b/dnss_test.go
index 6d03939..18c5771 100644
--- a/dnss_test.go
+++ b/dnss_test.go
@@ -66,7 +66,7 @@ func Setup(tb testing.TB) string {
 	// IP addresses directly in the http requests, the fallback resolver
 	// should not be needed.
 	r := httpresolver.NewDoH(HTTPSToDNSURL, "", "0.0.0.0:0")
-	dtoh := dnsserver.New(DNSToHTTPSAddr, r, "")
+	dtoh := dnsserver.New(DNSToHTTPSAddr, r, "", nil)
 	go dtoh.ListenAndServe()
 
 	if err := testutil.WaitForDNSServer(DNSToHTTPSAddr); err != nil {
diff --git a/go.mod b/go.mod
index 44a3c7c..5911a9f 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.15
 require (
 	blitiri.com.ar/go/log v1.1.0
 	blitiri.com.ar/go/systemd v1.1.0
+	github.com/google/go-cmp v0.5.9 // indirect
 	github.com/miekg/dns v1.1.48
 	golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4
 	golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect
diff --git a/go.sum b/go.sum
index 331c62d..9ce8252 100644
--- a/go.sum
+++ b/go.sum
@@ -2,6 +2,8 @@ blitiri.com.ar/go/log v1.1.0 h1:prKXp2hnYXRamcrYaCajq1SdQYvHU852lY7QStHyuaw=
 blitiri.com.ar/go/log v1.1.0/go.mod h1:CobnZ0FcxCAWHnkPCVtNPmj8AGiW9aNLKd/E7tI43Sw=
 blitiri.com.ar/go/systemd v1.1.0 h1:AMr7Ce/5CkvLZvGxsn/ZOagzFf3zU13rcgWdlbWMQ+Y=
 blitiri.com.ar/go/systemd v1.1.0/go.mod h1:0D9Ttrh+TX+WuKQ/dJpdhFND7NYy505v6jhsWrihmPY=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
 github.com/miekg/dns v1.1.48 h1:Ucfr7IIVyMBz4lRE8qmGUuZ4Wt3/ZGu9hmcMT3Uu4tQ=
 github.com/miekg/dns v1.1.48/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME=
 github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
diff --git a/internal/dnsserver/domainmap.go b/internal/dnsserver/domainmap.go
new file mode 100644
index 0000000..b4212d3
--- /dev/null
+++ b/internal/dnsserver/domainmap.go
@@ -0,0 +1,69 @@
+package dnsserver
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/miekg/dns"
+)
+
+// DomainMap maps a DNS name to an arbitrary string.
+type DomainMap map[string]string
+
+// Set the value for the given domain.
+func (m DomainMap) Set(domain, value string) {
+	m[dns.CanonicalName(domain)] = value
+}
+
+// GetExact value for the given domain, using an exact lookup (the domain must
+// match exactly what was set).
+func (m DomainMap) GetExact(domain string) (string, bool) {
+	v, ok := m[dns.CanonicalName(domain)]
+	return v, ok
+}
+
+// GetMostSpecific value for the given domain, using a most-specific lookup
+// (we pick the map entry that is closest to the domain).
+func (m DomainMap) GetMostSpecific(domain string) (string, bool) {
+	domain = dns.CanonicalName(domain)
+	mc := 0
+	mv := ""
+	ok := false
+	for d, v := range m {
+		if !dns.IsSubDomain(d, domain) {
+			continue
+		}
+
+		// Keep the match with the most labels (the most specific).
+		c := dns.CountLabel(d)
+		if c > mc {
+			mc = c
+			mv = v
+			ok = true
+		}
+	}
+
+	return mv, ok
+}
+
+// DomainMapFromString takes a string in the form of
+// "domain1:addr1,domain2:addr2,..." and returns a dnsserver.DomainMap like
+// {"domain1": "addr1", "domain2": "addr2", ...}.
+func DomainMapFromString(s string) (DomainMap, error) {
+	m := DomainMap{}
+	for _, pair := range strings.Split(s, ",") {
+		pair = strings.TrimSpace(pair)
+		if pair == "" {
+			continue
+		}
+
+		xs := strings.SplitN(pair, ":", 2)
+		if len(xs) != 2 {
+			return nil, fmt.Errorf("%q: %w", pair, errInvalidFormat)
+		}
+		m.Set(strings.TrimSpace(xs[0]), strings.TrimSpace(xs[1]))
+	}
+	return m, nil
+}
+
+var errInvalidFormat = fmt.Errorf("entry does not have a ':'")
diff --git a/internal/dnsserver/domainmap_test.go b/internal/dnsserver/domainmap_test.go
new file mode 100644
index 0000000..1a3386f
--- /dev/null
+++ b/internal/dnsserver/domainmap_test.go
@@ -0,0 +1,89 @@
+package dnsserver
+
+import (
+	"errors"
+	"testing"
+
+	"github.com/google/go-cmp/cmp"
+)
+
+func TestDomainMap(t *testing.T) {
+	m := DomainMap{}
+	m.Set("a.com", "valuex")
+	m.Set("a.com", "valueA")
+	m.Set("x.A.com", "valueX")
+	m.Set("y.a.com", "valueY")
+
+	type tcase struct {
+		req string
+		val string
+		ok  bool
+	}
+
+	cases := []tcase{
+		{"a.com", "valueA", true},
+		{"A.cOm", "valueA", true},
+		{"A.COM.", "valueA", true},
+		{"x.a.com", "valueX", true},
+		{"y.a.com", "valueY", true},
+		{"com", "", false},
+		{"b.a.com", "", false},
+	}
+	for i, c := range cases {
+		val, ok := m.GetExact(c.req)
+		if val != c.val || ok != c.ok {
+			t.Errorf("case %d: GetExact(%q) expected (%q, %v), got (%q, %v)",
+				i, c.req, c.val, c.ok, val, ok)
+		}
+	}
+
+	cases = []tcase{
+		{"a.com", "valueA", true},
+		{"x.a.com", "valueX", true},
+		{"y.a.com", "valueY", true},
+		{"b.a.com", "valueA", true},
+		{"z.x.a.com", "valueX", true},
+		{"com", "", false},
+	}
+	for i, c := range cases {
+		val, ok := m.GetMostSpecific(c.req)
+		if val != c.val || ok != c.ok {
+			t.Errorf("case %d: GetMostSpecific(%q) expected (%q, %v), got (%q, %v)",
+				i, c.req, c.val, c.ok, val, ok)
+		}
+	}
+}
+
+func TestDomainMapFromString(t *testing.T) {
+	cases := []struct {
+		s   string
+		m   DomainMap
+		err error
+	}{
+		{"", DomainMap{}, nil},
+		{"d1:1.1.1.1:1111", DomainMap{"d1.": "1.1.1.1:1111"}, nil},
+		{"Do-Main:1.1.1.1:1111", DomainMap{"do-main.": "1.1.1.1:1111"}, nil},
+		{
+			"d1:1.1.1.1:1111, d2.: 2.2.2.2:2222 ,,d3 : 3.3.3.3:3333, d4:",
+			DomainMap{
+				"d1.": "1.1.1.1:1111",
+				"d2.": "2.2.2.2:2222",
+				"d3.": "3.3.3.3:3333",
+				"d4.": "",
+			},
+			nil,
+		},
+		{"abc", nil, errInvalidFormat},
+		{"abc:def,xyz", nil, errInvalidFormat},
+	}
+	for i, c := range cases {
+		m, err := DomainMapFromString(c.s)
+		if diff := cmp.Diff(c.m, m); diff != "" {
+			t.Errorf("%d: DomainMapFromString(%q) mismatch (-want +got):\n%s", i, c.s, diff)
+		}
+		if !errors.Is(err, c.err) {
+			t.Errorf("%d: DomainMapFromString(%q) unexpected error: "+
+				"want:%q ; got:%q", i, c.s, c.err, err)
+		}
+	}
+}
diff --git a/internal/dnsserver/server.go b/internal/dnsserver/server.go
index 395a2f7..f78556d 100644
--- a/internal/dnsserver/server.go
+++ b/internal/dnsserver/server.go
@@ -44,18 +44,20 @@ func init() {
 // Server implements a DNS proxy, which will (mostly) use the given resolver
 // to resolve queries.
 type Server struct {
-	Addr        string
-	unqUpstream string
-	resolver    Resolver
+	Addr            string
+	unqUpstream     string
+	serverOverrides DomainMap
+	resolver        Resolver
 }
 
 // New *Server, which will listen on addr, use resolver as the backend
 // resolver, and use unqUpstream to resolve unqualified queries.
-func New(addr string, resolver Resolver, unqUpstream string) *Server {
+func New(addr string, resolver Resolver, unqUpstream string, serverOverrides DomainMap) *Server {
 	return &Server{
-		Addr:        addr,
-		resolver:    resolver,
-		unqUpstream: unqUpstream,
+		Addr:            addr,
+		resolver:        resolver,
+		unqUpstream:     unqUpstream,
+		serverOverrides: serverOverrides,
 	}
 }
 
@@ -74,6 +76,22 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
 		return
 	}
 
+	// If the domain has a server override, forward to it instead.
+	override, ok := s.serverOverrides.GetMostSpecific(r.Question[0].Name)
+	if ok {
+		tr.Printf("override found: %q", override)
+		u, err := dns.Exchange(r, override)
+		if err == nil {
+			tr.Answer(u)
+			s.writeReply(tr, w, r, u)
+		} else {
+			tr.Printf("override server returned error: %v", err)
+			dns.HandleFailed(w, r)
+		}
+
+		return
+	}
+
 	// Forward to the unqualified upstream server if:
 	//  - We have one configured.
 	//  - There's only one question in the request, to keep things simple.
diff --git a/internal/dnsserver/server_test.go b/internal/dnsserver/server_test.go
index 28820aa..2a939c5 100644
--- a/internal/dnsserver/server_test.go
+++ b/internal/dnsserver/server_test.go
@@ -21,12 +21,32 @@ func TestServe(t *testing.T) {
 	go testutil.ServeTestDNSServer(unqUpstreamAddr,
 		testutil.MakeStaticHandler(t, "unq. A 2.2.2.2"))
 
-	srv := New(testutil.GetFreePort(), res, unqUpstreamAddr)
+	overrideAddr3 := testutil.GetFreePort()
+	go testutil.ServeTestDNSServer(overrideAddr3,
+		testutil.MakeStaticHandler(t, "a.ov3. A 3.3.3.3"))
+
+	overrideAddr4 := testutil.GetFreePort()
+	go testutil.ServeTestDNSServer(overrideAddr4,
+		testutil.MakeStaticHandler(t, "b.ov4. A 4.4.4.4"))
+
+	overrides := DomainMap{
+		"ov3.":   overrideAddr3,
+		"a.ov4.": overrideAddr4,
+	}
+
+	srv := New(testutil.GetFreePort(), res, unqUpstreamAddr, overrides)
 	go srv.ListenAndServe()
 	testutil.WaitForDNSServer(srv.Addr)
 
 	query(t, srv.Addr, "response.test.", "1.1.1.1")
 	query(t, srv.Addr, "unqualified.", "2.2.2.2")
+
+	query(t, srv.Addr, "ov3.", "3.3.3.3")
+	query(t, srv.Addr, "x.ov3.", "3.3.3.3")
+	query(t, srv.Addr, "y.x.OV3.", "3.3.3.3")
+	query(t, srv.Addr, "A.ov4.", "4.4.4.4")
+	query(t, srv.Addr, "z.a.ov4.", "4.4.4.4")
+	query(t, srv.Addr, "b.ov4.", "1.1.1.1") // Not overridden.
 }
 
 func query(t *testing.T, srv, domain, expected string) {
@@ -49,13 +69,19 @@ func TestBadUpstreams(t *testing.T) {
 	// Get addresses but don't start the servers, so we get an error when
 	// trying to reach them.
 	unqUpstreamAddr := testutil.GetFreePort()
+	overrideAddr1 := testutil.GetFreePort()
+
+	overrides := DomainMap{
+		"ov1.": overrideAddr1,
+	}
 
-	srv := New(testutil.GetFreePort(), res, unqUpstreamAddr)
+	srv := New(testutil.GetFreePort(), res, unqUpstreamAddr, overrides)
 	go srv.ListenAndServe()
 	testutil.WaitForDNSServer(srv.Addr)
 
 	queryFailure(t, srv.Addr, "response.test.")
 	queryFailure(t, srv.Addr, "unqualified.")
+	queryFailure(t, srv.Addr, "ov1.")
 }
 
 func queryFailure(t *testing.T, srv, domain string) {