author | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-06-12 13:34:59 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2021-06-12 14:19:16 UTC |
parent | f0c8cdccc0a9ccdc189310a82c6b4ca8436503d6 |
internal/dnsserver/caching_test.go | +6 | -3 |
internal/dnsserver/resolver.go | +9 | -9 |
internal/dnsserver/server.go | +10 | -13 |
internal/httpresolver/resolver.go | +14 | -12 |
internal/httpresolver/resolver_test.go | +6 | -2 |
internal/httpserver/server.go | +17 | -17 |
internal/testutil/testutil.go | +2 | -38 |
internal/trace/trace.go | +121 | -0 |
internal/util/trace.go | +0 | -55 |
diff --git a/internal/dnsserver/caching_test.go b/internal/dnsserver/caching_test.go index 5d0f6e5..8e948df 100644 --- a/internal/dnsserver/caching_test.go +++ b/internal/dnsserver/caching_test.go @@ -10,6 +10,7 @@ import ( "time" "blitiri.com.ar/go/dnss/internal/testutil" + "blitiri.com.ar/go/dnss/internal/trace" "github.com/miekg/dns" ) @@ -212,7 +213,9 @@ func BenchmarkCacheSimple(b *testing.B) { c := NewCachingResolver(r) c.Init() - tr := &testutil.NullTrace{} + tr := trace.New("test", "Benchmark") + defer tr.Finish() + req := newQuery("test.", dns.TypeA) b.ResetTimer() @@ -254,7 +257,7 @@ func queryA(t *testing.T, c *cachingResolver, rr, domain, expected string) *dns. back.Response = newReply(mustNewRR(t, rr)) } - tr := testutil.NewTestTrace(t) + tr := trace.New("test", "queryA") defer tr.Finish() req := newQuery(domain, dns.TypeA) @@ -282,7 +285,7 @@ func queryFail(t *testing.T, c *cachingResolver) *dns.Msg { back.Response.Response = true back.Response.Rcode = dns.RcodeNameError - tr := testutil.NewTestTrace(t) + tr := trace.New("test", "queryFail") defer tr.Finish() req := newQuery("doesnotexist.", dns.TypeA) diff --git a/internal/dnsserver/resolver.go b/internal/dnsserver/resolver.go index 0b588f9..73666cb 100644 --- a/internal/dnsserver/resolver.go +++ b/internal/dnsserver/resolver.go @@ -8,10 +8,10 @@ import ( "sync" "time" - "blitiri.com.ar/go/log" + "blitiri.com.ar/go/dnss/internal/trace" + "blitiri.com.ar/go/log" "github.com/miekg/dns" - "golang.org/x/net/trace" ) // Resolver is the interface for DNS resolvers that can answer queries. @@ -24,7 +24,7 @@ type Resolver interface { Maintain() // Query responds to a DNS query. - Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) + Query(r *dns.Msg, tr *trace.Trace) (*dns.Msg, error) } /////////////////////////////////////////////////////////////////////////// @@ -178,7 +178,7 @@ func (c *cachingResolver) Maintain() { expired++ } c.mu.Unlock() - tr.LazyPrintf("total: %d expired: %d", total, expired) + tr.Printf("total: %d expired: %d", total, expired) tr.Finish() } } @@ -237,12 +237,12 @@ func copyRRSlice(a []dns.RR) []dns.RR { return b } -func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { +func (c *cachingResolver) Query(r *dns.Msg, tr *trace.Trace) (*dns.Msg, error) { stats.cacheTotal.Add(1) // To keep it simple we only cache single-question queries. if len(r.Question) != 1 { - tr.LazyPrintf("cache bypass: multi-question query") + tr.Printf("cache bypass: multi-question query") stats.cacheBypassed.Add(1) return c.back.Query(r, tr) } @@ -254,7 +254,7 @@ func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { c.mu.RUnlock() if hit { - tr.LazyPrintf("cache hit") + tr.Printf("cache hit") stats.cacheHits.Add(1) reply := &dns.Msg{ @@ -271,7 +271,7 @@ func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { return reply, nil } - tr.LazyPrintf("cache miss") + tr.Printf("cache miss") stats.cacheMisses.Add(1) reply, err := c.back.Query(r, tr) @@ -280,7 +280,7 @@ func (c *cachingResolver) Query(r *dns.Msg, tr trace.Trace) (*dns.Msg, error) { } if err = wantToCache(question, reply); err != nil { - tr.LazyPrintf("cache not recording reply: %v", err) + tr.Printf("cache not recording reply: %v", err) return reply, nil } diff --git a/internal/dnsserver/server.go b/internal/dnsserver/server.go index 4c47690..71d61f6 100644 --- a/internal/dnsserver/server.go +++ b/internal/dnsserver/server.go @@ -10,12 +10,11 @@ import ( "strings" "sync" - "github.com/miekg/dns" - "golang.org/x/net/trace" + "blitiri.com.ar/go/dnss/internal/trace" - "blitiri.com.ar/go/dnss/internal/util" "blitiri.com.ar/go/log" "blitiri.com.ar/go/systemd" + "github.com/miekg/dns" ) // newID is a channel used to generate new request IDs. @@ -66,13 +65,12 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) { tr := trace.New("dnsserver", "Handler") defer tr.Finish() - tr.LazyPrintf("from:%v id:%v", w.RemoteAddr(), r.Id) - - util.TraceQuestion(tr, r.Question) + tr.Printf("from:%v id:%v", w.RemoteAddr(), r.Id) + tr.Question(r.Question) // We only support single-question queries. if len(r.Question) != 1 { - tr.LazyPrintf("len(Q) != 1, failing") + tr.Printf("len(Q) != 1, failing") dns.HandleFailed(w, r) return } @@ -86,11 +84,11 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) { if useUnqUpstream { u, err := dns.Exchange(r, s.unqUpstream) if err == nil { - tr.LazyPrintf("used unqualified upstream") - util.TraceAnswer(tr, u) + tr.Printf("used unqualified upstream") + tr.Answer(u) w.WriteMsg(u) } else { - tr.LazyPrintf("unqualified upstream error: %v", err) + tr.Printf("unqualified upstream error: %v", err) dns.HandleFailed(w, r) } @@ -105,15 +103,14 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) { fromUp, err := s.resolver.Query(r, tr) if err != nil { log.Infof("resolver query error: %v", err) - tr.LazyPrintf(err.Error()) - tr.SetError() + tr.Error(err) r.Id = oldid dns.HandleFailed(w, r) return } - util.TraceAnswer(tr, fromUp) + tr.Answer(fromUp) fromUp.Id = oldid w.WriteMsg(fromUp) diff --git a/internal/httpresolver/resolver.go b/internal/httpresolver/resolver.go index 03905a2..9df8b08 100644 --- a/internal/httpresolver/resolver.go +++ b/internal/httpresolver/resolver.go @@ -16,10 +16,10 @@ import ( "time" "blitiri.com.ar/go/dnss/internal/dnsserver" - "blitiri.com.ar/go/log" + "blitiri.com.ar/go/dnss/internal/trace" + "blitiri.com.ar/go/log" "github.com/miekg/dns" - "golang.org/x/net/trace" ) // httpsResolver implements the dnsserver.Resolver interface by querying a @@ -36,8 +36,6 @@ type httpsResolver struct { mu sync.Mutex client *http.Client firstErr time.Time - - ev trace.EventLog } var errAppendingCerts = fmt.Errorf("error appending certificates") @@ -101,8 +99,9 @@ func (r *httpsResolver) Init() error { r.client = client r.mu.Unlock() - r.ev = trace.NewEventLog("httpresolver", r.Upstream.String()) - r.ev.Printf("Init complete, client: %p", r.client) + tr := trace.New("httpresolver", r.Upstream.String()) + tr.Printf("Init complete, client: %p", r.client) + tr.Finish() return err } @@ -178,28 +177,31 @@ func (r *httpsResolver) maybeRotateClient() { // The time chosen here combines with the transport timeouts set above, so // we never have too many in-flight connections. if time.Since(r.firstErr) > 10*time.Second { - r.ev.Printf("Rotating client after %s of errors: %p", + tr := trace.New("httpresolver", r.Upstream.String()) + defer tr.Finish() + + tr.Printf("Rotating client after %s of errors: %p", time.Since(r.firstErr), r.client) client, err := r.newClient() if err != nil { - r.ev.Errorf("Error creating new client: %v", err) + tr.Errorf("Error creating new client: %v", err) return } r.client = client r.firstErr = time.Time{} - r.ev.Printf("Rotated client: %p", r.client) + tr.Printf("Rotated client: %p", r.client) } } -func (r *httpsResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { +func (r *httpsResolver) Query(req *dns.Msg, tr *trace.Trace) (*dns.Msg, error) { packed, err := req.Pack() if err != nil { return nil, fmt.Errorf("cannot pack query: %v", err) } if log.V(3) { - tr.LazyPrintf("DoH POST %v", r.Upstream) + tr.Printf("DoH POST %v", r.Upstream) } // TODO: Accept header. @@ -216,7 +218,7 @@ func (r *httpsResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { if err != nil { return nil, fmt.Errorf("POST failed: %v", err) } - tr.LazyPrintf("%s %s", hr.Proto, hr.Status) + tr.Printf("%s %s", hr.Proto, hr.Status) defer hr.Body.Close() if hr.StatusCode != http.StatusOK { diff --git a/internal/httpresolver/resolver_test.go b/internal/httpresolver/resolver_test.go index 29f5258..c847b28 100644 --- a/internal/httpresolver/resolver_test.go +++ b/internal/httpresolver/resolver_test.go @@ -11,6 +11,7 @@ import ( "testing" "blitiri.com.ar/go/dnss/internal/testutil" + "blitiri.com.ar/go/dnss/internal/trace" "github.com/miekg/dns" ) @@ -37,7 +38,9 @@ func mustNewDoH(t *testing.T, urlS string) *httpsResolver { func query(t *testing.T, r *httpsResolver, req string) (dns.RR, error) { t.Helper() - tr := testutil.NewTestTrace(t) + tr := trace.New("test", "query") + defer tr.Finish() + dr := new(dns.Msg) dr.SetQuestion(req, dns.TypeA) resp, err := r.Query(dr, tr) @@ -161,7 +164,8 @@ func TestBadRequest(t *testing.T) { defer ts.Close() r := mustNewDoH(t, ts.URL) - tr := testutil.NewTestTrace(t) + tr := trace.New("test", "TestBadRequest") + defer tr.Finish() // Construct a request that cannot be packed, in this case the Rcode is // invalid. diff --git a/internal/httpserver/server.go b/internal/httpserver/server.go index e439385..b9205dd 100644 --- a/internal/httpserver/server.go +++ b/internal/httpserver/server.go @@ -12,10 +12,10 @@ import ( "mime" "net/http" - "blitiri.com.ar/go/dnss/internal/util" + "blitiri.com.ar/go/dnss/internal/trace" + "blitiri.com.ar/go/log" "github.com/miekg/dns" - "golang.org/x/net/trace" ) // Server is an HTTPS server that implements DNS over HTTPS, see the @@ -52,8 +52,8 @@ func (s *Server) ListenAndServe() { func (s *Server) Resolve(w http.ResponseWriter, req *http.Request) { tr := trace.New("httpserver", "/resolve") defer tr.Finish() - tr.LazyPrintf("from:%v", req.RemoteAddr) - tr.LazyPrintf("method:%v", req.Method) + tr.Printf("from:%v", req.RemoteAddr) + tr.Printf("method:%v", req.Method) req.ParseForm() @@ -61,11 +61,11 @@ func (s *Server) Resolve(w http.ResponseWriter, req *http.Request) { // - GET requests have a "dns=" query parameter. // - POST requests have a content-type = application/dns-message. if req.Method == "GET" && req.FormValue("dns") != "" { - tr.LazyPrintf("DoH:GET") + tr.Printf("DoH:GET") dnsQuery, err := base64.RawURLEncoding.DecodeString( req.FormValue("dns")) if err != nil { - util.TraceError(tr, err) + tr.Error(err) http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -77,17 +77,17 @@ func (s *Server) Resolve(w http.ResponseWriter, req *http.Request) { if req.Method == "POST" { ct, _, err := mime.ParseMediaType(req.Header.Get("Content-Type")) if err != nil { - util.TraceError(tr, err) + tr.Error(err) http.Error(w, err.Error(), http.StatusBadRequest) return } if ct == "application/dns-message" { - tr.LazyPrintf("DoH:POST") + tr.Printf("DoH:POST") // Limit the size of request to 4k. dnsQuery, err := ioutil.ReadAll(io.LimitReader(req.Body, 4092)) if err != nil { - util.TraceError(tr, err) + tr.Error(err) http.Error(w, err.Error(), http.StatusBadRequest) return } @@ -98,41 +98,41 @@ func (s *Server) Resolve(w http.ResponseWriter, req *http.Request) { } // Could not found how to handle this request. - util.TraceErrorf(tr, "unknown request type") + tr.Errorf("unknown request type") http.Error(w, "unknown request type", http.StatusUnsupportedMediaType) } // Resolve DNS over HTTPS requests, as specified in RFC 8484. -func (s *Server) resolveDoH(tr trace.Trace, w http.ResponseWriter, dnsQuery []byte) { +func (s *Server) resolveDoH(tr *trace.Trace, w http.ResponseWriter, dnsQuery []byte) { r := &dns.Msg{} err := r.Unpack(dnsQuery) if err != nil { - util.TraceError(tr, err) + tr.Error(err) http.Error(w, err.Error(), http.StatusBadRequest) return } - util.TraceQuestion(tr, r.Question) + tr.Question(r.Question) // Do the DNS request, get the reply. fromUp, err := dns.Exchange(r, s.Upstream) if err != nil { - err = util.TraceErrorf(tr, "dns exchange error: %v", err) + err = tr.Errorf("dns exchange error: %v", err) http.Error(w, err.Error(), http.StatusFailedDependency) return } if fromUp == nil { - err = util.TraceErrorf(tr, "no response from upstream") + err = tr.Errorf("no response from upstream") http.Error(w, err.Error(), http.StatusRequestTimeout) return } - util.TraceAnswer(tr, fromUp) + tr.Answer(fromUp) packed, err := fromUp.Pack() if err != nil { - err = util.TraceErrorf(tr, "cannot pack reply: %v", err) + err = tr.Errorf("cannot pack reply: %v", err) http.Error(w, err.Error(), http.StatusFailedDependency) return } diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 07bada6..80fdb84 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "golang.org/x/net/trace" + "blitiri.com.ar/go/dnss/internal/trace" "github.com/miekg/dns" ) @@ -123,7 +123,7 @@ func (r *TestResolver) Maintain() { } // Query handles the given query, returning the pre-recorded response. -func (r *TestResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { +func (r *TestResolver) Query(req *dns.Msg, tr *trace.Trace) (*dns.Msg, error) { r.LastQuery = req if r.Response != nil { r.Response.Question = req.Question @@ -163,39 +163,3 @@ func NewRR(tb testing.TB, s string) dns.RR { } return rr } - -// TestTrace implements the tracer.Trace interface, but prints using the test -// logging infrastructure. -type TestTrace struct { - T *testing.T -} - -func NewTestTrace(t *testing.T) *TestTrace { - return &TestTrace{t} -} - -func (t *TestTrace) LazyLog(x fmt.Stringer, sensitive bool) { - t.T.Logf("trace %p (%t): %s", t, sensitive, x) -} - -func (t *TestTrace) LazyPrintf(format string, a ...interface{}) { - prefix := fmt.Sprintf("trace %p: ", t) - t.T.Logf(prefix+format, a...) -} - -func (t *TestTrace) SetError() {} -func (t *TestTrace) SetRecycler(f func(interface{})) {} -func (t *TestTrace) SetTraceInfo(traceID, spanID uint64) {} -func (t *TestTrace) SetMaxEvents(m int) {} -func (t *TestTrace) Finish() {} - -// NullTrace implements the tracer.Trace interface, but discards everything. -type NullTrace struct{} - -func (t *NullTrace) LazyLog(x fmt.Stringer, sensitive bool) {} -func (t *NullTrace) LazyPrintf(format string, a ...interface{}) {} -func (t *NullTrace) SetError() {} -func (t *NullTrace) SetRecycler(f func(interface{})) {} -func (t *NullTrace) SetTraceInfo(traceID, spanID uint64) {} -func (t *NullTrace) SetMaxEvents(m int) {} -func (t *NullTrace) Finish() {} diff --git a/internal/trace/trace.go b/internal/trace/trace.go new file mode 100644 index 0000000..ad63270 --- /dev/null +++ b/internal/trace/trace.go @@ -0,0 +1,121 @@ +// Package trace extends golang.org/x/net/trace. +package trace + +import ( + "fmt" + "net/http" + "strconv" + "strings" + + "blitiri.com.ar/go/log" + "github.com/miekg/dns" + + nettrace "golang.org/x/net/trace" +) + +func init() { + // golang.org/x/net/trace has its own authorization which by default only + // allows localhost. This can be confusing and limiting in environments + // which access the monitoring server remotely. + nettrace.AuthRequest = func(req *http.Request) (any, sensitive bool) { + return true, true + } +} + +// A Trace represents an active request. +type Trace struct { + family string + title string + t nettrace.Trace +} + +// New trace. +func New(family, title string) *Trace { + t := &Trace{family, title, nettrace.New(family, title)} + + // The default for max events is 10, which is a bit short for our uses. + // Expand it to 30 which should be large enough to keep most of the + // traces. + t.t.SetMaxEvents(30) + return t +} + +// Printf adds this message to the trace's log. +func (t *Trace) Printf(format string, a ...interface{}) { + t.printf(1, format, a...) +} + +func (t *Trace) printf(n int, format string, a ...interface{}) { + t.t.LazyPrintf(format, a...) + + log.Log(log.Debug, n+1, "%s %s: %s", t.family, t.title, + quote(fmt.Sprintf(format, a...))) +} + +// Errorf adds this message to the trace's log, with an error level. +func (t *Trace) Errorf(format string, a ...interface{}) error { + // Note we can't just call t.Error here, as it breaks caller logging. + err := fmt.Errorf(format, a...) + t.t.SetError() + t.t.LazyPrintf("error: %v", err) + + log.Log(log.Info, 1, "%s %s: error: %s", t.family, t.title, + quote(err.Error())) + return err +} + +// Error marks the trace as having seen an error, and also logs it to the +// trace's log. +func (t *Trace) Error(err error) error { + t.t.SetError() + t.t.LazyPrintf("error: %v", err) + + log.Log(log.Info, 1, "%s %s: error: %s", t.family, t.title, + quote(err.Error())) + + return err +} + +// Finish the trace. It should not be changed after this is called. +func (t *Trace) Finish() { + t.t.Finish() +} + +//////////////////////////////////////////////////////////// +// DNS specific extensions +// + +// Question adds the given question to the trace. +func (t *Trace) Question(qs []dns.Question) { + if !log.V(3) { + return + } + + t.printf(1, questionsToString(qs)) +} + +func questionsToString(qs []dns.Question) string { + var s []string + for _, q := range qs { + s = append(s, fmt.Sprintf("(%s %s %s)", q.Name, + dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass])) + } + return "Q: " + strings.Join(s, " ; ") +} + +// Answer adds the given DNS answer to the trace. +func (t *Trace) Answer(m *dns.Msg) { + if !log.V(3) { + return + } + + t.printf(1, m.MsgHdr.String()) + for _, rr := range m.Answer { + t.printf(1, rr.String()) + } +} + +func quote(s string) string { + qs := strconv.Quote(s) + return qs[1 : len(qs)-1] +} diff --git a/internal/util/trace.go b/internal/util/trace.go deleted file mode 100644 index 3ce1bf7..0000000 --- a/internal/util/trace.go +++ /dev/null @@ -1,55 +0,0 @@ -package util - -import ( - "fmt" - "strings" - - "blitiri.com.ar/go/log" - - "github.com/miekg/dns" - "golang.org/x/net/trace" -) - -// TraceQuestion adds the given question to the trace. -func TraceQuestion(tr trace.Trace, qs []dns.Question) { - if !log.V(3) { - return - } - - tr.LazyPrintf(questionsToString(qs)) -} - -func questionsToString(qs []dns.Question) string { - var s []string - for _, q := range qs { - s = append(s, fmt.Sprintf("(%s %s %s)", q.Name, - dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass])) - } - return "Q: " + strings.Join(s, " ; ") -} - -// TraceAnswer adds the given DNS answer to the trace. -func TraceAnswer(tr trace.Trace, m *dns.Msg) { - if !log.V(3) { - return - } - - tr.LazyPrintf(m.MsgHdr.String()) - for _, rr := range m.Answer { - tr.LazyPrintf(rr.String()) - } -} - -// TraceError adds the given error to the trace. -func TraceError(tr trace.Trace, err error) { - log.Infof(err.Error()) - tr.LazyPrintf(err.Error()) - tr.SetError() -} - -// TraceErrorf adds an error message to the trace. -func TraceErrorf(tr trace.Trace, format string, a ...interface{}) error { - err := fmt.Errorf(format, a...) - TraceError(tr, err) - return err -}