git » gofer » master » tree

[master] / util / util.go

// Package util implements some common utilities.
package util

import (
	"context"
	"crypto/tls"
	"fmt"
	"io"
	"io/ioutil"
	"os"
	"path/filepath"
	"sync/atomic"

	"blitiri.com.ar/go/gofer/config"
	"blitiri.com.ar/go/gofer/trace"
	"golang.org/x/crypto/acme"
	"golang.org/x/crypto/acme/autocert"
)

// LoadCertsForHTTPS returns a TLS configuration based on the given HTTPS
// config.
func LoadCertsForHTTPS(conf config.HTTPS) (*tls.Config, error) {
	if conf.Certs != "" {
		tlsConfig, err := LoadCertsFromDir(conf.Certs)
		if err != nil {
			return nil, err
		}

		// We need to set the NextProtos manually before creating the TLS
		// listener, the library cannot help us with this.
		// For autocert, this is not needed because autocert.Manager does it
		// for us.
		tlsConfig.NextProtos = append(tlsConfig.NextProtos,
			"h2", "http/1.1")
		return tlsConfig, err
	}

	m := &autocert.Manager{
		// As indicated in the documentation, configuring autocerts
		// implies accepting the CA's TOS.
		Prompt:     autocert.AcceptTOS,
		Email:      conf.AutoCerts.Email,
		HostPolicy: autocert.HostWhitelist(conf.AutoCerts.Hosts...),
		Cache:      autocert.DirCache(cachePath(conf.AutoCerts.CacheDir)),
	}

	// Make sure we can write to the cache, to make it easier to detect and
	// troubleshoot permission issues.
	err := m.Cache.Put(context.Background(), "__gofer_check", []byte("test"))
	if err != nil {
		return nil, fmt.Errorf("error writing to the autocert cache %q: %v",
			m.Cache, err)
	}

	if conf.AutoCerts.AcmeURL != "" {
		m.Client = &acme.Client{
			DirectoryURL: conf.AutoCerts.AcmeURL,
			// Note that Key is generated by the Manager, we don't need to
			// fill it in here.
		}
	}

	// Wrap the TLSConfig.GetCertificate so we can log errors, otherwise
	// they're invisible and difficult to debug.
	tlsConf := m.TLSConfig()
	getCert := tlsConf.GetCertificate
	tlsConf.GetCertificate = func(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
		tr := trace.New("autocerts", h.Conn.RemoteAddr().String())
		defer tr.Finish()

		cert, err := getCert(h)
		if err != nil {
			// We want to mark this as an error so it's easy to find in the
			// traces, but don't want to log it as such, because these can
			// also be harmless and add a lot of noise (e.g. a user requesting
			// a non-whitelisted domain).
			tr.Printf("request for %q -> %v", h.ServerName, err)
			tr.SetError()
		}
		return cert, err
	}

	return tlsConf, nil
}

func cachePath(confDir string) string {
	if confDir != "" {
		return confDir
	}

	base := "gofer-autocert-cache"

	// systemd sets this variable if CacheDirectory= is set.
	if cd := os.Getenv("CACHE_DIRECTORY"); cd != "" {
		return filepath.Join(cd, base)
	}

	// System default (e.g. $HOME/.cache/).
	cd, err := os.UserCacheDir()
	if err == nil {
		return filepath.Join(cd, base)
	}

	// Last resort: relative path.
	return base
}

// LoadCertsFromDir loads certificates from the given directory, and returns a
// TLS config including them.
func LoadCertsFromDir(certDir string) (*tls.Config, error) {
	tlsConfig := &tls.Config{}

	infos, err := ioutil.ReadDir(certDir)
	if err != nil {
		return nil, fmt.Errorf("ReadDir(%q): %v", certDir, err)
	}
	for _, info := range infos {
		name := info.Name()
		dir := filepath.Join(certDir, name)
		if fi, err := os.Stat(dir); err == nil && !fi.IsDir() {
			// Skip non-directories.
			continue
		}

		certPath := filepath.Join(dir, "fullchain.pem")
		if _, err := os.Stat(certPath); os.IsNotExist(err) {
			continue
		}
		keyPath := filepath.Join(dir, "privkey.pem")
		if _, err := os.Stat(keyPath); os.IsNotExist(err) {
			continue
		}

		cert, err := tls.LoadX509KeyPair(certPath, keyPath)
		if err != nil {
			return nil, fmt.Errorf("error loading pair (%q, %q): %v",
				certPath, keyPath, err)
		}
		tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
	}

	if len(tlsConfig.Certificates) == 0 {
		return nil, fmt.Errorf("no certificates found in %q", certDir)
	}

	tlsConfig.BuildNameToCertificate()

	return tlsConfig, nil
}

func BidirCopy(src, dst io.ReadWriter) int64 {
	done := make(chan bool, 2)
	var total int64

	go func() {
		n, _ := io.Copy(src, dst)
		atomic.AddInt64(&total, n)
		done <- true
	}()

	go func() {
		n, _ := io.Copy(dst, src)
		atomic.AddInt64(&total, n)
		done <- true
	}()

	// Return when one of the two completes.
	// The other goroutine will remain alive, it is up to the caller to create
	// the conditions to complete it (e.g. by closing one of the sides).
	<-done

	return atomic.LoadInt64(&total)
}