// chasquid is an SMTP (email) server, with a focus on simplicity, security, // and ease of operation. // // See https://blitiri.com.ar/p/chasquid for more details. package main import ( "context" "flag" "fmt" "math/rand" "net" "os" "os/signal" "path" "path/filepath" "strings" "syscall" "time" "blitiri.com.ar/go/chasquid/internal/config" "blitiri.com.ar/go/chasquid/internal/courier" "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/dovecot" "blitiri.com.ar/go/chasquid/internal/localrpc" "blitiri.com.ar/go/chasquid/internal/maillog" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/smtpsrv" "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/log" "blitiri.com.ar/go/systemd" ) // Command-line flags. var ( configDir = flag.String("config_dir", "/etc/chasquid", "configuration directory") configOverrides = flag.String("config_overrides", "", "override configuration values (in text protobuf format)") showVer = flag.Bool("version", false, "show version and exit") ) func main() { flag.Parse() log.Init() parseVersionInfo() if *showVer { fmt.Printf("chasquid %s (source date: %s)\n", version, sourceDate) return } log.Infof("chasquid starting (version %s)", version) // Seed the PRNG, just to prevent for it to be totally predictable. rand.Seed(time.Now().UnixNano()) conf, err := config.Load(*configDir+"/chasquid.conf", *configOverrides) if err != nil { log.Fatalf("Error loading config: %v", err) } config.LogConfig(conf) // Change to the config dir. // This allow us to use relative paths for configuration directories. // It also can be useful in unusual environments and for testing purposes, // where paths inside the configuration itself could be relative, and this // fixes the point of reference. err = os.Chdir(*configDir) if err != nil { log.Fatalf("Error changing to config dir %q: %v", *configDir, err) } initMailLog(conf.MailLogPath) if conf.MonitoringAddress != "" { go launchMonitoringServer(conf) } s := smtpsrv.NewServer() s.Hostname = conf.Hostname s.MaxDataSize = conf.MaxDataSizeMb * 1024 * 1024 s.HookPath = "hooks/" s.HAProxyEnabled = conf.HaproxyIncoming s.SetAliasesConfig(*conf.SuffixSeparators, *conf.DropCharacters) if conf.DovecotAuth { loadDovecot(s, conf.DovecotUserdbPath, conf.DovecotClientPath) } // Load certificates from "certs/<directory>/{fullchain,privkey}.pem". // The structure matches letsencrypt's, to make it easier for that case. log.Infof("Loading certificates:") for _, info := range mustReadDir("certs/") { if info.Type().IsRegular() { // Ignore regular files, we only care about directories. continue } name := info.Name() dir := filepath.Join("certs/", name) loadCert(name, dir, s) } // Load domains from "domains/". log.Infof("Loading domains:") for _, info := range mustReadDir("domains/") { domain, err := normalize.Domain(info.Name()) if err != nil { log.Fatalf("Invalid name %+q: %v", info.Name(), err) } if info.Type().IsRegular() { // Ignore regular files, we only care about directories. continue } dir := filepath.Join("domains", info.Name()) loadDomain(domain, dir, s) } // Always include localhost as local domain. // This can prevent potential trouble if we were to accidentally treat it // as a remote domain (for loops, alias resolutions, etc.). s.AddDomain("localhost") dinfo, err := domaininfo.New(conf.DataDir + "/domaininfo") if err != nil { log.Fatalf("Error opening domain info database: %v", err) } s.SetDomainInfo(dinfo) stsCache, err := sts.NewCache(conf.DataDir + "/sts-cache") if err != nil { log.Fatalf("Failed to initialize STS cache: %v", err) } go stsCache.PeriodicallyRefresh(context.Background()) localC := &courier.MDA{ Binary: conf.MailDeliveryAgentBin, Args: conf.MailDeliveryAgentArgs, Timeout: 30 * time.Second, } remoteC := &courier.SMTP{ HelloDomain: conf.Hostname, Dinfo: dinfo, STSCache: stsCache, } s.InitQueue(conf.DataDir+"/queue", localC, remoteC) // Load the addresses and listeners. systemdLs, err := systemd.Listeners() if err != nil { log.Fatalf("Error getting systemd listeners: %v", err) } naddr := loadAddresses(s, conf.SmtpAddress, systemdLs["smtp"], smtpsrv.ModeSMTP) naddr += loadAddresses(s, conf.SubmissionAddress, systemdLs["submission"], smtpsrv.ModeSubmission) naddr += loadAddresses(s, conf.SubmissionOverTlsAddress, systemdLs["submission_tls"], smtpsrv.ModeSubmissionTLS) if naddr == 0 { log.Fatalf("No address to listen on") } go localrpc.DefaultServer.ListenAndServe(conf.DataDir + "/localrpc-v1") go signalHandler(dinfo, s) s.ListenAndServe() } func loadAddresses(srv *smtpsrv.Server, addrs []string, ls []net.Listener, mode smtpsrv.SocketMode) int { naddr := 0 for _, addr := range addrs { if addr == "" { // An empty address is invalid, to prevent accidental // misconfiguration. log.Errorf("Invalid empty listening address for %v", mode) log.Fatalf("If you want to disable %v, remove it from the config", mode) } else if addr == "systemd" { // The "systemd" address indicates we get listeners via systemd. srv.AddListeners(ls, mode) naddr += len(ls) } else { srv.AddAddr(addr, mode) naddr++ } } if naddr == 0 { log.Errorf("Warning: No %v addresses/listeners", mode) log.Errorf("If using systemd, check that you named the sockets") } return naddr } func initMailLog(path string) { var err error switch path { case "<syslog>": maillog.Default, err = maillog.NewSyslog() case "<stdout>": maillog.Default = maillog.New(os.Stdout) case "<stderr>": maillog.Default = maillog.New(os.Stderr) default: _ = os.MkdirAll(filepath.Dir(path), 0775) maillog.Default, err = maillog.NewFile(path) } if err != nil { log.Fatalf("Error opening mail log: %v", err) } } func signalHandler(dinfo *domaininfo.DB, srv *smtpsrv.Server) { var err error signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGHUP, syscall.SIGTERM, syscall.SIGINT) for { switch sig := <-signals; sig { case syscall.SIGHUP: log.Infof("Received SIGHUP, reloading") // SIGHUP triggers a reopen of the log files. This is used for log // rotation. err = log.Default.Reopen() if err != nil { log.Fatalf("Error reopening log: %v", err) } err = maillog.Default.Reopen() if err != nil { log.Fatalf("Error reopening maillog: %v", err) } // We don't want to reload the domain info database periodically, // as it can be expensive, and it is not expected that the user // changes this behind chasquid's back. err = dinfo.Reload() if err != nil { log.Fatalf("Error reloading domain info: %v", err) } // Also trigger a server reload. srv.Reload() case syscall.SIGTERM, syscall.SIGINT: log.Fatalf("Got signal to exit: %v", sig) default: log.Errorf("Unexpected signal %v", sig) } } } // Helper to load a single certificate configuration into the server. func loadCert(name, dir string, s *smtpsrv.Server) { log.Infof(" %s", name) // Ignore directories that don't have both keys. // We warn about this because it can be hard to debug otherwise. certPath := filepath.Join(dir, "fullchain.pem") if _, err := os.Stat(certPath); err != nil { log.Infof(" skipping: %v", err) return } keyPath := filepath.Join(dir, "privkey.pem") if _, err := os.Stat(keyPath); err != nil { log.Infof(" skipping: %v", err) return } err := s.AddCerts(certPath, keyPath) if err != nil { log.Fatalf(" %v", err) } } // Helper to load a single domain configuration into the server. func loadDomain(name, dir string, s *smtpsrv.Server) { log.Infof(" %s", name) s.AddDomain(name) nu, err := s.AddUserDB(name, dir+"/users") if err != nil { // If there is an error loading users, fail hard to make sure this is // noticed and fixed as soon as it happens. log.Fatalf(" users file error: %v", err) } na, err := s.AddAliasesFile(name, dir+"/aliases") if err != nil { // If there's an error loading aliases, fail hard to make sure this is // noticed and fixed as soon as it happens. log.Fatalf(" aliases file error: %v", err) } nd, err := loadDKIM(name, dir, s) if err != nil { // DKIM errors are fatal because if the user set DKIM up, then we // don't want it to be failing silently, as that could cause // deliverability issues. log.Fatalf(" DKIM loading error: %v", err) } log.Infof(" %d users, %d aliases, %d DKIM keys", nu, na, nd) } func loadDovecot(s *smtpsrv.Server, userdb, client string) { a := dovecot.NewAuth(userdb, client) s.SetAuthFallback(a) log.Infof("Fallback authenticator: %v", a) if err := a.Check(); err != nil { log.Errorf("Warning: Dovecot auth is not responding: %v", err) } } func loadDKIM(domain, dir string, s *smtpsrv.Server) (int, error) { glob := path.Clean(dir + "/dkim:*.pem") pems, err := filepath.Glob(glob) if err != nil { return 0, err } for _, pem := range pems { base := filepath.Base(pem) selector := strings.TrimPrefix(base, "dkim:") selector = strings.TrimSuffix(selector, ".pem") err = s.AddDKIMSigner(domain, selector, pem) if err != nil { return 0, err } } return len(pems), nil } // Read a directory, which must have at least some entries. func mustReadDir(path string) []os.DirEntry { dirs, err := os.ReadDir(path) if err != nil { log.Fatalf("Error reading %q directory: %v", path, err) } if len(dirs) == 0 { log.Fatalf("No entries found in %q", path) } return dirs }
// chasquid-util is a command-line utility for chasquid-related operations. package main import ( "bytes" "fmt" "os" "path/filepath" "sort" "strconv" "strings" "syscall" "blitiri.com.ar/go/chasquid/internal/config" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/localrpc" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/userdb" "golang.org/x/term" "google.golang.org/protobuf/encoding/prototext" ) // Usage to show users on --help or invocation errors. const usage = ` Usage: chasquid-util [options] user-add <user@domain> [--password=<password>] [--receive_only] Add a user to the userdb. chasquid-util [options] user-remove <user@domain> Remove a user from the userdb. chasquid-util [options] authenticate <user@domain> [--password=<password>] Authenticate a user. chasquid-util [options] check-userdb <domain> Check if the userdb for the given domain is accessible. chasquid-util [options] aliases-resolve <address> Resolve an address. Talks to the running chasquid. chasquid-util [options] domaininfo-remove <domain> Remove domaininfo for the given domain. Talks to the running chasquid. chasquid-util [options] print-config Print the current chasquid configuration. chasquid-util [options] dkim-keygen <domain> [<selector> <private-key.pem>] [--algo=rsa3072|rsa4096|ed25519] Generate a new DKIM key pair for the domain. chasquid-util [options] dkim-dns <domain> [<selector> <private-key.pem>] Print the DNS TXT record to use for the domain, selector and private key. Options: -C=<path>, --configdir=<path> Configuration directory -v Verbose mode ` // Command-line arguments. // Arguments starting with "-" will be parsed as key-value pairs, and // positional arguments will appear as "$POS" -> value. // // For example, "--abc=def x y -p=q -r" will result in: // {"--abc": "def", "$1": "x", "$2": "y", "-p": "q", "-r": ""} var args map[string]string // Globals, loaded from top-level options. var ( configDir = "/etc/chasquid" ) func main() { args = parseArgs(usage) if _, ok := args["--help"]; ok { fmt.Print(usage) return } // Load globals. if d, ok := args["--configdir"]; ok { configDir = d } if d, ok := args["-C"]; ok { configDir = d } commands := map[string]func(){ "user-add": userAdd, "user-remove": userRemove, "authenticate": authenticate, "check-userdb": checkUserDB, "aliases-resolve": aliasesResolve, "print-config": printConfig, "domaininfo-remove": domaininfoRemove, "dkim-keygen": dkimKeygen, "dkim-dns": dkimDNS, // These exist for testing purposes and may be removed in the future. // Do not rely on them. "dkim-verify": dkimVerify, "dkim-sign": dkimSign, } cmd := args["$1"] if f, ok := commands[cmd]; ok { f() } else { fmt.Printf("Unknown argument %q\n", cmd) Fatalf(usage) } } // Fatalf prints the given message to stderr, then exits the program with an // error code. func Fatalf(s string, arg ...interface{}) { fmt.Fprintf(os.Stderr, s+"\n", arg...) os.Exit(1) } func userDBForDomain(domain string) string { if domain == "" { domain = args["$2"] } return configDir + "/domains/" + domain + "/users" } func userDBFromArgs(create bool) (string, string, *userdb.DB) { username := args["$2"] user, domain := envelope.Split(username) if domain == "" { Fatalf("Domain missing, username should be of the form 'user@domain'") } db, err := userdb.Load(userDBForDomain(domain)) if err != nil { if create && os.IsNotExist(err) { fmt.Println("Creating database") err = os.MkdirAll(filepath.Dir(userDBForDomain(domain)), 0755) if err != nil { Fatalf("Error creating database dir: %v", err) } } else { Fatalf("Error loading database: %v", err) } } user, err = normalize.User(user) if err != nil { Fatalf("Error normalizing user: %v", err) } return user, domain, db } // chasquid-util check-userdb <domain> func checkUserDB() { path := userDBForDomain("") // Check if the file exists. This is because userdb.Load does not consider // it an error. if _, err := os.Stat(path); os.IsNotExist(err) { Fatalf("Error: file %q does not exist", path) } udb, err := userdb.Load(path) if err != nil { Fatalf("Error loading database: %v", err) } fmt.Printf("Database loaded (%d users)\n", udb.Len()) } // chasquid-util user-add <user@domain> [--password=<password>] [--receive_only] func userAdd() { user, _, db := userDBFromArgs(true) _, recvOnly := args["--receive_only"] _, hasPassword := args["--password"] if recvOnly && hasPassword { Fatalf("Cannot specify both --receive_only and --password") } var err error if recvOnly { err = db.AddDeniedUser(user) } else { password := getPassword() err = db.AddUser(user, password) } if err != nil { Fatalf("Error adding user: %v", err) } err = db.Write() if err != nil { Fatalf("Error writing database: %v", err) } fmt.Println("Added user") } // chasquid-util authenticate <user@domain> [--password=<password>] func authenticate() { user, _, db := userDBFromArgs(false) password := getPassword() ok := db.Authenticate(user, password) if ok { fmt.Println("Authentication succeeded") } else { Fatalf("Authentication failed") } } func getPassword() string { password, ok := args["--password"] if ok { return password } fmt.Printf("Password: ") p1, err := term.ReadPassword(syscall.Stdin) fmt.Printf("\n") if err != nil { Fatalf("Error reading password: %v\n", err) } fmt.Printf("Confirm password: ") p2, err := term.ReadPassword(syscall.Stdin) fmt.Printf("\n") if err != nil { Fatalf("Error reading password: %v", err) } if !bytes.Equal(p1, p2) { Fatalf("Passwords don't match") } return string(p1) } // chasquid-util user-remove <user@domain> func userRemove() { user, _, db := userDBFromArgs(false) present := db.RemoveUser(user) if !present { Fatalf("Unknown user") } err := db.Write() if err != nil { Fatalf("Error writing database: %v", err) } fmt.Println("Removed user") } // chasquid-util aliases-resolve <address> func aliasesResolve() { conf, err := config.Load(configDir+"/chasquid.conf", "") if err != nil { Fatalf("Error loading config: %v", err) } c := localrpc.NewClient(conf.DataDir + "/localrpc-v1") vs, err := c.Call("AliasResolve", "Address", args["$2"]) if err != nil { Fatalf("Error resolving: %v", err) } // Result is a map of type -> []addresses. // Sort the types for deterministic output. ts := []string{} for t := range vs { ts = append(ts, t) } sort.Strings(ts) for _, t := range ts { for _, a := range vs[t] { fmt.Printf("%v %s\n", t, a) } } } // chasquid-util print-config func printConfig() { conf, err := config.Load(configDir+"/chasquid.conf", "") if err != nil { Fatalf("Error loading config: %v", err) } fmt.Println(prototext.Format(conf)) } // chasquid-util domaininfo-remove <domain> func domaininfoRemove() { conf, err := config.Load(configDir+"/chasquid.conf", "") if err != nil { Fatalf("Error loading config: %v", err) } c := localrpc.NewClient(conf.DataDir + "/localrpc-v1") _, err = c.Call("DomaininfoClear", "Domain", args["$2"]) if err != nil { Fatalf("Error removing domaininfo entry: %v", err) } } // parseArgs parses the command line arguments, and returns a map. // // Arguments starting with "-" will be parsed as key-value pairs, and // positional arguments will appear as "$POS" -> value. // // For example, "--abc=def x y -p=q -r" will result in: // {"--abc": "def", "$1": "x", "$2": "y", "-p": "q", "-r": ""} func parseArgs(usage string) map[string]string { args := map[string]string{} pos := 1 for _, a := range os.Args[1:] { // Note: Consider handling end of args marker "--" explicitly in // the future if needed. if strings.HasPrefix(a, "-") { sp := strings.SplitN(a, "=", 2) if len(sp) < 2 { args[a] = "" } else { args[sp[0]] = sp[1] } } else { args["$"+strconv.Itoa(pos)] = a pos++ } } return args }
package main import ( "bytes" "context" "crypto" "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/base64" "encoding/pem" "fmt" "io" "net/mail" "os" "path" "path/filepath" "strings" "time" "blitiri.com.ar/go/chasquid/internal/dkim" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/normalize" ) func dkimSign() { domain := args["$2"] selector := args["$3"] keyPath := args["$4"] msg, err := io.ReadAll(os.Stdin) if err != nil { Fatalf("%v", err) } msg = normalize.ToCRLF(msg) if domain == "" { domain = getDomainFromMsg(msg) } if selector == "" { selector = findSelectorForDomain(domain) } if keyPath == "" { keyPath = keyPathFor(domain, selector) } signer := &dkim.Signer{ Domain: domain, Selector: selector, Signer: loadPrivateKey(keyPath), } ctx := context.Background() if _, verbose := args["-v"]; verbose { ctx = dkim.WithTraceFunc(ctx, func(format string, args ...interface{}) { fmt.Fprintf(os.Stderr, format+"\n", args...) }) } header, err := signer.Sign(ctx, string(msg)) if err != nil { Fatalf("Error signing message: %v", err) } fmt.Printf("DKIM-Signature: %s\r\n", strings.ReplaceAll(header, "\r\n", "\r\n\t")) } func dkimVerify() { msg, err := io.ReadAll(os.Stdin) if err != nil { Fatalf("%v", err) } msg = normalize.ToCRLF(msg) ctx := context.Background() if _, verbose := args["-v"]; verbose { ctx = dkim.WithTraceFunc(ctx, func(format string, args ...interface{}) { fmt.Fprintf(os.Stderr, format+"\n", args...) }) } if txt, ok := args["--txt"]; ok { ctx = dkim.WithLookupTXTFunc(ctx, func(ctx context.Context, domain string) ([]string, error) { return []string{txt}, nil }) } results, err := dkim.VerifyMessage(ctx, string(msg)) if err != nil { Fatalf("Error verifying message: %v", err) } hostname, _ := os.Hostname() ar := "Authentication-Results: " + hostname + "\r\n\t" ar += strings.ReplaceAll( results.AuthenticationResults(), "\r\n", "\r\n\t") fmt.Println(ar) } func dkimDNS() { domain := args["$2"] selector := args["$3"] keyPath := args["$4"] if domain == "" { Fatalf("Error: missing domain parameter") } if selector == "" { selector = findSelectorForDomain(domain) } if keyPath == "" { keyPath = keyPathFor(domain, selector) } fmt.Println(dnsRecordFor(domain, selector, loadPrivateKey(keyPath))) } func dnsRecordFor(domain, selector string, private crypto.Signer) string { public := private.Public() var err error algoStr := "" pubBytes := []byte{} switch private.(type) { case *rsa.PrivateKey: algoStr = "rsa" pubBytes, err = x509.MarshalPKIXPublicKey(public) case ed25519.PrivateKey: algoStr = "ed25519" pubBytes = public.(ed25519.PublicKey) } if err != nil { Fatalf("Error marshaling public key: %v", err) } return fmt.Sprintf( "%s._domainkey.%s\tTXT\t\"v=DKIM1; k=%s; p=%s\"", selector, domain, algoStr, base64.StdEncoding.EncodeToString(pubBytes)) } func dkimKeygen() { domain := args["$2"] selector := args["$3"] keyPath := args["$4"] algo := args["--algo"] if domain == "" { Fatalf("Error: missing domain parameter") } if selector == "" { selector = time.Now().UTC().Format("20060102") } if keyPath == "" { keyPath = keyPathFor(domain, selector) } if _, err := os.Stat(keyPath); !os.IsNotExist(err) { Fatalf("Error: key already exists at %q", keyPath) } var private crypto.Signer var err error switch algo { case "", "rsa3072": private, err = rsa.GenerateKey(rand.Reader, 3072) case "rsa4096": private, err = rsa.GenerateKey(rand.Reader, 4096) case "ed25519": _, private, err = ed25519.GenerateKey(rand.Reader) default: Fatalf("Error: unsupported algorithm %q", algo) } if err != nil { Fatalf("Error generating key: %v", err) } privB, err := x509.MarshalPKCS8PrivateKey(private) if err != nil { Fatalf("Error marshaling private key: %v", err) } f, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0660) if err != nil { Fatalf("Error creating key file %q: %v", keyPath, err) } block := &pem.Block{ Type: "PRIVATE KEY", Bytes: privB, } if err := pem.Encode(f, block); err != nil { Fatalf("Error PEM-encoding key: %v", err) } f.Close() fmt.Printf("Key written to %q\n\n", keyPath) fmt.Println(dnsRecordFor(domain, selector, private)) } func keyPathFor(domain, selector string) string { return path.Clean(fmt.Sprintf("%s/domains/%s/dkim:%s.pem", configDir, domain, selector)) } func getDomainFromMsg(msg []byte) string { m, err := mail.ReadMessage(bytes.NewReader(msg)) if err != nil { Fatalf("Error parsing message: %v", err) } addr, err := mail.ParseAddress(m.Header.Get("From")) if err != nil { Fatalf("Error parsing From: header: %v", err) } return envelope.DomainOf(addr.Address) } func findSelectorForDomain(domain string) string { glob := path.Clean(configDir + "/domains/" + domain + "/dkim:*.pem") ms, err := filepath.Glob(glob) if err != nil { Fatalf("Error finding DKIM keys: %v", err) } for _, m := range ms { base := filepath.Base(m) selector := strings.TrimPrefix(base, "dkim:") selector = strings.TrimSuffix(selector, ".pem") return selector } Fatalf("No DKIM keys found in %q", glob) return "" } func loadPrivateKey(path string) crypto.Signer { key, err := os.ReadFile(path) if err != nil { Fatalf("Error reading private key from %q: %v", path, err) } block, _ := pem.Decode(key) if block == nil { Fatalf("Error decoding PEM block") } switch strings.ToUpper(block.Type) { case "PRIVATE KEY": k, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { Fatalf("Error parsing private key: %v", err) } return k.(crypto.Signer) default: Fatalf("Unsupported key type: %s", block.Type) return nil } }
// CLI used for testing the dovecot authentication package. // // NOT for production use. package main import ( "flag" "fmt" "os" "blitiri.com.ar/go/chasquid/internal/dovecot" ) const help = ` Usage: dovecot-auth-cli <path prefix> exists user@domain dovecot-auth-cli <path prefix> auth user@domain password Example: dovecot-auth-cli /var/run/dovecot/auth-chasquid exists user@domain dovecot-auth-cli /var/run/dovecot/auth-chasquid auth user@domain password ` func main() { flag.Parse() if len(flag.Args()) < 3 { fmt.Fprint(os.Stderr, help) fmt.Print("no: invalid arguments\n") return } a := dovecot.NewAuth(flag.Arg(0)+"-userdb", flag.Arg(0)+"-client") var ok bool var err error switch flag.Arg(1) { case "exists": ok, err = a.Exists(flag.Arg(2)) case "auth": ok, err = a.Authenticate(flag.Arg(2), flag.Arg(3)) default: err = fmt.Errorf("unknown subcommand %q", flag.Arg(1)) } if ok { fmt.Print("yes\n") return } fmt.Printf("no: %v\n", err) }
// Support for overriding DNS lookups, for testing purposes. // This is only used in tests, when the "dnsoverride" tag is active. // It requires Go >= 1.8. // //go:build dnsoverride // +build dnsoverride package main import ( "context" "flag" "net" "time" ) var ( dnsAddr = flag.String("testing__dns_addr", "127.0.0.1:9053", "DNS server address to use, for testing purposes only") ) var dialer = &net.Dialer{ // We're going to talk to localhost, so have a short timeout so we fail // fast. Otherwise the callers might hang indefinitely when trying to // dial the DNS server. Timeout: 2 * time.Second, } func dial(ctx context.Context, network, address string) (net.Conn, error) { return dialer.DialContext(ctx, network, *dnsAddr) } func init() { // Override the resolver to talk with our local server for testing. net.DefaultResolver.PreferGo = true net.DefaultResolver.Dial = dial }
// Package aliases implements an email aliases resolver. // // The resolver can parse many files for different domains, and perform // lookups to resolve the aliases. // // # File format // // It generally follows the traditional aliases format used by sendmail and // exim. // // The file can contain lines of the form: // // user: address, address // user: | command // // Lines starting with "#" are ignored, as well as empty lines. // User names cannot contain spaces, ":" or commas, for parsing reasons. This // is a tradeoff between flexibility and keeping the file format easy to edit // for people. // // User names will be normalized internally to lower-case. // // Usually there will be one database per domain, and there's no need to // include the "@" in the user (in this case, "@" will be forbidden). // // If the user is the string "*", then it is considered a "catch-all alias": // emails that don't match any known users or other aliases will be sent here. // // # Recipients // // Recipients can be of different types: // - Email: the usual user@domain we all know and love, this is the default. // - Pipe: if the right side starts with "| ", the rest of the line specifies // a command to pipe the email through. // Command and arguments are space separated. No quoting, escaping, or // replacements of any kind. // // # Lookups // // The resolver will perform lookups recursively, until it finds all the final // recipients. // // There are recursion limits to avoid alias loops. If the limit is reached, // the entire resolution will fail. // // # Suffix removal // // The resolver can also remove suffixes from emails, and drop characters // completely. This can be used to turn "user+blah@domain" into "user@domain", // and "us.er@domain" into "user@domain". // // Both are optional, and the characters configurable globally. // // There are more complex semantics around handling of drop characters and // suffixes, see the documentation for more details. package aliases import ( "bufio" "context" "fmt" "io" "os" "os/exec" "strings" "sync" "time" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/expvarom" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/trace" ) // Exported variables. var ( hookResults = expvarom.NewMap("chasquid/aliases/hookResults", "result", "count of aliases hook results, by hook and result") ) // Recipient represents a single recipient, after resolving aliases. // They don't have any special interface, the callers will do a type switch // anyway. type Recipient struct { Addr string Type RType } // RType represents a recipient type, see the constants below for valid values. type RType string // Valid recipient types. const ( EMAIL RType = "(email)" PIPE RType = "(pipe)" ) var ( // ErrRecursionLimitExceeded is returned when the resolving lookup // exceeded the recursion limit. Usually caused by aliases loops. ErrRecursionLimitExceeded = fmt.Errorf("recursion limit exceeded") // How many levels of recursions we allow during lookups. // We don't expect much recursion, so keeping this low to catch errors // quickly. recursionLimit = 10 ) // Type of the "does this user exist" function", for convenience. type existsFn func(tr *trace.Trace, user, domain string) (bool, error) // Resolver represents the aliases resolver. type Resolver struct { // Suffix separator, to perform suffix removal. SuffixSep string // Characters to drop from the user part. DropChars string // Path to the resolve hook. ResolveHook string // Function to check if a user exists in the userdb. userExistsInDB existsFn // Map of domain -> alias files for that domain. // We keep track of them for reloading purposes. files map[string][]string domains map[string]bool // Map of address -> aliases. aliases map[string][]Recipient // Mutex protecting the structure. mu sync.Mutex } // NewResolver returns a new, empty Resolver. func NewResolver(userExists existsFn) *Resolver { return &Resolver{ files: map[string][]string{}, domains: map[string]bool{}, aliases: map[string][]Recipient{}, userExistsInDB: userExists, } } // Resolve the given address, returning the list of corresponding recipients // (if any). func (v *Resolver) Resolve(tr *trace.Trace, addr string) ([]Recipient, error) { tr = tr.NewChild("Alias.Resolve", addr) defer tr.Finish() return v.resolve(0, addr, tr) } // Exists check that the address exists in the database. It must only be // called for local addresses. func (v *Resolver) Exists(tr *trace.Trace, addr string) bool { tr = tr.NewChild("Alias.Exists", addr) defer tr.Finish() // First, see if there's an exact match in the database. // This allows us to have aliases that include suffixes in them, and have // them take precedence. rcpts, _ := v.lookup(addr, tr) if len(rcpts) > 0 { return true } // "Clean" the address, removing drop characters and suffixes, and try // again. addr = v.RemoveDropsAndSuffix(addr) rcpts, _ = v.lookup(addr, tr) if len(rcpts) > 0 { return true } domain := envelope.DomainOf(addr) catchAll, _ := v.lookup("*@"+domain, tr) if len(catchAll) > 0 { return true } return false } func (v *Resolver) lookup(addr string, tr *trace.Trace) ([]Recipient, error) { // Do a lookup in the aliases map. Note we remove drop characters first, // which matches what we did at parsing time. Suffixes, if any, are left // as-is; that is handled by the callers. clean := v.RemoveDropCharacters(addr) v.mu.Lock() rcpts := v.aliases[clean] v.mu.Unlock() // Augment with the hook results. // Note we use the original address, to give maximum flexibility to the // hooks. hr, err := v.runResolveHook(tr, addr) if err != nil { tr.Debugf("lookup(%q) hook error: %v", addr, err) return nil, err } tr.Debugf("lookup(%q) -> %v + %v", addr, rcpts, hr) return append(rcpts, hr...), nil } func (v *Resolver) resolve(rcount int, addr string, tr *trace.Trace) ([]Recipient, error) { tr.Debugf("%d| resolve(%d, %q)", rcount, rcount, addr) if rcount >= recursionLimit { return nil, ErrRecursionLimitExceeded } // If the address is not local, we return it as-is, so delivery is // attempted against it. // Example: an alias that resolves to a non-local address. user, domain := envelope.Split(addr) if _, ok := v.domains[domain]; !ok { tr.Debugf("%d| non-local domain, returning %q", rcount, addr) return []Recipient{{addr, EMAIL}}, nil } // First, see if there's an exact match in the database. // This allows us to have aliases that include suffixes in them, and have // them take precedence. rcpts, err := v.lookup(addr, tr) if err != nil { tr.Debugf("%d| error in lookup: %v", rcount, err) return nil, err } if len(rcpts) == 0 { // Retry after removing drop characters and suffixes. // This also means that we will return the clean version if there's no // match, which our callers can rely upon. addr = v.RemoveDropsAndSuffix(addr) rcpts, err = v.lookup(addr, tr) if err != nil { tr.Debugf("%d| error in lookup: %v", rcount, err) return nil, err } } // No alias for this local address. if len(rcpts) == 0 { tr.Debugf("%d| no alias found", rcount) // If the user exists, then use it as-is, no need to recurse further. ok, err := v.userExistsInDB(tr, user, domain) if err != nil { tr.Debugf("%d| error checking if user exists: %v", rcount, err) return nil, err } if ok { tr.Debugf("%d| user exists, returning %q", rcount, addr) return []Recipient{{addr, EMAIL}}, nil } catchAll, err := v.lookup("*@"+domain, tr) if err != nil { tr.Debugf("%d| error in catchall lookup: %v", rcount, err) return nil, err } if len(catchAll) > 0 { // If there's a catch-all, then use it and keep resolving // recursively (since the catch-all destination could be an // alias). tr.Debugf("%d| using catch-all: %v", rcount, catchAll) rcpts = catchAll } else { // Otherwise, return the original address unchanged. // The caller will handle that situation, and we don't need to // invalidate the whole resolution (there could be other valid // aliases). // The queue will attempt delivery against this local (but // evidently non-existing) address, and the courier will emit a // clearer failure, re-using the existing codepaths and // simplifying the logic. tr.Debugf("%d| no catch-all, returning %q", rcount, addr) return []Recipient{{addr, EMAIL}}, nil } } ret := []Recipient{} for _, r := range rcpts { // Only recurse for email recipients. if r.Type != EMAIL { ret = append(ret, r) continue } ar, err := v.resolve(rcount+1, r.Addr, tr) if err != nil { tr.Debugf("%d| resolve(%q) returned error: %v", rcount, r.Addr, err) return nil, err } ret = append(ret, ar...) } tr.Debugf("%d| returning %v", rcount, ret) return ret, nil } // Remove drop characters, but only up to the first suffix separator. func (v *Resolver) RemoveDropCharacters(addr string) string { user, domain := envelope.Split(addr) // Remove drop characters up to the first suffix separator. firstSuffixSep := strings.IndexAny(user, v.SuffixSep) if firstSuffixSep == -1 { firstSuffixSep = len(user) } nu := "" for _, c := range user[:firstSuffixSep] { if !strings.ContainsRune(v.DropChars, c) { nu += string(c) } } // Copy any remaining suffix as-is. if firstSuffixSep < len(user) { nu += user[firstSuffixSep:] } nu, _ = normalize.User(nu) return nu + "@" + domain } func (v *Resolver) RemoveDropsAndSuffix(addr string) string { user, domain := envelope.Split(addr) user = removeAllAfter(user, v.SuffixSep) user = removeChars(user, v.DropChars) user, _ = normalize.User(user) return user + "@" + domain } // AddDomain to the resolver, registering its existence. func (v *Resolver) AddDomain(domain string) { v.mu.Lock() v.domains[domain] = true v.mu.Unlock() } // AddAliasesFile to the resolver. The file will be parsed, and an error // returned if it does not parse correctly. Note that the file not existing // does NOT result in an error. func (v *Resolver) AddAliasesFile(domain, path string) (int, error) { // We unconditionally add the domain and file on our list. // Even if the file does not exist now, it may later. This makes it be // consider when doing Reload. // Adding it to the domains mean that we will do drop character and suffix // manipulation even if there are no aliases for it. v.mu.Lock() v.files[domain] = append(v.files[domain], path) v.domains[domain] = true v.mu.Unlock() aliases, err := v.parseFile(domain, path) if os.IsNotExist(err) { return 0, nil } if err != nil { return 0, err } // Add the aliases to the resolver, overriding any previous values. v.mu.Lock() for addr, rs := range aliases { v.aliases[addr] = rs } v.mu.Unlock() return len(aliases), nil } // AddAliasForTesting adds an alias to the resolver, for testing purposes. // Not for use in production code. func (v *Resolver) AddAliasForTesting(addr, rcpt string, rType RType) { v.aliases[addr] = append(v.aliases[addr], Recipient{rcpt, rType}) } // Reload aliases files for all known domains. func (v *Resolver) Reload() error { newAliases := map[string][]Recipient{} for domain, paths := range v.files { for _, path := range paths { aliases, err := v.parseFile(domain, path) if os.IsNotExist(err) { continue } if err != nil { return fmt.Errorf("error parsing %q: %v", path, err) } // Add the aliases to the resolver, overriding any previous values. for addr, rs := range aliases { newAliases[addr] = rs } } } v.mu.Lock() v.aliases = newAliases v.mu.Unlock() return nil } func (v *Resolver) parseFile(domain, path string) (map[string][]Recipient, error) { f, err := os.Open(path) if err != nil { return nil, err } defer f.Close() aliases, err := v.parseReader(domain, f) if err != nil { return nil, fmt.Errorf("reading %q: %v", path, err) } return aliases, nil } func (v *Resolver) parseReader(domain string, r io.Reader) (map[string][]Recipient, error) { aliases := map[string][]Recipient{} scanner := bufio.NewScanner(r) for i := 1; scanner.Scan(); i++ { line := strings.TrimSpace(scanner.Text()) if strings.HasPrefix(line, "#") { continue } sp := strings.SplitN(line, ":", 2) if len(sp) != 2 { continue } addr, rawalias := strings.TrimSpace(sp[0]), strings.TrimSpace(sp[1]) if len(addr) == 0 || len(rawalias) == 0 { continue } if strings.Contains(addr, "@") { // It's invalid for lhs addresses to contain @ (for now). continue } // We remove DropChars from the address, but leave the suffixes (if // any). This matches the behaviour expected by Exists and Resolve, // see the documentation for more details. addr = addr + "@" + domain addr = v.RemoveDropCharacters(addr) addr, _ = normalize.Addr(addr) rs := parseRHS(rawalias, domain) aliases[addr] = rs } return aliases, scanner.Err() } func parseRHS(rawalias, domain string) []Recipient { if len(rawalias) == 0 { return nil } if rawalias[0] == '|' { cmd := strings.TrimSpace(rawalias[1:]) if cmd == "" { // A pipe alias without a command is invalid. return nil } return []Recipient{{cmd, PIPE}} } rs := []Recipient{} for _, a := range strings.Split(rawalias, ",") { a = strings.TrimSpace(a) if a == "" { continue } // Addresses with no domain get the current one added, so it's // easier to share alias files. if !strings.Contains(a, "@") { a = a + "@" + domain } a, _ = normalize.Addr(a) rs = append(rs, Recipient{a, EMAIL}) } return rs } // removeAllAfter removes everything from s that comes after the separators, // including them. func removeAllAfter(s, seps string) string { for _, c := range strings.Split(seps, "") { if c == "" { continue } i := strings.Index(s, c) if i == -1 { continue } s = s[:i] } return s } // removeChars removes the runes in "chars" from s. func removeChars(s, chars string) string { for _, c := range strings.Split(chars, "") { s = strings.Replace(s, c, "", -1) } return s } func (v *Resolver) runResolveHook(tr *trace.Trace, addr string) ([]Recipient, error) { if v.ResolveHook == "" { hookResults.Add("resolve:notset", 1) return nil, nil } // TODO: check if the file is executable. if _, err := os.Stat(v.ResolveHook); os.IsNotExist(err) { hookResults.Add("resolve:skip", 1) return nil, nil } // TODO: this should be done via a context propagated all the way through. tr = tr.NewChild("Hook.Alias-Resolve", addr) defer tr.Finish() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() cmd := exec.CommandContext(ctx, v.ResolveHook, addr) outb, err := cmd.Output() out := string(outb) tr.Debugf("stdout: %q", out) if err != nil { hookResults.Add("resolve:fail", 1) tr.Error(err) return nil, err } // Extract recipients from the output. // Same format as the right hand side of aliases file, see parseRHS. domain := envelope.DomainOf(addr) raw := strings.TrimSpace(out) rs := parseRHS(raw, domain) tr.Debugf("recipients: %v", rs) hookResults.Add("resolve:success", 1) return rs, nil }
// Package auth implements authentication services for chasquid. package auth import ( "bytes" "encoding/base64" "errors" "fmt" "math/rand" "strings" "time" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/trace" ) // Backend is the common interface for all authentication backends. type Backend interface { Authenticate(user, password string) (bool, error) Exists(user string) (bool, error) Reload() error } // NoErrorBackend is the interface for authentication backends that don't need // to emit errors. This allows backends to avoid unnecessary complexity, in // exchange for a bit more here. // They can be converted to normal Backend using WrapNoErrorBackend (defined // below). type NoErrorBackend interface { Authenticate(user, password string) bool Exists(user string) bool Reload() error } // Authenticator tracks the backends for each domain, and allows callers to // query them with a more practical API. type Authenticator struct { // Registered backends, map of domain (string) -> Backend. // Backend operations will _not_ include the domain in the username. backends map[string]Backend // Fallback backend, to use when backends[domain] (which may not exist) // did not yield a positive result. // Note that this backend gets the user with the domain included, of the // form "user@domain" (if available). Fallback Backend // How long Authenticate calls should last, approximately. // This will be applied both for successful and unsuccessful attempts. // We will increase this number by 0-20%. AuthDuration time.Duration } // NewAuthenticator returns a new Authenticator with no backends. func NewAuthenticator() *Authenticator { return &Authenticator{ backends: map[string]Backend{}, AuthDuration: 100 * time.Millisecond, } } // Register a backend to use for the given domain. func (a *Authenticator) Register(domain string, be Backend) { a.backends[domain] = be } // Authenticate the user@domain with the given password. func (a *Authenticator) Authenticate(tr *trace.Trace, user, domain, password string) (bool, error) { tr = tr.NewChild("Auth.Authenticate", user+"@"+domain) defer tr.Finish() // Make sure the call takes a.AuthDuration + 0-20% regardless of the // outcome, to prevent basic timing attacks. defer func(start time.Time) { elapsed := time.Since(start) delay := a.AuthDuration - elapsed if delay > 0 { maxDelta := int64(float64(delay) * 0.2) delay += time.Duration(rand.Int63n(maxDelta)) time.Sleep(delay) } }(time.Now()) if be, ok := a.backends[domain]; ok { ok, err := be.Authenticate(user, password) tr.Debugf("Backend: %v %v", ok, err) if ok || err != nil { return ok, err } } if a.Fallback != nil { id := user if domain != "" { id = user + "@" + domain } ok, err := a.Fallback.Authenticate(id, password) tr.Debugf("Fallback: %v %v", ok, err) return ok, err } tr.Debugf("Rejected by default") return false, nil } // Exists checks that user@domain exists. func (a *Authenticator) Exists(tr *trace.Trace, user, domain string) (bool, error) { tr = tr.NewChild("Auth.Exists", user+"@"+domain) defer tr.Finish() if be, ok := a.backends[domain]; ok { ok, err := be.Exists(user) tr.Debugf("Backend: %v %v", ok, err) if ok || err != nil { return ok, err } } if a.Fallback != nil { id := user if domain != "" { id = user + "@" + domain } ok, err := a.Fallback.Exists(id) tr.Debugf("Fallback: %v %v", ok, err) return ok, err } tr.Debugf("Rejected by default") return false, nil } // Reload the registered backends. func (a *Authenticator) Reload() error { msgs := []string{} for domain, be := range a.backends { tr := trace.New("Auth.Reload", domain) err := be.Reload() if err != nil { tr.Error(err) msgs = append(msgs, fmt.Sprintf("%q: %v", domain, err)) } tr.Finish() } if a.Fallback != nil { tr := trace.New("Auth.Reload", "<fallback>") err := a.Fallback.Reload() if err != nil { tr.Error(err) msgs = append(msgs, fmt.Sprintf("<fallback>: %v", err)) } tr.Finish() } if len(msgs) > 0 { return errors.New(strings.Join(msgs, " ; ")) } return nil } // DecodeResponse decodes a plain auth response. // // It must be a a base64-encoded string of the form: // // <authorization id> NUL <authentication id> NUL <password> // // https://tools.ietf.org/html/rfc4954#section-4.1. // // Either both IDs match, or one of them is empty. // // We split the id into user@domain, since in most cases we expect that to be // the used form, and normalize them. If there is no domain, we just return // "" for it. The rest of the stack will know how to handle it. func DecodeResponse(response string) (user, domain, passwd string, err error) { buf, err := base64.StdEncoding.DecodeString(response) if err != nil { return } bufsp := bytes.SplitN(buf, []byte{0}, 3) if len(bufsp) != 3 { err = fmt.Errorf("response pieces != 3, as per RFC") return } identity := "" passwd = string(bufsp[2]) { // We don't make the distinction between the two IDs, as long as one is // empty, or they're the same. z := string(bufsp[0]) c := string(bufsp[1]) // If neither is empty, then they must be the same. if (z != "" && c != "") && (z != c) { err = fmt.Errorf("auth IDs do not match") return } if z != "" { identity = z } if c != "" { identity = c } } if identity == "" { err = fmt.Errorf("empty identity, must be in the form user@domain") return } // Split identity into "user@domain", if possible. user = identity idsp := strings.SplitN(identity, "@", 2) if len(idsp) >= 2 { user = idsp[0] domain = idsp[1] } // Normalize the user and domain. This is so users can write the username // in their own style and still can log in. For the domain, we use IDNA // and relevant transformations to turn it to utf8 which is what we use // internally. user, err = normalize.User(user) if err != nil { return } domain, err = normalize.Domain(domain) if err != nil { return } return } // WrapNoErrorBackend wraps a NoErrorBackend, converting it into a valid // Backend. This is normally used in Auth.Register calls, to register no-error // backends. func WrapNoErrorBackend(be NoErrorBackend) Backend { return &wrapNoErrorBackend{be} } type wrapNoErrorBackend struct { be NoErrorBackend } func (w *wrapNoErrorBackend) Authenticate(user, password string) (bool, error) { return w.be.Authenticate(user, password), nil } func (w *wrapNoErrorBackend) Exists(user string) (bool, error) { return w.be.Exists(user), nil } func (w *wrapNoErrorBackend) Reload() error { return w.be.Reload() }
// Package config implements the chasquid configuration. package config // Generate the config protobuf. //go:generate protoc --go_out=. --go_opt=paths=source_relative --experimental_allow_proto3_optional config.proto import ( "fmt" "os" "blitiri.com.ar/go/log" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) var defaultConfig = &Config{ MaxDataSizeMb: 50, SmtpAddress: []string{"systemd"}, SubmissionAddress: []string{"systemd"}, SubmissionOverTlsAddress: []string{"systemd"}, MailDeliveryAgentBin: "maildrop", MailDeliveryAgentArgs: []string{"-f", "%from%", "-d", "%to_user%"}, DataDir: "/var/lib/chasquid", SuffixSeparators: proto.String("+"), DropCharacters: proto.String("."), MailLogPath: "<syslog>", } // Load the config from the given file, with the given overrides. func Load(path, overrides string) (*Config, error) { // Start with a copy of the default config. c := proto.Clone(defaultConfig).(*Config) // Load from the path. buf, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("failed to read config at %q: %v", path, err) } fromFile := &Config{} err = prototext.Unmarshal(buf, fromFile) if err != nil { return nil, fmt.Errorf("parsing config: %v", err) } override(c, fromFile) // Handle command line overrides. fromOverrides := &Config{} err = prototext.Unmarshal([]byte(overrides), fromOverrides) if err != nil { return nil, fmt.Errorf("parsing override: %v", err) } override(c, fromOverrides) // Handle hostname separate, because if it is set, we don't need to call // os.Hostname which can fail. if c.Hostname == "" { c.Hostname, err = os.Hostname() if err != nil { return nil, fmt.Errorf("could not get hostname: %v", err) } } return c, nil } // Override fields in `c` that are set in `o`. We can't use proto.Merge // because the semantics would not be convenient for overriding. func override(c, o *Config) { if o.Hostname != "" { c.Hostname = o.Hostname } if o.MaxDataSizeMb > 0 { c.MaxDataSizeMb = o.MaxDataSizeMb } if len(o.SmtpAddress) > 0 { c.SmtpAddress = o.SmtpAddress } if len(o.SubmissionAddress) > 0 { c.SubmissionAddress = o.SubmissionAddress } if len(o.SubmissionOverTlsAddress) > 0 { c.SubmissionOverTlsAddress = o.SubmissionOverTlsAddress } if o.MonitoringAddress != "" { c.MonitoringAddress = o.MonitoringAddress } if o.MailDeliveryAgentBin != "" { c.MailDeliveryAgentBin = o.MailDeliveryAgentBin } if len(o.MailDeliveryAgentArgs) > 0 { c.MailDeliveryAgentArgs = o.MailDeliveryAgentArgs } if o.DataDir != "" { c.DataDir = o.DataDir } if o.SuffixSeparators != nil { c.SuffixSeparators = o.SuffixSeparators } if o.DropCharacters != nil { c.DropCharacters = o.DropCharacters } if o.MailLogPath != "" { c.MailLogPath = o.MailLogPath } if o.DovecotAuth { c.DovecotAuth = true } if o.DovecotUserdbPath != "" { c.DovecotUserdbPath = o.DovecotUserdbPath } if o.DovecotClientPath != "" { c.DovecotClientPath = o.DovecotClientPath } if o.HaproxyIncoming { c.HaproxyIncoming = true } } // LogConfig logs the given configuration, in a human-friendly way. func LogConfig(c *Config) { log.Infof("Configuration:") log.Infof(" Hostname: %q", c.Hostname) log.Infof(" Max data size (MB): %d", c.MaxDataSizeMb) log.Infof(" SMTP Addresses: %q", c.SmtpAddress) log.Infof(" Submission Addresses: %q", c.SubmissionAddress) log.Infof(" Submission+TLS Addresses: %q", c.SubmissionOverTlsAddress) log.Infof(" Monitoring address: %q", c.MonitoringAddress) log.Infof(" MDA: %q %q", c.MailDeliveryAgentBin, c.MailDeliveryAgentArgs) log.Infof(" Data directory: %q", c.DataDir) if c.SuffixSeparators == nil { log.Infof(" Suffix separators: nil") } else { log.Infof(" Suffix separators: %q", *c.SuffixSeparators) } if c.DropCharacters == nil { log.Infof(" Drop characters: nil") } else { log.Infof(" Drop characters: %q", *c.DropCharacters) } log.Infof(" Mail log: %q", c.MailLogPath) log.Infof(" Dovecot auth: %v (%q, %q)", c.DovecotAuth, c.DovecotUserdbPath, c.DovecotClientPath) log.Infof(" HAProxy incoming: %v", c.HaproxyIncoming) }
package courier import ( "bytes" "context" "fmt" "os/exec" "strings" "syscall" "time" "unicode" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/trace" ) var ( errTimeout = fmt.Errorf("operation timed out") ) // MDA delivers local mail by executing a local binary, like procmail or // maildrop. It works with any binary that: // - Receives the email to deliver via stdin. // - Exits with code EX_TEMPFAIL (75) for transient issues. type MDA struct { Binary string // Path to the binary. Args []string // Arguments to pass. Timeout time.Duration // Timeout for each invocation. } // Deliver an email. On failures, returns an error, and whether or not it is // permanent. func (p *MDA) Deliver(from string, to string, data []byte) (error, bool) { tr := trace.New("Courier.MDA", to) defer tr.Finish() // Sanitize, just in case. from = sanitizeForMDA(from) to = sanitizeForMDA(to) tr.Debugf("%s -> %s", from, to) // Prepare the command, replacing the necessary arguments. replacer := strings.NewReplacer( "%from%", from, "%from_user%", envelope.UserOf(from), "%from_domain%", envelope.DomainOf(from), "%to%", to, "%to_user%", envelope.UserOf(to), "%to_domain%", envelope.DomainOf(to), ) args := []string{} for _, a := range p.Args { args = append(args, replacer.Replace(a)) } tr.Debugf("%s %q", p.Binary, args) ctx, cancel := context.WithTimeout(context.Background(), p.Timeout) defer cancel() cmd := exec.CommandContext(ctx, p.Binary, args...) // Pass the email data via stdin. Normalize it to CRLF which is what the // RFC-compliant representation require. By doing this at this end, we can // keep a simpler internal representation and ensure there won't be any // inconsistencies in newlines within the message (e.g. added headers). cmd.Stdin = bytes.NewReader(normalize.ToCRLF(data)) output, err := cmd.CombinedOutput() if ctx.Err() == context.DeadlineExceeded { return tr.Error(errTimeout), false } if err != nil { // Determine if the error is permanent or not. // Default to permanent, but error code 75 is transient by general // convention (/usr/include/sysexits.h), and commonly relied upon. permanent := true if exiterr, ok := err.(*exec.ExitError); ok { if status, ok := exiterr.Sys().(syscall.WaitStatus); ok { permanent = status.ExitStatus() != 75 } } err = tr.Errorf("MDA delivery failed: %v - %q", err, string(output)) return err, permanent } tr.Debugf("delivered") return nil, false } // sanitizeForMDA cleans the string, removing characters that could be // problematic considering we will run an external command. // // The server does not rely on this to do substitution or proper filtering, // that's done at a different layer; this is just for defense in depth. func sanitizeForMDA(s string) string { valid := func(r rune) rune { switch { case unicode.IsSpace(r), unicode.IsControl(r), strings.ContainsRune("/;\"'\\|*&$%()[]{}`!", r): return rune(-1) default: return r } } return strings.Map(valid, s) }
package courier import ( "context" "crypto/tls" "crypto/x509" "flag" "net" "time" "golang.org/x/net/idna" "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/expvarom" "blitiri.com.ar/go/chasquid/internal/smtp" "blitiri.com.ar/go/chasquid/internal/sts" "blitiri.com.ar/go/chasquid/internal/trace" ) var ( // Timeouts for SMTP delivery. smtpDialTimeout = 1 * time.Minute smtpTotalTimeout = 10 * time.Minute // Port for outgoing SMTP. // Tests can override this. smtpPort = flag.String("testing__outgoing_smtp_port", "25", "port to use for outgoing SMTP connections, ONLY FOR TESTING") // Allow overriding of net.LookupMX for testing purposes. // TODO: replace this with proper lookup interception once it is supported // by Go. netLookupMX = net.LookupMX ) // Exported variables. var ( tlsCount = expvarom.NewMap("chasquid/smtpOut/tlsCount", "result", "count of TLS status on outgoing connections") slcResults = expvarom.NewMap("chasquid/smtpOut/securityLevelChecks", "result", "count of security level checks on outgoing connections") stsSecurityModes = expvarom.NewMap("chasquid/smtpOut/sts/mode", "mode", "count of STS checks on outgoing connections") stsSecurityResults = expvarom.NewMap("chasquid/smtpOut/sts/security", "result", "count of STS security checks on outgoing connections") ) // SMTP delivers remote mail via outgoing SMTP. type SMTP struct { HelloDomain string Dinfo *domaininfo.DB STSCache *sts.PolicyCache } // Deliver an email. On failures, returns an error, and whether or not it is // permanent. func (s *SMTP) Deliver(from string, to string, data []byte) (error, bool) { a := &attempt{ courier: s, from: from, to: to, toDomain: envelope.DomainOf(to), data: data, tr: trace.New("Courier.SMTP", to), } defer a.tr.Finish() a.tr.Debugf("%s -> %s", from, to) // smtp.Client.Mail will add the <> for us when the address is empty. if a.from == "<>" { a.from = "" } mxs, err, perm := lookupMXs(a.tr, a.toDomain) if err != nil || len(mxs) == 0 { // Note this is considered a permanent error. // This is in line with what other servers (Exim) do. However, the // downside is that temporary DNS issues can affect delivery, so we // have to make sure we try hard enough on the lookup above. return a.tr.Errorf("Could not find mail server: %v", err), perm } a.stsPolicy = s.fetchSTSPolicy(a.tr, a.toDomain) for _, mx := range mxs { if a.stsPolicy != nil && !a.stsPolicy.MXIsAllowed(mx) { a.tr.Printf("%q skipped as per MTA-STA policy", mx) continue } var permanent bool err, permanent = a.deliver(mx) if err == nil { return nil, false } if permanent { return err, true } a.tr.Errorf("%q returned transient error: %v", mx, err) } // We exhausted all MXs failed to deliver, try again later. return a.tr.Errorf("all MXs returned transient failures (last: %v)", err), false } type attempt struct { courier *SMTP from string to string data []byte toDomain string stsPolicy *sts.Policy tr *trace.Trace } func (a *attempt) deliver(mx string) (error, bool) { skipTLS := false retry: conn, err := net.DialTimeout("tcp", mx+":"+*smtpPort, smtpDialTimeout) if err != nil { return a.tr.Errorf("Could not dial: %v", err), false } defer conn.Close() conn.SetDeadline(time.Now().Add(smtpTotalTimeout)) c, err := smtp.NewClient(conn, mx) if err != nil { return a.tr.Errorf("Error creating client: %v", err), false } if err = c.Hello(a.courier.HelloDomain); err != nil { return a.tr.Errorf("Error saying hello: %v", err), false } secLevel := domaininfo.SecLevel_PLAIN if ok, _ := c.Extension("STARTTLS"); ok && !skipTLS { config := &tls.Config{ ServerName: mx, // Unfortunately, many servers use self-signed and invalid // certificates. So we use a custom verification (identical to // Go's) to distinguish between invalid and valid certificates. // That information is used to track the security level, to // prevent downgrade attacks. InsecureSkipVerify: true, VerifyConnection: func(cs tls.ConnectionState) error { secLevel = a.verifyConnection(cs) return nil }, } err = c.StartTLS(config) if err != nil { // If we could not complete a jump to TLS (either because the // STARTTLS command itself failed server-side, or because we got a // TLS negotiation error), retry but without trying to use TLS. // This should be quite rare, but it can happen if the server // certificate is not parseable by the Go library, or if it has a // broken TLS stack. // Note that invalid and self-signed certs do NOT fall in this // category, those are handled by the VerifyConnection function // above, and don't need a retry. This is only needed for lower // level errors. tlsCount.Add("tls:failed", 1) a.tr.Errorf("TLS error, retrying without TLS: %v", err) skipTLS = true conn.Close() goto retry } } else { tlsCount.Add("plain", 1) a.tr.Debugf("Insecure - NOT using TLS") } if !a.courier.Dinfo.OutgoingSecLevel(a.tr, a.toDomain, secLevel) { // We consider the failure transient, so transient misconfigurations // do not affect deliveries. slcResults.Add("fail", 1) return a.tr.Errorf("Security level check failed (level:%s)", secLevel), false } slcResults.Add("pass", 1) if a.stsPolicy != nil && a.stsPolicy.Mode == sts.Enforce { // The connection MUST be validated by TLS. // https://tools.ietf.org/html/rfc8461#section-4.2 if secLevel != domaininfo.SecLevel_TLS_SECURE { stsSecurityResults.Add("fail", 1) return a.tr.Errorf("invalid security level (%v) for STS policy", secLevel), false } stsSecurityResults.Add("pass", 1) a.tr.Debugf("STS policy: connection is using valid TLS") } if err = c.MailAndRcpt(a.from, a.to); err != nil { return a.tr.Errorf("MAIL+RCPT %v", err), smtp.IsPermanent(err) } w, err := c.Data() if err != nil { return a.tr.Errorf("DATA %v", err), smtp.IsPermanent(err) } _, err = w.Write(a.data) if err != nil { return a.tr.Errorf("DATA writing: %v", err), smtp.IsPermanent(err) } err = w.Close() if err != nil { return a.tr.Errorf("DATA closing %v", err), smtp.IsPermanent(err) } _ = c.Quit() a.tr.Debugf("done") return nil, false } // CA roots to validate against, so we can override it for testing. var certRoots *x509.CertPool = nil func (a *attempt) verifyConnection(cs tls.ConnectionState) domaininfo.SecLevel { // Validate certificates, using the same logic Go does, and following the // official example at // https://pkg.go.dev/crypto/tls#example-Config-VerifyConnection. opts := x509.VerifyOptions{ DNSName: cs.ServerName, Intermediates: x509.NewCertPool(), Roots: certRoots, } for _, cert := range cs.PeerCertificates[1:] { opts.Intermediates.AddCert(cert) } _, err := cs.PeerCertificates[0].Verify(opts) if err != nil { // Invalid TLS cert, since it could not be verified. a.tr.Debugf("Insecure - using TLS, but with an invalid cert") tlsCount.Add("tls:insecure", 1) return domaininfo.SecLevel_TLS_INSECURE } else { tlsCount.Add("tls:secure", 1) a.tr.Debugf("Secure - using TLS") return domaininfo.SecLevel_TLS_SECURE } } func (s *SMTP) fetchSTSPolicy(tr *trace.Trace, domain string) *sts.Policy { if s.STSCache == nil { return nil } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() policy, err := s.STSCache.Fetch(ctx, domain) if err != nil { return nil } tr.Debugf("got STS policy") stsSecurityModes.Add(string(policy.Mode), 1) return policy } func lookupMXs(tr *trace.Trace, domain string) ([]string, error, bool) { domain, err := idna.ToASCII(domain) if err != nil { return nil, err, true } mxs := []string{} mxRecords, err := netLookupMX(domain) if err != nil { // There was an error. It could be that the domain has no MX, in which // case we have to fall back to A, or a bigger problem. dnsErr, ok := err.(*net.DNSError) if !ok { tr.Debugf("Error resolving MX on %q: %v", domain, err) return nil, err, false } else if dnsErr.IsNotFound { // MX not found, fall back to A. tr.Debugf("MX for %s not found, falling back to A", domain) mxs = []string{domain} } else { tr.Debugf("MX lookup error on %q: %v", domain, dnsErr) return nil, err, !dnsErr.Temporary() } } else { // Convert the DNS records to a plain string slice. They're already // sorted by priority. for _, r := range mxRecords { mxs = append(mxs, r.Host) } } // Note that mxs could be empty; in that case we do NOT fall back to A. // This case is explicitly covered by the SMTP RFC. // https://tools.ietf.org/html/rfc5321#section-5.1 // Cap the list of MXs to 5 hosts, to keep delivery attempt times // sane and prevent abuse. if len(mxs) > 5 { mxs = mxs[:5] } tr.Debugf("MXs: %v", mxs) return mxs, nil, true }
package dkim import ( "errors" "fmt" "regexp" "strings" ) var ( errNoBody = errors.New("no body found") errUnknownCanonicalization = errors.New("unknown canonicalization") ) type canonicalization string var ( simpleCanonicalization canonicalization = "simple" relaxedCanonicalization canonicalization = "relaxed" ) func (c canonicalization) body(b string) string { switch c { case simpleCanonicalization: return simpleBody(b) case relaxedCanonicalization: return relaxBody(b) default: panic("unknown canonicalization") } } func (c canonicalization) headers(hs headers) headers { switch c { case simpleCanonicalization: return hs case relaxedCanonicalization: return relaxHeaders(hs) default: panic("unknown canonicalization") } } func (c canonicalization) header(h header) header { switch c { case simpleCanonicalization: return h case relaxedCanonicalization: return relaxHeader(h) default: panic("unknown canonicalization") } } func stringToCanonicalization(s string) (canonicalization, error) { switch s { case "simple": return simpleCanonicalization, nil case "relaxed": return relaxedCanonicalization, nil default: return "", fmt.Errorf("%w: %s", errUnknownCanonicalization, s) } } // Notes on whitespace reduction: // https://datatracker.ietf.org/doc/html/rfc6376#section-2.8 // There are only 3 forms of whitespace: // - WSP = SP / HTAB // Simple whitespace: space or tab. // - LWSP = *(WSP / CRLF WSP) // Linear whitespace: any number of { simple whitespace OR CRLF followed by // simple whitespace }. // - FWS = [*WSP CRLF] 1*WSP // Folding whitespace: optional { simple whitespace OR CRLF } followed by // one or more simple whitespace. func simpleBody(body string) string { // https://datatracker.ietf.org/doc/html/rfc6376#section-3.4.3 // Replace repeated CRLF at the end of the body with a single CRLF. body = repeatedCRLFAtTheEnd.ReplaceAllString(body, "\r\n") // Ensure a non-empty body ends with a single CRLF. // All bodies (including an empty one) must end with a CRLF. if !strings.HasSuffix(body, "\r\n") { body += "\r\n" } return body } var ( // Continued header: WSP after CRLF. continuedHeader = regexp.MustCompile(`\r\n[ \t]+`) // WSP before CRLF. wspBeforeCRLF = regexp.MustCompile(`[ \t]+\r\n`) // Repeated WSP. repeatedWSP = regexp.MustCompile(`[ \t]+`) // Empty lines at the end of the body. repeatedCRLFAtTheEnd = regexp.MustCompile(`(\r\n)+$`) ) func relaxBody(body string) string { // https://datatracker.ietf.org/doc/html/rfc6376#section-3.4.4 body = wspBeforeCRLF.ReplaceAllLiteralString(body, "\r\n") body = repeatedWSP.ReplaceAllLiteralString(body, " ") body = repeatedCRLFAtTheEnd.ReplaceAllLiteralString(body, "\r\n") // Ensure a non-empty body ends with a single CRLF. if len(body) >= 1 && !strings.HasSuffix(body, "\r\n") { body += "\r\n" } return body } func relaxHeader(h header) header { // https://datatracker.ietf.org/doc/html/rfc6376#section-3.4.2 // Convert all header field names to lowercase. name := strings.ToLower(h.Name) // Remove WSP before the ":" separating the name and value. name = strings.TrimRight(name, " \t") // Unfold continuation lines in values. value := continuedHeader.ReplaceAllString(h.Value, " ") // Reduce all sequences of WSP to a single SP. value = repeatedWSP.ReplaceAllLiteralString(value, " ") // Delete all WSP at the end of each unfolded header field value. value = strings.TrimRight(value, " \t") // Remove WSP after the ":" separating the name and value. value = strings.TrimLeft(value, " \t") return header{ Name: name, Value: value, // The "source" is the relaxed field: name, colon, and value (with // no space around the colon). Source: name + ":" + value, } } func relaxHeaders(hs headers) headers { rh := make(headers, 0, len(hs)) for _, h := range hs { rh = append(rh, relaxHeader(h)) } return rh }
package dkim import ( "context" "net" ) type contextKey string const traceKey contextKey = "trace" func trace(ctx context.Context, f string, args ...interface{}) { traceFunc, ok := ctx.Value(traceKey).(TraceFunc) if !ok { return } traceFunc(f, args...) } type TraceFunc func(f string, a ...interface{}) func WithTraceFunc(ctx context.Context, trace TraceFunc) context.Context { return context.WithValue(ctx, traceKey, trace) } const lookupTXTKey contextKey = "lookupTXT" func lookupTXT(ctx context.Context, domain string) ([]string, error) { lookupTXTFunc, ok := ctx.Value(lookupTXTKey).(lookupTXTFunc) if !ok { return net.LookupTXT(domain) } return lookupTXTFunc(ctx, domain) } type lookupTXTFunc func(ctx context.Context, domain string) ([]string, error) func WithLookupTXTFunc(ctx context.Context, lookupTXT lookupTXTFunc) context.Context { return context.WithValue(ctx, lookupTXTKey, lookupTXT) } const maxHeadersKey contextKey = "maxHeaders" func WithMaxHeaders(ctx context.Context, maxHeaders int) context.Context { return context.WithValue(ctx, maxHeadersKey, maxHeaders) } func maxHeaders(ctx context.Context) int { maxHeaders, ok := ctx.Value(maxHeadersKey).(int) if !ok { // By default, cap the number of headers to 5 (arbitrarily chosen, may // be adjusted in the future). return 5 } return maxHeaders }
package dkim import ( "context" "crypto" "crypto/ed25519" "crypto/rsa" "crypto/x509" "encoding/base64" "errors" "fmt" "slices" "strings" ) func findPublicKeys(ctx context.Context, domain, selector string) ([]*publicKey, error) { // Subdomain where the key lives. // https://datatracker.ietf.org/doc/html/rfc6376#section-3.6.2 d := selector + "._domainkey." + domain values, err := lookupTXT(ctx, d) if err != nil { trace(ctx, "TXT lookup of %q failed: %v", d, err) return nil, err } // There should be only a single record; RFC 6376 says the results are // undefined if there are multiple TXT records. // https://datatracker.ietf.org/doc/html/rfc6376#section-3.6.2.2 // // What other implementations do: // - dkimpy: Use the first TXT record (whatever it is). // - OpenDKIM: Use the first TXT record (whatever it is). // - driusan/dkim: Use the first TXT record that can be parsed as a key. // - go-msgauth: Reject if there are multiple records. // // What we do: use _all_ TXT records that can be parsed as keys. This is // possibly too much, and we could reconsider this in the future. pks := []*publicKey{} for _, v := range values { trace(ctx, "TXT record for %q: %q", d, v) pk, err := parsePublicKey(v) if err != nil { trace(ctx, "Skipping: %v", err) continue } trace(ctx, "Parsed public key: %s", pk) pks = append(pks, pk) } return pks, nil } // Function to verify a signature with this public key. type verifyFunc func(h crypto.Hash, hashed, signature []byte) error type publicKey struct { H []crypto.Hash K keyType P []byte T []string // t= tag, representing flags. verify verifyFunc } func (pk *publicKey) String() string { return fmt.Sprintf("[%s:%.8x]", pk.K, pk.P) } func (pk *publicKey) Matches(kt keyType, h crypto.Hash) bool { if pk.K != kt { return false } if len(pk.H) > 0 { return slices.Contains(pk.H, h) } return true } func (pk *publicKey) StrictDomainCheck() bool { // t=s is set. return slices.Contains(pk.T, "s") } func parsePublicKey(v string) (*publicKey, error) { // Public key is a tag-value list. // https://datatracker.ietf.org/doc/html/rfc6376#section-3.6.1 tags, err := parseTags(v) if err != nil { return nil, err } // "v" is optional, but if present it must be "DKIM1". ver, ok := tags["v"] if ok && ver != "DKIM1" { return nil, fmt.Errorf("%w: %q", errInvalidVersion, ver) } pk := &publicKey{ // The default key type is rsa. K: keyTypeRSA, } // h is a colon-separated list of hashing algorithm names. if tags["h"] != "" { hs := strings.Split(eatWhitespace.Replace(tags["h"]), ":") for _, h := range hs { x, err := hashFromString(h) if err != nil { // Unrecognized algorithms must be ignored. // https://datatracker.ietf.org/doc/html/rfc6376#section-3.6.1 continue } pk.H = append(pk.H, x) } } // k is key type (may not be present, rsa is used in that case). if tags["k"] != "" { pk.K, err = keyTypeFromString(tags["k"]) if err != nil { return nil, err } } // p is public-key data, base64-encoded, and whitespace in it must be // ignored. Required. p, err := base64.StdEncoding.DecodeString( eatWhitespace.Replace(tags["p"])) if err != nil { return nil, fmt.Errorf("error decoding p=: %w", err) } pk.P = p switch pk.K { case keyTypeRSA: pk.verify, err = parseRSAPublicKey(p) case keyTypeEd25519: pk.verify, err = parseEd25519PublicKey(p) } // t is a colon-separated list of flags. if t := eatWhitespace.Replace(tags["t"]); t != "" { pk.T = strings.Split(t, ":") } if err != nil { return nil, err } return pk, nil } var ( errInvalidRSAPublicKey = errors.New("invalid RSA public key") errNotRSAPublicKey = errors.New("not an RSA public key") errRSAKeyTooSmall = errors.New("RSA public key too small") errInvalidEd25519Key = errors.New("invalid Ed25519 public key") ) func parseRSAPublicKey(p []byte) (verifyFunc, error) { // Either PKCS#1 or SubjectPublicKeyInfo. // See https://www.rfc-editor.org/errata/eid3017. pub, err := x509.ParsePKIXPublicKey(p) if err != nil { pub, err = x509.ParsePKCS1PublicKey(p) } if err != nil { return nil, fmt.Errorf("%w: %w", errInvalidRSAPublicKey, err) } rsaPub, ok := pub.(*rsa.PublicKey) if !ok { return nil, errNotRSAPublicKey } // Enforce 1024-bit minimum. // https://datatracker.ietf.org/doc/html/rfc8301#section-3.2 if rsaPub.Size()*8 < 1024 { return nil, errRSAKeyTooSmall } return func(h crypto.Hash, hashed, signature []byte) error { return rsa.VerifyPKCS1v15(rsaPub, h, hashed, signature) }, nil } func parseEd25519PublicKey(p []byte) (verifyFunc, error) { // https: //datatracker.ietf.org/doc/html/rfc8463 if len(p) != ed25519.PublicKeySize { return nil, errInvalidEd25519Key } pub := ed25519.PublicKey(p) return func(h crypto.Hash, hashed, signature []byte) error { if ed25519.Verify(pub, hashed, signature) { return nil } return errors.New("signature verification failed") }, nil }
package dkim import ( "crypto" "encoding/base64" "errors" "fmt" "slices" "strconv" "strings" "time" ) // https://datatracker.ietf.org/doc/html/rfc6376#section-6 type dkimSignature struct { // Version. Must be "1". v string // Algorithm. Like "rsa-sha256". a string // Key type, extracted from a=. KeyType keyType // Hash, extracted from a=. Hash crypto.Hash // Signature data. // Decoded from base64, ignoring whitespace. b []byte // Hash of canonicalized body. // Decoded from base64, ignoring whitespace. bh []byte // Canonicalization modes. cH canonicalization cB canonicalization // Domain ("SDID"), in plain text. // IDNs MUST be encoded as A-labels. d string // Signed header fields. // Colon-separated list of header fields. h []string // AUID, in plain text. i string // Body octet count of the canonicalized body. l uint64 // Query methods used for DNS lookup. // Colon-separated list of methods. Only "dns/txt" is valid. q []string // Selector. s string // Timestamp. In Seconds since the UNIX epoch. t time.Time // Signature expiration. In Seconds since the UNIX epoch. x time.Time // Copied header fields. // Has a specific encoding but whitespace is ignored. z string } func (sig *dkimSignature) canonicalizationFromString(s string) error { if s == "" { sig.cH = simpleCanonicalization sig.cB = simpleCanonicalization return nil } // Either "header/body" or "header". In the latter case, "simple" is used // for the body canonicalization. // No whitespace around the '/' is allowed. hs, bs, _ := strings.Cut(s, "/") if bs == "" { bs = "simple" } var err error sig.cH, err = stringToCanonicalization(hs) if err != nil { return fmt.Errorf("header: %w", err) } sig.cB, err = stringToCanonicalization(bs) if err != nil { return fmt.Errorf("body: %w", err) } return nil } func (sig *dkimSignature) checkRequiredTags() error { // https://datatracker.ietf.org/doc/html/rfc6376#section-6.1.1 if sig.a == "" { return fmt.Errorf("%w: a=", errMissingRequiredTag) } if len(sig.b) == 0 { return fmt.Errorf("%w: b=", errMissingRequiredTag) } if len(sig.bh) == 0 { return fmt.Errorf("%w: bh=", errMissingRequiredTag) } if sig.d == "" { return fmt.Errorf("%w: d=", errMissingRequiredTag) } if len(sig.h) == 0 { return fmt.Errorf("%w: h=", errMissingRequiredTag) } if sig.s == "" { return fmt.Errorf("%w: s=", errMissingRequiredTag) } // h= must contain From. var isFrom = func(s string) bool { return strings.EqualFold(s, "from") } if !slices.ContainsFunc(sig.h, isFrom) { return fmt.Errorf("%w: h= does not contain 'from'", errInvalidTag) } // If i= is present, its domain must be equal to, or a subdomain of, d=. if sig.i != "" { _, domain, _ := strings.Cut(sig.i, "@") if domain != sig.d && !strings.HasSuffix(domain, "."+sig.d) { return fmt.Errorf("%w: i= is not a subdomain of d=", errInvalidTag) } } return nil } var ( errInvalidSignature = errors.New("invalid signature") errInvalidVersion = errors.New("invalid version") errBadATag = errors.New("invalid a= tag") errUnsupportedHash = errors.New("unsupported hash") errUnsupportedKeyType = errors.New("unsupported key type") errMissingRequiredTag = errors.New("missing required tag") ) // String replacer that removes whitespace. var eatWhitespace = strings.NewReplacer(" ", "", "\t", "", "\r", "", "\n", "") func dkimSignatureFromHeader(header string) (*dkimSignature, error) { tags, err := parseTags(header) if err != nil { return nil, err } sig := &dkimSignature{ v: tags["v"], a: tags["a"], } // v= tag is mandatory and must be 1. if sig.v != "1" { return nil, errInvalidVersion } // a= tag is mandatory; check that we can parse it and that we support the // algorithms. ktS, hS, found := strings.Cut(sig.a, "-") if !found { return nil, errBadATag } sig.KeyType, err = keyTypeFromString(ktS) if err != nil { return nil, fmt.Errorf("%w: %s", err, sig.a) } sig.Hash, err = hashFromString(hS) if err != nil { return nil, fmt.Errorf("%w: %s", err, sig.a) } // b is base64-encoded, and whitespace in it must be ignored. sig.b, err = base64.StdEncoding.DecodeString( eatWhitespace.Replace(tags["b"])) if err != nil { return nil, fmt.Errorf("%w: failed to decode b: %w", errInvalidSignature, err) } // bh - same as b. sig.bh, err = base64.StdEncoding.DecodeString( eatWhitespace.Replace(tags["bh"])) if err != nil { return nil, fmt.Errorf("%w: failed to decode bh: %w", errInvalidSignature, err) } err = sig.canonicalizationFromString(tags["c"]) if err != nil { return nil, fmt.Errorf("%w: failed to parse c: %w", errInvalidSignature, err) } sig.d = tags["d"] // h is a colon-separated list of header fields. if tags["h"] != "" { sig.h = strings.Split(eatWhitespace.Replace(tags["h"]), ":") } sig.i = tags["i"] if tags["l"] != "" { sig.l, err = strconv.ParseUint(tags["l"], 10, 64) if err != nil { return nil, fmt.Errorf("%w: failed to parse l: %w", errInvalidSignature, err) } } // q is a colon-separated list of query methods. if tags["q"] != "" { sig.q = strings.Split(eatWhitespace.Replace(tags["q"]), ":") } if len(sig.q) > 0 && !slices.Contains(sig.q, "dns/txt") { return nil, fmt.Errorf("%w: no dns/txt query method in q", errInvalidSignature) } sig.s = tags["s"] if tags["t"] != "" { sig.t, err = unixStrToTime(tags["t"]) if err != nil { return nil, fmt.Errorf("%w: failed to parse t: %w", errInvalidSignature, err) } } if tags["x"] != "" { sig.x, err = unixStrToTime(tags["x"]) if err != nil { return nil, fmt.Errorf("%w: failed to parse x: %w", errInvalidSignature, err) } } sig.z = eatWhitespace.Replace(tags["z"]) // Check required tags are present. if err := sig.checkRequiredTags(); err != nil { return nil, err } return sig, nil } func unixStrToTime(s string) (time.Time, error) { ti, err := strconv.ParseUint(s, 10, 64) if err != nil { return time.Time{}, err } return time.Unix(int64(ti), 0), nil } type keyType string const ( keyTypeRSA keyType = "rsa" keyTypeEd25519 keyType = "ed25519" ) func keyTypeFromString(s string) (keyType, error) { switch s { case "rsa": return keyTypeRSA, nil case "ed25519": return keyTypeEd25519, nil default: return "", errUnsupportedKeyType } } func hashFromString(s string) (crypto.Hash, error) { switch s { // Note SHA1 is not supported: as per RFC 8301, it must not be used // for signing or verifying. // https://datatracker.ietf.org/doc/html/rfc8301#section-3.1 case "sha256": return crypto.SHA256, nil default: return 0, errUnsupportedHash } } // DKIM Tag=Value lists, as defined in RFC 6376, Section 3.2. // https://datatracker.ietf.org/doc/html/rfc6376#section-3.2 type tags map[string]string var errInvalidTag = errors.New("invalid tag") func parseTags(s string) (tags, error) { // First trim space, and trailing semicolon, to simplify parsing below. s = strings.TrimSpace(s) s = strings.TrimSuffix(s, ";") tags := make(tags) for _, tv := range strings.Split(s, ";") { t, v, found := strings.Cut(tv, "=") if !found { return nil, fmt.Errorf("%w: missing '='", errInvalidTag) } // Trim leading and trailing whitespace from tag and value, as per // RFC. t = strings.TrimSpace(t) v = strings.TrimSpace(v) if t == "" { return nil, fmt.Errorf("%w: missing tag name", errInvalidTag) } // RFC 6376, Section 3.2: Tags with duplicate names MUST NOT occur // within a single tag-list; if a tag name does occur more than once, // the entire tag-list is invalid. if _, exists := tags[t]; exists { return nil, fmt.Errorf("%w: duplicate tag", errInvalidTag) } tags[t] = v } return tags, nil }
package dkim import ( "errors" "fmt" "strings" ) type header struct { Name string Value string Source string } type headers []header // FindAll the headers with the given name, in order of appearance. func (h headers) FindAll(name string) headers { hs := make(headers, 0) for _, header := range h { if strings.EqualFold(header.Name, name) { hs = append(hs, header) } } return hs } var errInvalidHeader = errors.New("invalid header") // Parse a RFC822 message, return the headers, body, and error if any. // We expect it to only contain CRLF line endings. // Does NOT touch whitespace, this is important to preserve the original // message and headers, which is required for the signature. func parseMessage(message string) (headers, string, error) { headers := make(headers, 0) lines := strings.Split(message, "\r\n") eoh := 0 for i, line := range lines { if line == "" { eoh = i break } if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") { // Continuation of the previous header. if len(headers) == 0 { return nil, "", fmt.Errorf( "%w: bad continuation", errInvalidHeader) } headers[len(headers)-1].Value += "\r\n" + line headers[len(headers)-1].Source += "\r\n" + line } else { // New header. h, err := parseHeader(line) if err != nil { return nil, "", err } headers = append(headers, h) } } return headers, strings.Join(lines[eoh+1:], "\r\n"), nil } func parseHeader(line string) (header, error) { name, value, found := strings.Cut(line, ":") if !found { return header{}, fmt.Errorf("%w: no colon", errInvalidHeader) } return header{ Name: name, Value: value, Source: line, }, nil }
package dkim import ( "context" "crypto" "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/sha256" "encoding/base64" "fmt" "strings" "time" ) type Signer struct { // Domain to sign for. Domain string // Selector to use. Selector string // Signer containing the private key. // This can be an *rsa.PrivateKey or a ed25519.PrivateKey. Signer crypto.Signer } var headersToSign = []string{ // https://datatracker.ietf.org/doc/html/rfc6376#section-5.4.1 "From", // Required. "Reply-To", "Subject", "Date", "To", "Cc", "Resent-Date", "Resent-From", "Resent-To", "Resent-Cc", "In-Reply-To", "References", "List-Id", "List-Help", "List-Unsubscribe", "List-Subscribe", "List-Post", "List-Owner", "List-Archive", // Our additions. "Message-ID", } var extraHeadersToSign = []string{ // Headers to add an extra of, to prevent additions after signing. // If they're included here, they must be in headersToSign too. "From", "Subject", "Date", "To", "Cc", "Message-ID", } // Sign the given message. Returns the *value* of the DKIM-Signature header to // be added to the message. It will usually be multi-line, but without // indenting. func (s *Signer) Sign(ctx context.Context, message string) (string, error) { headers, body, err := parseMessage(message) if err != nil { return "", err } algoStr, err := s.algoStr() if err != nil { return "", err } trace(ctx, "Signing for %s / %s with %s", s.Domain, s.Selector, algoStr) dkimSignature := fmt.Sprintf( "v=1; a=%s; c=relaxed/relaxed;\r\n", algoStr) dkimSignature += fmt.Sprintf( "d=%s; s=%s; t=%d;\r\n", s.Domain, s.Selector, time.Now().Unix()) // Add the headers to sign. hsForHeader := []string{} for _, h := range headersToSign { // Append the header as many times as it appears in the message. for i := 0; i < len(headers.FindAll(h)); i++ { hsForHeader = append(hsForHeader, h) } } hsForHeader = append(hsForHeader, extraHeadersToSign...) dkimSignature += fmt.Sprintf( "h=%s;\r\n", formatHeaders(hsForHeader)) // Compute and add bh= (body hash). bodyH := sha256.Sum256([]byte( relaxedCanonicalization.body(body))) dkimSignature += fmt.Sprintf( "bh=%s;\r\n", base64.StdEncoding.EncodeToString(bodyH[:])) // Compute b= (signature). // First, the canonicalized headers. b := sha256.New() for _, h := range headersToSign { for _, header := range headers.FindAll(h) { hsrc := relaxedCanonicalization.header(header).Source + "\r\n" trace(ctx, "Hashing header: %q", hsrc) b.Write([]byte(hsrc)) } } // Now, the (canonicalized) DKIM-Signature header itself, but with an // empty b= tag, without a trailing \r\n, and ending with ";". // We include the ";" because we will add it at the end (see below). It is // legal not to include that final ";", we just choose to do so. // We replace \r\n with \r\n\t so the canonicalization considers them // proper continuations, and works correctly. dkimSignature += "b=" dkimSignatureForSigning := strings.ReplaceAll( dkimSignature, "\r\n", "\r\n\t") + ";" relaxedDH := relaxedCanonicalization.header(header{ Name: "DKIM-Signature", Value: dkimSignatureForSigning, Source: dkimSignatureForSigning, }) b.Write([]byte(relaxedDH.Source)) trace(ctx, "Hashing header: %q", relaxedDH.Source) bSum := b.Sum(nil) trace(ctx, "Resulting hash: %q", base64.StdEncoding.EncodeToString(bSum)) // Finally, sign the hash. sig, err := s.sign(bSum) if err != nil { return "", err } sigb64 := base64.StdEncoding.EncodeToString(sig) dkimSignature += breakLongLines(sigb64) + ";" return dkimSignature, nil } func (s *Signer) algoStr() (string, error) { switch k := s.Signer.(type) { case *rsa.PrivateKey: return "rsa-sha256", nil case ed25519.PrivateKey: return "ed25519-sha256", nil default: return "", fmt.Errorf("%w: %T", errUnsupportedKeyType, k) } } func (s *Signer) sign(bSum []byte) ([]byte, error) { var h crypto.Hash switch s.Signer.(type) { case *rsa.PrivateKey: h = crypto.SHA256 case ed25519.PrivateKey: h = crypto.Hash(0) } return s.Signer.Sign(rand.Reader, bSum, h) } func breakLongLines(s string) string { // Break long lines, indenting with 2 spaces for continuation (to make // it clear it's under the same tag). const limit = 70 var sb strings.Builder for len(s) > 0 { if len(s) > limit { sb.WriteString(s[:limit]) sb.WriteString("\r\n ") s = s[limit:] } else { sb.WriteString(s) s = "" } } return sb.String() } func formatHeaders(hs []string) string { // Format the list of headers for inclusion in the DKIM-Signature header. // This includes converting them to lowercase, and line-wrapping. // Extra lines will be indented with 2 spaces, to make it clear they're // under the same tag. const limit = 70 var sb strings.Builder line := "" for i, h := range hs { if len(line)+1+len(h) > limit { sb.WriteString(line + "\r\n ") line = "" } if i > 0 { line += ":" } line += h } sb.WriteString(line) return strings.TrimSpace(strings.ToLower(sb.String())) }
package dkim import ( "bytes" "context" "crypto" "encoding/base64" "errors" "fmt" "net" "regexp" "slices" "strings" ) // These two errors are returned when the verification fails, but the header // is considered valid. var ( ErrBodyHashMismatch = errors.New("body hash mismatch") ErrVerificationFailed = errors.New("verification failed") ) // Evaluation states, as per // https://datatracker.ietf.org/doc/html/rfc6376#section-3.9. type EvaluationState string const ( SUCCESS EvaluationState = "SUCCESS" PERMFAIL EvaluationState = "PERMFAIL" TEMPFAIL EvaluationState = "TEMPFAIL" ) type VerifyResult struct { // How many signatures were found. Found uint // How many signatures were verified successfully. Valid uint // The details for each signature that was found. Results []*OneResult } type OneResult struct { // The raw signature header. SignatureHeader string // Domain and selector from the signature header. Domain string Selector string // Base64-encoded signature. May be missing if it is not present in the // header. B string // The result of the evaluation. State EvaluationState Error error } // Returns the DKIM-specific contents for an Authentication-Results header. // It is just the contents, the header needs to still be constructed. // Note that the output will need to be indented by the caller. // https://datatracker.ietf.org/doc/html/rfc8601#section-2.7.1 func (r *VerifyResult) AuthenticationResults() string { // The weird placement of the ";" is due to the specification saying they // have to be before each method, not at the end. // By doing it this way, we can concate the output of this function with // other results. ar := &strings.Builder{} if r.Found == 0 { // https://datatracker.ietf.org/doc/html/rfc8601#section-2.7.1 ar.WriteString(";dkim=none\r\n") return ar.String() } for _, res := range r.Results { // Map state to the corresponding result. // https://datatracker.ietf.org/doc/html/rfc8601#section-2.7.1 switch res.State { case SUCCESS: ar.WriteString(";dkim=pass") case TEMPFAIL: // The reason must come before the properties, include it here. fmt.Fprintf(ar, ";dkim=temperror reason=%q\r\n", res.Error) case PERMFAIL: // The reason must come before the properties, include it here. if errors.Is(res.Error, ErrVerificationFailed) || errors.Is(res.Error, ErrBodyHashMismatch) { fmt.Fprintf(ar, ";dkim=fail reason=%q\r\n", res.Error) } else { fmt.Fprintf(ar, ";dkim=permerror reason=%q\r\n", res.Error) } } if res.B != "" { // Include a partial b= tag to help identify which signature // is being referred to. // https://datatracker.ietf.org/doc/html/rfc6008#section-4 fmt.Fprintf(ar, " header.b=%.12s", res.B) } ar.WriteString(" header.d=" + res.Domain + "\r\n") } return ar.String() } func VerifyMessage(ctx context.Context, message string) (*VerifyResult, error) { // https://datatracker.ietf.org/doc/html/rfc6376#section-6 headers, body, err := parseMessage(message) if err != nil { trace(ctx, "Error parsing message: %v", err) return nil, err } results := &VerifyResult{ Results: []*OneResult{}, } for i, sig := range headers.FindAll("DKIM-Signature") { trace(ctx, "Found DKIM-Signature header: %s", sig.Value) if i >= maxHeaders(ctx) { // Protect from potential DoS by capping the number of signatures. // https://datatracker.ietf.org/doc/html/rfc6376#section-4.2 // https://datatracker.ietf.org/doc/html/rfc6376#section-8.4 trace(ctx, "Too many DKIM-Signature headers found") break } results.Found++ res := verifySignature(ctx, sig, headers, body) results.Results = append(results.Results, res) if res.State == SUCCESS { results.Valid++ } } trace(ctx, "Found %d signatures, %d valid", results.Found, results.Valid) return results, nil } // Regular expression that matches the "b=" tag. // First capture group is the "b=" part (including any whitespace up to the // '='). var bTag = regexp.MustCompile(`(b[ \t\r\n]*=)[^;]+`) func verifySignature(ctx context.Context, sigH header, headers headers, body string) *OneResult { result := &OneResult{ SignatureHeader: sigH.Value, } sig, err := dkimSignatureFromHeader(sigH.Value) if err != nil { // Header validation errors are a PERMFAIL. // https://datatracker.ietf.org/doc/html/rfc6376#section-6.1.1 result.Error = err result.State = PERMFAIL return result } result.Domain = sig.d result.Selector = sig.s result.B = base64.StdEncoding.EncodeToString(sig.b) // Get the public key. // https://datatracker.ietf.org/doc/html/rfc6376#section-6.1.2 pubKeys, err := findPublicKeys(ctx, sig.d, sig.s) if err != nil { result.Error = err // DNS errors when looking up the public key are a TEMPFAIL; all // others are PERMFAIL. // https://datatracker.ietf.org/doc/html/rfc6376#section-6.1.2 if dnsErr, ok := err.(*net.DNSError); ok && dnsErr.Temporary() { result.State = TEMPFAIL } else { result.State = PERMFAIL } return result } // Compute the verification. // https://datatracker.ietf.org/doc/html/rfc6376#section-6.1.3 // Step 1: Prepare a canonicalized version of the body, truncate it to l= // (if present). // https://datatracker.ietf.org/doc/html/rfc6376#section-3.7 bodyC := sig.cB.body(body) if sig.l > 0 { bodyC = bodyC[:sig.l] } // Step 2: Compute the hash of the canonicalized body. bodyH := hashWith(sig.Hash, []byte(bodyC)) // Step 3: Verify the hash of the body by comparing it with bh=. if !bytes.Equal(bodyH, sig.bh) { bodyHStr := base64.StdEncoding.EncodeToString(bodyH) trace(ctx, "Body hash mismatch: %q", bodyHStr) result.Error = fmt.Errorf("%w (got %s)", ErrBodyHashMismatch, bodyHStr) result.State = PERMFAIL return result } trace(ctx, "Body hash matches: %q", base64.StdEncoding.EncodeToString(bodyH)) // Step 4 A: Hash the (canonicalized) headers that appear in the h= tag. // https://datatracker.ietf.org/doc/html/rfc6376#section-3.7 b := sig.Hash.New() for _, header := range headersToInclude(sigH, sig.h, headers) { hsrc := sig.cH.header(header).Source + "\r\n" trace(ctx, "Hashing header: %q", hsrc) b.Write([]byte(hsrc)) } // Step 4 B: Hash the (canonicalized) DKIM-Signature header itself, but // with an empty b= tag, and without a trailing \r\n. // https://datatracker.ietf.org/doc/html/rfc6376#section-3.7 sigC := sig.cH.header(sigH) sigCStr := bTag.ReplaceAllString(sigC.Source, "$1") trace(ctx, "Hashing header: %q", sigCStr) b.Write([]byte(sigCStr)) bSum := b.Sum(nil) trace(ctx, "Resulting hash: %q", base64.StdEncoding.EncodeToString(bSum)) // Step 4 C: Validate the signature. for _, pubKey := range pubKeys { if !pubKey.Matches(sig.KeyType, sig.Hash) { trace(ctx, "PK %v: key type or hash mismatch, skipping", pubKey) continue } if sig.i != "" && pubKey.StrictDomainCheck() { _, domain, _ := strings.Cut(sig.i, "@") if domain != sig.d { trace(ctx, "PK %v: Strict domain check failed: %q != %q (%q)", pubKey, sig.d, domain, sig.i) continue } trace(ctx, "PK %v: Strict domain check passed", pubKey) } err := pubKey.verify(sig.Hash, bSum, sig.b) if err != nil { trace(ctx, "PK %v: Verification failed: %v", pubKey, err) continue } trace(ctx, "PK %v: Verification succeeded", pubKey) result.State = SUCCESS return result } result.State = PERMFAIL result.Error = ErrVerificationFailed return result } func headersToInclude(sigH header, hTag []string, headers headers) []header { // Return the actual headers to include in the hash, based on the list // given in the h= tag. // This is complicated because: // - Headers can be included multiple times. In that case, we must pick // the last instance (which hasn't been already included). // https://datatracker.ietf.org/doc/html/rfc6376#section-5.4.2 // - Headers may appear fewer times than they are requested. // - DKIM-Signature header may be included, but we must not include the // one being verified. // https://datatracker.ietf.org/doc/html/rfc6376#section-3.7 // - Headers may be missing, and that's allowed. // https://datatracker.ietf.org/doc/html/rfc6376#section-5.4 seen := map[string]int{} include := []header{} for _, h := range hTag { all := headers.FindAll(h) slices.Reverse(all) // We keep track of the last instance of each header that we // included, and find the next one every time it appears in h=. // We have to be careful because the header itself may not be present, // or we may be asked to include it more times than it appears. lh := strings.ToLower(h) i := seen[lh] if i >= len(all) { continue } seen[lh]++ selected := all[i] if selected == sigH { continue } include = append(include, selected) } return include } func hashWith(a crypto.Hash, data []byte) []byte { h := a.New() h.Write(data) return h.Sum(nil) }
// Package domaininfo implements a domain information database, to keep track // of things we know about a particular domain. package domaininfo import ( "fmt" "sync" "blitiri.com.ar/go/chasquid/internal/protoio" "blitiri.com.ar/go/chasquid/internal/trace" ) // Command to generate domaininfo.pb.go. //go:generate protoc --go_out=. --go_opt=paths=source_relative domaininfo.proto // DB represents the persistent domain information database. type DB struct { // Persistent store with the list of domains we know. store *protoio.Store info map[string]*Domain sync.Mutex } // New opens a domain information database on the given dir, creating it if // necessary. The returned database will not be loaded. func New(dir string) (*DB, error) { st, err := protoio.NewStore(dir) if err != nil { return nil, err } l := &DB{ store: st, info: map[string]*Domain{}, } err = l.Reload() if err != nil { return nil, err } return l, nil } // Reload the database from disk. func (db *DB) Reload() error { tr := trace.New("DomainInfo.Reload", "reload") defer tr.Finish() db.Lock() defer db.Unlock() // Clear the map, in case it has data. db.info = map[string]*Domain{} ids, err := db.store.ListIDs() if err != nil { tr.Error(err) return err } for _, id := range ids { d := &Domain{} _, err := db.store.Get(id, d) if err != nil { tr.Errorf("id %q: %v", id, err) return fmt.Errorf("error loading %q: %v", id, err) } db.info[d.Name] = d } tr.Debugf("loaded %d domains", len(ids)) return nil } func (db *DB) write(tr *trace.Trace, d *Domain) error { tr = tr.NewChild("DomainInfo.write", d.Name) defer tr.Finish() err := db.store.Put(d.Name, d) if err != nil { tr.Error(err) } else { tr.Debugf("saved") } return err } // IncomingSecLevel checks an incoming security level for the domain. // Returns true if allowed, false otherwise. func (db *DB) IncomingSecLevel(tr *trace.Trace, domain string, level SecLevel) bool { tr = tr.NewChild("DomainInfo.Incoming", domain) defer tr.Finish() tr.Debugf("incoming at level %s", level) db.Lock() defer db.Unlock() d, exists := db.info[domain] if !exists { d = &Domain{Name: domain} db.info[domain] = d defer db.write(tr, d) } if level < d.IncomingSecLevel { tr.Errorf("%s incoming denied: %s < %s", d.Name, level, d.IncomingSecLevel) return false } else if level == d.IncomingSecLevel { tr.Debugf("%s incoming allowed: %s == %s", d.Name, level, d.IncomingSecLevel) return true } else { tr.Printf("%s incoming level raised: %s > %s", d.Name, level, d.IncomingSecLevel) d.IncomingSecLevel = level if exists { defer db.write(tr, d) } return true } } // OutgoingSecLevel checks an incoming security level for the domain. // Returns true if allowed, false otherwise. func (db *DB) OutgoingSecLevel(tr *trace.Trace, domain string, level SecLevel) bool { tr = tr.NewChild("DomainInfo.Outgoing", domain) defer tr.Finish() tr.Debugf("outgoing at level %s", level) db.Lock() defer db.Unlock() d, exists := db.info[domain] if !exists { d = &Domain{Name: domain} db.info[domain] = d defer db.write(tr, d) } if level < d.OutgoingSecLevel { tr.Errorf("%s outgoing denied: %s < %s", d.Name, level, d.OutgoingSecLevel) return false } else if level == d.OutgoingSecLevel { tr.Debugf("%s outgoing allowed: %s == %s", d.Name, level, d.OutgoingSecLevel) return true } else { tr.Printf("%s outgoing level raised: %s > %s", d.Name, level, d.OutgoingSecLevel) d.OutgoingSecLevel = level if exists { defer db.write(tr, d) } return true } } // Clear sets the security level for the given domain to plain. // This can be used for manual overrides in case there's an operational need // to do so. func (db *DB) Clear(tr *trace.Trace, domain string) bool { tr = tr.NewChild("DomainInfo.SetToPlain", domain) defer tr.Finish() db.Lock() defer db.Unlock() d, exists := db.info[domain] if !exists { tr.Debugf("does not exist") return false } d.IncomingSecLevel = SecLevel_PLAIN d.OutgoingSecLevel = SecLevel_PLAIN db.write(tr, d) tr.Printf("set to plain") return true }
// Package dovecot implements functions to interact with Dovecot's // authentication service. // // In particular, it supports doing user authorization, and checking if a user // exists. It is a very basic implementation, with only the minimum needed to // cover chasquid's needs. // // https://wiki.dovecot.org/Design/AuthProtocol // https://wiki.dovecot.org/Services#auth package dovecot import ( "encoding/base64" "errors" "fmt" "net" "net/textproto" "os" "strings" "sync" "time" "unicode" ) // DefaultTimeout to use. We expect Dovecot to be quite fast, but don't want // to hang forever if something gets stuck. const DefaultTimeout = 5 * time.Second var ( errUsernameNotSafe = errors.New("username not safe (contains spaces)") errFailedToConnect = errors.New("failed to connect to dovecot") errNoUserdbSocket = errors.New("unable to find userdb socket") errNoClientSocket = errors.New("unable to find client socket") ) var defaultUserdbPaths = []string{ "/var/run/dovecot/auth-chasquid-userdb", "/var/run/dovecot/auth-userdb", } var defaultClientPaths = []string{ "/var/run/dovecot/auth-chasquid-client", "/var/run/dovecot/auth-client", } // Auth represents a particular Dovecot auth service to use. type Auth struct { addr struct { mu *sync.Mutex userdb string client string } // Timeout for connection and I/O operations (applies on each call). // Set to DefaultTimeout by NewAuth. Timeout time.Duration } // NewAuth returns a new connection against Dovecot authentication service. It // takes the addresses of userdb and client sockets (usually paths as // configured in dovecot). func NewAuth(userdb, client string) *Auth { a := &Auth{} a.addr.mu = &sync.Mutex{} a.addr.userdb = userdb a.addr.client = client a.Timeout = DefaultTimeout return a } // String representation of this Auth, for human consumption. func (a *Auth) String() string { a.addr.mu.Lock() defer a.addr.mu.Unlock() return fmt.Sprintf("DovecotAuth(%q, %q)", a.addr.userdb, a.addr.client) } // Check to see if this auth is functional. func (a *Auth) Check() error { u, c, err := a.getAddrs() if err != nil { return err } if !(a.canDial(u) && a.canDial(c)) { return errFailedToConnect } return nil } // Exists returns true if the user exists, false otherwise. func (a *Auth) Exists(user string) (bool, error) { if !isUsernameSafe(user) { return false, errUsernameNotSafe } userdbAddr, _, err := a.getAddrs() if err != nil { return false, err } conn, err := a.dial("unix", userdbAddr) if err != nil { return false, err } defer conn.Close() // Dovecot greets us with version and server pid. // VERSION\t<major>\t<minor> // SPID\t<pid> err = expect(conn, "VERSION\t1") if err != nil { return false, fmt.Errorf("error receiving version: %v", err) } err = expect(conn, "SPID\t") if err != nil { return false, fmt.Errorf("error receiving SPID: %v", err) } // Send our version, and then the request. err = write(conn, "VERSION\t1\t1\n") if err != nil { return false, err } err = write(conn, fmt.Sprintf("USER\t1\t%s\tservice=smtp\n", user)) if err != nil { return false, err } // Get the response, and we're done. resp, err := conn.ReadLine() if err != nil { return false, fmt.Errorf("error receiving response: %v", err) } else if strings.HasPrefix(resp, "USER\t1\t") { return true, nil } else if strings.HasPrefix(resp, "NOTFOUND\t") { return false, nil } return false, fmt.Errorf("invalid response: %q", resp) } // Authenticate returns true if the password is valid for the user, false // otherwise. func (a *Auth) Authenticate(user, passwd string) (bool, error) { if !isUsernameSafe(user) { return false, errUsernameNotSafe } _, clientAddr, err := a.getAddrs() if err != nil { return false, err } conn, err := a.dial("unix", clientAddr) if err != nil { return false, err } defer conn.Close() // Send our version, and then our PID. err = write(conn, fmt.Sprintf("VERSION\t1\t1\nCPID\t%d\n", os.Getpid())) if err != nil { return false, err } // Read the server-side handshake. We don't care about the contents // really, so just read all lines until we see the DONE. for { resp, err := conn.ReadLine() if err != nil { return false, fmt.Errorf("error receiving handshake: %v", err) } if resp == "DONE" { break } } // We only support PLAIN authentication, so construct the request. // Note we set the "secured" option, with the assumpition that we got the // password via a secure channel (like TLS). This is always true for // chasquid by design, and simplifies the API. // TODO: does dovecot handle utf8 domains well? do we need to encode them // in IDNA first? resp := base64.StdEncoding.EncodeToString( []byte(fmt.Sprintf("%s\x00%s\x00%s", user, user, passwd))) err = write(conn, fmt.Sprintf( "AUTH\t1\tPLAIN\tservice=smtp\tsecured\tno-penalty\tnologin\tresp=%s\n", resp)) if err != nil { return false, err } // Get the response, and we're done. resp, err = conn.ReadLine() if err != nil { return false, fmt.Errorf("error receiving response: %v", err) } else if strings.HasPrefix(resp, "OK\t1") { return true, nil } else if strings.HasPrefix(resp, "FAIL\t1") { return false, nil } return false, fmt.Errorf("invalid response: %q", resp) } // Reload the authenticator. It's a no-op for dovecot, but it is needed to // conform with the auth.Backend interface. func (a *Auth) Reload() error { return nil } func (a *Auth) dial(network, addr string) (*textproto.Conn, error) { nc, err := net.DialTimeout(network, addr, a.Timeout) if err != nil { return nil, err } nc.SetDeadline(time.Now().Add(a.Timeout)) return textproto.NewConn(nc), nil } func expect(conn *textproto.Conn, prefix string) error { resp, err := conn.ReadLine() if err != nil { return err } if !strings.HasPrefix(resp, prefix) { return fmt.Errorf("got %q", resp) } return nil } func write(conn *textproto.Conn, msg string) error { _, err := conn.W.Write([]byte(msg)) if err != nil { return err } return conn.W.Flush() } // isUsernameSafe to use in the dovecot protocol? // Unfortunately dovecot's protocol is not very robust wrt. whitespace, // so we need to be careful. func isUsernameSafe(user string) bool { for _, r := range user { if unicode.IsSpace(r) { return false } } return true } // getAddrs returns the addresses to the userdb and client sockets. func (a *Auth) getAddrs() (string, string, error) { a.addr.mu.Lock() defer a.addr.mu.Unlock() if a.addr.userdb == "" { for _, u := range defaultUserdbPaths { if a.canDial(u) { a.addr.userdb = u break } } if a.addr.userdb == "" { return "", "", errNoUserdbSocket } } if a.addr.client == "" { for _, c := range defaultClientPaths { if a.canDial(c) { a.addr.client = c break } } if a.addr.client == "" { return "", "", errNoClientSocket } } return a.addr.userdb, a.addr.client, nil } func (a *Auth) canDial(path string) bool { conn, err := a.dial("unix", path) if err != nil { return false } conn.Close() return true }
// Package envelope implements functions related to handling email envelopes // (basically tuples of (from, to, data). package envelope import ( "fmt" "strings" "blitiri.com.ar/go/chasquid/internal/set" ) // Split an user@domain address into user and domain. func Split(addr string) (string, string) { ps := strings.SplitN(addr, "@", 2) if len(ps) != 2 { return addr, "" } return ps[0], ps[1] } // UserOf user@domain returns user. func UserOf(addr string) string { user, _ := Split(addr) return user } // DomainOf user@domain returns domain. func DomainOf(addr string) string { _, domain := Split(addr) return domain } // DomainIn checks that the domain of the address is on the given set. func DomainIn(addr string, locals *set.String) bool { domain := DomainOf(addr) if domain == "" { return true } return locals.Has(domain) } // AddHeader adds (prepends) a MIME header to the message. func AddHeader(data []byte, k, v string) []byte { if len(v) > 0 { // If the value contains newlines, indent them properly. if v[len(v)-1] == '\n' { v = v[:len(v)-1] } v = strings.Replace(v, "\n", "\n\t", -1) } header := []byte(fmt.Sprintf("%s: %s\n", k, v)) return append(header, data...) }
// Package expvarom implements an OpenMetrics HTTP exporter for the variables // from the expvar package. // // This is useful for small servers that want to support both packages with // simple enough variables, without introducing any dependencies beyond the // standard library. // // Some functions to add descriptions and map labels are exported for // convenience, but their usage is optional. // // For more complex usage (like histograms, counters vs. gauges, etc.), use // the OpenMetrics libraries directly. // // The exporter uses the text-based format, as documented in: // https://prometheus.io/docs/instrumenting/exposition_formats/#text-based-format // https://github.com/OpenObservability/OpenMetrics/blob/master/specification/OpenMetrics.md // // Note the adoption of that format as OpenMetrics' one isn't finalized yet, // and it is possible that it will change in the future. // // Backwards compatibility is NOT guaranteed, until the format is fully // standardized. package expvarom import ( "expvar" "fmt" "io" "net/http" "sort" "strconv" "strings" "sync" "unicode/utf8" ) type exportedVar struct { Name string Desc string LabelName string I *expvar.Int F *expvar.Float M *expvar.Map } var ( infoMu = sync.Mutex{} descriptions = map[string]string{} mapLabelNames = map[string]string{} ) // MetricsHandler implements an http.HandlerFunc which serves the registered // metrics, using the OpenMetrics text-based format. func MetricsHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/openmetrics-text; version=1.0.0; charset=utf-8") vars := []exportedVar{} ignored := []string{} expvar.Do(func(kv expvar.KeyValue) { evar := exportedVar{ Name: metricNameToOM(kv.Key), } switch value := kv.Value.(type) { case *expvar.Int: evar.I = value case *expvar.Float: evar.F = value case *expvar.Map: evar.M = value default: // Unsupported type, ignore this variable. ignored = append(ignored, evar.Name) return } infoMu.Lock() evar.Desc = descriptions[kv.Key] evar.LabelName = mapLabelNames[kv.Key] infoMu.Unlock() // OM maps need a label name, while expvar ones do not. If we weren't // told what to use, use a generic "key". if evar.LabelName == "" { evar.LabelName = "key" } vars = append(vars, evar) }) // Sort the variables for reproducibility and readability. sort.Slice(vars, func(i, j int) bool { return vars[i].Name < vars[j].Name }) for _, v := range vars { writeVar(w, &v) } fmt.Fprintf(w, "# Generated by expvarom\n") fmt.Fprintf(w, "# EXPERIMENTAL - Format is not fully standard yet\n") fmt.Fprintf(w, "# Ignored variables: %q\n", ignored) fmt.Fprintf(w, "# EOF\n") // Mandated by the standard. } func writeVar(w io.Writer, v *exportedVar) { if v.Desc != "" { fmt.Fprintf(w, "# HELP %s %s\n", v.Name, v.Desc) } if v.I != nil { fmt.Fprintf(w, "%s %d\n\n", v.Name, v.I.Value()) return } if v.F != nil { fmt.Fprintf(w, "%s %g\n\n", v.Name, v.F.Value()) return } if v.M != nil { count := 0 v.M.Do(func(kv expvar.KeyValue) { vs := "" switch value := kv.Value.(type) { case *expvar.Int: vs = strconv.FormatInt(value.Value(), 10) case *expvar.Float: vs = strconv.FormatFloat(value.Value(), 'g', -1, 64) default: // We only support Int and Float in maps. return } labelValue := quoteLabelValue(kv.Key) fmt.Fprintf(w, "%s{%s=%s} %s\n", v.Name, v.LabelName, labelValue, vs) count++ }) if count > 0 { fmt.Fprintf(w, "\n") } } } // metricNameToOM converts an expvar metric name into an OpenMetrics-compliant // metric name. The latter is more restrictive, as it must match the regexp // "[a-zA-Z_:][a-zA-Z0-9_:]*", AND the ':' is not allowed for a direct // exporter. // // https://prometheus.io/docs/concepts/data_model/#metric-names-and-labels func metricNameToOM(name string) string { n := "" for _, c := range name { if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' { n += string(c) } else { n += "_" } } // If it begins with a number, prepend 'i' as a compromise. if len(n) > 0 && n[0] >= '0' && n[0] <= '9' { n = "i" + n } return n } // According to the spec, we only need to replace these 3 characters in label // values. var labelValueReplacer = strings.NewReplacer( `\`, `\\`, `"`, `\"`, "\n", `\n`) // quoteLabelValue takes an arbitrary string, and quotes it so it can be // used as a label value. Output includes the wrapping `"`. func quoteLabelValue(v string) string { // The spec requires label values to be valid UTF8, with `\`, `"` and "\n" // escaped. If it's invalid UTF8, hard-quote it first. This will result // in uglier looking values, but they will be well formed. if !utf8.ValidString(v) { v = strconv.QuoteToASCII(v) v = v[1 : len(v)-1] } return `"` + labelValueReplacer.Replace(v) + `"` } // NewInt registers a new expvar.Int variable, with the given description. func NewInt(name, desc string) *expvar.Int { infoMu.Lock() descriptions[name] = desc infoMu.Unlock() return expvar.NewInt(name) } // NewFloat registers a new expvar.Float variable, with the given description. func NewFloat(name, desc string) *expvar.Float { infoMu.Lock() descriptions[name] = desc infoMu.Unlock() return expvar.NewFloat(name) } // NewMap registers a new expvar.Map variable, with the given label // name and description. func NewMap(name, labelName, desc string) *expvar.Map { // Prevent accidents when using the description as the label name. if strings.Contains(labelName, " ") { panic(fmt.Sprintf( "label name has spaces, mix up with the description? %q", labelName)) } infoMu.Lock() descriptions[name] = desc mapLabelNames[name] = labelName infoMu.Unlock() return expvar.NewMap(name) }
// Package haproxy implements the handshake for the HAProxy client protocol // version 1, as described in // https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt. package haproxy import ( "bufio" "errors" "net" "strconv" "strings" ) var ( errInvalidProtoID = errors.New("invalid protocol identifier") errUnkProtocol = errors.New("unknown protocol") errInvalidFields = errors.New("invalid number of fields") errInvalidSrcIP = errors.New("invalid src ip") errInvalidDstIP = errors.New("invalid dst ip") errInvalidSrcPort = errors.New("invalid src port") errInvalidDstPort = errors.New("invalid dst port") ) // Handshake performs the HAProxy protocol v1 handshake on the given reader, // which is expected to be backed by a network connection. // It returns the source and destination addresses, or an error if the // handshake could not complete. // Note that any timeouts or limits must be set by the caller on the // underlying connection, this is helper only to perform the handshake. func Handshake(r *bufio.Reader) (src, dst net.Addr, err error) { line, err := r.ReadString('\n') if err != nil { return nil, nil, err } fields := strings.Fields(line) if len(fields) < 2 || fields[0] != "PROXY" { return nil, nil, errInvalidProtoID } switch fields[1] { case "TCP4", "TCP6": // Allowed to continue, nothing to do. default: return nil, nil, errUnkProtocol } if len(fields) != 6 { return nil, nil, errInvalidFields } srcIP := net.ParseIP(fields[2]) if srcIP == nil { return nil, nil, errInvalidSrcIP } dstIP := net.ParseIP(fields[3]) if dstIP == nil { return nil, nil, errInvalidDstIP } srcPort, err := strconv.ParseUint(fields[4], 10, 16) if err != nil { return nil, nil, errInvalidSrcPort } dstPort, err := strconv.ParseUint(fields[5], 10, 16) if err != nil { return nil, nil, errInvalidDstPort } src = &net.TCPAddr{IP: srcIP, Port: int(srcPort)} dst = &net.TCPAddr{IP: dstIP, Port: int(dstPort)} return src, dst, nil }
// Local RPC package. // // This is a simple RPC package that uses a line-oriented protocol for // encoding and decoding, and Unix sockets for transport. It is meant to be // used for lightweight occasional communication between processes on the // same machine. package localrpc import ( "errors" "net" "net/textproto" "net/url" "os" "strings" "time" "blitiri.com.ar/go/chasquid/internal/trace" ) // Handler is the type of RPC request handlers. type Handler func(tr *trace.Trace, input url.Values) (url.Values, error) // // Server // // Server represents the RPC server. type Server struct { handlers map[string]Handler lis net.Listener } // NewServer creates a new local RPC server. func NewServer() *Server { return &Server{ handlers: make(map[string]Handler), } } var errUnknownMethod = errors.New("unknown method") // Register a handler for the given name. func (s *Server) Register(name string, handler Handler) { s.handlers[name] = handler } // ListenAndServe starts the server. func (s *Server) ListenAndServe(path string) error { tr := trace.New("LocalRPC.Server", path) defer tr.Finish() // Previous instances of the server may have shut down uncleanly, leaving // behind the socket file. Remove it just in case. os.Remove(path) var err error s.lis, err = net.Listen("unix", path) if err != nil { return err } tr.Printf("Listening") for { conn, err := s.lis.Accept() if err != nil { tr.Errorf("Accept error: %v", err) return err } go s.handleConn(tr, conn) } } // Close stops the server. func (s *Server) Close() error { return s.lis.Close() } func (s *Server) handleConn(tr *trace.Trace, conn net.Conn) { tr = tr.NewChild("LocalRPC.Handle", conn.RemoteAddr().String()) defer tr.Finish() // Set a generous deadline to prevent client issues from tying up a server // goroutine indefinitely. conn.SetDeadline(time.Now().Add(5 * time.Second)) tconn := textproto.NewConn(conn) defer tconn.Close() // Read the request. name, inS, err := readRequest(&tconn.Reader) if err != nil { tr.Debugf("error reading request: %v", err) return } tr.Debugf("<- %s %s", name, inS) // Find the handler. handler, ok := s.handlers[name] if !ok { writeError(tr, tconn, errUnknownMethod) return } // Unmarshal the input. inV, err := url.ParseQuery(inS) if err != nil { writeError(tr, tconn, err) return } // Call the handler. outV, err := handler(tr, inV) if err != nil { writeError(tr, tconn, err) return } // Send the response. outS := outV.Encode() tr.Debugf("-> 200 %s", outS) tconn.PrintfLine("200 %s", outS) } func readRequest(r *textproto.Reader) (string, string, error) { line, err := r.ReadLine() if err != nil { return "", "", err } sp := strings.SplitN(line, " ", 2) if len(sp) == 1 { return sp[0], "", nil } return sp[0], sp[1], nil } func writeError(tr *trace.Trace, tconn *textproto.Conn, err error) { tr.Errorf("-> 500 %s", err.Error()) tconn.PrintfLine("500 %s", err.Error()) } // Default server. This is a singleton server that can be used for // convenience. var DefaultServer = NewServer() // // Client // // Client for the localrpc server. type Client struct { path string } // NewClient creates a new client for the given path. func NewClient(path string) *Client { return &Client{path: path} } // CallWithValues calls the given method. func (c *Client) CallWithValues(name string, input url.Values) (url.Values, error) { conn, err := textproto.Dial("unix", c.path) if err != nil { return nil, err } defer conn.Close() err = conn.PrintfLine("%s %s", name, input.Encode()) if err != nil { return nil, err } code, msg, err := conn.ReadCodeLine(0) if err != nil { return nil, err } if code != 200 { return nil, errors.New(msg) } return url.ParseQuery(msg) } // Call the given method. The arguments are key-value strings, and must be // provided in pairs. func (c *Client) Call(name string, args ...string) (url.Values, error) { v := url.Values{} for i := 0; i < len(args); i += 2 { v.Set(args[i], args[i+1]) } return c.CallWithValues(name, v) }
// Package maillog implements a log specifically for email. package maillog import ( "fmt" "io" "log/syslog" "net" "sync" "time" "blitiri.com.ar/go/chasquid/internal/trace" "blitiri.com.ar/go/log" ) // Global event logs. var ( authLog = trace.New("Authentication", "Incoming SMTP") ) // Logger contains a backend used to log data to, such as a file or syslog. // It implements various user-friendly methods for logging mail information to // it. type Logger struct { inner *log.Logger once sync.Once } // New creates a new Logger which will write messages to the given writer. func New(w io.WriteCloser) *Logger { inner := log.New(w) // Don't include level or caller in the output, it doesn't add value for // this type of log. inner.LogLevel = false inner.LogCaller = false return &Logger{inner: inner} } // NewFile creates a new Logger which will write messages to the file at the // given path. func NewFile(path string) (*Logger, error) { inner, err := log.NewFile(path) if err != nil { return nil, err } // Don't include level or caller in the output, it doesn't add value for // this type of log. inner.LogLevel = false inner.LogCaller = false return &Logger{inner: inner}, nil } // NewSyslog creates a new Logger which will write messages to syslog. func NewSyslog() (*Logger, error) { inner, err := log.NewSyslog(syslog.LOG_INFO|syslog.LOG_MAIL, "chasquid") if err != nil { return nil, err } // Like NewFile, we skip level and caller. In addition, we skip time, as // syslog usually adds that on its own. inner.LogLevel = false inner.LogCaller = false inner.LogTime = false return &Logger{inner: inner}, nil } func (l *Logger) printf(format string, args ...interface{}) { err := l.inner.Log(log.Info, 2, format, args...) if err != nil { l.once.Do(func() { log.Errorf("failed to write to maillog: %v", err) log.Errorf("(will not report this again)") }) } } // Reopen the underlying logger. func (l *Logger) Reopen() error { return l.inner.Reopen() } // Listening logs that the daemon is listening on the given address. func (l *Logger) Listening(a string) { l.printf("daemon listening on %s\n", a) } // Auth logs an authentication request. func (l *Logger) Auth(netAddr net.Addr, user string, successful bool) { res := "succeeded" if !successful { res = "failed" } msg := fmt.Sprintf("%s auth %s for %s\n", netAddr, res, user) l.printf(msg) authLog.Debugf(msg) } // Rejected logs that we've rejected an email. func (l *Logger) Rejected(netAddr net.Addr, from string, to []string, err string) { if from != "" { from = fmt.Sprintf(" from=%s", from) } toStr := "" if len(to) > 0 { toStr = fmt.Sprintf(" to=%v", to) } l.printf("%s rejected%s%s - %v\n", netAddr, from, toStr, err) } // Queued logs that we have queued an email. func (l *Logger) Queued(netAddr net.Addr, from string, to []string, id string) { l.printf("%s from=%s queued ip=%s to=%v\n", id, from, netAddr, to) } // SendAttempt logs that we have attempted to send an email. func (l *Logger) SendAttempt(id, from, to string, err error, permanent bool) { if err == nil { l.printf("%s from=%s to=%s sent\n", id, from, to) } else { t := "(temporary)" if permanent { t = "(permanent)" } l.printf("%s from=%s to=%s failed %s: %v\n", id, from, to, t, err) } } // QueueLoop logs that we have completed a queue loop. func (l *Logger) QueueLoop(id, from string, nextDelay time.Duration) { if nextDelay > 0 { l.printf("%s from=%s completed loop, next in %v\n", id, from, nextDelay) } else { l.printf("%s from=%s all done\n", id, from) } } type nopCloser struct { io.Writer } func (nopCloser) Close() error { return nil } // Default logger, used in the following top-level functions. var Default *Logger = New(nopCloser{io.Discard}) // Listening logs that the daemon is listening on the given address. func Listening(a string) { Default.Listening(a) } // Auth logs an authentication request. func Auth(netAddr net.Addr, user string, successful bool) { Default.Auth(netAddr, user, successful) } // Rejected logs that we've rejected an email. func Rejected(netAddr net.Addr, from string, to []string, err string) { Default.Rejected(netAddr, from, to, err) } // Queued logs that we have queued an email. func Queued(netAddr net.Addr, from string, to []string, id string) { Default.Queued(netAddr, from, to, id) } // SendAttempt logs that we have attempted to send an email. func SendAttempt(id, from, to string, err error, permanent bool) { Default.SendAttempt(id, from, to, err, permanent) } // QueueLoop logs that we have completed a queue loop. func QueueLoop(id, from string, nextDelay time.Duration) { Default.QueueLoop(id, from, nextDelay) }
package nettrace import "context" type ctxKeyT string const ctxKey ctxKeyT = "blitiri.com.ar/go/srv/nettrace" // NewContext returns a new context with the given trace attached. func NewContext(ctx context.Context, tr Trace) context.Context { return context.WithValue(ctx, ctxKey, tr) } // FromContext returns the trace attached to the given context (if any). func FromContext(ctx context.Context) (Trace, bool) { tr, ok := ctx.Value(ctxKey).(Trace) return tr, ok } // FromContextOrNew returns the trace attached to the given context, or a new // trace if there is none. func FromContextOrNew(ctx context.Context, family, title string) (Trace, context.Context) { tr, ok := FromContext(ctx) if ok { return tr, ctx } tr = New(family, title) return tr, NewContext(ctx, tr) } // ChildFromContext returns a new trace that is a child of the one attached to // the context (if any). func ChildFromContext(ctx context.Context, family, title string) Trace { parent, ok := FromContext(ctx) if ok { return parent.NewChild(family, title) } return New(family, title) }
package nettrace import "time" type evtRing struct { evts []event max int pos int // Points to the latest element. firstDrop time.Time } func newEvtRing(n int) *evtRing { return &evtRing{ max: n, pos: -1, } } func (r *evtRing) Add(e *event) { if len(r.evts) < r.max { r.evts = append(r.evts, *e) r.pos++ return } r.pos = (r.pos + 1) % r.max // Record the first drop as the time of the first dropped message. if r.firstDrop.IsZero() { r.firstDrop = r.evts[r.pos].When } r.evts[r.pos] = *e } func (r *evtRing) Do(f func(e *event)) { for i := 0; i < len(r.evts); i++ { // Go from older to newer by starting at (r.pos+1). pos := (r.pos + 1 + i) % len(r.evts) f(&r.evts[pos]) } }
package nettrace import ( "time" ) type histogram struct { count [nBuckets]uint64 totalQ uint64 totalT time.Duration min time.Duration max time.Duration } func (h *histogram) Add(bucket int, latency time.Duration) { if h.totalQ == 0 || h.min > latency { h.min = latency } if h.max < latency { h.max = latency } h.count[bucket]++ h.totalQ++ h.totalT += latency } type histSnapshot struct { Counts map[time.Duration]line Count uint64 Avg, Min, Max time.Duration } type line struct { Start time.Duration BucketIdx int Count uint64 Percent float32 CumPct float32 } func (h *histogram) Snapshot() *histSnapshot { s := &histSnapshot{ Counts: map[time.Duration]line{}, Count: h.totalQ, Min: h.min, Max: h.max, } if h.totalQ > 0 { s.Avg = time.Duration(uint64(h.totalT) / h.totalQ) } var cumCount uint64 for i := 0; i < nBuckets; i++ { cumCount += h.count[i] l := line{ Start: buckets[i], BucketIdx: i, Count: h.count[i], } if h.totalQ > 0 { l.Percent = float32(h.count[i]) / float32(h.totalQ) * 100 l.CumPct = float32(cumCount) / float32(h.totalQ) * 100 } s.Counts[buckets[i]] = l } return s }
package nettrace import ( "bytes" "embed" "fmt" "hash/crc32" "html/template" "math" "net/http" "sort" "strconv" "time" ) //go:embed "templates/*.tmpl" "templates/*.css" var templatesFS embed.FS var top *template.Template func init() { top = template.Must( template.New("_top").Funcs(template.FuncMap{ "stripZeros": stripZeros, "roundSeconds": roundSeconds, "roundDuration": roundDuration, "colorize": colorize, "depthspan": depthspan, "shorttitle": shorttitle, "traceemoji": traceemoji, }).ParseFS(templatesFS, "templates/*")) } // RegisterHandler registers a the trace handler in the given ServeMux, on // `/debug/traces`. func RegisterHandler(mux *http.ServeMux) { mux.HandleFunc("/debug/traces", RenderTraces) } // RenderTraces is an http.Handler that renders the tracing information. func RenderTraces(w http.ResponseWriter, req *http.Request) { data := &struct { Buckets *[]time.Duration FamTraces map[string]*familyTraces // When displaying traces for a specific family. Family string Bucket int BucketStr string AllGT bool Traces []*trace // When displaying latencies for a specific family. Latencies *histSnapshot // When displaying a specific trace. Trace *trace AllEvents []traceAndEvent // Error to show to the user. Error string }{} // Reference the common buckets, no need to copy them. data.Buckets = &buckets // Copy the family traces map, so we don't have to keep it locked for too // long. We'll still need to lock individual entries. data.FamTraces = copyFamilies() // Default to showing greater-than. data.AllGT = true if all := req.FormValue("all"); all != "" { data.AllGT, _ = strconv.ParseBool(all) } // Fill in the family related parameters. if fam := req.FormValue("fam"); fam != "" { if _, ok := data.FamTraces[fam]; !ok { data.Family = "" data.Error = "Unknown family" w.WriteHeader(http.StatusNotFound) goto render } data.Family = fam if bs := req.FormValue("b"); bs != "" { i, err := strconv.Atoi(bs) if err != nil { data.Error = "Invalid bucket (not a number)" w.WriteHeader(http.StatusBadRequest) goto render } else if i < -2 || i >= nBuckets { data.Error = "Invalid bucket number" w.WriteHeader(http.StatusBadRequest) goto render } data.Bucket = i data.Traces = data.FamTraces[data.Family].TracesFor(i, data.AllGT) switch i { case -2: data.BucketStr = "errors" case -1: data.BucketStr = "active" default: data.BucketStr = buckets[i].String() } } } if lat := req.FormValue("lat"); data.Family != "" && lat != "" { data.Latencies = data.FamTraces[data.Family].Latencies() } if traceID := req.FormValue("trace"); traceID != "" { refID := req.FormValue("ref") tr := findInFamilies(id(traceID), id(refID)) if tr == nil { data.Error = "Trace not found" w.WriteHeader(http.StatusNotFound) goto render } data.Trace = tr data.Family = tr.Family data.AllEvents = allEvents(tr) } render: // Write into a buffer, to avoid accidentally holding a lock on http // writes. It shouldn't happen, but just to be extra safe. bw := &bytes.Buffer{} bw.Grow(16 * 1024) err := top.ExecuteTemplate(bw, "index.html.tmpl", data) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) panic(err) } w.Write(bw.Bytes()) } type traceAndEvent struct { Trace *trace Event event Depth uint } // allEvents gets all the events for the trace and its children/linked traces; // and returns them sorted by timestamp. func allEvents(tr *trace) []traceAndEvent { // Map tracking all traces we've seen, to avoid loops. seen := map[id]bool{} // Recursively gather all events. evts := appendAllEvents(tr, []traceAndEvent{}, seen, 0) // Sort them by time. sort.Slice(evts, func(i, j int) bool { return evts[i].Event.When.Before(evts[j].Event.When) }) return evts } func appendAllEvents(tr *trace, evts []traceAndEvent, seen map[id]bool, depth uint) []traceAndEvent { if seen[tr.ID] { return evts } seen[tr.ID] = true subTraces := []*trace{} // Append all events of this trace. trevts := tr.Events() for _, e := range trevts { evts = append(evts, traceAndEvent{tr, e, depth}) if e.Ref != nil { subTraces = append(subTraces, e.Ref) } } for _, t := range subTraces { evts = appendAllEvents(t, evts, seen, depth+1) } return evts } func stripZeros(d time.Duration) string { if d < time.Second { _, frac := math.Modf(d.Seconds()) return fmt.Sprintf(" .%6d", int(frac*1000000)) } return fmt.Sprintf("%.6f", d.Seconds()) } func roundSeconds(d time.Duration) string { return fmt.Sprintf("%.6f", d.Seconds()) } func roundDuration(d time.Duration) time.Duration { return d.Round(time.Millisecond) } func colorize(depth uint, id id) template.CSS { if depth == 0 { return template.CSS("rgba(var(--text-color))") } if depth > 3 { depth = 3 } // Must match the number of nested color variables in the CSS. colori := crc32.ChecksumIEEE([]byte(id)) % 6 return template.CSS( fmt.Sprintf("var(--nested-d%02d-c%02d)", depth, colori)) } func depthspan(depth uint) template.HTML { s := `<span class="depth">` switch depth { case 0: case 1: s += "· " case 2: s += "· · " case 3: s += "· · · " case 4: s += "· · · · " default: s += fmt.Sprintf("· (%d) · ", depth) } s += `</span>` return template.HTML(s) } // Hand-picked emojis that have enough visual differences in most common // renderings, and are common enough to be able to easily describe them. var emojids = []rune(`😀🤣😇🥰🤧😈🤡👻👽🤖👋✊🦴👅` + `🐒🐕🦊🐱🐯🐎🐄🐷🐑🐐🐪🦒🐘🐀🦇🐓🦆🦚🦜🐢🐍🦖🐋🐟🦈🐙` + `🦋🐜🐝🪲🌻🌲🍉🍌🍍🍎🍑🥕🍄` + `🧀🍦🍰🧉🚂🚗🚜🛵🚲🛼🪂🚀🌞🌈🌊⚽`) func shorttitle(tr *trace) string { all := tr.Family + " - " + tr.Title if len(all) > 20 { all = "..." + all[len(all)-17:] } return all } func traceemoji(id id) string { i := crc32.ChecksumIEEE([]byte(id)) % uint32(len(emojids)) return string(emojids[i]) }
// Package nettrace implements tracing of requests. Traces are created by // nettrace.New, and can then be viewed over HTTP on /debug/traces. package nettrace import ( "container/ring" "fmt" "math/rand" "sort" "strconv" "strings" "sync" "time" ) // IDs are of the form "family!timestamp!unique identifier", which allows for // sorting them by time much easily, and also some convenient optimizations // when looking up an id across all the known ones. // Family is not escaped. It should not contain the separator. // It is not expected to be stable, for internal use only. type id string func newID(family string, ts int64) id { return id( family + "!" + strconv.FormatInt(ts, 10) + "!" + strconv.FormatUint(rand.Uint64(), 10)) } func (id id) Family() string { sp := strings.SplitN(string(id), "!", 2) if len(sp) != 2 { return string(id) } return sp[0] } // Trace represents a single request trace. type Trace interface { // NewChild creates a new trace, that is a child of this one. NewChild(family, title string) Trace // Link to another trace with the given message. Link(other Trace, msg string) // SetMaxEvents sets the maximum number of events that will be stored in // the trace. It must be called right after initialization. SetMaxEvents(n int) // SetError marks that the trace was for an error event. SetError() // Printf adds a message to the trace. Printf(format string, a ...interface{}) // Errorf adds a message to the trace, marks it as an error, and returns // an error for it. Errorf(format string, a ...interface{}) error // Finish marks the trace as complete. // The trace should not be used after calling this method. Finish() } // A single trace. Can be active or inactive. // Exported fields are allowed to be accessed directly, e.g. by the HTTP // handler. Private ones are mutex protected. type trace struct { ID id Family string Title string Parent *trace Start time.Time // Fields below are mu-protected. // We keep them unexported so they're not accidentally accessed in html // templates. mu sync.Mutex end time.Time isError bool maxEvents int // We keep two separate groups: the first ~1/3rd events in a simple slice, // and the last 2/3rd in a ring so we can drop events without losing the // first ones. cutoff int firstEvents []event lastEvents *evtRing } type evtType uint8 const ( evtLOG = evtType(1 + iota) evtCHILD evtLINK evtDROP ) func (t evtType) IsLog() bool { return t == evtLOG } func (t evtType) IsChild() bool { return t == evtCHILD } func (t evtType) IsLink() bool { return t == evtLINK } func (t evtType) IsDrop() bool { return t == evtDROP } type event struct { When time.Time Type evtType Ref *trace Msg string } const defaultMaxEvents = 30 func newTrace(family, title string) *trace { start := time.Now() tr := &trace{ ID: newID(family, start.UnixNano()), Family: family, Title: title, Start: start, maxEvents: defaultMaxEvents, cutoff: defaultMaxEvents / 3, } // Pre-allocate a couple of events to speed things up. // Don't allocate lastEvents, that can be expensive and it is not always // needed. No need to slow down trace creation just for it. tr.firstEvents = make([]event, 0, 4) familiesMu.Lock() ft, ok := families[family] if !ok { ft = newFamilyTraces() families[family] = ft } familiesMu.Unlock() ft.mu.Lock() ft.active[tr.ID] = tr ft.mu.Unlock() return tr } // New creates a new trace with the given family and title. func New(family, title string) Trace { return newTrace(family, title) } func (tr *trace) append(evt *event) { tr.mu.Lock() defer tr.mu.Unlock() if len(tr.firstEvents) < tr.cutoff { tr.firstEvents = append(tr.firstEvents, *evt) return } if tr.lastEvents == nil { // The ring holds the last 2/3rds of the events. tr.lastEvents = newEvtRing(tr.maxEvents - tr.cutoff) } tr.lastEvents.Add(evt) } // String is for debugging only. func (tr *trace) String() string { return fmt.Sprintf("trace{%s, %s, %q, %d}", tr.Family, tr.Title, tr.ID, len(tr.Events())) } func (tr *trace) NewChild(family, title string) Trace { c := newTrace(family, title) c.Parent = tr // Add the event to the parent. evt := &event{ When: time.Now(), Type: evtCHILD, Ref: c, } tr.append(evt) return c } func (tr *trace) Link(other Trace, msg string) { evt := &event{ When: time.Now(), Type: evtLINK, Ref: other.(*trace), Msg: msg, } tr.append(evt) } func (tr *trace) SetMaxEvents(n int) { // Set a minimum of 6, so the truncation works without running into // issues. if n < 6 { n = 6 } tr.mu.Lock() tr.maxEvents = n tr.cutoff = n / 3 tr.mu.Unlock() } func (tr *trace) SetError() { tr.mu.Lock() tr.isError = true tr.mu.Unlock() } func (tr *trace) Printf(format string, a ...interface{}) { evt := &event{ When: time.Now(), Type: evtLOG, Msg: fmt.Sprintf(format, a...), } tr.append(evt) } func (tr *trace) Errorf(format string, a ...interface{}) error { tr.SetError() err := fmt.Errorf(format, a...) tr.Printf(err.Error()) return err } func (tr *trace) Finish() { tr.mu.Lock() tr.end = time.Now() tr.mu.Unlock() familiesMu.Lock() ft := families[tr.Family] familiesMu.Unlock() ft.finalize(tr) } // Duration of this trace. func (tr *trace) Duration() time.Duration { tr.mu.Lock() start, end := tr.Start, tr.end tr.mu.Unlock() if end.IsZero() { return time.Since(start) } return end.Sub(start) } // Events returns a copy of the trace events. func (tr *trace) Events() []event { tr.mu.Lock() defer tr.mu.Unlock() evts := make([]event, len(tr.firstEvents)) copy(evts, tr.firstEvents) if tr.lastEvents == nil { return evts } if !tr.lastEvents.firstDrop.IsZero() { evts = append(evts, event{ When: tr.lastEvents.firstDrop, Type: evtDROP, }) } tr.lastEvents.Do(func(e *event) { evts = append(evts, *e) }) return evts } func (tr *trace) IsError() bool { tr.mu.Lock() defer tr.mu.Unlock() return tr.isError } // // Trace hierarchy // // Each trace belongs to a family. For each family, we have all active traces, // and then N traces that finished <1s, N that finished <2s, etc. // We keep this many buckets of finished traces. const nBuckets = 8 // Buckets to use. Length must match nBuckets. // "Traces with a latency >= $duration". var buckets = []time.Duration{ time.Duration(0), 5 * time.Millisecond, 10 * time.Millisecond, 50 * time.Millisecond, 100 * time.Millisecond, 300 * time.Millisecond, 1 * time.Second, 10 * time.Second, } func findBucket(latency time.Duration) int { for i, d := range buckets { if latency >= d { continue } return i - 1 } return nBuckets - 1 } // How many traces we keep per bucket. const tracesInBucket = 10 type traceRing struct { ring *ring.Ring max int l int } func newTraceRing(n int) *traceRing { return &traceRing{ ring: ring.New(n), max: n, } } func (r *traceRing) Add(tr *trace) { r.ring.Value = tr r.ring = r.ring.Next() if r.l < r.max { r.l++ } } func (r *traceRing) Len() int { return r.l } func (r *traceRing) Do(f func(tr *trace)) { r.ring.Do(func(x interface{}) { if x == nil { return } f(x.(*trace)) }) } type familyTraces struct { mu sync.Mutex // All active ones. active map[id]*trace // The ones we decided to keep. // Each bucket is a ring-buffer, finishedHead keeps the head pointer. finished [nBuckets]*traceRing // The ones that errored have their own bucket. errors *traceRing // Histogram of latencies. latencies histogram } func newFamilyTraces() *familyTraces { ft := &familyTraces{} ft.active = map[id]*trace{} for i := 0; i < nBuckets; i++ { ft.finished[i] = newTraceRing(tracesInBucket) } ft.errors = newTraceRing(tracesInBucket) return ft } func (ft *familyTraces) LenActive() int { ft.mu.Lock() defer ft.mu.Unlock() return len(ft.active) } func (ft *familyTraces) LenErrors() int { ft.mu.Lock() defer ft.mu.Unlock() return ft.errors.Len() } func (ft *familyTraces) LenBucket(b int) int { ft.mu.Lock() defer ft.mu.Unlock() return ft.finished[b].Len() } func (ft *familyTraces) TracesFor(b int, allgt bool) []*trace { ft.mu.Lock() defer ft.mu.Unlock() trs := []*trace{} appendTrace := func(tr *trace) { trs = append(trs, tr) } if b == -2 { ft.errors.Do(appendTrace) } else if b == -1 { for _, tr := range ft.active { appendTrace(tr) } } else if b < nBuckets { ft.finished[b].Do(appendTrace) if allgt { for i := b + 1; i < nBuckets; i++ { ft.finished[i].Do(appendTrace) } } } // Sort them by start, newer first. This is the order that will be used // when displaying them. sort.Slice(trs, func(i, j int) bool { return trs[i].Start.After(trs[j].Start) }) return trs } func (ft *familyTraces) find(id id) *trace { ft.mu.Lock() defer ft.mu.Unlock() if tr, ok := ft.active[id]; ok { return tr } var found *trace for _, bs := range ft.finished { bs.Do(func(tr *trace) { if tr.ID == id { found = tr } }) if found != nil { return found } } ft.errors.Do(func(tr *trace) { if tr.ID == id { found = tr } }) if found != nil { return found } return nil } func (ft *familyTraces) finalize(tr *trace) { latency := tr.end.Sub(tr.Start) b := findBucket(latency) ft.mu.Lock() // Delete from the active list. delete(ft.active, tr.ID) // Add it to the corresponding finished bucket, based on the trace // latency. ft.finished[b].Add(tr) // Errors go on their own list, in addition to the above. if tr.isError { ft.errors.Add(tr) } ft.latencies.Add(b, latency) ft.mu.Unlock() } func (ft *familyTraces) Latencies() *histSnapshot { ft.mu.Lock() defer ft.mu.Unlock() return ft.latencies.Snapshot() } // // Global state // var ( familiesMu sync.Mutex families = map[string]*familyTraces{} ) func copyFamilies() map[string]*familyTraces { n := map[string]*familyTraces{} familiesMu.Lock() for f, trs := range families { n[f] = trs } familiesMu.Unlock() return n } func findInFamilies(traceID id, refID id) *trace { // First, try to find it via the family. family := traceID.Family() familiesMu.Lock() fts, ok := families[family] familiesMu.Unlock() if ok { tr := fts.find(traceID) if tr != nil { return tr } } // If that fail and we have a reference, try finding via it. // The reference can be a parent or a child. if refID != id("") { ref := findInFamilies(refID, "") if ref == nil { return nil } // Is the reference's parent the one we're looking for? if ref.Parent != nil && ref.Parent.ID == traceID { return ref.Parent } // Try to find it in the ref's events. for _, e := range ref.Events() { if e.Ref != nil && e.Ref.ID == traceID { return e.Ref } } } return nil }
// Package normalize contains functions to normalize usernames, domains and // addresses. package normalize import ( "bytes" "strings" "blitiri.com.ar/go/chasquid/internal/envelope" "golang.org/x/net/idna" "golang.org/x/text/secure/precis" "golang.org/x/text/unicode/norm" ) // User normalizes an username using PRECIS. // On error, it will also return the original username to simplify callers. func User(user string) (string, error) { norm, err := precis.UsernameCaseMapped.String(user) if err != nil { return user, err } return norm, nil } // Domain normalizes a DNS domain into a cleaned UTF-8 form. // On error, it will also return the original domain to simplify callers. func Domain(domain string) (string, error) { // For now, we just convert them to lower case and make sure it's in NFC // form for consistency. // There are other possible transformations (like nameprep) but for our // purposes these should be enough. // https://tools.ietf.org/html/rfc5891#section-5.2 // https://blog.golang.org/normalization d, err := idna.ToUnicode(domain) if err != nil { return domain, err } d = norm.NFC.String(d) d = strings.ToLower(d) return d, nil } // Addr normalizes an email address, applying User and Domain to its // respective components. // On error, it will also return the original address to simplify callers. func Addr(addr string) (string, error) { user, domain := envelope.Split(addr) user, err := User(user) if err != nil { return addr, err } domain, err = Domain(domain) if err != nil { return addr, err } return user + "@" + domain, nil } // DomainToUnicode takes an address with an ASCII domain, and convert it to // Unicode as per IDNA, including basic normalization. // The user part is unchanged. func DomainToUnicode(addr string) (string, error) { if addr == "<>" { return addr, nil } user, domain := envelope.Split(addr) domain, err := Domain(domain) return user + "@" + domain, err } // ToCRLF converts the given buffer to CRLF line endings. If a line has a // preexisting CRLF, it leaves it be. It assumes that CR is never used on its // own. func ToCRLF(in []byte) []byte { b := bytes.NewBuffer(nil) b.Grow(len(in)) for _, c := range in { switch c { case '\r': // Ignore CR, we'll add it back later. It should never appear // alone in the contexts where this function is used. case '\n': b.Write([]byte("\r\n")) default: b.WriteByte(c) } } return b.Bytes() } // StringToCRLF is like ToCRLF, but operates on strings. func StringToCRLF(in string) string { return string(ToCRLF([]byte(in))) }
// Package protoio contains I/O functions for protocol buffers. package protoio import ( "net/url" "os" "strings" "blitiri.com.ar/go/chasquid/internal/safeio" "google.golang.org/protobuf/encoding/prototext" "google.golang.org/protobuf/proto" ) // ReadMessage reads a protocol buffer message from fname, and unmarshalls it // into pb. func ReadMessage(fname string, pb proto.Message) error { in, err := os.ReadFile(fname) if err != nil { return err } return proto.Unmarshal(in, pb) } // ReadTextMessage reads a text format protocol buffer message from fname, and // unmarshalls it into pb. func ReadTextMessage(fname string, pb proto.Message) error { in, err := os.ReadFile(fname) if err != nil { return err } return prototext.Unmarshal(in, pb) } // WriteMessage marshals pb and atomically writes it into fname. func WriteMessage(fname string, pb proto.Message, perm os.FileMode) error { out, err := proto.Marshal(pb) if err != nil { return err } return safeio.WriteFile(fname, out, perm) } var textOpts = prototext.MarshalOptions{ Multiline: true, } // WriteTextMessage marshals pb in text format and atomically writes it into // fname. func WriteTextMessage(fname string, pb proto.Message, perm os.FileMode) error { out, err := textOpts.Marshal(pb) if err != nil { return err } return safeio.WriteFile(fname, out, perm) } /////////////////////////////////////////////////////////////// // Store represents a persistent protocol buffer message store. type Store struct { // Directory where the store is. dir string } // NewStore returns a new Store instance. It will create dir if needed. func NewStore(dir string) (*Store, error) { s := &Store{dir} err := os.MkdirAll(dir, 0770) return s, err } const storeIDPrefix = "s:" // idToFname takes a generic id and returns the corresponding file for it // (which may or may not exist). func (s *Store) idToFname(id string) string { return s.dir + "/" + storeIDPrefix + url.QueryEscape(id) } // Put a message into the store. func (s *Store) Put(id string, m proto.Message) error { return WriteTextMessage(s.idToFname(id), m, 0660) } // Get a message from the store. func (s *Store) Get(id string, m proto.Message) (bool, error) { err := ReadTextMessage(s.idToFname(id), m) if os.IsNotExist(err) { return false, nil } return err == nil, err } // ListIDs in the store. func (s *Store) ListIDs() ([]string, error) { ids := []string{} entries, err := os.ReadDir(s.dir) if err != nil { return nil, err } for _, e := range entries { if !strings.HasPrefix(e.Name(), storeIDPrefix) { continue } id := e.Name()[len(storeIDPrefix):] id, err = url.QueryUnescape(id) if err != nil { continue } ids = append(ids, id) } return ids, nil }
package queue import ( "bytes" "net/mail" "strings" "text/template" "time" ) // Maximum length of the original message to include in the DSN. // The receiver of the DSN might have a smaller message size than what we // accepted, so we truncate to a value that should be large enough to be // useful, but not problematic for modern deployments. const maxOrigMsgLen = 256 * 1024 // deliveryStatusNotification creates a delivery status notification (DSN) for // the given item, and puts it in the queue. // // References: // - https://tools.ietf.org/html/rfc3464 (DSN) // - https://tools.ietf.org/html/rfc6533 (Internationalized DSN) func deliveryStatusNotification(domainFrom string, item *Item) ([]byte, error) { info := dsnInfo{ OurDomain: domainFrom, Destination: item.From, MessageID: "chasquid-dsn-" + <-newID + "@" + domainFrom, Date: time.Now().Format(time.RFC1123Z), To: item.To, Recipients: item.Rcpt, FailedTo: map[string]string{}, } for _, rcpt := range item.Rcpt { if rcpt.Status != Recipient_SENT { info.FailedTo[rcpt.OriginalAddress] = rcpt.OriginalAddress switch rcpt.Status { case Recipient_FAILED: info.FailedRecipients = append(info.FailedRecipients, rcpt) case Recipient_PENDING: info.PendingRecipients = append(info.PendingRecipients, rcpt) } } } if len(item.Data) > maxOrigMsgLen { info.OriginalMessage = string(item.Data[:maxOrigMsgLen]) } else { info.OriginalMessage = string(item.Data) } info.OriginalMessageID = getMessageID(item.Data) info.Boundary = <-newID buf := &bytes.Buffer{} err := dsnTemplate.Execute(buf, info) return buf.Bytes(), err } func getMessageID(data []byte) string { msg, err := mail.ReadMessage(bytes.NewBuffer(data)) if err != nil { return "" } return msg.Header.Get("Message-ID") } type dsnInfo struct { OurDomain string Destination string MessageID string Date string To []string FailedTo map[string]string Recipients []*Recipient FailedRecipients []*Recipient PendingRecipients []*Recipient OriginalMessage string // Message-ID of the original message. OriginalMessageID string // MIME boundary to use to form the message. Boundary string } // indent s with the given number of spaces. func indent(sp int, s string) string { pad := strings.Repeat(" ", sp) return strings.Replace(s, "\n", "\n"+pad, -1) } var dsnTemplate = template.Must( template.New("dsn").Funcs( template.FuncMap{ "indent": indent, "trim": strings.TrimSpace, }).Parse( `From: Mail Delivery System <postmaster-dsn@{{.OurDomain}}> To: <{{.Destination}}> Subject: Mail delivery failed: returning message to sender Message-ID: <{{.MessageID}}> Date: {{.Date}} In-Reply-To: {{.OriginalMessageID}} References: {{.OriginalMessageID}} X-Failed-Recipients: {{range .FailedTo}}{{.}}, {{end}} Auto-Submitted: auto-replied MIME-Version: 1.0 Content-Type: multipart/report; report-type=delivery-status; boundary="{{.Boundary}}" --{{.Boundary}} Content-Type: text/plain; charset="utf-8" Content-Disposition: inline Content-Description: Notification Content-Transfer-Encoding: 8bit Delivery of your message to the following recipient(s) failed permanently: {{range .FailedTo}} - {{.}} {{end}} Technical details: {{- range .FailedRecipients}} - "{{.Address}}" ({{.Type}}) failed permanently with error: {{.LastFailureMessage | trim | indent 4}} {{- end}} {{- range .PendingRecipients}} - "{{.Address}}" ({{.Type}}) failed repeatedly and timed out, last error: {{.LastFailureMessage | trim | indent 4}} {{- end}} --{{.Boundary}} Content-Type: message/global-delivery-status Content-Description: Delivery Report Content-Transfer-Encoding: 8bit Reporting-MTA: dns; {{.OurDomain}} {{range .FailedRecipients -}} Original-Recipient: utf-8; {{.OriginalAddress}} Final-Recipient: utf-8; {{.Address}} Action: failed Status: 5.0.0 Diagnostic-Code: smtp; {{.LastFailureMessage | trim | indent 4}} {{end -}} {{range .PendingRecipients -}} Original-Recipient: utf-8; {{.OriginalAddress}} Final-Recipient: utf-8; {{.Address}} Action: failed Status: 4.0.0 Diagnostic-Code: smtp; {{.LastFailureMessage | trim | indent 4}} {{end}} --{{.Boundary}} Content-Type: message/rfc822 Content-Description: Undelivered Message Content-Transfer-Encoding: 8bit {{.OriginalMessage}} --{{.Boundary}}-- `))
// Package queue implements our email queue. // Accepted envelopes get put in the queue, and processed asynchronously. package queue // Command to generate queue.pb.go from queue.proto. //go:generate protoc --go_out=. --go_opt=paths=source_relative -I=${GOPATH}/src -I. queue.proto import ( "bytes" "context" "encoding/base64" "fmt" "math/rand" "os" "os/exec" "path/filepath" "strings" "sync" "time" "blitiri.com.ar/go/chasquid/internal/aliases" "blitiri.com.ar/go/chasquid/internal/courier" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/expvarom" "blitiri.com.ar/go/chasquid/internal/maillog" "blitiri.com.ar/go/chasquid/internal/protoio" "blitiri.com.ar/go/chasquid/internal/set" "blitiri.com.ar/go/chasquid/internal/trace" "blitiri.com.ar/go/log" "golang.org/x/net/idna" ) const ( // Maximum size of the queue; we reject emails when we hit this. maxQueueSize = 200 // Give up sending attempts after this duration. giveUpAfter = 20 * time.Hour // Prefix for item file names. // This is for convenience, versioning, and to be able to tell them apart // temporary files and other cruft. // It's important that it's outside the base64 space so it doesn't get // generated accidentally. itemFilePrefix = "m:" ) var ( errQueueFull = fmt.Errorf("Queue size too big, try again later") ) // Exported variables. var ( putCount = expvarom.NewInt("chasquid/queue/putCount", "count of envelopes attempted to be put in the queue") itemsWritten = expvarom.NewInt("chasquid/queue/itemsWritten", "count of items the queue wrote to disk") dsnQueued = expvarom.NewInt("chasquid/queue/dsnQueued", "count of DSNs that we generated (queued)") deliverAttempts = expvarom.NewMap("chasquid/queue/deliverAttempts", "recipient_type", "attempts to deliver mail, by recipient type") ) // Channel used to get random IDs for items in the queue. var newID chan string func generateNewIDs() { // The IDs are only used internally, we are ok with using a PRNG. // We create our own to avoid relying on external sources initializing it // properly. prng := rand.New(rand.NewSource(time.Now().UnixNano())) // IDs are base64(8 random bytes), but the code doesn't care. buf := make([]byte, 8) id := "" for { prng.Read(buf) id = base64.RawURLEncoding.EncodeToString(buf) newID <- id } } func init() { newID = make(chan string, 4) go generateNewIDs() } // Queue that keeps mail waiting for delivery. type Queue struct { // Items in the queue. Map of id -> Item. q map[string]*Item // Mutex protecting q. mu sync.RWMutex // Couriers to use to deliver mail. localC courier.Courier remoteC courier.Courier // Domains we consider local. localDomains *set.String // Path where we store the queue. path string // Aliases resolver. aliases *aliases.Resolver } // New creates a new Queue instance. func New(path string, localDomains *set.String, aliases *aliases.Resolver, localC, remoteC courier.Courier) (*Queue, error) { err := os.MkdirAll(path, 0700) q := &Queue{ q: map[string]*Item{}, localC: localC, remoteC: remoteC, localDomains: localDomains, path: path, aliases: aliases, } return q, err } // Load the queue and launch the sending loops on startup. func (q *Queue) Load() error { files, err := filepath.Glob(q.path + "/" + itemFilePrefix + "*") if err != nil { return err } for _, fname := range files { item, err := ItemFromFile(fname) if err != nil { log.Errorf("error loading queue item from %q: %v", fname, err) continue } q.mu.Lock() q.q[item.ID] = item q.mu.Unlock() go item.SendLoop(q) } return nil } // Len returns the number of elements in the queue. func (q *Queue) Len() int { q.mu.RLock() defer q.mu.RUnlock() return len(q.q) } // Put an envelope in the queue. func (q *Queue) Put(tr *trace.Trace, from string, to []string, data []byte) (string, error) { tr = tr.NewChild("Queue.Put", from) defer tr.Finish() if q.Len() >= maxQueueSize { tr.Errorf("queue full") return "", errQueueFull } putCount.Add(1) item := &Item{ Message: Message{ ID: <-newID, From: from, Data: data, }, CreatedAt: time.Now(), } for _, t := range to { item.To = append(item.To, t) rcpts, err := q.aliases.Resolve(tr, t) if err != nil { return "", fmt.Errorf("error resolving aliases for %q: %v", t, err) } // Add the recipients (after resolving aliases); this conversion is // not very pretty but at least it's self contained. for _, aliasRcpt := range rcpts { r := &Recipient{ Address: aliasRcpt.Addr, Status: Recipient_PENDING, OriginalAddress: t, } switch aliasRcpt.Type { case aliases.EMAIL: r.Type = Recipient_EMAIL case aliases.PIPE: r.Type = Recipient_PIPE default: log.Errorf("unknown alias type %v when resolving %q", aliasRcpt.Type, t) return "", tr.Errorf("internal error - unknown alias type") } item.Rcpt = append(item.Rcpt, r) tr.Debugf("recipient: %v", r.Address) } } err := item.WriteTo(q.path) if err != nil { return "", tr.Errorf("failed to write item: %v", err) } q.mu.Lock() q.q[item.ID] = item q.mu.Unlock() // Begin to send it right away. go item.SendLoop(q) tr.Debugf("queued") return item.ID, nil } // Remove an item from the queue. func (q *Queue) Remove(id string) { path := fmt.Sprintf("%s/%s%s", q.path, itemFilePrefix, id) err := os.Remove(path) if err != nil { log.Errorf("failed to remove queue file %q: %v", path, err) } q.mu.Lock() delete(q.q, id) q.mu.Unlock() } // DumpString returns a human-readable string with the current queue. // Useful for debugging purposes. func (q *Queue) DumpString() string { q.mu.RLock() defer q.mu.RUnlock() s := "# Queue status\n\n" s += fmt.Sprintf("date: %v\n", time.Now()) s += fmt.Sprintf("length: %d\n\n", len(q.q)) for id, item := range q.q { s += fmt.Sprintf("## Item %s\n", id) item.Lock() s += fmt.Sprintf("created at: %s\n", item.CreatedAt) s += fmt.Sprintf("from: %s\n", item.From) s += fmt.Sprintf("to: %s\n", item.To) for _, rcpt := range item.Rcpt { s += fmt.Sprintf("%s %s (%s)\n", rcpt.Status, rcpt.Address, rcpt.Type) s += fmt.Sprintf(" original address: %s\n", rcpt.OriginalAddress) s += fmt.Sprintf(" last failure: %q\n", rcpt.LastFailureMessage) } item.Unlock() s += "\n" } return s } // An Item in the queue. type Item struct { // Base the item on the protobuf message. // We will use this for serialization, so any fields below are NOT // serialized. Message // Protect the entire item. sync.Mutex // Go-friendly version of Message.CreatedAtTs. CreatedAt time.Time } // ItemFromFile loads an item from the given file. func ItemFromFile(fname string) (*Item, error) { item := &Item{} err := protoio.ReadTextMessage(fname, &item.Message) if err != nil { return nil, err } item.CreatedAt = timeFromProto(item.CreatedAtTs) return item, nil } // WriteTo saves an item to the given directory. func (item *Item) WriteTo(dir string) error { item.Lock() defer item.Unlock() itemsWritten.Add(1) item.CreatedAtTs = timeToProto(item.CreatedAt) path := fmt.Sprintf("%s/%s%s", dir, itemFilePrefix, item.ID) return protoio.WriteTextMessage(path, &item.Message, 0600) } // SendLoop repeatedly attempts to send the item. func (item *Item) SendLoop(q *Queue) { tr := trace.New("Queue.SendLoop", item.ID) defer tr.Finish() tr.Printf("from %s", item.From) for time.Since(item.CreatedAt) < giveUpAfter { // Send to all recipients that are still pending. var wg sync.WaitGroup for _, rcpt := range item.Rcpt { if rcpt.Status != Recipient_PENDING { continue } wg.Add(1) go item.sendOneRcpt(&wg, tr, q, rcpt) } wg.Wait() // If they're all done, no need to wait. if item.countRcpt(Recipient_PENDING) == 0 { break } // TODO: Consider sending a non-final notification after 30m or so, // that some of the messages have been delayed. delay := nextDelay(item.CreatedAt) tr.Printf("waiting for %v", delay) maillog.QueueLoop(item.ID, item.From, delay) time.Sleep(delay) } // Completed to all recipients (some may not have succeeded). if item.countRcpt(Recipient_FAILED, Recipient_PENDING) > 0 && item.From != "<>" { sendDSN(tr, q, item) } tr.Printf("all done") maillog.QueueLoop(item.ID, item.From, 0) q.Remove(item.ID) } // sendOneRcpt, and update it with the results. func (item *Item) sendOneRcpt(wg *sync.WaitGroup, tr *trace.Trace, q *Queue, rcpt *Recipient) { defer wg.Done() to := rcpt.Address tr.Debugf("%s sending", to) err, permanent := item.deliver(q, rcpt) item.Lock() if err != nil { rcpt.LastFailureMessage = err.Error() if permanent { tr.Errorf("%s permanent error: %v", to, err) maillog.SendAttempt(item.ID, item.From, to, err, true) rcpt.Status = Recipient_FAILED } else { tr.Printf("%s temporary error: %v", to, err) maillog.SendAttempt(item.ID, item.From, to, err, false) } } else { tr.Printf("%s sent", to) maillog.SendAttempt(item.ID, item.From, to, nil, false) rcpt.Status = Recipient_SENT } item.Unlock() err = item.WriteTo(q.path) if err != nil { tr.Errorf("failed to write: %v", err) } } // deliver the item to the given recipient, using the couriers from the queue. // Return an error (if any), and whether it is permanent or not. func (item *Item) deliver(q *Queue, rcpt *Recipient) (err error, permanent bool) { if rcpt.Type == Recipient_PIPE { deliverAttempts.Add("pipe", 1) c := strings.Fields(rcpt.Address) if len(c) == 0 { return fmt.Errorf("empty pipe"), true } ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() cmd := exec.CommandContext(ctx, c[0], c[1:]...) cmd.Stdin = bytes.NewReader(item.Data) return cmd.Run(), true } // Recipient type is EMAIL. if envelope.DomainIn(rcpt.Address, q.localDomains) { deliverAttempts.Add("email:local", 1) return q.localC.Deliver(item.From, rcpt.Address, item.Data) } deliverAttempts.Add("email:remote", 1) from := item.From if !envelope.DomainIn(item.From, q.localDomains) { // We're sending from a non-local to a non-local. This should // happen only when there's an alias to forward email to a // non-local domain. In this case, using the original From is // problematic, as we may not be an authorized sender for this. // Some MTAs (like Exim) will do it anyway, others (like // gmail) will construct a special address based on the // original address. We go with the latter. // Note this assumes "+" is an alias suffix separator. // We use the IDNA version of the domain if possible, because // we can't know if the other side will support SMTPUTF8. from = fmt.Sprintf("%s+fwd_from=%s@%s", envelope.UserOf(rcpt.OriginalAddress), strings.Replace(from, "@", "=", -1), mustIDNAToASCII(envelope.DomainOf(rcpt.OriginalAddress))) } return q.remoteC.Deliver(from, rcpt.Address, item.Data) } // countRcpt counts how many recipients are in the given status. func (item *Item) countRcpt(statuses ...Recipient_Status) int { c := 0 for _, rcpt := range item.Rcpt { for _, status := range statuses { if rcpt.Status == status { c++ break } } } return c } func sendDSN(tr *trace.Trace, q *Queue, item *Item) { tr.Debugf("sending DSN") // Pick a (local) domain to send the DSN from. We should always find one, // as otherwise we're relaying. domain := "unknown" if item.From != "<>" && envelope.DomainIn(item.From, q.localDomains) { domain = envelope.DomainOf(item.From) } else { for _, rcpt := range item.Rcpt { if envelope.DomainIn(rcpt.OriginalAddress, q.localDomains) { domain = envelope.DomainOf(rcpt.OriginalAddress) break } } } msg, err := deliveryStatusNotification(domain, item) if err != nil { tr.Errorf("failed to build DSN: %v", err) return } // TODO: DKIM signing. id, err := q.Put(tr, "<>", []string{item.From}, msg) if err != nil { tr.Errorf("failed to queue DSN: %v", err) return } tr.Printf("queued DSN: %s", id) dsnQueued.Add(1) } func nextDelay(createdAt time.Time) time.Duration { var delay time.Duration since := time.Since(createdAt) switch { case since < 1*time.Minute: delay = 1 * time.Minute case since < 5*time.Minute: delay = 5 * time.Minute case since < 10*time.Minute: delay = 10 * time.Minute default: delay = 20 * time.Minute } // Perturb the delay, to avoid all queued emails to be retried at the // exact same time after a restart. delay += time.Duration(rand.Intn(60)) * time.Second return delay } func mustIDNAToASCII(s string) string { a, err := idna.ToASCII(s) if err != nil { return a } return s } func timeFromProto(ts *Timestamp) time.Time { return time.Unix(ts.Seconds, int64(ts.Nanos)).UTC() } func timeToProto(t time.Time) *Timestamp { return &Timestamp{ Seconds: t.Unix(), Nanos: int32(t.Nanosecond()), } }
// Package safeio implements convenient I/O routines that provide additional // levels of safety in the presence of unexpected failures. package safeio import ( "os" "path" "syscall" ) // osFile is an interface to the methods of os.File that we need, so we can // simulate failures in tests. type osFile interface { Name() string Chmod(os.FileMode) error Chown(int, int) error Write([]byte) (int, error) Close() error } var createTemp func(dir, pattern string) (osFile, error) = func( dir, pattern string) (osFile, error) { return os.CreateTemp(dir, pattern) } // FileOp represents an operation on a file (passed by its name). type FileOp func(fname string) error // WriteFile writes data to a file named by filename, atomically. // // It's a wrapper to os.WriteFile, but provides atomicity (and increased // safety) by writing to a temporary file and renaming it at the end. // // Before the final rename, the given ops (if any) are called. They can be // used to manipulate the file before it is atomically renamed. // If any operation fails, the file is removed and the error is returned. // // Note this relies on same-directory Rename being atomic, which holds in most // reasonably modern filesystems. func WriteFile(filename string, data []byte, perm os.FileMode, ops ...FileOp) error { // Note we create the temporary file in the same directory, otherwise we // would have no expectation of Rename being atomic. // We make the file names start with "." so there's no confusion with the // originals. tmpf, err := createTemp(path.Dir(filename), "."+path.Base(filename)) if err != nil { return err } if err = tmpf.Chmod(perm); err != nil { tmpf.Close() os.Remove(tmpf.Name()) return err } if uid, gid := getOwner(filename); uid >= 0 { if err = tmpf.Chown(uid, gid); err != nil { tmpf.Close() os.Remove(tmpf.Name()) return err } } if _, err = tmpf.Write(data); err != nil { tmpf.Close() os.Remove(tmpf.Name()) return err } if err = tmpf.Close(); err != nil { os.Remove(tmpf.Name()) return err } for _, op := range ops { if err = op(tmpf.Name()); err != nil { os.Remove(tmpf.Name()) return err } } return os.Rename(tmpf.Name(), filename) } func getOwner(fname string) (uid, gid int) { uid = -1 gid = -1 stat, err := os.Stat(fname) if err == nil { if sysstat, ok := stat.Sys().(*syscall.Stat_t); ok { uid = int(sysstat.Uid) gid = int(sysstat.Gid) } } return }
// Package set implement sets for various types. Well, only string for now :) package set // String set. type String struct { m map[string]struct{} } // NewString returns a new string set, with the given values in it. func NewString(values ...string) *String { s := &String{} s.Add(values...) return s } // Add values to the string set. func (s *String) Add(values ...string) { if s.m == nil { s.m = map[string]struct{}{} } for _, v := range values { s.m[v] = struct{}{} } } // Has checks if the set has the given value. func (s *String) Has(value string) bool { // We explicitly allow s to be nil *in this function* to simplify callers' // code. Note that Add will not tolerate it, and will panic. if s == nil || s.m == nil { return false } _, ok := s.m[value] return ok }
// Package smtp implements the Simple Mail Transfer Protocol as defined in RFC // 5321. It extends net/smtp as follows: // // - Supports SMTPUTF8, via MailAndRcpt. // - Adds IsPermanent. package smtp import ( "bufio" "io" "net" "net/smtp" "net/textproto" "unicode" "blitiri.com.ar/go/chasquid/internal/envelope" "golang.org/x/net/idna" ) // A Client represents a client connection to an SMTP server. type Client struct { *smtp.Client } // NewClient uses the given connection to create a new Client. func NewClient(conn net.Conn, host string) (*Client, error) { c, err := smtp.NewClient(conn, host) if err != nil { return nil, err } // Wrap the textproto.Conn reader so we are not exposed to a memory // exhaustion DoS on very long replies from the server. // Limit to 2 MiB total (all replies through the lifetime of the client), // which should be plenty for our uses of SMTP. lr := &io.LimitedReader{R: c.Text.Reader.R, N: 2 * 1024 * 1024} c.Text.Reader.R = bufio.NewReader(lr) return &Client{c}, nil } // cmd sends a command and returns the response over the text connection. // Based on Go's method of the same name. func (c *Client) cmd(expectCode int, format string, args ...interface{}) (int, string, error) { id, err := c.Text.Cmd(format, args...) if err != nil { return 0, "", err } c.Text.StartResponse(id) defer c.Text.EndResponse(id) return c.Text.ReadResponse(expectCode) } // MailAndRcpt issues MAIL FROM and RCPT TO commands, in sequence. // It will check the addresses, decide if SMTPUTF8 is needed, and apply the // necessary transformations. func (c *Client) MailAndRcpt(from string, to string) error { from, fromNeeds, err := c.prepareForSMTPUTF8(from) if err != nil { return err } to, toNeeds, err := c.prepareForSMTPUTF8(to) if err != nil { return err } smtputf8Needed := fromNeeds || toNeeds cmdStr := "MAIL FROM:<%s>" if ok, _ := c.Extension("8BITMIME"); ok { cmdStr += " BODY=8BITMIME" } if smtputf8Needed { cmdStr += " SMTPUTF8" } _, _, err = c.cmd(250, cmdStr, from) if err != nil { return err } _, _, err = c.cmd(25, "RCPT TO:<%s>", to) return err } // prepareForSMTPUTF8 prepares the address for SMTPUTF8. // It returns: // - The address to use. It is based on addr, and possibly modified to make // it not need the extension, if the server does not support it. // - Whether the address needs the extension or not. // - An error if the address needs the extension, but the client does not // support it. func (c *Client) prepareForSMTPUTF8(addr string) (string, bool, error) { // ASCII address pass through. if isASCII(addr) { return addr, false, nil } // Non-ASCII address also pass through if the server supports the // extension. // Note there's a chance the server wants the domain in IDNA anyway, but // it could also require it to be UTF8. We assume that if it supports // SMTPUTF8 then it knows what its doing. if ok, _ := c.Extension("SMTPUTF8"); ok { return addr, true, nil } // Something is not ASCII, and the server does not support SMTPUTF8: // - If it's the local part, there's no way out and is required. // - If it's the domain, use IDNA. user, domain := envelope.Split(addr) if !isASCII(user) { return addr, true, &textproto.Error{Code: 599, Msg: "local part is not ASCII but server does not support SMTPUTF8"} } // If it's only the domain, convert to IDNA and move on. domain, err := idna.ToASCII(domain) if err != nil { // The domain is not IDNA compliant, which is odd. // Fail with a permanent error, not ideal but this should not // happen. return addr, true, &textproto.Error{ Code: 599, Msg: "non-ASCII domain is not IDNA safe"} } return user + "@" + domain, false, nil } // isASCII returns true if all the characters in s are ASCII, false otherwise. func isASCII(s string) bool { for _, c := range s { if c > unicode.MaxASCII { return false } } return true } // IsPermanent returns true if the error is permanent, and false otherwise. // If it can't tell, it returns false. func IsPermanent(err error) bool { terr, ok := err.(*textproto.Error) if !ok { return false } // Error codes 5yz are permanent. // https://tools.ietf.org/html/rfc5321#section-4.2.1 if terr.Code >= 500 && terr.Code < 600 { return true } return false }
package smtpsrv import ( "bufio" "bytes" "context" "crypto/tls" "flag" "fmt" "io" "math/rand" "net" "net/mail" "os" "os/exec" "strconv" "strings" "syscall" "time" "blitiri.com.ar/go/chasquid/internal/aliases" "blitiri.com.ar/go/chasquid/internal/auth" "blitiri.com.ar/go/chasquid/internal/dkim" "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/envelope" "blitiri.com.ar/go/chasquid/internal/expvarom" "blitiri.com.ar/go/chasquid/internal/haproxy" "blitiri.com.ar/go/chasquid/internal/maillog" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/queue" "blitiri.com.ar/go/chasquid/internal/set" "blitiri.com.ar/go/chasquid/internal/tlsconst" "blitiri.com.ar/go/chasquid/internal/trace" "blitiri.com.ar/go/spf" ) // Exported variables. var ( commandCount = expvarom.NewMap("chasquid/smtpIn/commandCount", "command", "count of SMTP commands received, by command") responseCodeCount = expvarom.NewMap("chasquid/smtpIn/responseCodeCount", "code", "response codes returned to SMTP commands") spfResultCount = expvarom.NewMap("chasquid/smtpIn/spfResultCount", "result", "SPF result count") loopsDetected = expvarom.NewInt("chasquid/smtpIn/loopsDetected", "count of loops detected") tlsCount = expvarom.NewMap("chasquid/smtpIn/tlsCount", "status", "count of TLS usage in incoming connections") slcResults = expvarom.NewMap("chasquid/smtpIn/securityLevelChecks", "result", "incoming security level check results") hookResults = expvarom.NewMap("chasquid/smtpIn/hookResults", "result", "count of hook invocations, by result") wrongProtoCount = expvarom.NewMap("chasquid/smtpIn/wrongProtoCount", "command", "count of commands for other protocols") dkimSigned = expvarom.NewInt("chasquid/smtpIn/dkimSigned", "count of successful DKIM signs") dkimSignErrors = expvarom.NewInt("chasquid/smtpIn/dkimSignErrors", "count of DKIM sign errors") dkimVerifyFound = expvarom.NewInt("chasquid/smtpIn/dkimVerifyFound", "count of messages with at least one DKIM signature") dkimVerifyNotFound = expvarom.NewInt("chasquid/smtpIn/dkimVerifyNotFound", "count of messages with no DKIM signatures") dkimVerifyValid = expvarom.NewInt("chasquid/smtpIn/dkimVerifyValid", "count of messages with at least one valid DKIM signature") dkimVerifyErrors = expvarom.NewInt("chasquid/smtpIn/dkimVerifyErrors", "count of DKIM verification errors") ) var ( maxReceivedHeaders = flag.Int("testing__max_received_headers", 50, "max Received headers, for loop detection; ONLY FOR TESTING") // Some go tests disable SPF, to avoid leaking DNS lookups. disableSPFForTesting = false ) // SocketMode represents the mode for a socket (listening or connection). // We keep them distinct, as policies can differ between them. type SocketMode struct { // Is this mode submission? IsSubmission bool // Is this mode TLS-wrapped? That means that we don't use STARTTLS, the // connection is directly established over TLS (like HTTPS). TLS bool } func (mode SocketMode) String() string { s := "SMTP" if mode.IsSubmission { s = "submission" } if mode.TLS { s += "+TLS" } return s } // Valid socket modes. var ( ModeSMTP = SocketMode{IsSubmission: false, TLS: false} ModeSubmission = SocketMode{IsSubmission: true, TLS: false} ModeSubmissionTLS = SocketMode{IsSubmission: true, TLS: true} ) // Conn represents an incoming SMTP connection. type Conn struct { // Main hostname, used for display only. hostname string // Maximum data size. maxDataSize int64 // Post-DATA hook location. postDataHook string // Connection information. conn net.Conn mode SocketMode tlsConnState *tls.ConnectionState remoteAddr net.Addr // Reader and text writer, so we can control limits. reader *bufio.Reader writer *bufio.Writer // Tracer to use. tr *trace.Trace // TLS configuration. tlsConfig *tls.Config // Domain given at HELO/EHLO. ehloDomain string // Envelope. mailFrom string rcptTo []string data []byte // SPF results. spfResult spf.Result spfError error // DKIM verification results. dkimVerifyResult *dkim.VerifyResult // Are we using TLS? onTLS bool // Have we used EHLO? isESMTP bool // Authenticator, aliases and local domains, taken from the server at // creation time. authr *auth.Authenticator localDomains *set.String aliasesR *aliases.Resolver dinfo *domaininfo.DB // Map of domain -> DKIM signers. Taken from the server at creation time. dkimSigners map[string][]*dkim.Signer // Have we successfully completed AUTH? completedAuth bool // Authenticated user and domain, empty if !completedAuth. authUser string authDomain string // When we should close this connection, no matter what. deadline time.Time // Queue where we put incoming mails. queue *queue.Queue // Time we wait for network operations. commandTimeout time.Duration // Enable HAProxy on incoming connections. haproxyEnabled bool } // Close the connection. func (c *Conn) Close() { c.conn.Close() } // Handle implements the main protocol loop (reading commands, sending // replies). func (c *Conn) Handle() { defer c.Close() c.tr = trace.New("SMTP.Conn", c.conn.RemoteAddr().String()) defer c.tr.Finish() c.tr.Debugf("Connected, mode: %s", c.mode) // Set the first deadline, which covers possibly the TLS handshake and // then our initial greeting. c.conn.SetDeadline(time.Now().Add(c.commandTimeout)) if tc, ok := c.conn.(*tls.Conn); ok { // For TLS connections, complete the handshake and get the state, so // it can be used when we say hello below. err := tc.Handshake() if err != nil { c.tr.Errorf("error completing TLS handshake: %v", err) return } cstate := tc.ConnectionState() c.tlsConnState = &cstate if name := c.tlsConnState.ServerName; name != "" { c.hostname = name } } // Set up a buffered reader and writer from the conn. // They will be used to do line-oriented, limited I/O. c.reader = bufio.NewReader(c.conn) c.writer = bufio.NewWriter(c.conn) c.remoteAddr = c.conn.RemoteAddr() if c.haproxyEnabled { src, dst, err := haproxy.Handshake(c.reader) if err != nil { c.tr.Errorf("error in haproxy handshake: %v", err) return } c.remoteAddr = src c.tr.Debugf("haproxy handshake: %v -> %v", src, dst) } c.printfLine("220 %s ESMTP chasquid", c.hostname) var cmd, params string var err error var errCount int loop: for { if time.Since(c.deadline) > 0 { err = fmt.Errorf("connection deadline exceeded") c.tr.Error(err) break } c.conn.SetDeadline(time.Now().Add(c.commandTimeout)) cmd, params, err = c.readCommand() if err != nil { c.printfLine("554 error reading command: %v", err) break } if cmd == "AUTH" { c.tr.Debugf("-> AUTH <redacted>") } else { c.tr.Debugf("-> %s %s", cmd, params) } var code int var msg string switch cmd { case "HELO": code, msg = c.HELO(params) case "EHLO": code, msg = c.EHLO(params) case "HELP": code, msg = c.HELP(params) case "NOOP": code, msg = c.NOOP(params) case "RSET": code, msg = c.RSET(params) case "VRFY": code, msg = c.VRFY(params) case "EXPN": code, msg = c.EXPN(params) case "MAIL": code, msg = c.MAIL(params) case "RCPT": code, msg = c.RCPT(params) case "DATA": // DATA handles the whole sequence. code, msg = c.DATA(params) case "STARTTLS": code, msg = c.STARTTLS(params) case "AUTH": code, msg = c.AUTH(params) case "QUIT": _ = c.writeResponse(221, "2.0.0 Be seeing you...") break loop case "GET", "POST", "CONNECT": // HTTP protocol detection, to prevent cross-protocol attacks // (e.g. https://alpaca-attack.com/). wrongProtoCount.Add(cmd, 1) c.tr.Errorf("http command, closing connection") _ = c.writeResponse(502, "5.7.0 You hear someone cursing shoplifters") break loop default: // Sanitize it a bit to avoid filling the logs and events with // noisy data. Keep the first 6 bytes for debugging. cmd = fmt.Sprintf("unknown<%.6q>", cmd) code = 500 msg = "5.5.1 Unknown command" } commandCount.Add(cmd, 1) if code > 0 { c.tr.Debugf("<- %d %s", code, msg) if code >= 400 { // Be verbose about errors, to help troubleshooting. c.tr.Errorf("%s failed: %d %s", cmd, code, msg) // Close the connection after 3 errors. // This helps prevent cross-protocol attacks. errCount++ if errCount >= 3 { // https://tools.ietf.org/html/rfc5321#section-4.3.2 c.tr.Errorf("too many errors, breaking connection") _ = c.writeResponse(421, "4.5.0 Too many errors, bye") break } } err = c.writeResponse(code, msg) if err != nil { break } } else if code < 0 { // Negative code means that we have to break the connection. // TODO: This is hacky, it's probably worth it at this point to // refactor this into using a custom response type. c.tr.Errorf("%s closed the connection: %s", cmd, msg) break } } if err != nil { if err == io.EOF { c.tr.Debugf("client closed the connection") } else { c.tr.Errorf("exiting with error: %v", err) } } } // HELO SMTP command handler. func (c *Conn) HELO(params string) (code int, msg string) { if len(strings.TrimSpace(params)) == 0 { return 501, "Invisible customers are not welcome!" } c.ehloDomain = strings.Fields(params)[0] types := []string{ "general store", "used armor dealership", "second-hand bookstore", "liquor emporium", "antique weapons outlet", "delicatessen", "jewelers", "quality apparel and accessories", "hardware", "rare books", "lighting store"} t := types[rand.Int()%len(types)] msg = fmt.Sprintf("Hello my friend, welcome to chasqui's %s!", t) return 250, msg } // EHLO SMTP command handler. func (c *Conn) EHLO(params string) (code int, msg string) { if len(strings.TrimSpace(params)) == 0 { return 501, "Invisible customers are not welcome!" } c.ehloDomain = strings.Fields(params)[0] c.isESMTP = true buf := bytes.NewBuffer(nil) fmt.Fprintf(buf, c.hostname+" - Your hour of destiny has come.\n") fmt.Fprintf(buf, "8BITMIME\n") fmt.Fprintf(buf, "PIPELINING\n") fmt.Fprintf(buf, "SMTPUTF8\n") fmt.Fprintf(buf, "ENHANCEDSTATUSCODES\n") fmt.Fprintf(buf, "SIZE %d\n", c.maxDataSize) if c.onTLS { fmt.Fprintf(buf, "AUTH PLAIN\n") } else { fmt.Fprintf(buf, "STARTTLS\n") } fmt.Fprintf(buf, "HELP\n") return 250, buf.String() } // HELP SMTP command handler. func (c *Conn) HELP(params string) (code int, msg string) { return 214, "2.0.0 Hoy por ti, mañana por mi" } // RSET SMTP command handler. func (c *Conn) RSET(params string) (code int, msg string) { c.resetEnvelope() msgs := []string{ "Who was that Maud person anyway?", "Thinking of Maud you forget everything else.", "Your mind releases itself from mundane concerns.", "As your mind turns inward on itself, you forget everything else.", } return 250, "2.0.0 " + msgs[rand.Int()%len(msgs)] } // VRFY SMTP command handler. func (c *Conn) VRFY(params string) (code int, msg string) { // We intentionally don't implement this command. return 502, "5.5.1 You have a strange feeling for a moment, then it passes." } // EXPN SMTP command handler. func (c *Conn) EXPN(params string) (code int, msg string) { // We intentionally don't implement this command. return 502, "5.5.1 You feel disoriented for a moment." } // NOOP SMTP command handler. func (c *Conn) NOOP(params string) (code int, msg string) { return 250, "2.0.0 You hear a faint typing noise." } // MAIL SMTP command handler. func (c *Conn) MAIL(params string) (code int, msg string) { // params should be: "FROM:<name@host>", and possibly followed by // options such as "BODY=8BITMIME" (which we ignore). // Check that it begins with "FROM:" first, it's mandatory. if !strings.HasPrefix(strings.ToLower(params), "from:") { return 500, "5.5.2 Unknown command" } if c.mode.IsSubmission && !c.completedAuth { return 550, "5.7.9 Mail to submission port must be authenticated" } rawAddr := "" _, err := fmt.Sscanf(params[5:], "%s ", &rawAddr) if err != nil { return 500, "5.5.4 Malformed command: " + err.Error() } // Note some servers check (and fail) if we had a previous MAIL command, // but that's not according to the RFC. We reset the envelope instead. c.resetEnvelope() // Special case a null reverse-path, which is explicitly allowed and used // for notification messages. // It should be written "<>", we check for that and remove spaces just to // be more flexible. addr := "" if strings.Replace(rawAddr, " ", "", -1) == "<>" { addr = "<>" } else { e, err := mail.ParseAddress(rawAddr) if err != nil || e.Address == "" { return 501, "5.1.7 Sender address malformed" } addr = e.Address if !strings.Contains(addr, "@") { return 501, "5.1.8 Sender address must contain a domain" } // https://tools.ietf.org/html/rfc5321#section-4.5.3.1.3 if len(addr) > 256 { return 501, "5.1.7 Sender address too long" } // SPF check - https://tools.ietf.org/html/rfc7208#section-2.4 // We opt not to fail on errors, to avoid accidents from preventing // delivery. c.spfResult, c.spfError = c.checkSPF(addr) if c.spfResult == spf.Fail { // https://tools.ietf.org/html/rfc7208#section-8.4 maillog.Rejected(c.remoteAddr, addr, nil, fmt.Sprintf("failed SPF: %v", c.spfError)) return 550, fmt.Sprintf( "5.7.23 SPF check failed: %v", c.spfError) } if !c.secLevelCheck(addr) { maillog.Rejected(c.remoteAddr, addr, nil, "security level check failed") return 550, "5.7.3 Security level check failed" } addr, err = normalize.DomainToUnicode(addr) if err != nil { maillog.Rejected(c.remoteAddr, addr, nil, fmt.Sprintf("malformed address: %v", err)) return 501, "5.1.8 Malformed sender domain (IDNA conversion failed)" } } c.mailFrom = addr return 250, "2.1.5 You feel like you are being watched" } // checkSPF for the given address, based on the current connection. func (c *Conn) checkSPF(addr string) (spf.Result, error) { // Does not apply to authenticated connections, they're allowed regardless. if c.completedAuth { return "", nil } if disableSPFForTesting { return "", nil } if tcp, ok := c.remoteAddr.(*net.TCPAddr); ok { spfTr := c.tr.NewChild("SPF", tcp.IP.String()) defer spfTr.Finish() res, err := spf.CheckHostWithSender( tcp.IP, envelope.DomainOf(addr), addr, spf.WithTraceFunc(func(f string, a ...interface{}) { spfTr.Debugf(f, a...) })) c.tr.Debugf("SPF %v (%v)", res, err) spfResultCount.Add(string(res), 1) return res, err } return "", nil } // secLevelCheck checks if the security level is acceptable for the given // address. func (c *Conn) secLevelCheck(addr string) bool { // Only check if SPF passes. This serves two purposes: // - Skip for authenticated connections (we trust them implicitly). // - Don't apply this if we can't be sure the sender is authorized. // Otherwise anyone could raise the level of any domain. if c.spfResult != spf.Pass { slcResults.Add("skip", 1) c.tr.Debugf("SPF did not pass, skipping security level check") return true } domain := envelope.DomainOf(addr) level := domaininfo.SecLevel_PLAIN if c.onTLS { level = domaininfo.SecLevel_TLS_CLIENT } ok := c.dinfo.IncomingSecLevel(c.tr, domain, level) if ok { slcResults.Add("pass", 1) c.tr.Debugf("security level check for %s passed (%s)", domain, level) } else { slcResults.Add("fail", 1) c.tr.Errorf("security level check for %s failed (%s)", domain, level) } return ok } // RCPT SMTP command handler. func (c *Conn) RCPT(params string) (code int, msg string) { // params should be: "TO:<name@host>", and possibly followed by options // such as "NOTIFY=SUCCESS,DELAY" (which we ignore). // Check that it begins with "TO:" first, it's mandatory. if !strings.HasPrefix(strings.ToLower(params), "to:") { return 500, "5.5.2 Unknown command" } if c.mailFrom == "" { return 503, "5.5.1 Sender not yet given" } rawAddr := "" _, err := fmt.Sscanf(params[3:], "%s ", &rawAddr) if err != nil { return 500, "5.5.4 Malformed command: " + err.Error() } // RFC says 100 is the minimum limit for this, but it seems excessive. // https://tools.ietf.org/html/rfc5321#section-4.5.3.1.8 if len(c.rcptTo) > 100 { return 452, "4.5.3 Too many recipients" } e, err := mail.ParseAddress(rawAddr) if err != nil || e.Address == "" { return 501, "5.1.3 Malformed destination address" } addr, err := normalize.DomainToUnicode(e.Address) if err != nil { return 501, "5.1.2 Malformed destination domain (IDNA conversion failed)" } // https://tools.ietf.org/html/rfc5321#section-4.5.3.1.3 if len(addr) > 256 { return 501, "5.1.3 Destination address too long" } localDst := envelope.DomainIn(addr, c.localDomains) if !localDst && !c.completedAuth { maillog.Rejected(c.remoteAddr, c.mailFrom, []string{addr}, "relay not allowed") return 503, "5.7.1 Relay not allowed" } if localDst { addr, err = normalize.Addr(addr) if err != nil { maillog.Rejected(c.remoteAddr, c.mailFrom, []string{addr}, fmt.Sprintf("invalid address: %v", err)) return 550, "5.1.3 Destination address is invalid" } ok, err := c.localUserExists(addr) if err != nil { c.tr.Errorf("error checking if user %q exists: %v", addr, err) maillog.Rejected(c.remoteAddr, c.mailFrom, []string{addr}, fmt.Sprintf("error checking if user exists: %v", err)) return 451, "4.4.3 Temporary error checking address" } if !ok { maillog.Rejected(c.remoteAddr, c.mailFrom, []string{addr}, "local user does not exist") return 550, "5.1.1 Destination address is unknown (user does not exist)" } } c.rcptTo = append(c.rcptTo, addr) return 250, "2.1.5 You have an eerie feeling..." } // DATA SMTP command handler. func (c *Conn) DATA(params string) (code int, msg string) { if c.ehloDomain == "" { return 503, "5.5.1 Invisible customers are not welcome!" } if c.mailFrom == "" { return 503, "5.5.1 Sender not yet given" } if len(c.rcptTo) == 0 { return 503, "5.5.1 Need an address to send to" } // We're going ahead. err := c.writeResponse(354, "You suddenly realize it is unnaturally quiet") if err != nil { return 554, fmt.Sprintf("5.4.0 Error writing DATA response: %v", err) } c.tr.Debugf("<- 354 You experience a strange sense of peace") if c.onTLS { tlsCount.Add("tls", 1) } else { tlsCount.Add("plain", 1) } // Increase the deadline for the data transfer to the connection-level // one, we don't want the command timeout to interfere. c.conn.SetDeadline(c.deadline) // Read the data. Enforce CRLF correctness, and maximum size. c.data, err = readUntilDot(c.reader, c.maxDataSize) if err != nil { if err == errMessageTooLarge { // Message is too big; excess data has already been discarded. return 552, "5.3.4 Message too big" } if err == errInvalidLineEnding { // We can't properly recover from this, so we have to drop the // connection. c.writeResponse(521, "5.5.2 Error reading DATA: invalid line ending") return -1, "Invalid line ending, closing connection" } return 554, fmt.Sprintf("5.4.0 Error reading DATA: %v", err) } c.tr.Debugf("-> ... %d bytes of data", len(c.data)) if err := checkData(c.data); err != nil { maillog.Rejected(c.remoteAddr, c.mailFrom, c.rcptTo, err.Error()) return 554, err.Error() } if c.completedAuth { err = c.dkimSign() if err != nil { // If we failed to sign, then reject to prevent sending unsigned // messages. Treat the failure as temporary. c.tr.Errorf("DKIM failed: %v", err) return 451, "4.3.0 DKIM signing failed" } } else { c.dkimVerify() } c.addReceivedHeader() hookOut, permanent, err := c.runPostDataHook(c.data) if err != nil { maillog.Rejected(c.remoteAddr, c.mailFrom, c.rcptTo, err.Error()) if permanent { return 554, err.Error() } return 451, err.Error() } c.data = append(hookOut, c.data...) // There are no partial failures here: we put it in the queue, and then if // individual deliveries fail, we report via email. // If we fail to queue, return a transient error. msgID, err := c.queue.Put(c.tr, c.mailFrom, c.rcptTo, c.data) if err != nil { return 451, fmt.Sprintf("4.3.0 Failed to queue message: %v", err) } c.tr.Printf("Queued from %s to %s - %s", c.mailFrom, c.rcptTo, msgID) maillog.Queued(c.remoteAddr, c.mailFrom, c.rcptTo, msgID) // It is very important that we reset the envelope before returning, // so clients can send other emails right away without needing to RSET. c.resetEnvelope() msgs := []string{ "You offer the Amulet of Yendor to Anhur...", "An invisible choir sings, and you are bathed in radiance...", "The voice of Anhur booms out: Congratulations, mortal!", "In return to thy service, I grant thee the gift of Immortality!", "You ascend to the status of Demigod(dess)...", } return 250, "2.0.0 " + msgs[rand.Int()%len(msgs)] } func (c *Conn) addReceivedHeader() { var received string // Format is semi-structured, defined by // https://tools.ietf.org/html/rfc5321#section-4.4 if c.completedAuth { // For authenticated users, only show the EHLO domain they gave; // explicitly hide their network address. received += fmt.Sprintf("from %s\n", c.ehloDomain) } else { // For non-authenticated users we show the real address as canonical, // and then the given EHLO domain for convenience and // troubleshooting. received += fmt.Sprintf("from [%s] (%s)\n", addrLiteral(c.remoteAddr), c.ehloDomain) } received += fmt.Sprintf("by %s (chasquid) ", c.hostname) // https://www.iana.org/assignments/mail-parameters/mail-parameters.xhtml#mail-parameters-7 with := "SMTP" if c.isESMTP { with = "ESMTP" } if c.onTLS { with += "S" } if c.completedAuth { with += "A" } received += fmt.Sprintf("with %s\n", with) if c.tlsConnState != nil { // https://tools.ietf.org/html/rfc8314#section-4.3 received += fmt.Sprintf("tls %s\n", tlsconst.CipherSuiteName(c.tlsConnState.CipherSuite)) } received += fmt.Sprintf("(over %s, ", c.mode) if c.tlsConnState != nil { received += fmt.Sprintf("%s, ", tlsconst.VersionName(c.tlsConnState.Version)) } else { received += "plain text!, " } // Note we must NOT include c.rcptTo, that would leak BCCs. received += fmt.Sprintf("envelope from %q)\n", c.mailFrom) // This should be the last part in the Received header, by RFC. // The ";" is a mandatory separator. The date format is not standard but // this one seems to be widely used. // https://tools.ietf.org/html/rfc5322#section-3.6.7 received += fmt.Sprintf("; %s\n", time.Now().Format(time.RFC1123Z)) c.data = envelope.AddHeader(c.data, "Received", received) // Add Authentication-Results header too, but only if there's anything to // report. We add it above the Received header, so it can easily be // associated and traced to it, even though it is not a hard requirement. // Note we include results even if they're "none" or "neutral", as that // allows MUAs to know that the message was checked. arHdr := c.hostname + "\r\n" includeAR := false if c.spfResult != "" { // https://tools.ietf.org/html/rfc7208#section-9.1 received = fmt.Sprintf("%s (%v)", c.spfResult, c.spfError) c.data = envelope.AddHeader(c.data, "Received-SPF", received) // https://datatracker.ietf.org/doc/html/rfc8601#section-2.7.2 arHdr += fmt.Sprintf(";spf=%s (%v)\r\n", c.spfResult, c.spfError) includeAR = true } if c.dkimVerifyResult != nil { // https://datatracker.ietf.org/doc/html/rfc8601#section-2.7.1 arHdr += c.dkimVerifyResult.AuthenticationResults() + "\r\n" includeAR = true } if includeAR { // Only include the Authentication-Results header if we have something // to report. c.data = envelope.AddHeader(c.data, "Authentication-Results", strings.TrimSpace(arHdr)) } } // addrLiteral converts a net.Addr (must be TCP) into a string for use as // address literal, compliant with // https://tools.ietf.org/html/rfc5321#section-4.1.3. func addrLiteral(addr net.Addr) string { tcp, ok := addr.(*net.TCPAddr) if !ok { // Fall back to Go's string representation; non-compliant but // better than anything for our purposes. return addr.String() } // IPv6 addresses take the "IPv6:" prefix. // IPv4 addresses are used literally. s := tcp.IP.String() if strings.Contains(s, ":") { return "IPv6:" + s } return s } // checkData performs very basic checks on the body of the email, to help // detect very broad problems like email loops. It does not fully check the // sanity of the headers or the structure of the payload. func checkData(data []byte) error { msg, err := mail.ReadMessage(bytes.NewBuffer(data)) if err != nil { return fmt.Errorf("5.6.0 Error parsing message: %v", err) } // This serves as a basic form of loop prevention. It's not infallible but // should catch most instances of accidental looping. // https://tools.ietf.org/html/rfc5321#section-6.3 if len(msg.Header["Received"]) > *maxReceivedHeaders { loopsDetected.Add(1) return fmt.Errorf("5.4.6 Loop detected (%d hops)", *maxReceivedHeaders) } return nil } // Sanitize HELO/EHLO domain. // RFC is extremely flexible with EHLO domain values, allowing all printable // ASCII characters. They can be tricky to use in shell scripts (commonly used // as post-data hooks), so this function sanitizes the value to make it // shell-safe. func sanitizeEHLODomain(s string) string { n := "" for _, c := range s { // Allow a-zA-Z0-9 and []-.: // That's enough for all domains, IPv4 and IPv6 literals, and also // shell-safe. // Non-ASCII are forbidden as EHLO domains per RFC. switch { case c >= 'a' && c <= 'z', c >= 'A' && c <= 'Z', c >= '0' && c <= '9', c == '-', c == '.', c == '[', c == ']', c == ':': n += string(c) } } return n } // runPostDataHook and return the new headers to add, and on error a boolean // indicating if it's permanent, and the error itself. func (c *Conn) runPostDataHook(data []byte) ([]byte, bool, error) { // TODO: check if the file is executable. if _, err := os.Stat(c.postDataHook); os.IsNotExist(err) { hookResults.Add("post-data:skip", 1) return nil, false, nil } tr := trace.New("Hook.Post-DATA", c.remoteAddr.String()) defer tr.Finish() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) defer cancel() cmd := exec.CommandContext(ctx, c.postDataHook) cmd.Stdin = bytes.NewReader(data) // Prepare the environment, copying some common variables so the hook has // something reasonable, and then setting the specific ones for this case. for _, v := range strings.Fields("USER PWD SHELL PATH") { cmd.Env = append(cmd.Env, v+"="+os.Getenv(v)) } cmd.Env = append(cmd.Env, "REMOTE_ADDR="+c.remoteAddr.String()) cmd.Env = append(cmd.Env, "EHLO_DOMAIN="+sanitizeEHLODomain(c.ehloDomain)) cmd.Env = append(cmd.Env, "EHLO_DOMAIN_RAW="+c.ehloDomain) cmd.Env = append(cmd.Env, "MAIL_FROM="+c.mailFrom) cmd.Env = append(cmd.Env, "RCPT_TO="+strings.Join(c.rcptTo, " ")) if c.completedAuth { cmd.Env = append(cmd.Env, "AUTH_AS="+c.authUser+"@"+c.authDomain) } else { cmd.Env = append(cmd.Env, "AUTH_AS=") } cmd.Env = append(cmd.Env, "ON_TLS="+boolToStr(c.onTLS)) cmd.Env = append(cmd.Env, "FROM_LOCAL_DOMAIN="+boolToStr( envelope.DomainIn(c.mailFrom, c.localDomains))) cmd.Env = append(cmd.Env, "SPF_PASS="+boolToStr(c.spfResult == spf.Pass)) out, err := cmd.Output() tr.Debugf("stdout: %q", out) if err != nil { hookResults.Add("post-data:fail", 1) tr.Error(err) permanent := false if ee, ok := err.(*exec.ExitError); ok { tr.Printf("stderr: %q", string(ee.Stderr)) if status, ok := ee.Sys().(syscall.WaitStatus); ok { permanent = status.ExitStatus() == 20 } } // The error contains the last line of stdout, so filters can pass // some rejection information back to the sender. err = fmt.Errorf(lastLine(string(out))) return nil, permanent, err } // Check that output looks like headers, to avoid breaking the email // contents. If it does not, just skip it. if !isHeader(out) { hookResults.Add("post-data:badoutput", 1) tr.Errorf("error parsing post-data output: %q", out) return nil, false, nil } tr.Debugf("success") hookResults.Add("post-data:success", 1) return out, false, nil } // isHeader checks if the given buffer is a valid MIME header. func isHeader(b []byte) bool { s := string(b) if len(s) == 0 { return true } // If it is just a \n, or contains two \n, then it's not a header. if s == "\n" || strings.Contains(s, "\n\n") { return false } // If it does not end in \n, not a header. if s[len(s)-1] != '\n' { return false } // Each line must either start with a space or have a ':'. seen := false for _, line := range strings.SplitAfter(s, "\n") { if line == "" { continue } if strings.HasPrefix(line, " ") || strings.HasPrefix(line, "\t") { if !seen { // Continuation without a header first (invalid). return false } continue } if !strings.Contains(line, ":") { return false } seen = true } return true } func lastLine(s string) string { l := strings.Split(s, "\n") if len(l) < 2 { return "" } return l[len(l)-2] } func boolToStr(b bool) string { if b { return "1" } return "0" } func (c *Conn) dkimSign() error { // We only sign if the user authenticated. However, the authenticated user // and the MAIL FROM address may be different; even the domain may be // different. // We explicitly let this happen and trust authenticated users. // So for DKIM signing purposes, we use the MAIL FROM domain: this // prevents leaking the authenticated user's domain, and is more in line // with expectations around signatures. domain := envelope.DomainOf(c.mailFrom) signers := c.dkimSigners[domain] if len(signers) == 0 { return nil } tr := c.tr.NewChild("DKIM.Sign", domain) defer tr.Finish() ctx := context.Background() ctx = dkim.WithTraceFunc(ctx, tr.Debugf) for _, signer := range signers { sig, err := signer.Sign(ctx, normalize.StringToCRLF(string(c.data))) if err != nil { dkimSignErrors.Add(1) return err } // The signature is returned with \r\n; however, our internal // representation uses \n, so normalize it. sig = strings.ReplaceAll(sig, "\r\n", "\n") c.data = envelope.AddHeader(c.data, "DKIM-Signature", sig) } dkimSigned.Add(1) return nil } func (c *Conn) dkimVerify() { tr := c.tr.NewChild("DKIM.Verify", c.mailFrom) defer tr.Finish() var err error ctx := context.Background() ctx = dkim.WithTraceFunc(ctx, tr.Debugf) c.dkimVerifyResult, err = dkim.VerifyMessage( ctx, string(normalize.ToCRLF(c.data))) if err != nil { // The only error we expect is because of a malformed mail, which is // checked before this is invoked. tr.Errorf("Error verifying DKIM: %v", err) dkimVerifyErrors.Add(1) } if c.dkimVerifyResult != nil { if c.dkimVerifyResult.Found > 0 { dkimVerifyFound.Add(1) } else { dkimVerifyNotFound.Add(1) } if c.dkimVerifyResult.Valid > 0 { dkimVerifyValid.Add(1) } } // Note we don't fail emails because they failed to verify, in line // with RFC recommendations. // DMARC policies may cause it to fail at some point, but that is not // implemented yet, and would happen separately. // The results will get included in the Authentication-Results header, see // addReceivedHeader for more details. } // STARTTLS SMTP command handler. func (c *Conn) STARTTLS(params string) (code int, msg string) { if c.onTLS { return 503, "5.5.1 You are already wearing that!" } err := c.writeResponse(220, "2.0.0 You experience a strange sense of peace") if err != nil { return 554, fmt.Sprintf("5.4.0 Error writing STARTTLS response: %v", err) } c.tr.Debugf("<- 220 You experience a strange sense of peace") server := tls.Server(c.conn, c.tlsConfig) err = server.Handshake() if err != nil { return 554, fmt.Sprintf("5.5.0 Error in TLS handshake: %v", err) } c.tr.Debugf("<> ... jump to TLS was successful") // Override the connection. We don't need the older one anymore. c.conn = server c.reader = bufio.NewReader(c.conn) c.writer = bufio.NewWriter(c.conn) // Take the connection state, so we can use it later for logging and // tracing purposes. cstate := server.ConnectionState() c.tlsConnState = &cstate // Reset the envelope; clients must start over after switching to TLS. c.resetEnvelope() c.onTLS = true // If the client requested a specific server and we complied, that's our // identity from now on. if name := c.tlsConnState.ServerName; name != "" { c.hostname = name } // 0 indicates not to send back a reply. return 0, "" } // AUTH SMTP command handler. func (c *Conn) AUTH(params string) (code int, msg string) { if !c.onTLS { return 503, "5.7.10 You feel vulnerable" } if c.completedAuth { // After a successful AUTH command completes, a server MUST reject // any further AUTH commands with a 503 reply. // https://tools.ietf.org/html/rfc4954#section-4 return 503, "5.5.1 You are already wearing that!" } // We only support PLAIN for now, so no need to make this too complicated. // Params should be either "PLAIN" or "PLAIN <response>". // If the response is not there, we reply with 334, and expect the // response back from the client in the next message. sp := strings.SplitN(params, " ", 2) if len(sp) < 1 || sp[0] != "PLAIN" { // As we only offer plain, this should not really happen. return 534, "5.7.9 Asmodeus demands 534 zorkmids for safe passage" } // Note we use more "serious" error messages from now own, as these may // find their way to the users in some circumstances. // Get the response, either from the message or interactively. response := "" if len(sp) == 2 { response = sp[1] } else { // Reply 334 and expect the user to provide it. // In this case, the text IS relevant, as it is taken as the // server-side SASL challenge (empty for PLAIN). // https://tools.ietf.org/html/rfc4954#section-4 err := c.writeResponse(334, "") if err != nil { return 554, fmt.Sprintf("5.4.0 Error writing AUTH 334: %v", err) } response, err = c.readLine() if err != nil { return 554, fmt.Sprintf("5.4.0 Error reading AUTH response: %v", err) } } user, domain, passwd, err := auth.DecodeResponse(response) if err != nil { // https://tools.ietf.org/html/rfc4954#section-4 return 501, fmt.Sprintf("5.5.2 Error decoding AUTH response: %v", err) } // https://tools.ietf.org/html/rfc4954#section-6 authOk, err := c.authr.Authenticate(c.tr, user, domain, passwd) if err != nil { c.tr.Errorf("error authenticating %q@%q: %v", user, domain, err) maillog.Auth(c.remoteAddr, user+"@"+domain, false) return 454, "4.7.0 Temporary authentication failure" } if authOk { c.authUser = user c.authDomain = domain c.completedAuth = true maillog.Auth(c.remoteAddr, user+"@"+domain, true) return 235, "2.7.0 Authentication successful" } maillog.Auth(c.remoteAddr, user+"@"+domain, false) return 535, "5.7.8 Incorrect user or password" } func (c *Conn) resetEnvelope() { c.mailFrom = "" c.rcptTo = nil c.data = nil c.spfResult = "" c.spfError = nil } func (c *Conn) localUserExists(addr string) (bool, error) { if c.aliasesR.Exists(c.tr, addr) { return true, nil } // Remove the drop chars and suffixes, if any, so the database lookup is // on a "clean" address. addr = c.aliasesR.RemoveDropsAndSuffix(addr) user, domain := envelope.Split(addr) return c.authr.Exists(c.tr, user, domain) } func (c *Conn) readCommand() (cmd, params string, err error) { msg, err := c.readLine() if err != nil { return "", "", err } sp := strings.SplitN(msg, " ", 2) cmd = strings.ToUpper(sp[0]) if len(sp) > 1 { params = sp[1] } return cmd, params, err } func (c *Conn) readLine() (line string, err error) { // The bufio reader's ReadLine will only read up to the buffer size, which // prevents DoS due to memory exhaustion on extremely long lines. l, more, err := c.reader.ReadLine() if err != nil { return "", err } // As per RFC, the maximum length of a text line is 1000 octets. // https://tools.ietf.org/html/rfc5321#section-4.5.3.1.6 if len(l) > 1000 || more { // Keep reading to maintain the protocol status, but discard the data. for more && err == nil { _, more, err = c.reader.ReadLine() } return "", fmt.Errorf("line too long") } return string(l), nil } func (c *Conn) writeResponse(code int, msg string) error { defer c.writer.Flush() responseCodeCount.Add(strconv.Itoa(code), 1) return writeResponse(c.writer, code, msg) } func (c *Conn) printfLine(format string, args ...interface{}) { fmt.Fprintf(c.writer, format+"\r\n", args...) c.writer.Flush() } // writeResponse writes a multi-line response to the given writer. // This is the writing version of textproto.Reader.ReadResponse(). func writeResponse(w io.Writer, code int, msg string) error { var i int lines := strings.Split(msg, "\n") // The first N-1 lines use "<code>-<text>". for i = 0; i < len(lines)-2; i++ { _, err := w.Write([]byte(fmt.Sprintf("%d-%s\r\n", code, lines[i]))) if err != nil { return err } } // The last line uses "<code> <text>". _, err := w.Write([]byte(fmt.Sprintf("%d %s\r\n", code, lines[i]))) if err != nil { return err } return nil }
package smtpsrv import ( "bufio" "bytes" "errors" "io" ) var ( // TODO: Include the line number and specific error, and have the // caller add them to the trace. errMessageTooLarge = errors.New("message too large") errInvalidLineEnding = errors.New("invalid line ending") ) // readUntilDot reads from r until it encounters a dot-terminated line, or we // read max bytes. It enforces that input lines are terminated by "\r\n", and // that there are not "lonely" "\r" or "\n"s in the input. // It returns \n-terminated lines, which is what we use for our internal // representation for convenience (same as textproto DotReader does). func readUntilDot(r *bufio.Reader, max int64) ([]byte, error) { buf := make([]byte, 0, 1024) n := int64(0) // Little state machine. const ( prevOther = iota prevCR prevCRLF ) // Start as if we just came from a '\r\n'; that way we avoid the need // for special-casing the dot-stuffing at the very beginning. prev := prevCRLF last4 := make([]byte, 4) skip := false loop: for { b, err := r.ReadByte() if err == io.EOF { return buf, io.ErrUnexpectedEOF } else if err != nil { return buf, err } n++ switch b { case '\r': if prev == prevCR { return buf, errInvalidLineEnding } prev = prevCR // We return a LF-terminated line, so skip the CR. This simplifies // internal representation and makes it easier/less error prone to // work with. It is converted back to CRLF on endpoints (e.g. in // the couriers). skip = true case '\n': if prev != prevCR { return buf, errInvalidLineEnding } // If we come from a '\r\n.\r', we're done. if bytes.Equal(last4, []byte("\r\n.\r")) { break loop } // If we are only starting and see ".\r\n", we're also done; in // that case the message is empty. if n == 3 && bytes.Equal(last4, []byte("\x00\x00.\r")) { return []byte{}, nil } prev = prevCRLF default: if prev == prevCR { return buf, errInvalidLineEnding } if b == '.' && prev == prevCRLF { // We come from "\r\n" and got a "."; as per dot-stuffing // rules, we should skip that '.' in the output. // https://www.rfc-editor.org/rfc/rfc5321#section-4.5.2 skip = true } prev = prevOther } // Keep the last 4 bytes separately, because they may not be in buf on // messages that are too large. copy(last4, last4[1:]) last4[3] = b if int64(len(buf)) < max && !skip { buf = append(buf, b) } skip = false } // Return an error if the message is too large. It is important to do this // _outside_ the loop, because we need to keep reading until we get to the // final "." before we return an error, so the SMTP dialog can continue // properly after that. // If we return too early, the remainder of the email is interpreted as // part of the SMTP dialog (and exposing ourselves to smuggling attacks). if n > max { return buf, errMessageTooLarge } // If we made it this far, buf naturally ends in "\n" because we skipped // the '.' due to dot-stuffing, and skip "\r"s. return buf, nil }
// Package smtpsrv implements chasquid's SMTP server and connection handler. package smtpsrv import ( "crypto" "crypto/ed25519" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/pem" "flag" "fmt" "net" "net/http" "net/url" "os" "path" "strings" "time" "blitiri.com.ar/go/chasquid/internal/aliases" "blitiri.com.ar/go/chasquid/internal/auth" "blitiri.com.ar/go/chasquid/internal/courier" "blitiri.com.ar/go/chasquid/internal/dkim" "blitiri.com.ar/go/chasquid/internal/domaininfo" "blitiri.com.ar/go/chasquid/internal/localrpc" "blitiri.com.ar/go/chasquid/internal/maillog" "blitiri.com.ar/go/chasquid/internal/queue" "blitiri.com.ar/go/chasquid/internal/set" "blitiri.com.ar/go/chasquid/internal/trace" "blitiri.com.ar/go/chasquid/internal/userdb" "blitiri.com.ar/go/log" ) var ( // Reload frequency. // We should consider making this a proper option if there's interest in // changing it, but until then, it's a test-only flag for simplicity. reloadEvery = flag.Duration("testing__reload_every", 30*time.Second, "how often to reload, ONLY FOR TESTING") ) // Server represents an SMTP server instance. type Server struct { // Main hostname, used for display only. Hostname string // Maximum data size. MaxDataSize int64 // Addresses. addrs map[SocketMode][]string // Listeners (that came via systemd). listeners map[SocketMode][]net.Listener // TLS config (including loaded certificates). tlsConfig *tls.Config // Use HAProxy on incoming connections. HAProxyEnabled bool // Local domains. localDomains *set.String // User databases (per domain). // Authenticator. authr *auth.Authenticator // Aliases resolver. aliasesR *aliases.Resolver // Domain info database. dinfo *domaininfo.DB // Map of domain -> DKIM signers. dkimSigners map[string][]*dkim.Signer // Time before we give up on a connection, even if it's sending data. connTimeout time.Duration // Time we wait for command round-trips (excluding DATA). commandTimeout time.Duration // Queue where we put incoming mail. queue *queue.Queue // Path to the hooks. HookPath string } // NewServer returns a new empty Server. func NewServer() *Server { authr := auth.NewAuthenticator() aliasesR := aliases.NewResolver(authr.Exists) return &Server{ addrs: map[SocketMode][]string{}, listeners: map[SocketMode][]net.Listener{}, tlsConfig: &tls.Config{}, connTimeout: 20 * time.Minute, commandTimeout: 1 * time.Minute, localDomains: &set.String{}, authr: authr, aliasesR: aliasesR, dkimSigners: map[string][]*dkim.Signer{}, } } // AddCerts (TLS) to the server. func (s *Server) AddCerts(certPath, keyPath string) error { cert, err := tls.LoadX509KeyPair(certPath, keyPath) if err != nil { return err } s.tlsConfig.Certificates = append(s.tlsConfig.Certificates, cert) return nil } // AddAddr adds an address for the server to listen on. func (s *Server) AddAddr(a string, m SocketMode) { s.addrs[m] = append(s.addrs[m], a) } // AddListeners adds listeners for the server to listen on. func (s *Server) AddListeners(ls []net.Listener, m SocketMode) { s.listeners[m] = append(s.listeners[m], ls...) } // AddDomain adds a local domain to the server. func (s *Server) AddDomain(d string) { s.localDomains.Add(d) s.aliasesR.AddDomain(d) } // AddUserDB adds a userdb file as backend for the domain. func (s *Server) AddUserDB(domain, f string) (int, error) { // Load the userdb, and register it unconditionally (so reload works even // if there are errors right now). udb, err := userdb.Load(f) s.authr.Register(domain, auth.WrapNoErrorBackend(udb)) return udb.Len(), err } // AddAliasesFile adds an aliases file for the given domain. func (s *Server) AddAliasesFile(domain, f string) (int, error) { return s.aliasesR.AddAliasesFile(domain, f) } var ( errDecodingPEMBlock = fmt.Errorf("error decoding PEM block") errUnsupportedBlockType = fmt.Errorf("unsupported block type") errUnsupportedKeyType = fmt.Errorf("unsupported key type") ) // AddDKIMSigner for the given domain and selector. func (s *Server) AddDKIMSigner(domain, selector, keyPath string) error { key, err := os.ReadFile(keyPath) if err != nil { return err } block, _ := pem.Decode(key) if block == nil { return errDecodingPEMBlock } if strings.ToUpper(block.Type) != "PRIVATE KEY" { return fmt.Errorf("%w: %s", errUnsupportedBlockType, block.Type) } signer, err := x509.ParsePKCS8PrivateKey(block.Bytes) if err != nil { return err } switch k := signer.(type) { case *rsa.PrivateKey, ed25519.PrivateKey: // These are supported, nothing to do. default: return fmt.Errorf("%w: %T", errUnsupportedKeyType, k) } s.dkimSigners[domain] = append(s.dkimSigners[domain], &dkim.Signer{ Domain: domain, Selector: selector, Signer: signer.(crypto.Signer), }) return nil } // SetAuthFallback sets the authentication backend to use as fallback. func (s *Server) SetAuthFallback(be auth.Backend) { s.authr.Fallback = be } // SetAliasesConfig sets the aliases configuration options. func (s *Server) SetAliasesConfig(suffixSep, dropChars string) { s.aliasesR.SuffixSep = suffixSep s.aliasesR.DropChars = dropChars s.aliasesR.ResolveHook = path.Join(s.HookPath, "alias-resolve") } // SetDomainInfo sets the domain info database to use. func (s *Server) SetDomainInfo(dinfo *domaininfo.DB) { s.dinfo = dinfo } // InitQueue initializes the queue. func (s *Server) InitQueue(path string, localC, remoteC courier.Courier) { q, err := queue.New(path, s.localDomains, s.aliasesR, localC, remoteC) if err != nil { log.Fatalf("Error initializing queue: %v", err) } err = q.Load() if err != nil { log.Fatalf("Error loading queue: %v", err) } s.queue = q http.HandleFunc("/debug/queue", func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(q.DumpString())) }) } func (s *Server) aliasResolveRPC(tr *trace.Trace, req url.Values) (url.Values, error) { rcpts, err := s.aliasesR.Resolve(tr, req.Get("Address")) if err != nil { return nil, err } v := url.Values{} for _, rcpt := range rcpts { v.Add(string(rcpt.Type), rcpt.Addr) } return v, nil } func (s *Server) dinfoClearRPC(tr *trace.Trace, req url.Values) (url.Values, error) { domain := req.Get("Domain") exists := s.dinfo.Clear(tr, domain) if !exists { return nil, fmt.Errorf("does not exist") } return nil, nil } // periodicallyReload some of the server's information that can be changed // without the server knowing, such as aliases and the user databases. func (s *Server) periodicallyReload() { if reloadEvery == nil { return } //lint:ignore SA1015 This lasts the program's lifetime. for range time.Tick(*reloadEvery) { s.Reload() } } func (s *Server) Reload() { // Note that any error while reloading is fatal: this way, if there is an // unexpected error it can be detected (and corrected) quickly, instead of // much later (e.g. upon restart) when it might be harder to debug. if err := s.aliasesR.Reload(); err != nil { log.Fatalf("Error reloading aliases: %v", err) } if err := s.authr.Reload(); err != nil { log.Fatalf("Error reloading authenticators: %v", err) } } // ListenAndServe on the addresses and listeners that were previously added. // This function will not return. func (s *Server) ListenAndServe() { if len(s.tlsConfig.Certificates) == 0 { // chasquid assumes there's at least one valid certificate (for things // like STARTTLS, user authentication, etc.), so we fail if none was // found. log.Errorf("No SSL/TLS certificates found") log.Errorf("Ideally there should be a certificate for each MX you act as") log.Fatalf("At least one valid certificate is needed") } localrpc.DefaultServer.Register("AliasResolve", s.aliasResolveRPC) localrpc.DefaultServer.Register("DomaininfoClear", s.dinfoClearRPC) go s.periodicallyReload() for m, addrs := range s.addrs { for _, addr := range addrs { l, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("Error listening: %v", err) } log.Infof("Server listening on %s (%v)", addr, m) maillog.Listening(addr) go s.serve(l, m) } } for m, ls := range s.listeners { for _, l := range ls { log.Infof("Server listening on %s (%v, via systemd)", l.Addr(), m) maillog.Listening(l.Addr().String()) go s.serve(l, m) } } // Never return. If the serve goroutines have problems, they will abort // execution. for { time.Sleep(24 * time.Hour) } } func (s *Server) serve(l net.Listener, mode SocketMode) { // If this mode is expected to be TLS-wrapped, make it so. if mode.TLS { l = tls.NewListener(l, s.tlsConfig) } pdhook := path.Join(s.HookPath, "post-data") for { conn, err := l.Accept() if err != nil { log.Fatalf("Error accepting: %v", err) } sc := &Conn{ hostname: s.Hostname, maxDataSize: s.MaxDataSize, postDataHook: pdhook, conn: conn, mode: mode, tlsConfig: s.tlsConfig, haproxyEnabled: s.HAProxyEnabled, onTLS: mode.TLS, authr: s.authr, aliasesR: s.aliasesR, localDomains: s.localDomains, dinfo: s.dinfo, dkimSigners: s.dkimSigners, deadline: time.Now().Add(s.connTimeout), commandTimeout: s.commandTimeout, queue: s.queue, } go sc.Handle() } }
// Package sts implements the MTA-STS (Strict Transport Security), RFC 8461. // // Note that "report" mode is not supported. // // Reference: https://tools.ietf.org/html/rfc8461 package sts import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "mime" "net" "net/http" "os" "strconv" "strings" "time" "blitiri.com.ar/go/chasquid/internal/expvarom" "blitiri.com.ar/go/chasquid/internal/safeio" "blitiri.com.ar/go/chasquid/internal/trace" "golang.org/x/net/context/ctxhttp" "golang.org/x/net/idna" ) // Exported variables. var ( cacheFetches = expvarom.NewInt("chasquid/sts/cache/fetches", "count of total fetches in the STS cache") cacheHits = expvarom.NewInt("chasquid/sts/cache/hits", "count of hits in the STS cache") cacheExpired = expvarom.NewInt("chasquid/sts/cache/expired", "count of expired entries in the STS cache") cacheIOErrors = expvarom.NewInt("chasquid/sts/cache/ioErrors", "count of I/O errors when maintaining STS cache") cacheFailedFetch = expvarom.NewInt("chasquid/sts/cache/failedFetch", "count of failed fetches in the STS cache") cacheInvalid = expvarom.NewInt("chasquid/sts/cache/invalid", "count of invalid policies in the STS cache") cacheMarshalErrors = expvarom.NewInt("chasquid/sts/cache/marshalErrors", "count of marshalling errors when maintaining STS cache") cacheUnmarshalErrors = expvarom.NewInt("chasquid/sts/cache/unmarshalErrors", "count of unmarshalling errors in STS cache") cacheRefreshCycles = expvarom.NewInt("chasquid/sts/cache/refreshCycles", "count of STS cache refresh cycles") cacheRefreshes = expvarom.NewInt("chasquid/sts/cache/refreshes", "count of STS cache refreshes") cacheRefreshErrors = expvarom.NewInt("chasquid/sts/cache/refreshErrors", "count of STS cache refresh errors") ) // Policy represents a parsed policy. // https://tools.ietf.org/html/rfc8461#section-3.2 // The json annotations are used for serializing for caching purposes. type Policy struct { Version string `json:"version"` Mode Mode `json:"mode"` MXs []string `json:"mx"` MaxAge time.Duration `json:"max_age"` } // The Mode of a policy. Valid values (according to the standard) are // constants below. type Mode string // Valid modes. const ( Enforce = Mode("enforce") Testing = Mode("testing") None = Mode("none") ) // parsePolicy parses a text representation of the policy (as specified in the // RFC), and returns the corresponding Policy structure. func parsePolicy(raw []byte) (*Policy, error) { p := &Policy{} scanner := bufio.NewScanner(bytes.NewReader(raw)) for scanner.Scan() { sp := strings.SplitN(scanner.Text(), ":", 2) if len(sp) != 2 { continue } key := strings.TrimSpace(sp[0]) value := strings.TrimSpace(sp[1]) // Only care for the keys we recognize. switch key { case "version": p.Version = value case "mode": p.Mode = Mode(value) case "max_age": // On error, p.MaxAge will be 0 which is invalid. maxAge, _ := strconv.Atoi(value) p.MaxAge = time.Duration(maxAge) * time.Second case "mx": p.MXs = append(p.MXs, value) } } if err := scanner.Err(); err != nil { return nil, err } return p, nil } // Check errors. var ( ErrUnknownVersion = errors.New("unknown policy version") ErrInvalidMaxAge = errors.New("invalid max_age") ErrInvalidMode = errors.New("invalid mode") ErrInvalidMX = errors.New("invalid mx") ) // Fetch errors. var ( ErrInvalidMediaType = errors.New("invalid HTTP media type") ) // Check that the policy contents are valid. func (p *Policy) Check() error { if p.Version != "STSv1" { return ErrUnknownVersion } // A 0 max age is invalid (could also represent an Atoi error), and so is // one greater than 31557600 (1 year), as per // https://tools.ietf.org/html/rfc8461#section-3.2. if p.MaxAge <= 0 || p.MaxAge > 31557600*time.Second { return ErrInvalidMaxAge } if p.Mode != Enforce && p.Mode != Testing && p.Mode != None { return ErrInvalidMode } // "mx" field is required, and the policy is invalid if it's not present. // https://mailarchive.ietf.org/arch/msg/uta/Omqo1Bw6rJbrTMl2Zo69IJr35Qo if len(p.MXs) == 0 { return ErrInvalidMX } return nil } // MXIsAllowed checks if the given MX is allowed, according to the policy. // https://tools.ietf.org/html/rfc8461#section-4.1 func (p *Policy) MXIsAllowed(mx string) bool { if p.Mode != Enforce { return true } for _, pattern := range p.MXs { if matchDomain(mx, pattern) { return true } } return false } // UncheckedFetch fetches and parses the policy, but does NOT check it. // This can be useful for debugging and troubleshooting, but you should always // call Check on the policy before using it. func UncheckedFetch(ctx context.Context, domain string) (*Policy, error) { // Convert the domain to ascii form, as httpGet does not support IDNs in // any other way. domain, err := idna.ToASCII(domain) if err != nil { return nil, err } ok, err := hasSTSRecord(domain) if err != nil { return nil, err } if !ok { return nil, fmt.Errorf("MTA-STS TXT record missing") } url := urlForDomain(domain) rawPolicy, err := httpGet(ctx, url) if err != nil { return nil, err } return parsePolicy(rawPolicy) } // Fake URL for testing purposes, so we can do more end-to-end tests, // including the HTTP fetching code. var fakeURLForTesting string func urlForDomain(domain string) string { if fakeURLForTesting != "" { return fakeURLForTesting + "/" + domain } // URL composed from the domain, as explained in: // https://tools.ietf.org/html/rfc8461#section-3.3 // https://tools.ietf.org/html/rfc8461#section-3.2 return "https://mta-sts." + domain + "/.well-known/mta-sts.txt" } // Fetch a policy for the given domain. Note this results in various network // lookups and HTTPS GETs, so it can be slow. // The returned policy is parsed and sanity-checked (using Policy.Check), so // it should be safe to use. func Fetch(ctx context.Context, domain string) (*Policy, error) { p, err := UncheckedFetch(ctx, domain) if err != nil { return nil, err } err = p.Check() if err != nil { return nil, err } return p, nil } // httpGet performs an HTTP GET of the given URL, using the context and // rejecting redirects, as per the standard. func httpGet(ctx context.Context, url string) ([]byte, error) { client := &http.Client{ // We MUST NOT follow redirects, see // https://tools.ietf.org/html/rfc8461#section-3.3 CheckRedirect: rejectRedirect, } resp, err := ctxhttp.Get(ctx, client, url) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("HTTP response status code: %v", resp.StatusCode) } // Media type must be "text/plain" to guard against cases where webservers // allow untrusted users to host non-text content (like HTML or images) at // a user-defined path. // https://tools.ietf.org/html/rfc8461#section-3.2 mt, _, err := mime.ParseMediaType(resp.Header.Get("Content-type")) if err != nil { return nil, fmt.Errorf("HTTP media type error: %v", err) } if mt != "text/plain" { return nil, ErrInvalidMediaType } // Read but up to 10k; policies should be way smaller than that, and // having a limit prevents abuse/accidents with very large replies. return io.ReadAll(&io.LimitedReader{R: resp.Body, N: 10 * 1024}) } var errRejectRedirect = errors.New("redirects not allowed in MTA-STS") func rejectRedirect(req *http.Request, via []*http.Request) error { return errRejectRedirect } // matchDomain checks if the domain matches the given pattern, according to // from https://tools.ietf.org/html/rfc8461#section-4.1 // (based on https://tools.ietf.org/html/rfc6125#section-6.4). func matchDomain(domain, pattern string) bool { domain, dErr := domainToASCII(domain) pattern, pErr := domainToASCII(pattern) if dErr != nil || pErr != nil { // Domains should already have been checked and normalized by the // caller, exposing this is not worth the API complexity in this case. return false } // Simplify the case of a literal match. if domain == pattern { return true } // For wildcards, skip the first part of the domain and match the rest. // Note that if the pattern is malformed this might fail, but we are ok // with that. if strings.HasPrefix(pattern, "*.") { parts := strings.SplitN(domain, ".", 2) if len(parts) > 1 && parts[1] == pattern[2:] { return true } } return false } // domainToASCII converts the domain to ASCII form, similar to idna.ToASCII // but with some preprocessing convenient for our use cases. func domainToASCII(domain string) (string, error) { domain = strings.TrimSuffix(domain, ".") domain = strings.ToLower(domain) return idna.ToASCII(domain) } // Function that we override for testing purposes. // In the future we will override net.DefaultResolver, but we don't do that // yet for backwards compatibility. var lookupTXT = net.LookupTXT // hasSTSRecord checks if there is a valid MTA-STS TXT record for the domain. // We don't do full parsing and don't care about the "id=" field, as it is // unused in this implementation. func hasSTSRecord(domain string) (bool, error) { txts, err := lookupTXT("_mta-sts." + domain) if err != nil { return false, err } for _, txt := range txts { if strings.HasPrefix(txt, "v=STSv1;") { return true, nil } } return false, nil } // PolicyCache is a caching layer for fetching policies. // // Policies are cached by domain, and stored in a single directory. // The files will have as mtime the time when the policy expires, this makes // the store simpler, as it can avoid keeping additional metadata. // // There is no in-memory caching. This may be added in the future, but for // now disk is good enough for our purposes. type PolicyCache struct { dir string } // NewCache creates an instance of PolicyCache using the given directory as // backing storage. The directory will be created if it does not exist. func NewCache(dir string) (*PolicyCache, error) { c := &PolicyCache{ dir: dir, } err := os.MkdirAll(dir, 0770) return c, err } const pathPrefix = "pol:" func (c *PolicyCache) domainPath(domain string) string { // We assume the domain is well formed, sanity check just in case. if strings.Contains(domain, "/") { panic("domain contains slash") } return c.dir + "/" + pathPrefix + domain } var errExpired = errors.New("cache entry expired") func (c *PolicyCache) load(domain string) (*Policy, error) { fname := c.domainPath(domain) fi, err := os.Stat(fname) if err != nil { return nil, err } if time.Since(fi.ModTime()) > 0 { cacheExpired.Add(1) return nil, errExpired } data, err := os.ReadFile(fname) if err != nil { cacheIOErrors.Add(1) return nil, err } p := &Policy{} err = json.Unmarshal(data, p) if err != nil { cacheUnmarshalErrors.Add(1) return nil, err } // The policy should always be valid, as we marshalled it ourselves; // however, check it just to be safe. if err := p.Check(); err != nil { cacheInvalid.Add(1) return nil, fmt.Errorf( "%s unmarshalled invalid policy %v: %v", domain, p, err) } return p, nil } func (c *PolicyCache) store(domain string, p *Policy) error { data, err := json.Marshal(p) if err != nil { cacheMarshalErrors.Add(1) return fmt.Errorf("%s failed to marshal policy %v, error: %v", domain, p, err) } // Change the modification time to the future, when the policy expires. // load will check for this to detect expired cache entries, see above for // the details. expires := time.Now().Add(p.MaxAge) chTime := func(fname string) error { return os.Chtimes(fname, expires, expires) } fname := c.domainPath(domain) err = safeio.WriteFile(fname, data, 0640, chTime) if err != nil { cacheIOErrors.Add(1) } return err } // Fetch a policy for the given domain, using the cache. func (c *PolicyCache) Fetch(ctx context.Context, domain string) (*Policy, error) { cacheFetches.Add(1) tr := trace.New("STSCache.Fetch", domain) defer tr.Finish() p, err := c.load(domain) if err == nil { tr.Debugf("cache hit: %v", p) cacheHits.Add(1) return p, nil } p, err = Fetch(ctx, domain) if err != nil { tr.Debugf("failed to fetch: %v", err) cacheFailedFetch.Add(1) return nil, err } tr.Debugf("fetched: %v", p) // We could do this asynchronously, as we got the policy to give to the // caller. However, to make troubleshooting easier and the cost of storing // entries easier to track down, we store synchronously. // Note that even if the store returns an error, we pass on the policy: at // this point we rather use the policy even if we couldn't store it in the // cache. err = c.store(domain, p) if err != nil { tr.Errorf("failed to store: %v", err) } else { tr.Debugf("stored") } return p, nil } // PeriodicallyRefresh the cache, by re-fetching all entries. func (c *PolicyCache) PeriodicallyRefresh(ctx context.Context) { for ctx.Err() == nil { c.refresh(ctx) cacheRefreshCycles.Add(1) // Wait 10 minutes between passes; this is a background refresh and // there's no need to poke the servers very often. time.Sleep(10 * time.Minute) } } func (c *PolicyCache) refresh(ctx context.Context) { tr := trace.New("STSCache.Refresh", c.dir) defer tr.Finish() entries, err := os.ReadDir(c.dir) if err != nil { tr.Errorf("failed to list directory %q: %v", c.dir, err) return } tr.Debugf("%d entries", len(entries)) for _, e := range entries { if !strings.HasPrefix(e.Name(), pathPrefix) { continue } domain := e.Name()[len(pathPrefix):] cacheRefreshes.Add(1) tr.Debugf("%v: refreshing", domain) fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second) p, err := Fetch(fetchCtx, domain) cancel() if err != nil { tr.Debugf("%v: failed to fetch: %v", domain, err) cacheRefreshErrors.Add(1) continue } tr.Debugf("%v: fetched", domain) err = c.store(domain, p) if err != nil { tr.Errorf("%v: failed to store: %v", domain, err) } else { tr.Debugf("%v: stored", domain) } } tr.Debugf("refresh done") }
// Package testlib provides common test utilities. package testlib import ( "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "math/big" "net" "os" "strings" "sync" "testing" "time" ) // MustTempDir creates a temporary directory, or dies trying. func MustTempDir(t *testing.T) string { dir, err := os.MkdirTemp("", "testlib_") if err != nil { t.Fatal(err) } err = os.Chdir(dir) if err != nil { t.Fatal(err) } t.Logf("test directory: %q", dir) return dir } // RemoveIfOk removes the given directory, but only if we have not failed. We // want to keep the failed directories for debugging. func RemoveIfOk(t *testing.T, dir string) { // Safeguard, to make sure we only remove test directories. // This should help prevent accidental deletions. if !strings.Contains(dir, "testlib_") { panic("invalid/dangerous directory") } if !t.Failed() { os.RemoveAll(dir) } } // Rewrite a file with the given contents. func Rewrite(t *testing.T, path, contents string) error { // Safeguard, to make sure we only mess with test files. if !strings.Contains(path, "testlib_") { panic("invalid/dangerous path") } err := os.WriteFile(path, []byte(contents), 0600) if err != nil { t.Errorf("failed to rewrite file: %v", err) } return err } // GetFreePort returns a free TCP port. This is hacky and not race-free, but // it works well enough for testing purposes. func GetFreePort() string { l, err := net.Listen("tcp", "localhost:0") if err != nil { panic(err) } defer l.Close() return l.Addr().String() } // WaitFor f to return true (returns true), or d to pass (returns false). func WaitFor(f func() bool, d time.Duration) bool { start := time.Now() for time.Since(start) < d { if f() { return true } time.Sleep(20 * time.Millisecond) } return false } type deliverRequest struct { From string To string Data []byte } // TestCourier never fails, and always remembers everything. type TestCourier struct { wg sync.WaitGroup Requests []*deliverRequest ReqFor map[string]*deliverRequest sync.Mutex } // Deliver the given mail (saving it in tc.Requests). func (tc *TestCourier) Deliver(from string, to string, data []byte) (error, bool) { defer tc.wg.Done() dr := &deliverRequest{from, to, data} tc.Lock() tc.Requests = append(tc.Requests, dr) tc.ReqFor[to] = dr tc.Unlock() return nil, false } // Expect i mails to be delivered. func (tc *TestCourier) Expect(i int) { tc.wg.Add(i) } // Wait until all mails have been delivered. func (tc *TestCourier) Wait() { tc.wg.Wait() } // NewTestCourier returns a new, empty TestCourier instance. func NewTestCourier() *TestCourier { return &TestCourier{ ReqFor: map[string]*deliverRequest{}, } } type dumbCourier struct{} func (c dumbCourier) Deliver(from string, to string, data []byte) (error, bool) { return nil, false } // DumbCourier always succeeds delivery, and ignores everything. var DumbCourier = dumbCourier{} // GenerateCert generates a new, INSECURE self-signed certificate and writes // it to a pair of (cert.pem, key.pem) files to the given path. // Note the certificate is only useful for testing purposes. func GenerateCert(path string) (*tls.Config, error) { tmpl := x509.Certificate{ SerialNumber: big.NewInt(1234), Subject: pkix.Name{ Organization: []string{"chasquid_test.go"}, }, DNSNames: []string{"localhost"}, IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, NotBefore: time.Now(), NotAfter: time.Now().Add(30 * time.Minute), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, IsCA: true, } priv, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { return nil, err } derBytes, err := x509.CreateCertificate( rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv) if err != nil { return nil, err } // Create a global config for convenience. srvCert, err := x509.ParseCertificate(derBytes) if err != nil { return nil, err } rootCAs := x509.NewCertPool() rootCAs.AddCert(srvCert) tlsConfig := &tls.Config{ ServerName: "localhost", RootCAs: rootCAs, } certOut, err := os.Create(path + "/cert.pem") if err != nil { return nil, err } defer certOut.Close() err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) if err != nil { return nil, err } keyOut, err := os.OpenFile( path+"/key.pem", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return nil, err } defer keyOut.Close() block := &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv), } err = pem.Encode(keyOut, block) return tlsConfig, err }
// Package tlsconst contains TLS constants for human consumption. package tlsconst // Most of the constants get automatically generated from IANA's assignments. //go:generate ./generate-ciphers.py ciphers.go import "fmt" var versionName = map[uint16]string{ 0x0300: "SSL-3.0", 0x0301: "TLS-1.0", 0x0302: "TLS-1.1", 0x0303: "TLS-1.2", 0x0304: "TLS-1.3", } // VersionName returns a human-readable TLS version name. func VersionName(v uint16) string { name, ok := versionName[v] if !ok { return fmt.Sprintf("TLS-%#04x", v) } return name } // CipherSuiteName returns a human-readable TLS cipher suite name. func CipherSuiteName(s uint16) string { name, ok := cipherSuiteName[s] if !ok { return fmt.Sprintf("TLS_UNKNOWN_CIPHER_SUITE-%#04x", s) } return name }
// Package trace extends nettrace with logging. package trace import ( "fmt" "strconv" "blitiri.com.ar/go/chasquid/internal/nettrace" "blitiri.com.ar/go/log" ) func init() { } // A Trace represents an active request. type Trace struct { family string title string t nettrace.Trace } // New trace. func New(family, title string) *Trace { t := &Trace{family, title, nettrace.New(family, title)} // The default for max events is 10, which is a bit short for a normal // SMTP exchange. Expand it to 100 which should be large enough to keep // most of the traces. t.t.SetMaxEvents(100) return t } // NewChild creates a new child trace. func (t *Trace) NewChild(family, title string) *Trace { n := &Trace{family, title, t.t.NewChild(family, title)} n.t.SetMaxEvents(100) return n } // Printf adds this message to the trace's log. func (t *Trace) Printf(format string, a ...interface{}) { t.t.Printf(format, a...) log.Log(log.Info, 1, "%s %s: %s", t.family, t.title, quote(fmt.Sprintf(format, a...))) } // Debugf adds this message to the trace's log, with a debugging level. func (t *Trace) Debugf(format string, a ...interface{}) { t.t.Printf(format, a...) log.Log(log.Debug, 1, "%s %s: %s", t.family, t.title, quote(fmt.Sprintf(format, a...))) } // Errorf adds this message to the trace's log, with an error level. func (t *Trace) Errorf(format string, a ...interface{}) error { // Note we can't just call t.Error here, as it breaks caller logging. err := fmt.Errorf(format, a...) t.t.SetError() t.t.Printf("error: %v", err) log.Log(log.Info, 1, "%s %s: error: %s", t.family, t.title, quote(err.Error())) return err } // Error marks the trace as having seen an error, and also logs it to the // trace's log. func (t *Trace) Error(err error) error { t.t.SetError() t.t.Printf("error: %v", err) log.Log(log.Info, 1, "%s %s: error: %s", t.family, t.title, quote(err.Error())) return err } // Finish the trace. It should not be changed after this is called. func (t *Trace) Finish() { t.t.Finish() } func quote(s string) string { qs := strconv.Quote(s) return qs[1 : len(qs)-1] }
// Package userdb implements a simple user database. // // # Format // // The user database is a file containing a list of users and their passwords, // encrypted with some scheme. // We use a text-encoded protobuf, the structure can be found in userdb.proto. // // We write text instead of binary to make it easier for administrators to // troubleshoot, and since performance is not an issue for our expected usage. // // Users must be UTF-8 and NOT contain whitespace; the library will enforce // this. // // # Schemes // // The default scheme is SCRYPT, with hard-coded parameters. The API does not // allow the user to change this, at least for now. // A PLAIN scheme is also supported for debugging purposes. // // # Writing // // The functions that write a database file will not preserve ordering, // invalid lines, empty lines, or any formatting. // // It is also not safe for concurrent use from different processes. package userdb //go:generate protoc --go_out=. --go_opt=paths=source_relative userdb.proto import ( "crypto/rand" "crypto/subtle" "errors" "fmt" "os" "sync" "golang.org/x/crypto/scrypt" "blitiri.com.ar/go/chasquid/internal/normalize" "blitiri.com.ar/go/chasquid/internal/protoio" ) // DB represents a single user database. type DB struct { fname string db *ProtoDB // Lock protecting db. mu sync.RWMutex } // New returns a new user database, on the given file name. func New(fname string) *DB { return &DB{ fname: fname, db: &ProtoDB{Users: map[string]*Password{}}, } } // Load the database from the given file. // Return the database, and an error if the database could not be loaded. If // the file does not exist, that is not considered an error. func Load(fname string) (*DB, error) { db := New(fname) err := protoio.ReadTextMessage(fname, db.db) // Reading may result in an empty protobuf or dictionary; make sure we // return an empty but usable structure. // This simplifies many of our uses, as we can assume the map is not nil. if db.db == nil || db.db.Users == nil { db.db = &ProtoDB{Users: map[string]*Password{}} } if os.IsNotExist(err) { // If the file does not exist now, it is not an error, as it might // exist later and we want to be able to read it. err = nil } return db, err } // Reload the database, refreshing its contents from the current file on disk. // If there are errors reading from the file, they are returned and the // database is not changed. func (db *DB) Reload() error { newdb, err := Load(db.fname) if err != nil { return err } db.mu.Lock() db.db = newdb.db db.mu.Unlock() return nil } // Write the database to disk. It will do a complete rewrite each time, and is // not safe to call it from different processes in parallel. func (db *DB) Write() error { db.mu.RLock() defer db.mu.RUnlock() return protoio.WriteTextMessage(db.fname, db.db, 0660) } // Authenticate returns true if the password is valid for the user, false // otherwise. func (db *DB) Authenticate(name, plainPassword string) bool { db.mu.RLock() passwd, ok := db.db.Users[name] db.mu.RUnlock() if !ok { return false } return passwd.PasswordMatches(plainPassword) } // PasswordMatches returns true if the given password is a match. func (p *Password) PasswordMatches(plain string) bool { switch s := p.Scheme.(type) { case nil: return false case *Password_Scrypt: return s.Scrypt.PasswordMatches(plain) case *Password_Plain: return s.Plain.PasswordMatches(plain) case *Password_Denied: return false default: return false } } // AddUser to the database. If the user is already present, override it. // Note we enforce that the name has been normalized previously. func (db *DB) AddUser(name, plainPassword string) error { if norm, err := normalize.User(name); err != nil || name != norm { return errors.New("invalid username") } s := &Scrypt{ // Use hard-coded standard parameters for now. // Follow the recommendations from the scrypt paper. LogN: 14, R: 8, P: 1, KeyLen: 32, // 16 bytes of salt (will be filled later). Salt: make([]byte, 16), } n, err := rand.Read(s.Salt) if n != 16 || err != nil { return fmt.Errorf("failed to get salt - %d - %v", n, err) } s.Encrypted, err = scrypt.Key([]byte(plainPassword), s.Salt, 1<<s.LogN, int(s.R), int(s.P), int(s.KeyLen)) if err != nil { return fmt.Errorf("scrypt failed: %v", err) } db.mu.Lock() db.db.Users[name] = &Password{ Scheme: &Password_Scrypt{s}, } db.mu.Unlock() return nil } // AddDenied to the database. If the user is already present, override it. // Note we enforce that the name has been normalized previously. func (db *DB) AddDeniedUser(name string) error { if norm, err := normalize.User(name); err != nil || name != norm { return errors.New("invalid username") } db.mu.Lock() db.db.Users[name] = &Password{ Scheme: &Password_Denied{&Denied{}}, } db.mu.Unlock() return nil } // RemoveUser from the database. Returns True if the user was there, False // otherwise. func (db *DB) RemoveUser(name string) bool { db.mu.Lock() _, present := db.db.Users[name] delete(db.db.Users, name) db.mu.Unlock() return present } // Exists returns true if the user is present, false otherwise. func (db *DB) Exists(name string) bool { db.mu.Lock() _, present := db.db.Users[name] db.mu.Unlock() return present } // Len returns the number of users in the database. func (db *DB) Len() int { db.mu.Lock() defer db.mu.Unlock() return len(db.db.Users) } /////////////////////////////////////////////////////////// // Encryption schemes // // PasswordMatches implementation for the plain text scheme. // Useful mostly for testing and debugging. // TODO: Do we really need this? Removing it would make accidents less likely // to happen. Consider doing so when we add another scheme, so we a least have // two and multi-scheme support does not bit-rot. func (p *Plain) PasswordMatches(plain string) bool { return plain == string(p.Password) } // PasswordMatches implementation for the scrypt scheme, which we use by // default. func (s *Scrypt) PasswordMatches(plain string) bool { dk, err := scrypt.Key([]byte(plain), s.Salt, 1<<s.LogN, int(s.R), int(s.P), int(s.KeyLen)) if err != nil { // The encryption failed, this is due to the parameters being invalid. // We validated them before, so something went really wrong. // TODO: do we want to return false instead? panic(fmt.Sprintf("scrypt failed: %v", err)) } // This comparison should be high enough up the stack that it doesn't // matter, but do it in constant time just in case. return subtle.ConstantTimeCompare(dk, []byte(s.Encrypted)) == 1 }
package main import ( "context" "expvar" "flag" "fmt" "html/template" "net/http" "os" "runtime" "runtime/debug" "strconv" "time" "blitiri.com.ar/go/chasquid/internal/config" "blitiri.com.ar/go/chasquid/internal/expvarom" "blitiri.com.ar/go/chasquid/internal/nettrace" "blitiri.com.ar/go/log" "google.golang.org/protobuf/encoding/prototext" // To enable live profiling in the monitoring server. _ "net/http/pprof" ) // Build information, overridden at build time using // -ldflags="-X main.version=blah". var ( version = "" sourceDateTs = "" ) var ( versionVar = expvar.NewString("chasquid/version") sourceDate time.Time sourceDateVar = expvar.NewString("chasquid/sourceDateStr") sourceDateTsVar = expvarom.NewInt("chasquid/sourceDateTimestamp", "timestamp when the binary was built, in seconds since epoch") ) func parseVersionInfo() { bi, ok := debug.ReadBuildInfo() if !ok { panic("unable to read build info") } dirty := false gitRev := "" gitTime := "" for _, s := range bi.Settings { switch s.Key { case "vcs.modified": if s.Value == "true" { dirty = true } case "vcs.time": gitTime = s.Value case "vcs.revision": gitRev = s.Value } } if sourceDateTs != "" { sdts, err := strconv.ParseInt(sourceDateTs, 10, 0) if err != nil { panic(err) } sourceDate = time.Unix(sdts, 0) } else { sourceDate, _ = time.Parse(time.RFC3339, gitTime) } sourceDateVar.Set(sourceDate.Format("2006-01-02 15:04:05 -0700")) sourceDateTsVar.Set(sourceDate.Unix()) if version == "" { version = sourceDate.Format("20060102") if gitRev != "" { version += fmt.Sprintf("-%.9s", gitRev) } if dirty { version += "-dirty" } } versionVar.Set(version) } func launchMonitoringServer(conf *config.Config) { log.Infof("Monitoring HTTP server listening on %s", conf.MonitoringAddress) osHostname, _ := os.Hostname() indexData := struct { Version string GoVersion string SourceDate time.Time StartTime time.Time Config *config.Config Hostname string }{ Version: version, GoVersion: runtime.Version(), SourceDate: sourceDate, StartTime: time.Now(), Config: conf, Hostname: osHostname, } http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { http.NotFound(w, r) return } if err := monitoringHTMLIndex.Execute(w, indexData); err != nil { log.Infof("monitoring handler error: %v", err) } }) srv := &http.Server{Addr: conf.MonitoringAddress} http.HandleFunc("/exit", exitHandler(srv)) http.HandleFunc("/metrics", expvarom.MetricsHandler) http.HandleFunc("/debug/flags", debugFlagsHandler) http.HandleFunc("/debug/config", debugConfigHandler(conf)) http.HandleFunc("/debug/traces", nettrace.RenderTraces) if err := srv.ListenAndServe(); err != http.ErrServerClosed { log.Fatalf("Monitoring server failed: %v", err) } } // Functions available inside the templates. var tmplFuncs = template.FuncMap{ "since": time.Since, "roundDuration": roundDuration, } // Static index for the monitoring website. var monitoringHTMLIndex = template.Must( template.New("index").Funcs(tmplFuncs).Parse( `<!DOCTYPE html> <html> <head> <meta name="viewport" content="width=device-width, initial-scale=1"> <title>{{.Hostname}}: chasquid monitoring</title> <style type="text/css"> body { font-family: sans-serif; } @media (prefers-color-scheme: dark) { body { background: #121212; color: #c9d1d9; } a { color: #44b4ec; } } </style> </head> <body> <h1>chasquid @{{.Config.Hostname}}</h1> <p> chasquid {{.Version}}<br> source date {{.SourceDate.Format "2006-01-02 15:04:05 -0700"}}<br> built with {{.GoVersion}}<br> </p> <p> started {{.StartTime.Format "Mon, 2006-01-02 15:04:05 -0700"}}<br> up for {{.StartTime | since | roundDuration}}<br> os hostname <i>{{.Hostname}}</i><br> </p> <ul> <li><a href="/debug/queue">queue</a> <li>monitoring <ul> <li><a href="/debug/traces">traces</a> <li><a href="https://blitiri.com.ar/p/chasquid/monitoring/#variables"> exported variables</a>: <a href="/debug/vars">expvar</a> <small><a href="https://golang.org/pkg/expvar/">(ref)</a></small>, <a href="/metrics">openmetrics</a> <small><a href="https://openmetrics.io/">(ref)</a></small> </ul> <li>execution <ul> <li><a href="/debug/flags">flags</a> <li><a href="/debug/config">config</a> <li><a href="/debug/pprof/cmdline">command line</a> </ul> <li><a href="/debug/pprof">pprof</a> <small><a href="https://golang.org/pkg/net/http/pprof/">(ref)</a></small> <ul> </ul> </ul> </body> </html> `)) func exitHandler(srv *http.Server) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Use POST method for exiting", http.StatusMethodNotAllowed) return } log.Infof("Received /exit") http.Error(w, "OK exiting", http.StatusOK) // Launch srv.Shutdown asynchronously, and then exit. // The http documentation says to wait for Shutdown to return before // exiting, to gracefully close all ongoing requests. go func() { if err := srv.Shutdown(context.Background()); err != nil { log.Fatalf("Monitoring server shutdown failed: %v", err) } os.Exit(0) }() } } func debugFlagsHandler(w http.ResponseWriter, r *http.Request) { visited := make(map[string]bool) // Print set flags first, then the rest. flag.Visit(func(f *flag.Flag) { fmt.Fprintf(w, "-%s=%s\n", f.Name, f.Value.String()) visited[f.Name] = true }) fmt.Fprintf(w, "\n") flag.VisitAll(func(f *flag.Flag) { if !visited[f.Name] { fmt.Fprintf(w, "-%s=%s\n", f.Name, f.Value.String()) } }) } func debugConfigHandler(conf *config.Config) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(prototext.Format(conf))) } } func roundDuration(d time.Duration) time.Duration { return d.Round(time.Second) }