git » dnss » main » tree

[main] / internal / httpresolver / resolver.go

package httpresolver

import (
	"bytes"
	"context"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"io/ioutil"
	"mime"
	"net"
	"net/http"
	"net/url"
	"sync"
	"time"

	"blitiri.com.ar/go/dnss/internal/dnsserver"
	"blitiri.com.ar/go/dnss/internal/trace"

	"blitiri.com.ar/go/log"
	"github.com/miekg/dns"
)

// httpsResolver implements the dnsserver.Resolver interface by querying a
// server via DNS over HTTPS (DoH, RFC 8484).
type httpsResolver struct {
	Upstream  *url.URL
	CAFile    string
	tlsConfig *tls.Config

	// net.Resolver that will contact the server at --fallback_upstream for
	// DNS resolutions.
	fallbackResolver *net.Resolver

	mu       sync.Mutex
	client   *http.Client
	firstErr time.Time
	tr       *trace.Trace
}

var errAppendingCerts = fmt.Errorf("error appending certificates")

func loadCertPool(caFile string) (*x509.CertPool, error) {
	pemData, err := ioutil.ReadFile(caFile)
	if err != nil {
		return nil, err
	}

	pool := x509.NewCertPool()
	if !pool.AppendCertsFromPEM(pemData) {
		return nil, errAppendingCerts
	}

	return pool, nil
}

// NewDoH creates a new DoH resolver, which uses the given upstream
// URL to resolve queries.
func NewDoH(upstream *url.URL, caFile, fallback string) *httpsResolver {
	r := &httpsResolver{
		Upstream: upstream,
		CAFile:   caFile,
	}

	if fallback != "" {
		// Dial function that will always use the fallback address to contact
		// DNS.
		dialer := net.Dialer{}
		dialFallback := func(ctx context.Context, network, address string) (net.Conn, error) {
			return dialer.DialContext(ctx, network, fallback)
		}

		r.fallbackResolver = &net.Resolver{
			PreferGo: true, // Avoid the system resolver.
			Dial:     dialFallback,
		}
	}

	return r
}

func (r *httpsResolver) Init() error {
	// If CAFile is empty, we're ok with the defaults (use the system default
	// CA database).
	if r.CAFile != "" {
		pool, err := loadCertPool(r.CAFile)
		if err != nil {
			return err
		}

		r.tlsConfig = &tls.Config{
			RootCAs: pool,
		}
	}

	client, err := r.newClient()

	r.mu.Lock()
	r.client = client
	r.tr = trace.New("httpresolver.Client", r.Upstream.String())
	r.tr.Printf("Init complete, client: %p", r.client)
	r.mu.Unlock()

	return err
}

func (r *httpsResolver) newClient() (*http.Client, error) {
	transport := &http.Transport{
		TLSClientConfig: r.tlsConfig,

		// Take the semi-standard proxy settings from the environment.
		Proxy: http.ProxyFromEnvironment,

		// Drop connections after 30s idle.
		// This helps prevent connection pile-up on frequent client rotations,
		// which can happen with intermittent network issues.
		IdleConnTimeout: 30 * time.Second,

		// Reasonable defaults, based on http.DefaultTransport.
		DialContext: (&net.Dialer{
			Timeout:   10 * time.Second,
			KeepAlive: 1 * time.Second,
			DualStack: true,
			Resolver:  r.fallbackResolver,
		}).DialContext,
		ForceAttemptHTTP2:     true,
		MaxIdleConns:          10,
		TLSHandshakeTimeout:   4 * time.Second,
		ExpectContinueTimeout: 1 * time.Second,
	}

	client := &http.Client{
		// Give our HTTP requests 4 second timeouts: DNS usually doesn't wait
		// that long anyway, but this helps with slow connections.
		Timeout: 4 * time.Second,

		Transport: transport,
	}

	return client, nil
}

func (r *httpsResolver) setClientError(err error) {
	r.mu.Lock()
	defer r.mu.Unlock()

	if err == nil {
		r.firstErr = time.Time{}
	} else {
		if r.firstErr.IsZero() {
			r.firstErr = time.Now()
		}
		r.tr.Printf("Client error: %v", err)
	}
}

func (r *httpsResolver) Maintain() {
	for range time.Tick(2 * time.Second) {
		r.maybeRotateClient()
	}
}

func (r *httpsResolver) maybeRotateClient() {
	r.mu.Lock()
	defer r.mu.Unlock()

	if r.firstErr.IsZero() {
		return
	}

	// If we've seen errors for the last 10s, rotate the client.
	// This is unfortunately needed because the Go HTTP/2 transport will
	// insist on using a dead connection for a long time, and cannot be told
	// to close it. This causes problems when the computer changes connections
	// (e.g. switch wifi network) or is having intermittent network issues.
	// This workaround works because a new client will initiate a new
	// connection, and the old one will die in the background.
	// The time chosen here combines with the transport timeouts set above, so
	// we never have too many in-flight connections.
	if time.Since(r.firstErr) > 10*time.Second {
		// Close the old trace, and create a new one.
		// This makes it easier to analyze the client behaviour in the traces.
		r.tr.Errorf("Rotating client after %s of errors: %p",
			time.Since(r.firstErr), r.client)
		r.tr.Finish()

		r.tr = trace.New("httpresolver.Client", r.Upstream.String())
		client, err := r.newClient()
		if err != nil {
			r.tr.Errorf("Error creating new client: %v", err)
			return
		}

		r.client = client
		r.firstErr = time.Time{}
		r.tr.Printf("Rotated client: %p", r.client)
	}
}

func (r *httpsResolver) Query(req *dns.Msg, tr *trace.Trace) (*dns.Msg, error) {
	packed, err := req.Pack()
	if err != nil {
		return nil, fmt.Errorf("cannot pack query: %v", err)
	}

	if log.V(1) {
		tr.Printf("DoH POST %v", r.Upstream)
	}

	// TODO: Accept header.

	r.mu.Lock()
	client := r.client
	r.mu.Unlock()

	hr, err := client.Post(
		r.Upstream.String(),
		"application/dns-message",
		bytes.NewReader(packed))
	r.setClientError(err)
	if err != nil {
		return nil, fmt.Errorf("POST failed: %v", err)
	}
	tr.Printf("%s  %s", hr.Proto, hr.Status)
	defer hr.Body.Close()

	if hr.StatusCode != http.StatusOK {
		return nil, fmt.Errorf("Response status: %s", hr.Status)
	}

	// Read the HTTPS response, and parse the message.
	ct, _, err := mime.ParseMediaType(hr.Header.Get("Content-Type"))
	if err != nil {
		return nil, fmt.Errorf("failed to parse content type: %v", err)
	}

	if ct != "application/dns-message" {
		return nil, fmt.Errorf("unknown response content type %q", ct)
	}

	respRaw, err := ioutil.ReadAll(io.LimitReader(hr.Body, 64*1024))
	if err != nil {
		return nil, fmt.Errorf("error reading from body: %v", err)
	}

	respDNS := &dns.Msg{}
	err = respDNS.Unpack(respRaw)
	if err != nil {
		return nil, fmt.Errorf("error unpacking response: %v", err)
	}

	return respDNS, nil
}

// Compile-time check that the implementation matches the interface.
var _ dnsserver.Resolver = &httpsResolver{}