author | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-11-26 12:29:15 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-11-26 12:29:15 UTC |
.travis.yml | +15 | -0 |
LICENSE | +219 | -0 |
README.md | +74 | -0 |
dnss.go | +222 | -0 |
etc/systemd/dns-to-grpc/dnss.service | +35 | -0 |
etc/systemd/dns-to-grpc/dnss.socket | +11 | -0 |
etc/systemd/dns-to-https/dnss.service | +32 | -0 |
etc/systemd/dns-to-https/dnss.socket | +11 | -0 |
etc/systemd/grpc-to-dns/dnss.service | +24 | -0 |
glide.lock | +39 | -0 |
glide.yaml | +18 | -0 |
internal/dnstox/caching_test.go | +363 | -0 |
internal/dnstox/dnstox.go | +233 | -0 |
internal/dnstox/resolver.go | +548 | -0 |
internal/grpctodns/grpctodns.go | +109 | -0 |
internal/proto/dnss.pb.go | +134 | -0 |
internal/proto/dnss.proto | +15 | -0 |
internal/proto/dummy.go | +4 | -0 |
internal/util/strings.go | +30 | -0 |
monitoring_test.go | +48 | -0 |
testing/grpc/grpc.go | +4 | -0 |
testing/grpc/grpc_test.go | +243 | -0 |
testing/https/https.go | +4 | -0 |
testing/https/https_test.go | +163 | -0 |
testing/util/util.go | +84 | -0 |
tools/bench | +119 | -0 |
diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..1d3f7ea --- /dev/null +++ b/.travis.yml @@ -0,0 +1,15 @@ +# Configuration for https://travis-ci.org/ + +language: go +go_import_path: blitiri.com.ar/go/dnss + +go: + - 1.6 + - 1.7 + - tip + +script: + - go test ./... + - go test -bench . ./... + - go test -race -bench . ./... + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..6f1ab52 --- /dev/null +++ b/LICENSE @@ -0,0 +1,219 @@ + +Copyright 2016 + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + +-------------------------------------------------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..4cbc71f --- /dev/null +++ b/README.md @@ -0,0 +1,74 @@ + +# dnss + +dnss is a tool for encapsulating DNS over more secure protocols, like HTTPS or +GRPC. + +## Quick start + +If you want to set up dnss quickly, in DNS-over-HTTPS mode and using +https://dns.google.com as a server, you can run the following: + +``` +# If you have Go installed but no environment prepared, do: +mkdir /tmp/dnss; export GOPATH=/tmp/dnss; cd $GOPATH + +# Download and build the binary. +go get blitiri.com.ar/go/dnss + +# Copy the binary to a system-wide location. +sudo cp $GOPATH/bin/dnss /usr/local/bin + +# Set it up in systemd. +sudo cp $GOPATH/src/blitiri.com.ar/go/dnss/etc/systemd/dns-to-https/* \ + /etc/systemd/system/ + +sudo systemctl dnss enable +``` + + +## DNS over HTTPS + +dnss can act as a DNS-over-HTTPS proxy, using https://dns.google.com as a +server. + +``` ++--------+ +----------------+ +----------------+ +| | | dnss | | | +| client +-------> (dns-to-https) +--------> dns.google.com | +| | DNS | | | | ++--------+ UDP +----------------+ HTTP +----------------+ + SSL + TCP +``` + + +## DNS over GRPC + +dnss can encapsulate DNS over GRPC. + +It can be useful when you want to use a particular DNS server, but don't want +some parts of the network in between to be able to see your traffic. + + +``` ++--------+ +---------------+ +---------------+ +------------+ +| | | dnss | | dnss | | | +| client +-------> (dns-to-grpc) +--------> (grpc-to-dns) +------> DNS server | +| | DNS | | DNS | | DNS | | ++--------+ UDP +---------------+ GRPC +---------------+ UDP +------------+ + SSL + TCP +``` + +In "dns-to-grpc" mode, it listens to DNS requests and pass them on to a server +using GRPC. It also has a small cache. + +In "grpc-to-dns" mode, it receives DNS requests via GRPC, and resolves them +using a normal, fixed DNS server. + + +## Alternatives + +https://dnscrypt.org/ is a great, more end-to-end alternative to dnss. + diff --git a/dnss.go b/dnss.go new file mode 100644 index 0000000..3912fa0 --- /dev/null +++ b/dnss.go @@ -0,0 +1,222 @@ +package main + +import ( + "flag" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "blitiri.com.ar/go/dnss/internal/dnstox" + "blitiri.com.ar/go/dnss/internal/grpctodns" + + "github.com/golang/glog" + "google.golang.org/grpc" + + // Register pprof handlers for monitoring and debugging. + _ "net/http/pprof" + + // Make GRPC log to glog. + _ "google.golang.org/grpc/grpclog/glogger" +) + +var ( + dnsListenAddr = flag.String("dns_listen_addr", ":53", + "address to listen on for DNS") + + dnsUnqualifiedUpstream = flag.String("dns_unqualified_upstream", "", + "DNS server to forward unqualified requests to") + + fallbackUpstream = flag.String("fallback_upstream", "8.8.8.8:53", + "DNS server to resolve domains in --fallback_domains") + fallbackDomains = flag.String("fallback_domains", "dns.google.com.", + "Domains we resolve via DNS, using --fallback_upstream"+ + " (space-separated list)") + + enableDNStoGRPC = flag.Bool("enable_dns_to_grpc", false, + "enable DNS-to-GRPC server") + grpcUpstream = flag.String("grpc_upstream", "localhost:9953", + "address of the upstream GRPC server") + grpcClientCAFile = flag.String("grpc_client_cafile", "", + "CA file to use for the GRPC client") + + enableGRPCtoDNS = flag.Bool("enable_grpc_to_dns", false, + "enable GRPC-to-DNS server") + grpcListenAddr = flag.String("grpc_listen_addr", ":9953", + "address to listen on for GRPC") + dnsUpstream = flag.String("dns_upstream", "8.8.8.8:53", + "address of the upstream DNS server") + + enableDNStoHTTPS = flag.Bool("enable_dns_to_https", false, + "enable DNS-to-HTTPS proxy") + httpsUpstream = flag.String("https_upstream", + "https://dns.google.com/resolve", + "URL of upstream DNS-to-HTTP server") + httpsClientCAFile = flag.String("https_client_cafile", "", + "CA file to use for the HTTPS client") + + grpcCert = flag.String("grpc_cert", "", + "certificate file for the GRPC server") + grpcKey = flag.String("grpc_key", "", + "key file for the GRPC server") + + logFlushEvery = flag.Duration("log_flush_every", 30*time.Second, + "how often to flush logs") + monitoringListenAddr = flag.String("monitoring_listen_addr", "", + "address to listen on for monitoring HTTP requests") +) + +func flushLogs() { + c := time.Tick(*logFlushEvery) + for range c { + glog.Flush() + } +} + +func main() { + defer glog.Flush() + + flag.Parse() + + go flushLogs() + + grpc.EnableTracing = false + if *monitoringListenAddr != "" { + launchMonitoringServer(*monitoringListenAddr) + } + + if !*enableDNStoGRPC && !*enableGRPCtoDNS && !*enableDNStoHTTPS { + glog.Error("Need to set one of the following:") + glog.Error(" --enable_dns_to_https") + glog.Error(" --enable_dns_to_grpc") + glog.Error(" --enable_grpc_to_dns") + glog.Fatal("") + } + + if *enableDNStoGRPC && *enableDNStoHTTPS { + glog.Error("The following options cannot be set at the same time:") + glog.Error(" --enable_dns_to_grpc and --enable_dns_to_https") + glog.Fatal("") + } + + var wg sync.WaitGroup + + // DNS to GRPC. + if *enableDNStoGRPC { + r := dnstox.NewGRPCResolver(*grpcUpstream, *grpcClientCAFile) + cr := dnstox.NewCachingResolver(r) + cr.RegisterDebugHandlers() + dtg := dnstox.New(*dnsListenAddr, cr, *dnsUnqualifiedUpstream) + dtg.SetFallback( + *fallbackUpstream, strings.Split(*fallbackDomains, " ")) + wg.Add(1) + go func() { + defer wg.Done() + dtg.ListenAndServe() + }() + } + + // GRPC to DNS. + if *enableGRPCtoDNS { + gtd := &grpctodns.Server{ + Addr: *grpcListenAddr, + Upstream: *dnsUpstream, + CertFile: *grpcCert, + KeyFile: *grpcKey, + } + wg.Add(1) + go func() { + defer wg.Done() + gtd.ListenAndServe() + }() + } + + // DNS to HTTPS. + if *enableDNStoHTTPS { + r := dnstox.NewHTTPSResolver(*httpsUpstream, *httpsClientCAFile) + cr := dnstox.NewCachingResolver(r) + cr.RegisterDebugHandlers() + dth := dnstox.New(*dnsListenAddr, cr, *dnsUnqualifiedUpstream) + dth.SetFallback( + *fallbackUpstream, strings.Split(*fallbackDomains, " ")) + wg.Add(1) + go func() { + defer wg.Done() + dth.ListenAndServe() + }() + } + + wg.Wait() +} + +func launchMonitoringServer(addr string) { + glog.Infof("Monitoring HTTP server listening on %s", addr) + grpc.EnableTracing = true + + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.NotFound(w, r) + return + } + w.Write([]byte(monitoringHTMLIndex)) + }) + + flags := dumpFlags() + http.HandleFunc("/debug/flags", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(flags)) + }) + + go http.ListenAndServe(addr, nil) +} + +// Static index for the monitoring website. +const monitoringHTMLIndex = `<!DOCTYPE html> +<html> + <head> + <title>dnss monitoring</title> + </head> + <body> + <h1>dnss monitoring</h1> + <ul> + <li><a href="/debug/requests">requests</a> + <small><a href="https://godoc.org/golang.org/x/net/trace"> + (ref)</a></small> + <ul> + <li><a href="/debug/requests?fam=dnstox&b=11">dnstox latency</a> + <li><a href="/debug/requests?fam=dnstox&b=0&exp=1">dnstox trace</a> + </ul> + <li><a href="/debug/dnstox/cache/dump">cache dump</a> + <li><a href="/debug/pprof">pprof</a> + <small><a href="https://golang.org/pkg/net/http/pprof/"> + (ref)</a></small> + <ul> + <li><a href="/debug/pprof/goroutine?debug=1">goroutines</a> + </ul> + <li><a href="/debug/flags">flags</a> + <li><a href="/debug/vars">public variables</a> + </ul> + </body> +</html> +` + +// dumpFlags to a string, for troubleshooting purposes. +func dumpFlags() string { + s := "" + visited := make(map[string]bool) + + // Print set flags first, then the rest. + flag.Visit(func(f *flag.Flag) { + s += fmt.Sprintf("-%s=%s\n", f.Name, f.Value.String()) + visited[f.Name] = true + }) + + s += "\n" + flag.VisitAll(func(f *flag.Flag) { + if !visited[f.Name] { + s += fmt.Sprintf("-%s=%s\n", f.Name, f.Value.String()) + } + }) + + return s +} diff --git a/etc/systemd/dns-to-grpc/dnss.service b/etc/systemd/dns-to-grpc/dnss.service new file mode 100644 index 0000000..bf4e3d8 --- /dev/null +++ b/etc/systemd/dns-to-grpc/dnss.service @@ -0,0 +1,35 @@ +[Unit] +Description = dnss daemon - DNS to GRPC mode + +# Note we get the sockets via systemd, see the matching .socket configuration. +Requires=dnss.socket + + +[Service] +ExecStart = /usr/bin/dnss \ + --dns_listen_addr=systemd \ + --logtostderr \ + --monitoring_listen_addr=127.0.0.1:8081 \ + --grpc_upstream=1.2.3.4:9953 \ + --grpc_client_cafile=/etc/ssl/dnss/1.2.3.4-cert.pem \ + --enable_dns_to_grpc + + +Type = simple +Restart = always + +# The user can be created with no permissions using: +# +# sudo useradd -U dnss -M -d /nonexistent -s /bin/false +User = ddns +Group = ddns + +# Simple security measures just in case. +CapabilityBoundingSet = CAP_NET_BIND_SERVICE +ProtectSystem = full + + +[Install] +Also=dnss.socket +WantedBy = multi-user.target + diff --git a/etc/systemd/dns-to-grpc/dnss.socket b/etc/systemd/dns-to-grpc/dnss.socket new file mode 100644 index 0000000..b73523c --- /dev/null +++ b/etc/systemd/dns-to-grpc/dnss.socket @@ -0,0 +1,11 @@ +# Sockets for dnss. +# +# This lets dnss run unprivileged. +# We typically want one UDP and one TCP socket. + +[Socket] +ListenDatagram=53 +ListenStream=53 + +[Install] +WantedBy=sockets.target diff --git a/etc/systemd/dns-to-https/dnss.service b/etc/systemd/dns-to-https/dnss.service new file mode 100644 index 0000000..262092a --- /dev/null +++ b/etc/systemd/dns-to-https/dnss.service @@ -0,0 +1,32 @@ +[Unit] +Description = dnss daemon - DNS over HTTPS mode + +# Note we get the sockets via systemd, see dnss.socket. +Requires=dnss.socket + +[Service] +ExecStart=/usr/local/bin/dnss \ + --dns_listen_addr=systemd \ + --logtostderr \ + --monitoring_listen_addr=127.0.0.1:8081 \ + --enable_dns_to_https + + +Type = simple +Restart = always + +# The user can be created with no permissions using: +# +# sudo useradd -U dnss -M -d /nonexistent -s /bin/false +User = dnss +Group = dnss + +# Simple security measures just in case. +CapabilityBoundingSet = CAP_NET_BIND_SERVICE +ProtectSystem=full + + +[Install] +Also=dnss.socket +WantedBy = multi-user.target + diff --git a/etc/systemd/dns-to-https/dnss.socket b/etc/systemd/dns-to-https/dnss.socket new file mode 100644 index 0000000..b73523c --- /dev/null +++ b/etc/systemd/dns-to-https/dnss.socket @@ -0,0 +1,11 @@ +# Sockets for dnss. +# +# This lets dnss run unprivileged. +# We typically want one UDP and one TCP socket. + +[Socket] +ListenDatagram=53 +ListenStream=53 + +[Install] +WantedBy=sockets.target diff --git a/etc/systemd/grpc-to-dns/dnss.service b/etc/systemd/grpc-to-dns/dnss.service new file mode 100644 index 0000000..8fa8086 --- /dev/null +++ b/etc/systemd/grpc-to-dns/dnss.service @@ -0,0 +1,24 @@ +[Unit] +Description = dnss daemon - GRPC to DNS mode + +[Service] +ExecStart = /usr/bin/dnss --enable_grpc_to_dns \ + --grpc_key=/etc/ssl/dnss/key.pem \ + --grpc_cert=/etc/ssl/dnss/cert.pem + --monitoring_listen_addr=127.0.0.1:9981 \ + --logtostderr + +Type = simple +Restart = always + +User = ddns +Group = ddns + +# Simple security measures just in case. +CapabilityBoundingSet = +ProtectSystem = full + + +[Install] +WantedBy = multi-user.target + diff --git a/glide.lock b/glide.lock new file mode 100644 index 0000000..2a18567 --- /dev/null +++ b/glide.lock @@ -0,0 +1,39 @@ +hash: 7c3cf09373a10a9df1e6cd5211536c7df297cec6a21fabe892e1fe6bb0ae9299 +updated: 2016-11-08T13:43:03.40490918Z +imports: +- name: github.com/coreos/go-systemd + version: d659fb6603e3f57bf01c293c5ffda0debc726d8d + subpackages: + - activation +- name: github.com/golang/glog + version: 23def4e6c14b4da8ac2ed8007337bc5eb5007998 +- name: github.com/golang/protobuf + version: 2bc9827a78f95c6665b5fe0abd1fd66b496ae2d8 + subpackages: + - proto +- name: github.com/miekg/dns + version: 58f52c57ce9df13460ac68200cef30a008b9c468 +- name: golang.org/x/net + version: 87635b2611e4683a6f0aa595c36683021a00b2e4 + subpackages: + - context + - http2 + - http2/hpack + - idna + - internal/timeseries + - lex/httplex + - trace +- name: google.golang.org/grpc + version: eddbd11c7ecd001d06777cb87b5ea7af54227907 + subpackages: + - codes + - credentials + - grpclog + - grpclog/glogger + - internal + - metadata + - naming + - peer + - tap + - transport +testImports: [] diff --git a/glide.yaml b/glide.yaml new file mode 100644 index 0000000..74eea02 --- /dev/null +++ b/glide.yaml @@ -0,0 +1,18 @@ +package: blitiri.com.ar/go/dnss +import: +- package: github.com/coreos/go-systemd + subpackages: + - activation +- package: github.com/golang/glog +- package: github.com/golang/protobuf + subpackages: + - proto +- package: github.com/miekg/dns +- package: golang.org/x/net + subpackages: + - context + - trace +- package: google.golang.org/grpc + subpackages: + - credentials + - grpclog/glogger diff --git a/internal/dnstox/caching_test.go b/internal/dnstox/caching_test.go new file mode 100644 index 0000000..778972d --- /dev/null +++ b/internal/dnstox/caching_test.go @@ -0,0 +1,363 @@ +package dnstox + +// Tests for the caching resolver. +// Note the other resolvers have more functional tests in the testing/ +// directory. + +import ( + "fmt" + "reflect" + "strconv" + "testing" + "time" + + "blitiri.com.ar/go/dnss/testing/util" + + "github.com/miekg/dns" + "golang.org/x/net/trace" +) + +// A test resolver that we use as backing for the caching resolver under test. +type TestResolver struct { + // Has this resolver been initialized? + init bool + + // Maintain() sends a value over this channel. + maintain chan bool + + // The last query we've seen. + lastQuery *dns.Msg + + // What we will respond to queries. + response *dns.Msg + respError error +} + +func NewTestResolver() *TestResolver { + return &TestResolver{ + maintain: make(chan bool, 1), + } +} + +func (r *TestResolver) Init() error { + r.init = true + return nil +} + +func (r *TestResolver) Maintain() { + r.maintain <- true +} + +func (r *TestResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { + r.lastQuery = req + if r.response != nil { + r.response.Question = req.Question + r.response.Authoritative = true + } + return r.response, r.respError +} + +// +// === Tests === +// + +// Test basic functionality. +func TestBasic(t *testing.T) { + r := NewTestResolver() + + c := NewCachingResolver(r) + + c.Init() + if !r.init { + t.Errorf("caching resolver did not initialize backing") + } + + resetStats() + + resp := queryA(t, c, "test. A 1.2.3.4", "test.", "1.2.3.4") + if !statsEquals(1, 0, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + if !resp.Authoritative { + t.Errorf("cache miss was not authoritative") + } + + // Same query, should be cached. + resp = queryA(t, c, "", "test.", "1.2.3.4") + if !statsEquals(2, 1, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + if resp.Authoritative { + t.Errorf("cache hit was authoritative") + } +} + +// Test TTL handling. +func TestTTL(t *testing.T) { + r := NewTestResolver() + c := NewCachingResolver(r) + c.Init() + resetStats() + + // Note we don't start c.Maintain() yet, as we don't want the background + // TTL updater until later. + + // Test a record with a larger-than-max TTL (1 day). + // The TTL of the response should be capped. + resp := queryA(t, c, "test. 86400 A 1.2.3.4", "test.", "1.2.3.4") + if !statsEquals(1, 0, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + if ttl := getTTL(resp.Answer); ttl != maxTTL { + t.Errorf("expected max TTL (%v), got %v", maxTTL, ttl) + } + + // Same query, should be cached, and TTL also capped. + // As we've not enabled cache maintenance, we can be sure TTL == maxTTL. + resp = queryA(t, c, "", "test.", "1.2.3.4") + if !statsEquals(2, 1, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + if ttl := getTTL(resp.Answer); ttl != maxTTL { + t.Errorf("expected max TTL (%v), got %v", maxTTL, ttl) + } + + // To test that the TTL is reduced appropriately, set a small maintenance + // period, and then repeatedly query the record. We should see its TTL + // shrinking down within 1s. + // Even though the TTL resolution in the protocol is in seconds, we don't + // need to wait that much "thanks" to rounding artifacts. + maintenancePeriod = 50 * time.Millisecond + go c.Maintain() + resetStats() + + // Check that the back resolver's Maintain() is called. + select { + case <-r.maintain: + t.Log("Maintain() called") + case <-time.After(1 * time.Second): + t.Errorf("back resolver Maintain() was not called") + } + + start := time.Now() + for time.Since(start) < 1*time.Second { + resp = queryA(t, c, "", "test.", "1.2.3.4") + t.Logf("TTL %v", getTTL(resp.Answer)) + if ttl := getTTL(resp.Answer); ttl <= (maxTTL - 1*time.Second) { + break + } + time.Sleep(maintenancePeriod) + } + if ttl := getTTL(resp.Answer); ttl > (maxTTL - 1*time.Second) { + t.Errorf("expected maxTTL-1s, got %v", ttl) + } +} + +// Test that we don't cache failed queries. +func TestFailedQueries(t *testing.T) { + r := NewTestResolver() + c := NewCachingResolver(r) + c.Init() + resetStats() + + // Do two failed identical queries, check that both are cache misses. + queryFail(t, c) + if !statsEquals(1, 0, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } + + queryFail(t, c) + if !statsEquals(2, 0, 2) { + t.Errorf("bad stats: %v", dumpStats()) + } +} + +// Test that we handle the cache filling up. +// Note this test is tied to the current behaviour of not doing any eviction +// when we're full, which is not ideal and will likely be changed in the +// future. +func TestCacheFull(t *testing.T) { + r := NewTestResolver() + c := NewCachingResolver(r) + c.Init() + resetStats() + + r.response = newReply(mustNewRR(t, "test. A 1.2.3.4")) + + // Do maxCacheSize+1 different requests. + for i := 0; i < maxCacheSize+1; i++ { + queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4") + if !statsEquals(i+1, 0, i+1) { + t.Errorf("bad stats: %v", dumpStats()) + } + } + + // Query up to maxCacheSize, they should all be hits. + resetStats() + for i := 0; i < maxCacheSize; i++ { + queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4") + if !statsEquals(i+1, i+1, 0) { + t.Errorf("bad stats: %v", dumpStats()) + } + } + + // Querying maxCacheSize+1 should be a miss, because the cache was full. + resetStats() + queryA(t, c, "", fmt.Sprintf("test%d.", maxCacheSize), "1.2.3.4") + if !statsEquals(1, 0, 1) { + t.Errorf("bad stats: %v", dumpStats()) + } +} + +// Test behaviour when the size of the cache is 0 (so users can disable it +// that way). +func TestZeroSize(t *testing.T) { + r := NewTestResolver() + c := NewCachingResolver(r) + c.Init() + resetStats() + + // Override the max cache size to 0. + prevMaxCacheSize := maxCacheSize + maxCacheSize = 0 + defer func() { maxCacheSize = prevMaxCacheSize }() + + r.response = newReply(mustNewRR(t, "test. A 1.2.3.4")) + + // Do 5 different requests. + for i := 0; i < 5; i++ { + queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4") + if !statsEquals(i+1, 0, i+1) { + t.Errorf("bad stats: %v", dumpStats()) + } + } + + // Query them back, they should all be misses. + resetStats() + for i := 0; i < 5; i++ { + queryA(t, c, "", fmt.Sprintf("test%d.", i), "1.2.3.4") + if !statsEquals(i+1, 0, i+1) { + t.Errorf("bad stats: %v", dumpStats()) + } + } +} + +// +// === Benchmarks === +// + +func BenchmarkCacheSimple(b *testing.B) { + var err error + + r := NewTestResolver() + r.response = newReply(mustNewRR(b, "test. A 1.2.3.4")) + + c := NewCachingResolver(r) + c.Init() + + tr := &util.NullTrace{} + req := newQuery("test.", dns.TypeA) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err = c.Query(req, tr) + if err != nil { + b.Errorf("query failed: %v", err) + } + } +} + +// +// === Helpers === +// + +func resetStats() { + stats.cacheTotal.Set(0) + stats.cacheBypassed.Set(0) + stats.cacheHits.Set(0) + stats.cacheMisses.Set(0) + stats.cacheRecorded.Set(0) +} + +func statsEquals(total, hits, misses int) bool { + return (stats.cacheTotal.String() == strconv.Itoa(total) && + stats.cacheHits.String() == strconv.Itoa(hits) && + stats.cacheMisses.String() == strconv.Itoa(misses)) +} + +func dumpStats() string { + return fmt.Sprintf("(t:%v h:%s m:%v)", + stats.cacheTotal, stats.cacheHits, stats.cacheMisses) +} + +func queryA(t *testing.T, c *cachingResolver, rr, domain, expected string) *dns.Msg { + // Set up the response from the given RR (if any). + if rr != "" { + back := c.back.(*TestResolver) + back.response = newReply(mustNewRR(t, rr)) + } + + tr := util.NewTestTrace(t) + defer tr.Finish() + + req := newQuery(domain, dns.TypeA) + resp, err := c.Query(req, tr) + if err != nil { + t.Fatalf("query failed: %v", err) + } + + a := resp.Answer[0].(*dns.A) + if a.A.String() != expected { + t.Errorf("expected %s, got %v", expected, a.A) + } + + if !reflect.DeepEqual(req.Question, resp.Question) { + t.Errorf("question mis-match: request %v, response %v", + req.Question, resp.Question) + } + + return resp +} + +func queryFail(t *testing.T, c *cachingResolver) *dns.Msg { + back := c.back.(*TestResolver) + back.response = &dns.Msg{} + back.response.Response = true + back.response.Rcode = dns.RcodeNameError + + tr := util.NewTestTrace(t) + defer tr.Finish() + + req := newQuery("doesnotexist.", dns.TypeA) + resp, err := c.Query(req, tr) + if err != nil { + t.Fatalf("query failed: %v", err) + } + + return resp +} + +func mustNewRR(tb testing.TB, s string) dns.RR { + rr, err := dns.NewRR(s) + if err != nil { + tb.Fatalf("invalid RR %q: %v", s, err) + } + return rr +} + +func newQuery(domain string, t uint16) *dns.Msg { + m := &dns.Msg{} + m.SetQuestion(domain, t) + return m +} + +func newReply(answer dns.RR) *dns.Msg { + return &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Response: true, + Authoritative: false, + Rcode: dns.RcodeSuccess, + }, + Answer: []dns.RR{answer}, + } +} diff --git a/internal/dnstox/dnstox.go b/internal/dnstox/dnstox.go new file mode 100644 index 0000000..3fce7c4 --- /dev/null +++ b/internal/dnstox/dnstox.go @@ -0,0 +1,233 @@ +// DNS to GRPC. + +package dnstox + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "net" + "strings" + "sync" + + "github.com/coreos/go-systemd/activation" + "github.com/golang/glog" + "github.com/miekg/dns" + "golang.org/x/net/trace" + + "blitiri.com.ar/go/dnss/internal/util" +) + +// newID is a channel used to generate new request IDs. +// There is a goroutine created at init() time that will get IDs randomly, to +// help prevent guesses. +var newId chan uint16 + +func init() { + // Buffer 100 numbers to avoid blocking on crypto rand. + newId = make(chan uint16, 100) + + go func() { + var id uint16 + var err error + + for { + err = binary.Read(rand.Reader, binary.LittleEndian, &id) + if err != nil { + panic(fmt.Sprintf("error creating id: %v", err)) + } + + newId <- id + } + + }() +} + +type Server struct { + Addr string + unqUpstream string + resolver Resolver + + fallbackDomains map[string]struct{} + fallbackUpstream string +} + +func New(addr string, resolver Resolver, unqUpstream string) *Server { + return &Server{ + Addr: addr, + resolver: resolver, + unqUpstream: unqUpstream, + fallbackDomains: map[string]struct{}{}, + } +} + +func (s *Server) SetFallback(upstream string, domains []string) { + s.fallbackUpstream = upstream + for _, d := range domains { + s.fallbackDomains[d] = struct{}{} + } +} + +func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) { + tr := trace.New("dnstox", "Handler") + defer tr.Finish() + + tr.LazyPrintf("from:%v id:%v", w.RemoteAddr(), r.Id) + + if glog.V(3) { + tr.LazyPrintf(util.QuestionsToString(r.Question)) + } + + // We only support single-question queries. + if len(r.Question) != 1 { + tr.LazyPrintf("len(Q) != 1, failing") + dns.HandleFailed(w, r) + return + } + + // Forward to the unqualified upstream server if: + // - We have one configured. + // - There's only one question in the request, to keep things simple. + // - The question is unqualified (only one '.' in the name). + useUnqUpstream := s.unqUpstream != "" && + strings.Count(r.Question[0].Name, ".") <= 1 + if useUnqUpstream { + u, err := dns.Exchange(r, s.unqUpstream) + if err == nil { + tr.LazyPrintf("used unqualified upstream") + if glog.V(3) { + util.TraceAnswer(tr, u) + } + w.WriteMsg(u) + } else { + tr.LazyPrintf("unqualified upstream error: %v", err) + dns.HandleFailed(w, r) + } + + return + } + + // Forward to the fallback server if the domain is on our list. + if _, ok := s.fallbackDomains[r.Question[0].Name]; ok { + u, err := dns.Exchange(r, s.fallbackUpstream) + if err == nil { + tr.LazyPrintf("used fallback upstream (%s)", s.fallbackUpstream) + if glog.V(3) { + util.TraceAnswer(tr, u) + } + w.WriteMsg(u) + } else { + tr.LazyPrintf("fallback upstream error: %v", err) + dns.HandleFailed(w, r) + } + + return + } + + // Create our own IDs, in case different users pick the same id and we + // pass that upstream. + oldid := r.Id + r.Id = <-newId + + from_up, err := s.resolver.Query(r, tr) + if err != nil { + glog.Infof(err.Error()) + tr.LazyPrintf(err.Error()) + tr.SetError() + return + } + + if glog.V(3) { + util.TraceAnswer(tr, from_up) + } + + from_up.Id = oldid + w.WriteMsg(from_up) +} + +func (s *Server) ListenAndServe() { + err := s.resolver.Init() + if err != nil { + glog.Fatalf("Error initializing: %v", err) + } + + go s.resolver.Maintain() + + if s.Addr == "systemd" { + s.systemdServe() + } else { + s.classicServe() + } +} + +func (s *Server) classicServe() { + glog.Infof("DNS listening on %s", s.Addr) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err := dns.ListenAndServe(s.Addr, "udp", dns.HandlerFunc(s.Handler)) + glog.Fatalf("Exiting UDP: %v", err) + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := dns.ListenAndServe(s.Addr, "tcp", dns.HandlerFunc(s.Handler)) + glog.Fatalf("Exiting TCP: %v", err) + }() + + wg.Wait() +} + +func (s *Server) systemdServe() { + // We will usually have at least one TCP socket and one UDP socket. + // PacketConns are UDP sockets, Listeners are TCP sockets. + // To make things more annoying, both can (and usually will) have nil + // entries for the file descriptors that don't match. + pconns, err := activation.PacketConns(false) + if err != nil { + glog.Fatalf("Error getting systemd packet conns: %v", err) + } + + listeners, err := activation.Listeners(false) + if err != nil { + glog.Fatalf("Error getting systemd listeners: %v", err) + } + + var wg sync.WaitGroup + + for _, pconn := range pconns { + if pconn == nil { + continue + } + + wg.Add(1) + go func(c net.PacketConn) { + defer wg.Done() + glog.Infof("Activate on packet connection (UDP)") + err := dns.ActivateAndServe(nil, c, dns.HandlerFunc(s.Handler)) + glog.Fatalf("Exiting UDP listener: %v", err) + }(pconn) + } + + for _, lis := range listeners { + if lis == nil { + continue + } + + wg.Add(1) + go func(l net.Listener) { + defer wg.Done() + glog.Infof("Activate on listening socket (TCP)") + err := dns.ActivateAndServe(l, nil, dns.HandlerFunc(s.Handler)) + glog.Fatalf("Exiting TCP listener: %v", err) + }(lis) + } + + wg.Wait() + + // We should only get here if there were no useful sockets. + glog.Fatalf("No systemd sockets, did you forget the .socket?") +} diff --git a/internal/dnstox/resolver.go b/internal/dnstox/resolver.go new file mode 100644 index 0000000..165a1be --- /dev/null +++ b/internal/dnstox/resolver.go @@ -0,0 +1,548 @@ +package dnstox + +import ( + "crypto/tls" + "crypto/x509" + "encoding/json" + "expvar" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "sync" + "time" + + "github.com/golang/glog" + "github.com/miekg/dns" + "golang.org/x/net/context" + "golang.org/x/net/trace" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "bytes" + + pb "blitiri.com.ar/go/dnss/internal/proto" +) + +// Interface for DNS resolvers that can answer queries. +type Resolver interface { + // Initialize the resolver. + Init() error + + // Maintain performs resolver maintenance. It's expected to run + // indefinitely, but may return early if appropriate. + Maintain() + + // Query responds to a DNS query. + Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) +} + +/////////////////////////////////////////////////////////////////////////// +// GRPC resolver. + +// grpcResolver implements the Resolver interface by querying a server via +// GRPC. +type grpcResolver struct { + Upstream string + CAFile string + client pb.DNSServiceClient +} + +func NewGRPCResolver(upstream, caFile string) *grpcResolver { + return &grpcResolver{ + Upstream: upstream, + CAFile: caFile, + } +} + +func (g *grpcResolver) Init() error { + var err error + var creds credentials.TransportCredentials + if g.CAFile == "" { + creds = credentials.NewClientTLSFromCert(nil, "") + } else { + creds, err = credentials.NewClientTLSFromFile(g.CAFile, "") + if err != nil { + return err + } + } + + conn, err := grpc.Dial(g.Upstream, grpc.WithTransportCredentials(creds)) + if err != nil { + return err + } + + g.client = pb.NewDNSServiceClient(conn) + return nil +} + +func (g *grpcResolver) Maintain() { +} + +func (g *grpcResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { + buf, err := r.Pack() + if err != nil { + return nil, err + } + + // Give our RPCs 2 second timeouts: DNS usually doesn't wait that long + // anyway. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + reply, err := g.client.Query(ctx, &pb.RawMsg{Data: buf}) + if err != nil { + return nil, err + } + + m := &dns.Msg{} + err = m.Unpack(reply.Data) + return m, err +} + +/////////////////////////////////////////////////////////////////////////// +// HTTPS resolver. + +// httpsResolver implements the Resolver interface by querying a server via +// DNS-over-HTTPS (like https://dns.google.com). +type httpsResolver struct { + Upstream string + CAFile string + client *http.Client +} + +func loadCertPool(caFile string) (*x509.CertPool, error) { + pemData, err := ioutil.ReadFile(caFile) + if err != nil { + return nil, err + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(pemData) { + return nil, fmt.Errorf("Error appending certificates") + } + + return pool, nil +} + +func NewHTTPSResolver(upstream, caFile string) *httpsResolver { + return &httpsResolver{ + Upstream: upstream, + CAFile: caFile, + } +} + +func (r *httpsResolver) Init() error { + r.client = &http.Client{ + // Give our HTTP requests 4 second timeouts: DNS usually doesn't wait + // that long anyway, but this helps with slow connections. + Timeout: 4 * time.Second, + } + + if r.CAFile == "" { + return nil + } + + pool, err := loadCertPool(r.CAFile) + if err != nil { + return err + } + + r.client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + ClientCAs: pool, + }, + } + + return nil +} + +func (r *httpsResolver) Maintain() { +} + +// Structure for parsing JSON responses. +type jsonResponse struct { + Status int + TC bool + RD bool + RA bool + AD bool + CD bool + Question []jsonRR + Answer []jsonRR +} + +type jsonRR struct { + Name string `json:name` + Type uint16 `json:type` + TTL uint32 `json:TTL` + Data string `json:data` +} + +func (r *httpsResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { + // Only answer single-question queries. + // In practice, these are all we get, and almost no server supports + // multi-question requests anyway. + if len(req.Question) != 1 { + return nil, fmt.Errorf("multi-question query") + } + + question := req.Question[0] + // Only answer IN-class queries, which are the ones used in practice. + if question.Qclass != dns.ClassINET { + return nil, fmt.Errorf("query class != IN") + } + + // Build the query and send the request. + v := url.Values{} + v.Set("name", question.Name) + v.Set("type", dns.TypeToString[question.Qtype]) + // TODO: add random_padding. + + url := r.Upstream + "?" + v.Encode() + if glog.V(3) { + tr.LazyPrintf("GET %q", url) + } + + hr, err := r.client.Get(url) + if err != nil { + return nil, fmt.Errorf("GET failed: %v", err) + } + tr.LazyPrintf("%s %s", hr.Proto, hr.Status) + defer hr.Body.Close() + + if hr.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Response status: %s", hr.Status) + } + + // Read the HTTPS response, and parse the JSON. + body, err := ioutil.ReadAll(hr.Body) + if err != nil { + return nil, fmt.Errorf("Failed to read body: %v", err) + } + + jr := &jsonResponse{} + err = json.Unmarshal(body, jr) + if err != nil { + return nil, fmt.Errorf("Failed to unmarshall: %v", err) + } + + if len(jr.Question) != 1 { + return nil, fmt.Errorf("Wrong number of questions in the response") + } + + // Build the DNS response. + resp := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: req.Id, + Response: true, + Opcode: req.Opcode, + Rcode: jr.Status, + + Truncated: jr.TC, + RecursionDesired: jr.RD, + RecursionAvailable: jr.RA, + AuthenticatedData: jr.AD, + CheckingDisabled: jr.CD, + }, + Question: []dns.Question{ + dns.Question{ + Name: jr.Question[0].Name, + Qtype: jr.Question[0].Type, + Qclass: dns.ClassINET, + }}, + } + + for _, answer := range jr.Answer { + // TODO: This "works" but is quite hacky. Is there a better way, + // without doing lots of data parsing? + s := fmt.Sprintf("%s %d IN %s %s", + answer.Name, answer.TTL, + dns.TypeToString[answer.Type], answer.Data) + rr, err := dns.NewRR(s) + if err != nil { + return nil, fmt.Errorf("Error parsing answer: %v", err) + } + + resp.Answer = append(resp.Answer, rr) + } + + return resp, nil +} + +/////////////////////////////////////////////////////////////////////////// +// Caching resolver. + +// cachingResolver implements a caching Resolver. +// It is backed by another Resolver, but will cache results. +type cachingResolver struct { + // Backing resolver. + back Resolver + + // The cache where we keep the records. + answer map[dns.Question][]dns.RR + + // mu protects the answer map. + mu *sync.RWMutex +} + +func NewCachingResolver(back Resolver) *cachingResolver { + return &cachingResolver{ + back: back, + answer: map[dns.Question][]dns.RR{}, + mu: &sync.RWMutex{}, + } +} + +// Constants that tune the cache. +// They are declared as variables so we can tweak them for testing. +var ( + // Maximum number of entries we keep in the cache. + // 2k should be reasonable for a small network. + // Keep in mind that increasing this too much will interact negatively + // with Maintain(). + maxCacheSize = 2000 + + // Minimum TTL for entries we consider for the cache. + minTTL = 2 * time.Minute + + // Maximum TTL for our cache. We cap records that exceed this. + maxTTL = 2 * time.Hour + + // How often to run GC on the cache. + // Must be < minTTL if we don't want to have entries stale for too long. + maintenancePeriod = 30 * time.Second +) + +// Exported variables for statistics. +// These are global and not per caching resolver, so if we have more than once +// the results will be mixed. +var stats = struct { + // Total number of queries handled by the cache resolver. + cacheTotal *expvar.Int + + // Queries that we passed directly through our back resolver. + cacheBypassed *expvar.Int + + // Cache misses. + cacheMisses *expvar.Int + + // Cache hits. + cacheHits *expvar.Int + + // Entries we decided to record in the cache. + cacheRecorded *expvar.Int +}{} + +func init() { + stats.cacheTotal = expvar.NewInt("cache-total") + stats.cacheBypassed = expvar.NewInt("cache-bypassed") + stats.cacheHits = expvar.NewInt("cache-hits") + stats.cacheMisses = expvar.NewInt("cache-misses") + stats.cacheRecorded = expvar.NewInt("cache-recorded") +} + +func (c *cachingResolver) Init() error { + return c.back.Init() +} + +// RegisterDebugHandlers registers http debug handlers, which can be accessed +// from the monitoring server. +// Note these are global by nature, if you try to register them multiple +// times, you will get a panic. +func (c *cachingResolver) RegisterDebugHandlers() { + http.HandleFunc("/debug/dnstox/cache/dump", c.DumpCache) + http.HandleFunc("/debug/dnstox/cache/flush", c.FlushCache) +} + +func (c *cachingResolver) DumpCache(w http.ResponseWriter, r *http.Request) { + buf := bytes.NewBuffer(nil) + + c.mu.RLock() + for q, ans := range c.answer { + // Only include names and records if we are running verbosily. + name := "<hidden>" + if glog.V(3) { + name = q.Name + } + + fmt.Fprintf(buf, "Q: %s %s %s\n", name, dns.TypeToString[q.Qtype], + dns.ClassToString[q.Qclass]) + + ttl := getTTL(ans) + fmt.Fprintf(buf, " expires in %s (%s)\n", ttl, time.Now().Add(ttl)) + + if glog.V(3) { + for _, rr := range ans { + fmt.Fprintf(buf, " %s\n", rr.String()) + } + } else { + fmt.Fprintf(buf, " %d RRs in answer\n", len(ans)) + } + fmt.Fprintf(buf, "\n\n") + } + c.mu.RUnlock() + + buf.WriteTo(w) +} + +func (c *cachingResolver) FlushCache(w http.ResponseWriter, r *http.Request) { + c.mu.Lock() + c.answer = map[dns.Question][]dns.RR{} + c.mu.Unlock() + + w.Write([]byte("cache flush complete")) +} + +func (c *cachingResolver) Maintain() { + go c.back.Maintain() + + for range time.Tick(maintenancePeriod) { + tr := trace.New("dnstox.Cache", "GC") + var total, expired int + + c.mu.Lock() + total = len(c.answer) + for q, ans := range c.answer { + newTTL := getTTL(ans) - maintenancePeriod + if newTTL > 0 { + // Don't modify in place, create a copy and override. + // That way, we avoid races with users that have gotten a + // cached answer and are returning it. + newans := copyRRSlice(ans) + setTTL(newans, newTTL) + c.answer[q] = newans + continue + } + + delete(c.answer, q) + expired++ + } + c.mu.Unlock() + tr.LazyPrintf("total: %d expired: %d", total, expired) + tr.Finish() + } +} + +func wantToCache(question dns.Question, reply *dns.Msg) error { + if reply.Rcode != dns.RcodeSuccess { + return fmt.Errorf("unsuccessful query") + } else if !reply.Response { + return fmt.Errorf("response = false") + } else if reply.Opcode != dns.OpcodeQuery { + return fmt.Errorf("opcode %d != query", reply.Opcode) + } else if len(reply.Answer) == 0 { + return fmt.Errorf("answer is empty") + } else if len(reply.Question) != 1 { + return fmt.Errorf("too many/few questions (%d)", len(reply.Question)) + } else if reply.Question[0] != question { + return fmt.Errorf( + "reply question does not match: asked %v, got %v", + question, reply.Question[0]) + } + + return nil +} + +func limitTTL(answer []dns.RR) time.Duration { + // This assumes all RRs have the same TTL. That may not be the case in + // theory, but we are ok not caring for this for now. + ttl := time.Duration(answer[0].Header().Ttl) * time.Second + + // This helps prevent cache pollution due to unused but long entries, as + // we don't do usage-based caching yet. + if ttl > maxTTL { + ttl = maxTTL + } + + return ttl +} + +func getTTL(answer []dns.RR) time.Duration { + // This assumes all RRs have the same TTL. That may not be the case in + // theory, but we are ok not caring for this for now. + return time.Duration(answer[0].Header().Ttl) * time.Second +} + +func setTTL(answer []dns.RR, newTTL time.Duration) { + for _, rr := range answer { + rr.Header().Ttl = uint32(newTTL.Seconds()) + } +} + +func copyRRSlice(a []dns.RR) []dns.RR { + b := make([]dns.RR, 0, len(a)) + for _, rr := range a { + b = append(b, dns.Copy(rr)) + } + return b +} + +func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { + stats.cacheTotal.Add(1) + + // To keep it simple we only cache single-question queries. + if len(r.Question) != 1 { + tr.LazyPrintf("cache bypass: multi-question query") + stats.cacheBypassed.Add(1) + return c.back.Query(r, tr) + } + + question := r.Question[0] + + c.mu.RLock() + answer, hit := c.answer[question] + c.mu.RUnlock() + + if hit { + tr.LazyPrintf("cache hit") + stats.cacheHits.Add(1) + + reply := &dns.Msg{ + MsgHdr: dns.MsgHdr{ + Id: r.Id, + Response: true, + Authoritative: false, + Rcode: dns.RcodeSuccess, + }, + Question: r.Question, + Answer: answer, + } + + return reply, nil + } + + tr.LazyPrintf("cache miss") + stats.cacheMisses.Add(1) + + reply, err := c.back.Query(r, tr) + if err != nil { + return reply, err + } + + if err = wantToCache(question, reply); err != nil { + tr.LazyPrintf("cache not recording reply: %v", err) + return reply, nil + } + + answer = reply.Answer + ttl := limitTTL(answer) + + // Only store answers if they're going to stay around for a bit, + // there's not much point in caching things we have to expire quickly. + if ttl < minTTL { + return reply, nil + } + + // Store the answer in the cache, but don't exceed 2k entries. + // TODO: Do usage based eviction when we're approaching ~1.5k. + c.mu.Lock() + if len(c.answer) < maxCacheSize { + setTTL(answer, ttl) + c.answer[question] = answer + stats.cacheRecorded.Add(1) + } + c.mu.Unlock() + + return reply, nil +} diff --git a/internal/grpctodns/grpctodns.go b/internal/grpctodns/grpctodns.go new file mode 100644 index 0000000..3ddf5a8 --- /dev/null +++ b/internal/grpctodns/grpctodns.go @@ -0,0 +1,109 @@ +// GRPC to DNS. + +package grpctodns + +import ( + "fmt" + "net" + "strings" + + pb "blitiri.com.ar/go/dnss/internal/proto" + "blitiri.com.ar/go/dnss/internal/util" + "github.com/golang/glog" + "github.com/miekg/dns" + "golang.org/x/net/context" + "golang.org/x/net/trace" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +func questionsToString(qs []dns.Question) string { + var s []string + for _, q := range qs { + s = append(s, fmt.Sprintf("(%s %s %s)", q.Name, + dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass])) + } + return "Q[" + strings.Join(s, " ") + "]" +} + +func rrsToString(rrs []dns.RR) string { + var s []string + for _, rr := range rrs { + s = append(s, fmt.Sprintf("(%s)", rr)) + } + return "RR[" + strings.Join(s, " ") + "]" + +} + +type Server struct { + Addr string + Upstream string + CertFile string + KeyFile string +} + +func (s *Server) Query(ctx context.Context, in *pb.RawMsg) (*pb.RawMsg, error) { + tr := trace.New("grpctodns", "Query") + defer tr.Finish() + + r := &dns.Msg{} + err := r.Unpack(in.Data) + if err != nil { + return nil, err + } + + if glog.V(3) { + tr.LazyPrintf(util.QuestionsToString(r.Question)) + } + + // TODO: we should create our own IDs, in case different users pick the + // same id and we pass that upstream. + from_up, err := dns.Exchange(r, s.Upstream) + if err != nil { + msg := fmt.Sprintf("dns exchange error: %v", err) + glog.Info(msg) + tr.LazyPrintf(msg) + tr.SetError() + return nil, err + } + + if from_up == nil { + err = fmt.Errorf("no response from upstream") + tr.LazyPrintf(err.Error()) + tr.SetError() + return nil, err + } + + if glog.V(3) { + util.TraceAnswer(tr, from_up) + } + + buf, err := from_up.Pack() + if err != nil { + glog.Infof(" error packing: %v", err) + tr.LazyPrintf("error packing: %v", err) + tr.SetError() + return nil, err + } + + return &pb.RawMsg{Data: buf}, nil +} + +func (s *Server) ListenAndServe() { + lis, err := net.Listen("tcp", s.Addr) + if err != nil { + glog.Fatalf("failed to listen: %v", err) + } + + ta, err := credentials.NewServerTLSFromFile(s.CertFile, s.KeyFile) + if err != nil { + glog.Fatalf("failed to create TLS transport auth: %v", err) + } + + grpcServer := grpc.NewServer(grpc.Creds(ta)) + pb.RegisterDNSServiceServer(grpcServer, s) + + glog.Infof("GRPC listening on %s", s.Addr) + err = grpcServer.Serve(lis) + glog.Fatalf("GRPC exiting: %s", err) +} diff --git a/internal/proto/dnss.pb.go b/internal/proto/dnss.pb.go new file mode 100644 index 0000000..78db3a0 --- /dev/null +++ b/internal/proto/dnss.pb.go @@ -0,0 +1,134 @@ +// Code generated by protoc-gen-go. +// source: dnss.proto +// DO NOT EDIT! + +/* +Package dnss is a generated protocol buffer package. + +It is generated from these files: + dnss.proto + +It has these top-level messages: + RawMsg +*/ +package dnss + +import proto "github.com/golang/protobuf/proto" +import fmt "fmt" +import math "math" + +import ( + context "golang.org/x/net/context" + grpc "google.golang.org/grpc" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type RawMsg struct { + // DNS-encoded message. + // A horrible hack, but will do for now. + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` +} + +func (m *RawMsg) Reset() { *m = RawMsg{} } +func (m *RawMsg) String() string { return proto.CompactTextString(m) } +func (*RawMsg) ProtoMessage() {} +func (*RawMsg) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} } + +func init() { + proto.RegisterType((*RawMsg)(nil), "dnss.RawMsg") +} + +// Reference imports to suppress errors if they are not otherwise used. +var _ context.Context +var _ grpc.ClientConn + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +const _ = grpc.SupportPackageIsVersion4 + +// Client API for DNSService service + +type DNSServiceClient interface { + Query(ctx context.Context, in *RawMsg, opts ...grpc.CallOption) (*RawMsg, error) +} + +type dNSServiceClient struct { + cc *grpc.ClientConn +} + +func NewDNSServiceClient(cc *grpc.ClientConn) DNSServiceClient { + return &dNSServiceClient{cc} +} + +func (c *dNSServiceClient) Query(ctx context.Context, in *RawMsg, opts ...grpc.CallOption) (*RawMsg, error) { + out := new(RawMsg) + err := grpc.Invoke(ctx, "/dnss.DNSService/Query", in, out, c.cc, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// Server API for DNSService service + +type DNSServiceServer interface { + Query(context.Context, *RawMsg) (*RawMsg, error) +} + +func RegisterDNSServiceServer(s *grpc.Server, srv DNSServiceServer) { + s.RegisterService(&_DNSService_serviceDesc, srv) +} + +func _DNSService_Query_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RawMsg) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DNSServiceServer).Query(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/dnss.DNSService/Query", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DNSServiceServer).Query(ctx, req.(*RawMsg)) + } + return interceptor(ctx, in, info, handler) +} + +var _DNSService_serviceDesc = grpc.ServiceDesc{ + ServiceName: "dnss.DNSService", + HandlerType: (*DNSServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Query", + Handler: _DNSService_Query_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "dnss.proto", +} + +func init() { proto.RegisterFile("dnss.proto", fileDescriptor0) } + +var fileDescriptor0 = []byte{ + // 105 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xe2, 0x4a, 0xc9, 0x2b, 0x2e, + 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x01, 0xb1, 0x95, 0x64, 0xb8, 0xd8, 0x82, 0x12, + 0xcb, 0x7d, 0x8b, 0xd3, 0x85, 0x84, 0xb8, 0x58, 0x52, 0x12, 0x4b, 0x12, 0x25, 0x18, 0x15, 0x18, + 0x35, 0x78, 0x82, 0xc0, 0x6c, 0x23, 0x43, 0x2e, 0x2e, 0x17, 0xbf, 0xe0, 0xe0, 0xd4, 0xa2, 0xb2, + 0xcc, 0xe4, 0x54, 0x21, 0x65, 0x2e, 0xd6, 0xc0, 0xd2, 0xd4, 0xa2, 0x4a, 0x21, 0x1e, 0x3d, 0xb0, + 0x39, 0x10, 0x8d, 0x52, 0x28, 0xbc, 0x24, 0x36, 0xb0, 0xe9, 0xc6, 0x80, 0x00, 0x00, 0x00, 0xff, + 0xff, 0xc0, 0xc7, 0x8e, 0xd6, 0x6b, 0x00, 0x00, 0x00, +} diff --git a/internal/proto/dnss.proto b/internal/proto/dnss.proto new file mode 100644 index 0000000..8f982ea --- /dev/null +++ b/internal/proto/dnss.proto @@ -0,0 +1,15 @@ + +syntax = "proto3"; + +package dnss; + +message RawMsg { + // DNS-encoded message. + // A horrible hack, but will do for now. + bytes data = 1; +} + +service DNSService { + rpc Query(RawMsg) returns (RawMsg); +} + diff --git a/internal/proto/dummy.go b/internal/proto/dummy.go new file mode 100644 index 0000000..5136454 --- /dev/null +++ b/internal/proto/dummy.go @@ -0,0 +1,4 @@ +package dnss + +// Generate the protobuf+grpc service. +//go:generate protoc --go_out=plugins=grpc:. dnss.proto diff --git a/internal/util/strings.go b/internal/util/strings.go new file mode 100644 index 0000000..5624cc3 --- /dev/null +++ b/internal/util/strings.go @@ -0,0 +1,30 @@ +package util + +// Utility functions for logging DNS messages. + +import ( + "fmt" + "strings" + + "github.com/miekg/dns" + "golang.org/x/net/trace" +) + +func QuestionsToString(qs []dns.Question) string { + var s []string + for _, q := range qs { + s = append(s, fmt.Sprintf("(%s %s %s)", q.Name, + dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass])) + } + return "Q: " + strings.Join(s, " ; ") +} + +func TraceAnswer(tr trace.Trace, m *dns.Msg) { + if m.Rcode != dns.RcodeSuccess { + rcode := dns.RcodeToString[m.Rcode] + tr.LazyPrintf(rcode) + } + for _, rr := range m.Answer { + tr.LazyPrintf(rr.String()) + } +} diff --git a/monitoring_test.go b/monitoring_test.go new file mode 100644 index 0000000..e198e0b --- /dev/null +++ b/monitoring_test.go @@ -0,0 +1,48 @@ +package main + +// Tests for the monitoring server. +// +// Note that functional tests for dnss are in the testing/ directory, here we +// only test the monitoring server created in dnss.go. + +import ( + "net/http" + "testing" + + "google.golang.org/grpc" +) + +func TestMonitoringServer(t *testing.T) { + // TODO: Don't hard-code this. + const addr = "localhost:19395" + launchMonitoringServer(addr) + + if !grpc.EnableTracing { + t.Errorf("grpc tracing is disabled") + } + + checkGet(t, "http://"+addr+"/") + checkGet(t, "http://"+addr+"/debug/requests") + checkGet(t, "http://"+addr+"/debug/pprof/goroutine") + checkGet(t, "http://"+addr+"/debug/flags") + checkGet(t, "http://"+addr+"/debug/vars") + + // Check that we emit 404 for non-existing paths. + r, _ := http.Get("http://" + addr + "/doesnotexist") + if r.StatusCode != 404 { + t.Errorf("expected 404, got %s", r.Status) + } +} + +func checkGet(t *testing.T, url string) { + r, err := http.Get(url) + if err != nil { + t.Error(err) + return + } + + if r.StatusCode != 200 { + t.Errorf("%q - invalid status: %s", url, r.Status) + } + +} diff --git a/testing/grpc/grpc.go b/testing/grpc/grpc.go new file mode 100644 index 0000000..268365b --- /dev/null +++ b/testing/grpc/grpc.go @@ -0,0 +1,4 @@ +package grpc + +// Dummy file so "go build ./..." does not complain about the directory not +// having buildable files. diff --git a/testing/grpc/grpc_test.go b/testing/grpc/grpc_test.go new file mode 100644 index 0000000..0172b9a --- /dev/null +++ b/testing/grpc/grpc_test.go @@ -0,0 +1,243 @@ +// Tests for dnss in GRPC modes. +package grpc + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "flag" + "fmt" + "io/ioutil" + "math/big" + "net" + "os" + "testing" + "time" + + "blitiri.com.ar/go/dnss/internal/dnstox" + "blitiri.com.ar/go/dnss/internal/grpctodns" + "blitiri.com.ar/go/dnss/testing/util" + + "github.com/golang/glog" + "github.com/miekg/dns" +) + +// Addresses to use for testing. These will be picked at initialization time, +// see init(). +var dnsToGrpcAddr, grpcToDnsAddr, dnsSrvAddr string + +func init() { + dnsToGrpcAddr = util.GetFreePort() + grpcToDnsAddr = util.GetFreePort() + dnsSrvAddr = util.GetFreePort() +} + +// +// === Tests === +// + +func dnsQuery(conn *dns.Conn) error { + m := &dns.Msg{} + m.SetQuestion("ca.chai.", dns.TypeMX) + + conn.WriteMsg(m) + _, err := conn.ReadMsg() + return err +} + +func TestSimple(t *testing.T) { + conn, err := dns.DialTimeout("udp", dnsToGrpcAddr, 1*time.Second) + if err != nil { + t.Fatalf("dns.Dial error: %v", err) + } + defer conn.Close() + + err = dnsQuery(conn) + if err != nil { + t.Errorf("dns query returned error: %v", err) + } +} + +// +// === Benchmarks === +// + +func manyDNSQueries(b *testing.B, addr string) { + conn, err := dns.DialTimeout("udp", addr, 1*time.Second) + if err != nil { + b.Fatalf("dns.Dial error: %v", err) + } + defer conn.Close() + + for i := 0; i < b.N; i++ { + err = dnsQuery(conn) + if err != nil { + b.Errorf("dns query returned error: %v", err) + } + } +} + +func BenchmarkGRPCDirect(b *testing.B) { + manyDNSQueries(b, dnsSrvAddr) +} + +func BenchmarkGRPCWithProxy(b *testing.B) { + manyDNSQueries(b, dnsToGrpcAddr) +} + +// +// === Test environment === +// + +// dnsServer implements a DNS server for testing. +// It always gives the same reply, regardless of the query. +type dnsServer struct { + Addr string + srv *dns.Server + answerRR dns.RR +} + +func (s *dnsServer) Handler(w dns.ResponseWriter, r *dns.Msg) { + // Building the reply (and setting the corresponding id) is cheaper than + // copying a "master" message. + m := &dns.Msg{} + m.Id = r.Id + m.Response = true + m.Authoritative = true + m.Rcode = dns.RcodeSuccess + m.Answer = append(m.Answer, s.answerRR) + w.WriteMsg(m) +} + +func (s *dnsServer) ListenAndServe() { + var err error + + s.answerRR, err = dns.NewRR("test.blah A 1.2.3.4") + if err != nil { + panic(err) + } + + s.srv = &dns.Server{ + Addr: s.Addr, + Net: "udp", + Handler: dns.HandlerFunc(s.Handler), + } + err = s.srv.ListenAndServe() + if err != nil { + panic(err) + } +} + +// generateCert generates a new, INSECURE self-signed certificate and writes +// it to a pair of (cert.pem, key.pem) files to the given path. +// Note the certificate is only useful for testing purposes. +func generateCert(path string) error { + tmpl := x509.Certificate{ + SerialNumber: big.NewInt(1234), + Subject: pkix.Name{ + Organization: []string{"dnss testing"}, + }, + + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + + NotBefore: time.Now(), + NotAfter: time.Now().Add(30 * time.Minute), + + KeyUsage: x509.KeyUsageKeyEncipherment | + x509.KeyUsageDigitalSignature | + x509.KeyUsageCertSign, + + BasicConstraintsValid: true, + IsCA: true, + } + + priv, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + return err + } + + derBytes, err := x509.CreateCertificate( + rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) + if err != nil { + return err + } + + certOut, err := os.Create(path + "/cert.pem") + if err != nil { + return err + } + defer certOut.Close() + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + keyOut, err := os.OpenFile( + path+"/key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return err + } + defer keyOut.Close() + + block := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(priv), + } + pem.Encode(keyOut, block) + return nil +} + +// realMain is the real main function, which returns the value to pass to +// os.Exit(). We have to do this so we can use defer. +func realMain(m *testing.M) int { + flag.Parse() + defer glog.Flush() + + // Generate certificates in a temporary directory. + tmpDir, err := ioutil.TempDir("", "dnss_test:") + if err != nil { + fmt.Printf("Failed to create temp dir: %v\n", tmpDir) + return 1 + } + defer os.RemoveAll(tmpDir) + + err = generateCert(tmpDir) + if err != nil { + fmt.Printf("Failed to generate cert for testing: %v\n", err) + return 1 + } + + // DNS to GRPC server. + gr := dnstox.NewGRPCResolver(grpcToDnsAddr, tmpDir+"/cert.pem") + cr := dnstox.NewCachingResolver(gr) + dtg := dnstox.New(dnsToGrpcAddr, cr, "") + go dtg.ListenAndServe() + + // GRPC to DNS server. + gtd := &grpctodns.Server{ + Addr: grpcToDnsAddr, + Upstream: dnsSrvAddr, + CertFile: tmpDir + "/cert.pem", + KeyFile: tmpDir + "/key.pem", + } + go gtd.ListenAndServe() + + // DNS test server. + dnsSrv := dnsServer{ + Addr: dnsSrvAddr, + } + go dnsSrv.ListenAndServe() + + // Wait for the servers to start up. + err = util.WaitForDNSServer(dnsToGrpcAddr) + if err != nil { + fmt.Printf("Error waiting for the test servers to start: %v\n", err) + fmt.Printf("Check the INFO logs for more details\n") + return 1 + } + + return m.Run() +} + +func TestMain(m *testing.M) { + os.Exit(realMain(m)) +} diff --git a/testing/https/https.go b/testing/https/https.go new file mode 100644 index 0000000..cc8acf1 --- /dev/null +++ b/testing/https/https.go @@ -0,0 +1,4 @@ +package https + +// Dummy file so "go build ./..." does not complain about the directory not +// having buildable files. diff --git a/testing/https/https_test.go b/testing/https/https_test.go new file mode 100644 index 0000000..b984ea2 --- /dev/null +++ b/testing/https/https_test.go @@ -0,0 +1,163 @@ +// Tests for dnss in HTTPS mode. +package https + +import ( + "flag" + "fmt" + "net/http" + "net/http/httptest" + "os" + "testing" + + "blitiri.com.ar/go/dnss/internal/dnstox" + "blitiri.com.ar/go/dnss/testing/util" + + "github.com/golang/glog" + "github.com/miekg/dns" +) + +// +// === Tests === +// +func dnsQuery(addr string, qtype uint16) (*dns.Msg, dns.RR, error) { + m := new(dns.Msg) + m.SetQuestion(addr, qtype) + in, err := dns.Exchange(m, DNSAddr) + + if err != nil { + return nil, nil, err + } else if len(in.Answer) > 0 { + return in, in.Answer[0], nil + } else { + return in, nil, nil + } +} + +func TestSimple(t *testing.T) { + _, ans, err := dnsQuery("test.blah.", dns.TypeA) + if err != nil { + t.Errorf("dns query returned error: %v", err) + } + if ans.(*dns.A).A.String() != "1.2.3.4" { + t.Errorf("unexpected result: %q", ans) + } + + _, ans, err = dnsQuery("test.blah.", dns.TypeMX) + if err != nil { + t.Errorf("dns query returned error: %v", err) + } + if ans.(*dns.MX).Mx != "mail.test.blah." { + t.Errorf("unexpected result: %q", ans.(*dns.MX).Mx) + } + + in, _, err := dnsQuery("unknown.", dns.TypeA) + if err != nil { + t.Errorf("dns query returned error: %v", err) + } + if in.Rcode != dns.RcodeNameError { + t.Errorf("unexpected result: %q", in) + } +} + +// +// === Benchmarks === +// + +func BenchmarkHTTPSimple(b *testing.B) { + var err error + for i := 0; i < b.N; i++ { + _, _, err = dnsQuery("test.blah.", dns.TypeA) + if err != nil { + b.Errorf("dns query returned error: %v", err) + } + } +} + +// +// === Test environment === +// + +// DNSHandler handles DNS-over-HTTP requests, and returns json data. +// This is used as the test server for our resolver. +func DNSHandler(w http.ResponseWriter, r *http.Request) { + err := r.ParseForm() + if err != nil { + panic(err) + } + + w.Header().Set("Content-Type", "text/json") + + resp := jsonNXDOMAIN + + if r.Form["name"][0] == "test.blah." { + switch r.Form["type"][0] { + case "1", "A": + resp = jsonA + case "15", "MX": + resp = jsonMX + default: + resp = jsonNXDOMAIN + } + } + + w.Write([]byte(resp)) +} + +// A record. +const jsonA = ` { + "Status": 0, "TC": false, "RD": true, "RA": true, "AD": false, "CD": false, + "Question": [ { "name": "test.blah.", "type": 1 } + ], + "Answer": [ { "name": "test.blah.", "type": 1, "TTL": 21599, + "data": "1.2.3.4" } ] } +` + +// MX record. +const jsonMX = ` { + "Status": 0, "TC": false, "RD": true, "RA": true, "AD": false, "CD": false, + "Question": [ { "name": "test.blah.", "type": 15 } ], + "Answer": [ { "name": "test.blah.", "type": 15, "TTL": 21599, + "data": "10 mail.test.blah." } ] } +` + +// NXDOMAIN error. +const jsonNXDOMAIN = ` { + "Status": 3, "TC": false, "RD": true, "RA": true, "AD": true, "CD": false, + "Question": [ { "name": "doesnotexist.", "type": 15 } ], + "Authority": [ { "name": ".", "type": 6, "TTL": 1798, + "data": "root. nstld. 2016052201 1800 900 604800 86400" } ] } +` + +// Address where we will set up the DNS server. +var DNSAddr string + +// realMain is the real main function, which returns the value to pass to +// os.Exit(). We have to do this so we can use defer. +func realMain(m *testing.M) int { + flag.Parse() + defer glog.Flush() + + DNSAddr = util.GetFreePort() + + // Test http server. + httpsrv := httptest.NewServer(http.HandlerFunc(DNSHandler)) + + // DNS to HTTPS server. + r := dnstox.NewHTTPSResolver(httpsrv.URL, "") + dth := dnstox.New(DNSAddr, r, "") + go dth.ListenAndServe() + + // Wait for the servers to start up. + err := util.WaitForDNSServer(DNSAddr) + if err != nil { + fmt.Printf("Error waiting for the test servers to start: %v\n", err) + fmt.Printf("Check the INFO logs for more details\n") + return 1 + } + + return m.Run() +} + +func TestMain(m *testing.M) { + os.Exit(realMain(m)) +} diff --git a/testing/util/util.go b/testing/util/util.go new file mode 100644 index 0000000..52a080f --- /dev/null +++ b/testing/util/util.go @@ -0,0 +1,84 @@ +// Package util implements common testing utilities. +package util + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/miekg/dns" +) + +// WaitForDNSServer waits 5 seconds for a DNS server to start, and returns an +// error if it fails to do so. +// It does this by repeatedly querying the DNS server until it either replies +// or times out. Note we do not do any validation of the reply. +func WaitForDNSServer(addr string) error { + conn, err := dns.DialTimeout("udp", addr, 1*time.Second) + if err != nil { + return fmt.Errorf("dns.Dial error: %v", err) + } + defer conn.Close() + + m := &dns.Msg{} + m.SetQuestion("unused.", dns.TypeA) + + deadline := time.Now().Add(5 * time.Second) + tick := time.Tick(100 * time.Millisecond) + + for (<-tick).Before(deadline) { + conn.SetDeadline(time.Now().Add(1 * time.Second)) + conn.WriteMsg(m) + _, err := conn.ReadMsg() + if err == nil { + return nil + } + } + + return fmt.Errorf("timed out") +} + +// Get a free (TCP) port. This is hacky and not race-free, but it works well +// enough for testing purposes. +func GetFreePort() string { + l, _ := net.Listen("tcp", "localhost:0") + defer l.Close() + return l.Addr().String() +} + +// TestTrace implements the tracer.Trace interface, but prints using the test +// logging infrastructure. +type TestTrace struct { + T *testing.T +} + +func NewTestTrace(t *testing.T) *TestTrace { + return &TestTrace{t} +} + +func (t *TestTrace) LazyLog(x fmt.Stringer, sensitive bool) { + t.T.Logf("trace %p (%b): %s", t, sensitive, x) +} + +func (t *TestTrace) LazyPrintf(format string, a ...interface{}) { + prefix := fmt.Sprintf("trace %p: ", t) + t.T.Logf(prefix+format, a...) +} + +func (t *TestTrace) SetError() {} +func (t *TestTrace) SetRecycler(f func(interface{})) {} +func (t *TestTrace) SetTraceInfo(traceID, spanID uint64) {} +func (t *TestTrace) SetMaxEvents(m int) {} +func (t *TestTrace) Finish() {} + +// NullTrace implements the tracer.Trace interface, but discards everything. +type NullTrace struct{} + +func (t *NullTrace) LazyLog(x fmt.Stringer, sensitive bool) {} +func (t *NullTrace) LazyPrintf(format string, a ...interface{}) {} +func (t *NullTrace) SetError() {} +func (t *NullTrace) SetRecycler(f func(interface{})) {} +func (t *NullTrace) SetTraceInfo(traceID, spanID uint64) {} +func (t *NullTrace) SetMaxEvents(m int) {} +func (t *NullTrace) Finish() {} diff --git a/tools/bench b/tools/bench new file mode 100755 index 0000000..45a007e --- /dev/null +++ b/tools/bench @@ -0,0 +1,119 @@ +#!/bin/bash +# +# This is a small utility that helps run and diff benchmarks, using +# "go test -bench" and "benchcmp". +# +# It's only used for development and not meant to be portable, or have a +# stable interface. +# +# Examples: +# # Run the benchmarks, recording the output IFF the tree is not dirty. +# ./tools/bench +# +# # Diff between two recorded commits. +# ./tools/bench diff 8b25916 HEAD +# +# # Run the benchmarks without recording, and compare against a commit. +# ./tools/bench rundiff 8b25916 +# + +set -e + +cd "$(git rev-parse --show-toplevel)" + +BDIR=".bench-history" + +# Get a filename based on the current commit. +function commit_fname() { + git log --date=format:"%F-%H:%M" --pretty=format:"%cd__%h__%f" -1 $1 +} + + +MODE=bench +RUN_COUNT=3 +BEST= +NO_RECORD= + +# Don't record results for a dirty tree. +# Note this tool is explicitly excluded so we can easily test old commits. +DIRTY=$(git status --porcelain | grep -v tools/bench | grep -v "^??" | wc -l) +if [ "$DIRTY" -gt 0 ]; then + echo "Dirty tree, not recording results" + NO_RECORD=1 +fi + +while getopts "m:c:1rbn" OPT ; do + case $OPT in + m) + MODE=$OPTARG + ;; + 1) + RUN_COUNT=1 + ;; + c) + RUN_COUNT=$OPTARG + ;; + b) + BEST="-best" + ;; + n) + NO_RECORD=1 + ;; + \?) + exit 1 + ;; + esac +done + +shift $((OPTIND-1)) + +if [ $1 ]; then + MODE=$1 + shift +fi + +if [ $MODE == bench ]; then + FNAME=$BDIR/$(commit_fname) + RAWFNAME=$BDIR/.$(commit_fname).raw + + if [ $NO_RECORD ]; then + go test -run=NONE -bench=. -benchmem ./... + exit + fi + + echo -n "Running: " + echo > "$RAWFNAME" + for i in `seq $RUN_COUNT`; do + go test -run=NONE -bench=. -benchmem ./... >> "$RAWFNAME" + echo -n "$i " + done + echo + + # Filter and sort the results to make them more succint and easier to + # compare. + cat "$RAWFNAME" | grep allocs | sort > "$FNAME" + + cat "$FNAME" + +elif [ $MODE == diff ]; then + F1=$BDIR/$(commit_fname $1) + F2=$BDIR/$(commit_fname $2) + benchcmp $BEST "$F1" "$F2" + +elif [ $MODE == rundiff ]; then + TMPF=$(mktemp) + F1=$BDIR/$(commit_fname $1) + + go test -run=NONE -bench=. -benchmem ./... > $TMPF + benchcmp -best "$F1" "$TMPF" + + rm $TMPF + +elif [ $MODE == ls ]; then + cd $BDIR + ls -1 +else + echo "Unknown mode $MODE" + exit 1 +fi +