author | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-08-22 22:49:43 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-08-22 22:49:43 UTC |
parent | da7ed40160b1ce27621a955708562cfd9fb7a1f5 |
internal/httpresolver/resolver.go | +106 | -16 |
diff --git a/internal/httpresolver/resolver.go b/internal/httpresolver/resolver.go index 665fa92..1e0426f 100644 --- a/internal/httpresolver/resolver.go +++ b/internal/httpresolver/resolver.go @@ -8,8 +8,10 @@ import ( "io" "io/ioutil" "mime" + "net" "net/http" "net/url" + "sync" "time" "blitiri.com.ar/go/dnss/internal/dnsserver" @@ -22,9 +24,15 @@ import ( // httpsResolver implements the dnsserver.Resolver interface by querying a // server via DNS over HTTPS (DoH, RFC 8484). type httpsResolver struct { - Upstream *url.URL - CAFile string + Upstream *url.URL + CAFile string + tlsConfig *tls.Config + + mu sync.Mutex client *http.Client + firstErr time.Time + + ev trace.EventLog } func loadCertPool(caFile string) (*x509.CertPool, error) { @@ -51,12 +59,56 @@ func NewDoH(upstream *url.URL, caFile string) *httpsResolver { } func (r *httpsResolver) Init() error { + // If CAFile is empty, we're ok with the defaults (use the system default + // CA database). + if r.CAFile != "" { + pool, err := loadCertPool(r.CAFile) + if err != nil { + return err + } + + r.tlsConfig = &tls.Config{ + RootCAs: pool, + } + } + + client, err := r.newClient() + + r.mu.Lock() + r.client = client + r.mu.Unlock() + + r.ev = trace.NewEventLog("httpresolver", r.Upstream.String()) + r.ev.Printf("Init complete, client: %p", r.client) + + return err +} + +func (r *httpsResolver) newClient() (*http.Client, error) { transport := &http.Transport{ + TLSClientConfig: r.tlsConfig, + // Take the semi-standard proxy settings from the environment. Proxy: http.ProxyFromEnvironment, + + // Drop connections after 30s idle. + // This helps prevent connection pile-up on frequent client rotations, + // which can happen with intermittent network issues. + IdleConnTimeout: 30 * time.Second, + + // Reasonable defaults, based on http.DefaultTransport. + DialContext: (&net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 1 * time.Second, + DualStack: true, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 10, + TLSHandshakeTimeout: 4 * time.Second, + ExpectContinueTimeout: 1 * time.Second, } - r.client = &http.Client{ + client := &http.Client{ // Give our HTTP requests 4 second timeouts: DNS usually doesn't wait // that long anyway, but this helps with slow connections. Timeout: 4 * time.Second, @@ -64,23 +116,56 @@ func (r *httpsResolver) Init() error { Transport: transport, } - // If CAFile is empty, we're ok with the defaults (use the system default - // CA database). - if r.CAFile != "" { - pool, err := loadCertPool(r.CAFile) - if err != nil { - return err - } + return client, nil +} - transport.TLSClientConfig = &tls.Config{ - RootCAs: pool, - } - } +func (r *httpsResolver) setClientError(err error) { + r.mu.Lock() + defer r.mu.Unlock() - return nil + if err == nil { + r.firstErr = time.Time{} + } else if r.firstErr.IsZero() { + r.firstErr = time.Now() + } } func (r *httpsResolver) Maintain() { + for range time.Tick(2 * time.Second) { + r.maybeRotateClient() + } +} + +func (r *httpsResolver) maybeRotateClient() { + r.mu.Lock() + defer r.mu.Unlock() + + if r.firstErr.IsZero() { + return + } + + // If we've seen errors for the last 10s, rotate the client. + // This is unfortunately needed because the Go HTTP/2 transport will + // insist on using a dead connection for a long time, and cannot be told + // to close it. This causes problems when the computer changes connections + // (e.g. switch wifi network) or is having intermittent network issues. + // This workaround works because a new client will initiate a new + // connection, and the old one will die in the background. + // The time chosen here combines with the transport timeouts set above, so + // we never have too many in-flight connections. + if time.Since(r.firstErr) > 10*time.Second { + r.ev.Printf("Rotating client after %s of errors: %p", + time.Since(r.firstErr), r.client) + client, err := r.newClient() + if err != nil { + r.ev.Errorf("Error creating new client: %v", err) + return + } + + r.client = client + r.firstErr = time.Time{} + r.ev.Printf("Rotated client: %p", r.client) + } } func (r *httpsResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { @@ -95,10 +180,15 @@ func (r *httpsResolver) Query(req *dns.Msg, tr trace.Trace) (*dns.Msg, error) { // TODO: Accept header. - hr, err := r.client.Post( + r.mu.Lock() + client := r.client + r.mu.Unlock() + + hr, err := client.Post( r.Upstream.String(), "application/dns-message", bytes.NewReader(packed)) + r.setClientError(err) if err != nil { return nil, fmt.Errorf("POST failed: %v", err) }