// Package util implements some common utilities.
package util
import (
"context"
"crypto/tls"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"sync/atomic"
"blitiri.com.ar/go/gofer/config"
"blitiri.com.ar/go/gofer/trace"
"blitiri.com.ar/go/log"
"golang.org/x/crypto/acme"
"golang.org/x/crypto/acme/autocert"
)
// LoadCertsForHTTPS returns a TLS configuration based on the given HTTPS
// config.
func LoadCertsForHTTPS(conf config.HTTPS) (*tls.Config, error) {
if conf.Certs != "" {
tlsConfig, err := LoadCertsFromDir(conf.Certs)
if err != nil {
return nil, err
}
// We need to set the NextProtos manually before creating the TLS
// listener, the library cannot help us with this.
// For autocert, this is not needed because autocert.Manager does it
// for us.
tlsConfig.NextProtos = append(tlsConfig.NextProtos,
"h2", "http/1.1")
if conf.InsecureKeyLogFile != "" {
log.Infof("INSECURE TLS key log is enabled, writing to %q",
conf.InsecureKeyLogFile)
tlsConfig.KeyLogWriter, err = os.Create(conf.InsecureKeyLogFile)
}
return tlsConfig, err
}
m := &autocert.Manager{
// As indicated in the documentation, configuring autocerts
// implies accepting the CA's TOS.
Prompt: autocert.AcceptTOS,
Email: conf.AutoCerts.Email,
HostPolicy: autocert.HostWhitelist(conf.AutoCerts.Hosts...),
Cache: autocert.DirCache(cachePath(conf.AutoCerts.CacheDir)),
}
// Make sure we can write to the cache, to make it easier to detect and
// troubleshoot permission issues.
err := m.Cache.Put(context.Background(), "__gofer_check", []byte("test"))
if err != nil {
return nil, fmt.Errorf("error writing to the autocert cache %q: %v",
m.Cache, err)
}
if conf.AutoCerts.AcmeURL != "" {
m.Client = &acme.Client{
DirectoryURL: conf.AutoCerts.AcmeURL,
// Note that Key is generated by the Manager, we don't need to
// fill it in here.
}
}
// Wrap the TLSConfig.GetCertificate so we can log errors, otherwise
// they're invisible and difficult to debug.
tlsConf := m.TLSConfig()
getCert := tlsConf.GetCertificate
tlsConf.GetCertificate = func(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
tr := trace.New("autocerts", h.Conn.RemoteAddr().String())
defer tr.Finish()
cert, err := getCert(h)
if err != nil {
// We want to mark this as an error so it's easy to find in the
// traces, but don't want to log it as such, because these can
// also be harmless and add a lot of noise (e.g. a user requesting
// a non-whitelisted domain).
tr.Printf("request for %q -> %v", h.ServerName, err)
tr.SetError()
}
return cert, err
}
if conf.InsecureKeyLogFile != "" {
log.Infof("INSECURE TLS key log is enabled, writing to %q",
conf.InsecureKeyLogFile)
tlsConf.KeyLogWriter, err = os.Create(conf.InsecureKeyLogFile)
}
return tlsConf, err
}
func cachePath(confDir string) string {
if confDir != "" {
return confDir
}
base := "gofer-autocert-cache"
// systemd sets this variable if CacheDirectory= is set.
if cd := os.Getenv("CACHE_DIRECTORY"); cd != "" {
return filepath.Join(cd, base)
}
// System default (e.g. $HOME/.cache/).
cd, err := os.UserCacheDir()
if err == nil {
return filepath.Join(cd, base)
}
// Last resort: relative path.
return base
}
// LoadCertsFromDir loads certificates from the given directory, and returns a
// TLS config including them.
func LoadCertsFromDir(certDir string) (*tls.Config, error) {
tlsConfig := &tls.Config{}
infos, err := ioutil.ReadDir(certDir)
if err != nil {
return nil, fmt.Errorf("ReadDir(%q): %v", certDir, err)
}
for _, info := range infos {
name := info.Name()
dir := filepath.Join(certDir, name)
if fi, err := os.Stat(dir); err == nil && !fi.IsDir() {
// Skip non-directories.
continue
}
certPath := filepath.Join(dir, "fullchain.pem")
if _, err := os.Stat(certPath); os.IsNotExist(err) {
continue
}
keyPath := filepath.Join(dir, "privkey.pem")
if _, err := os.Stat(keyPath); os.IsNotExist(err) {
continue
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("error loading pair (%q, %q): %v",
certPath, keyPath, err)
}
tlsConfig.Certificates = append(tlsConfig.Certificates, cert)
}
if len(tlsConfig.Certificates) == 0 {
return nil, fmt.Errorf("no certificates found in %q", certDir)
}
tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
}
func BidirCopy(src, dst io.ReadWriter) int64 {
done := make(chan bool, 2)
var total int64
go func() {
n, _ := io.Copy(src, dst)
atomic.AddInt64(&total, n)
done <- true
}()
go func() {
n, _ := io.Copy(dst, src)
atomic.AddInt64(&total, n)
done <- true
}()
// Return when one of the two completes.
// The other goroutine will remain alive, it is up to the caller to create
// the conditions to complete it (e.g. by closing one of the sides).
<-done
return atomic.LoadInt64(&total)
}