author | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-01-26 11:17:04 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-01-27 18:10:09 UTC |
parent | cda7f0a1c825f26cf35cc69a8f816338f49ea518 |
README.md | +4 | -0 |
dnss.go | +11 | -1 |
dnss_test.go | +1 | -1 |
go.mod | +1 | -0 |
go.sum | +2 | -0 |
internal/dnsserver/domainmap.go | +69 | -0 |
internal/dnsserver/domainmap_test.go | +89 | -0 |
internal/dnsserver/server.go | +25 | -7 |
internal/dnsserver/server_test.go | +28 | -2 |
diff --git a/README.md b/README.md index 7d3c2fa..c57102f 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,10 @@ dnss -enable_dns_to_https -https_upstream="https://1.1.1.1/dns-query" # Use Google's dns.google: dnss -enable_dns_to_https -https_upstream="https://dns.google/dns-query" + +# Use the default HTTPS URL for all resolutions, except for domain "myhome" +# which is resolved via a local DNS server. +dnss -enable_dns_to_https -dns_server_for_domain="myhome:10.0.1.1:53" ``` ### HTTPS server diff --git a/dnss.go b/dnss.go index c79aa94..8f5f685 100644 --- a/dnss.go +++ b/dnss.go @@ -33,6 +33,9 @@ var ( dnsUnqualifiedUpstream = flag.String("dns_unqualified_upstream", "", "DNS server to forward unqualified requests to") + dnsServerForDomain = flag.String("dns_server_for_domain", "", + "DNS server to use for a specific domain, "+ + `in the form of "domain1:addr1, domain2:addr, ..."`) fallbackUpstream = flag.String("fallback_upstream", "8.8.8.8:53", "DNS server used to resolve domains in -https_upstream"+ @@ -104,7 +107,14 @@ func main() { cr.RegisterDebugHandlers() resolver = cr } - dth := dnsserver.New(*dnsListenAddr, resolver, *dnsUnqualifiedUpstream) + + overrides, err := dnsserver.DomainMapFromString(*dnsServerForDomain) + if err != nil { + log.Fatalf("-dns_server_for_domain is not valid: %v", err) + } + + dth := dnsserver.New(*dnsListenAddr, resolver, + *dnsUnqualifiedUpstream, overrides) wg.Add(1) go func() { diff --git a/dnss_test.go b/dnss_test.go index 6d03939..18c5771 100644 --- a/dnss_test.go +++ b/dnss_test.go @@ -66,7 +66,7 @@ func Setup(tb testing.TB) string { // IP addresses directly in the http requests, the fallback resolver // should not be needed. r := httpresolver.NewDoH(HTTPSToDNSURL, "", "0.0.0.0:0") - dtoh := dnsserver.New(DNSToHTTPSAddr, r, "") + dtoh := dnsserver.New(DNSToHTTPSAddr, r, "", nil) go dtoh.ListenAndServe() if err := testutil.WaitForDNSServer(DNSToHTTPSAddr); err != nil { diff --git a/go.mod b/go.mod index 44a3c7c..5911a9f 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.15 require ( blitiri.com.ar/go/log v1.1.0 blitiri.com.ar/go/systemd v1.1.0 + github.com/google/go-cmp v0.5.9 // indirect github.com/miekg/dns v1.1.48 golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 golang.org/x/sys v0.0.0-20220422013727-9388b58f7150 // indirect diff --git a/go.sum b/go.sum index 331c62d..9ce8252 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ blitiri.com.ar/go/log v1.1.0 h1:prKXp2hnYXRamcrYaCajq1SdQYvHU852lY7QStHyuaw= blitiri.com.ar/go/log v1.1.0/go.mod h1:CobnZ0FcxCAWHnkPCVtNPmj8AGiW9aNLKd/E7tI43Sw= blitiri.com.ar/go/systemd v1.1.0 h1:AMr7Ce/5CkvLZvGxsn/ZOagzFf3zU13rcgWdlbWMQ+Y= blitiri.com.ar/go/systemd v1.1.0/go.mod h1:0D9Ttrh+TX+WuKQ/dJpdhFND7NYy505v6jhsWrihmPY= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/miekg/dns v1.1.48 h1:Ucfr7IIVyMBz4lRE8qmGUuZ4Wt3/ZGu9hmcMT3Uu4tQ= github.com/miekg/dns v1.1.48/go.mod h1:e3IlAVfNqAllflbibAZEWOXOQ+Ynzk/dDozDxY7XnME= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= diff --git a/internal/dnsserver/domainmap.go b/internal/dnsserver/domainmap.go new file mode 100644 index 0000000..b4212d3 --- /dev/null +++ b/internal/dnsserver/domainmap.go @@ -0,0 +1,69 @@ +package dnsserver + +import ( + "fmt" + "strings" + + "github.com/miekg/dns" +) + +// DomainMap maps a DNS name to an arbitrary string. +type DomainMap map[string]string + +// Set the value for the given domain. +func (m DomainMap) Set(domain, value string) { + m[dns.CanonicalName(domain)] = value +} + +// GetExact value for the given domain, using an exact lookup (the domain must +// match exactly what was set). +func (m DomainMap) GetExact(domain string) (string, bool) { + v, ok := m[dns.CanonicalName(domain)] + return v, ok +} + +// GetMostSpecific value for the given domain, using a most-specific lookup +// (we pick the map entry that is closest to the domain). +func (m DomainMap) GetMostSpecific(domain string) (string, bool) { + domain = dns.CanonicalName(domain) + mc := 0 + mv := "" + ok := false + for d, v := range m { + if !dns.IsSubDomain(d, domain) { + continue + } + + // Keep the match with the most labels (the most specific). + c := dns.CountLabel(d) + if c > mc { + mc = c + mv = v + ok = true + } + } + + return mv, ok +} + +// DomainMapFromString takes a string in the form of +// "domain1:addr1,domain2:addr2,..." and returns a dnsserver.DomainMap like +// {"domain1": "addr1", "domain2": "addr2", ...}. +func DomainMapFromString(s string) (DomainMap, error) { + m := DomainMap{} + for _, pair := range strings.Split(s, ",") { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + + xs := strings.SplitN(pair, ":", 2) + if len(xs) != 2 { + return nil, fmt.Errorf("%q: %w", pair, errInvalidFormat) + } + m.Set(strings.TrimSpace(xs[0]), strings.TrimSpace(xs[1])) + } + return m, nil +} + +var errInvalidFormat = fmt.Errorf("entry does not have a ':'") diff --git a/internal/dnsserver/domainmap_test.go b/internal/dnsserver/domainmap_test.go new file mode 100644 index 0000000..1a3386f --- /dev/null +++ b/internal/dnsserver/domainmap_test.go @@ -0,0 +1,89 @@ +package dnsserver + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestDomainMap(t *testing.T) { + m := DomainMap{} + m.Set("a.com", "valuex") + m.Set("a.com", "valueA") + m.Set("x.A.com", "valueX") + m.Set("y.a.com", "valueY") + + type tcase struct { + req string + val string + ok bool + } + + cases := []tcase{ + {"a.com", "valueA", true}, + {"A.cOm", "valueA", true}, + {"A.COM.", "valueA", true}, + {"x.a.com", "valueX", true}, + {"y.a.com", "valueY", true}, + {"com", "", false}, + {"b.a.com", "", false}, + } + for i, c := range cases { + val, ok := m.GetExact(c.req) + if val != c.val || ok != c.ok { + t.Errorf("case %d: GetExact(%q) expected (%q, %v), got (%q, %v)", + i, c.req, c.val, c.ok, val, ok) + } + } + + cases = []tcase{ + {"a.com", "valueA", true}, + {"x.a.com", "valueX", true}, + {"y.a.com", "valueY", true}, + {"b.a.com", "valueA", true}, + {"z.x.a.com", "valueX", true}, + {"com", "", false}, + } + for i, c := range cases { + val, ok := m.GetMostSpecific(c.req) + if val != c.val || ok != c.ok { + t.Errorf("case %d: GetMostSpecific(%q) expected (%q, %v), got (%q, %v)", + i, c.req, c.val, c.ok, val, ok) + } + } +} + +func TestDomainMapFromString(t *testing.T) { + cases := []struct { + s string + m DomainMap + err error + }{ + {"", DomainMap{}, nil}, + {"d1:1.1.1.1:1111", DomainMap{"d1.": "1.1.1.1:1111"}, nil}, + {"Do-Main:1.1.1.1:1111", DomainMap{"do-main.": "1.1.1.1:1111"}, nil}, + { + "d1:1.1.1.1:1111, d2.: 2.2.2.2:2222 ,,d3 : 3.3.3.3:3333, d4:", + DomainMap{ + "d1.": "1.1.1.1:1111", + "d2.": "2.2.2.2:2222", + "d3.": "3.3.3.3:3333", + "d4.": "", + }, + nil, + }, + {"abc", nil, errInvalidFormat}, + {"abc:def,xyz", nil, errInvalidFormat}, + } + for i, c := range cases { + m, err := DomainMapFromString(c.s) + if diff := cmp.Diff(c.m, m); diff != "" { + t.Errorf("%d: DomainMapFromString(%q) mismatch (-want +got):\n%s", i, c.s, diff) + } + if !errors.Is(err, c.err) { + t.Errorf("%d: DomainMapFromString(%q) unexpected error: "+ + "want:%q ; got:%q", i, c.s, c.err, err) + } + } +} diff --git a/internal/dnsserver/server.go b/internal/dnsserver/server.go index 395a2f7..f78556d 100644 --- a/internal/dnsserver/server.go +++ b/internal/dnsserver/server.go @@ -44,18 +44,20 @@ func init() { // Server implements a DNS proxy, which will (mostly) use the given resolver // to resolve queries. type Server struct { - Addr string - unqUpstream string - resolver Resolver + Addr string + unqUpstream string + serverOverrides DomainMap + resolver Resolver } // New *Server, which will listen on addr, use resolver as the backend // resolver, and use unqUpstream to resolve unqualified queries. -func New(addr string, resolver Resolver, unqUpstream string) *Server { +func New(addr string, resolver Resolver, unqUpstream string, serverOverrides DomainMap) *Server { return &Server{ - Addr: addr, - resolver: resolver, - unqUpstream: unqUpstream, + Addr: addr, + resolver: resolver, + unqUpstream: unqUpstream, + serverOverrides: serverOverrides, } } @@ -74,6 +76,22 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) { return } + // If the domain has a server override, forward to it instead. + override, ok := s.serverOverrides.GetMostSpecific(r.Question[0].Name) + if ok { + tr.Printf("override found: %q", override) + u, err := dns.Exchange(r, override) + if err == nil { + tr.Answer(u) + s.writeReply(tr, w, r, u) + } else { + tr.Printf("override server returned error: %v", err) + dns.HandleFailed(w, r) + } + + return + } + // Forward to the unqualified upstream server if: // - We have one configured. // - There's only one question in the request, to keep things simple. diff --git a/internal/dnsserver/server_test.go b/internal/dnsserver/server_test.go index 28820aa..2a939c5 100644 --- a/internal/dnsserver/server_test.go +++ b/internal/dnsserver/server_test.go @@ -21,12 +21,32 @@ func TestServe(t *testing.T) { go testutil.ServeTestDNSServer(unqUpstreamAddr, testutil.MakeStaticHandler(t, "unq. A 2.2.2.2")) - srv := New(testutil.GetFreePort(), res, unqUpstreamAddr) + overrideAddr3 := testutil.GetFreePort() + go testutil.ServeTestDNSServer(overrideAddr3, + testutil.MakeStaticHandler(t, "a.ov3. A 3.3.3.3")) + + overrideAddr4 := testutil.GetFreePort() + go testutil.ServeTestDNSServer(overrideAddr4, + testutil.MakeStaticHandler(t, "b.ov4. A 4.4.4.4")) + + overrides := DomainMap{ + "ov3.": overrideAddr3, + "a.ov4.": overrideAddr4, + } + + srv := New(testutil.GetFreePort(), res, unqUpstreamAddr, overrides) go srv.ListenAndServe() testutil.WaitForDNSServer(srv.Addr) query(t, srv.Addr, "response.test.", "1.1.1.1") query(t, srv.Addr, "unqualified.", "2.2.2.2") + + query(t, srv.Addr, "ov3.", "3.3.3.3") + query(t, srv.Addr, "x.ov3.", "3.3.3.3") + query(t, srv.Addr, "y.x.OV3.", "3.3.3.3") + query(t, srv.Addr, "A.ov4.", "4.4.4.4") + query(t, srv.Addr, "z.a.ov4.", "4.4.4.4") + query(t, srv.Addr, "b.ov4.", "1.1.1.1") // Not overridden. } func query(t *testing.T, srv, domain, expected string) { @@ -49,13 +69,19 @@ func TestBadUpstreams(t *testing.T) { // Get addresses but don't start the servers, so we get an error when // trying to reach them. unqUpstreamAddr := testutil.GetFreePort() + overrideAddr1 := testutil.GetFreePort() + + overrides := DomainMap{ + "ov1.": overrideAddr1, + } - srv := New(testutil.GetFreePort(), res, unqUpstreamAddr) + srv := New(testutil.GetFreePort(), res, unqUpstreamAddr, overrides) go srv.ListenAndServe() testutil.WaitForDNSServer(srv.Addr) queryFailure(t, srv.Addr, "response.test.") queryFailure(t, srv.Addr, "unqualified.") + queryFailure(t, srv.Addr, "ov1.") } func queryFailure(t *testing.T, srv, domain string) {