git » chasquid » main » tree

[main] / internal / dkim / header.go

package dkim

import (
	"crypto"
	"encoding/base64"
	"errors"
	"fmt"
	"slices"
	"strconv"
	"strings"
	"time"
)

// https://datatracker.ietf.org/doc/html/rfc6376#section-6

type dkimSignature struct {
	// Version. Must be "1".
	v string

	// Algorithm. Like "rsa-sha256".
	a string

	// Key type, extracted from a=.
	KeyType keyType

	// Hash, extracted from a=.
	Hash crypto.Hash

	// Signature data.
	// Decoded from base64, ignoring whitespace.
	b []byte

	// Hash of canonicalized body.
	// Decoded from base64, ignoring whitespace.
	bh []byte

	// Canonicalization modes.
	cH canonicalization
	cB canonicalization

	// Domain ("SDID"), in plain text.
	// IDNs MUST be encoded as A-labels.
	d string

	// Signed header fields.
	// Colon-separated list of header fields.
	h []string

	// AUID, in plain text.
	i string

	// Body octet count of the canonicalized body.
	l uint64

	// Query methods used for DNS lookup.
	// Colon-separated list of methods. Only "dns/txt" is valid.
	q []string

	// Selector.
	s string

	// Timestamp. In Seconds since the UNIX epoch.
	t time.Time

	// Signature expiration. In Seconds since the UNIX epoch.
	x time.Time

	// Copied header fields.
	// Has a specific encoding but whitespace is ignored.
	z string
}

func (sig *dkimSignature) canonicalizationFromString(s string) error {
	if s == "" {
		sig.cH = simpleCanonicalization
		sig.cB = simpleCanonicalization
		return nil
	}

	// Either "header/body" or "header". In the latter case, "simple" is used
	// for the body canonicalization.
	// No whitespace around the '/' is allowed.
	hs, bs, _ := strings.Cut(s, "/")
	if bs == "" {
		bs = "simple"
	}

	var err error
	sig.cH, err = stringToCanonicalization(hs)
	if err != nil {
		return fmt.Errorf("header: %w", err)
	}
	sig.cB, err = stringToCanonicalization(bs)
	if err != nil {
		return fmt.Errorf("body: %w", err)
	}

	return nil
}

func (sig *dkimSignature) checkRequiredTags() error {
	// https://datatracker.ietf.org/doc/html/rfc6376#section-6.1.1
	if sig.a == "" {
		return fmt.Errorf("%w: a=", errMissingRequiredTag)
	}
	if len(sig.b) == 0 {
		return fmt.Errorf("%w: b=", errMissingRequiredTag)
	}
	if len(sig.bh) == 0 {
		return fmt.Errorf("%w: bh=", errMissingRequiredTag)
	}
	if sig.d == "" {
		return fmt.Errorf("%w: d=", errMissingRequiredTag)
	}
	if len(sig.h) == 0 {
		return fmt.Errorf("%w: h=", errMissingRequiredTag)
	}
	if sig.s == "" {
		return fmt.Errorf("%w: s=", errMissingRequiredTag)
	}

	// h= must contain From.
	var isFrom = func(s string) bool { return strings.EqualFold(s, "from") }
	if !slices.ContainsFunc(sig.h, isFrom) {
		return fmt.Errorf("%w: h= does not contain 'from'", errInvalidTag)
	}

	// If i= is present, its domain must be equal to, or a subdomain of, d=.
	if sig.i != "" {
		_, domain, _ := strings.Cut(sig.i, "@")
		if domain != sig.d && !strings.HasSuffix(domain, "."+sig.d) {
			return fmt.Errorf("%w: i= is not a subdomain of d=",
				errInvalidTag)
		}
	}

	return nil
}

var (
	errInvalidSignature   = errors.New("invalid signature")
	errInvalidVersion     = errors.New("invalid version")
	errBadATag            = errors.New("invalid a= tag")
	errUnsupportedHash    = errors.New("unsupported hash")
	errUnsupportedKeyType = errors.New("unsupported key type")
	errMissingRequiredTag = errors.New("missing required tag")
	errNegativeTimestamp  = errors.New("negative timestamp")
)

// String replacer that removes whitespace.
var eatWhitespace = strings.NewReplacer(" ", "", "\t", "", "\r", "", "\n", "")

func dkimSignatureFromHeader(header string) (*dkimSignature, error) {
	tags, err := parseTags(header)
	if err != nil {
		return nil, err
	}

	sig := &dkimSignature{
		v: tags["v"],
		a: tags["a"],
	}

	// v= tag is mandatory and must be 1.
	if sig.v != "1" {
		return nil, errInvalidVersion
	}

	// a= tag is mandatory; check that we can parse it and that we support the
	// algorithms.
	ktS, hS, found := strings.Cut(sig.a, "-")
	if !found {
		return nil, errBadATag
	}
	sig.KeyType, err = keyTypeFromString(ktS)
	if err != nil {
		return nil, fmt.Errorf("%w: %s", err, sig.a)
	}
	sig.Hash, err = hashFromString(hS)
	if err != nil {
		return nil, fmt.Errorf("%w: %s", err, sig.a)
	}

	// b is base64-encoded, and whitespace in it must be ignored.
	sig.b, err = base64.StdEncoding.DecodeString(
		eatWhitespace.Replace(tags["b"]))
	if err != nil {
		return nil, fmt.Errorf("%w: failed to decode b: %w",
			errInvalidSignature, err)
	}

	// bh - same as b.
	sig.bh, err = base64.StdEncoding.DecodeString(
		eatWhitespace.Replace(tags["bh"]))
	if err != nil {
		return nil, fmt.Errorf("%w: failed to decode bh: %w",
			errInvalidSignature, err)
	}

	err = sig.canonicalizationFromString(tags["c"])
	if err != nil {
		return nil, fmt.Errorf("%w: failed to parse c: %w",
			errInvalidSignature, err)
	}

	sig.d = tags["d"]

	// h is a colon-separated list of header fields.
	if tags["h"] != "" {
		sig.h = strings.Split(eatWhitespace.Replace(tags["h"]), ":")
	}

	sig.i = tags["i"]

	if tags["l"] != "" {
		sig.l, err = strconv.ParseUint(tags["l"], 10, 64)
		if err != nil {
			return nil, fmt.Errorf("%w: failed to parse l: %w",
				errInvalidSignature, err)
		}
	}

	// q is a colon-separated list of query methods.
	if tags["q"] != "" {
		sig.q = strings.Split(eatWhitespace.Replace(tags["q"]), ":")
	}
	if len(sig.q) > 0 && !slices.Contains(sig.q, "dns/txt") {
		return nil, fmt.Errorf("%w: no dns/txt query method in q",
			errInvalidSignature)
	}

	sig.s = tags["s"]

	if tags["t"] != "" {
		sig.t, err = unixStrToTime(tags["t"])
		if err != nil {
			return nil, fmt.Errorf("%w: failed to parse t: %w",
				errInvalidSignature, err)
		}
	}

	if tags["x"] != "" {
		sig.x, err = unixStrToTime(tags["x"])
		if err != nil {
			return nil, fmt.Errorf("%w: failed to parse x: %w",
				errInvalidSignature, err)
		}
	}

	sig.z = eatWhitespace.Replace(tags["z"])

	// Check required tags are present.
	if err := sig.checkRequiredTags(); err != nil {
		return nil, err
	}

	return sig, nil
}

func unixStrToTime(s string) (time.Time, error) {
	// Technically the timestamp is an "unsigned decimal integer", but since
	// time.Unix takes an int64, we use that and check it's positive.
	ti, err := strconv.ParseInt(s, 10, 64)
	if err != nil {
		return time.Time{}, err
	}
	if ti < 0 {
		return time.Time{}, errNegativeTimestamp
	}
	return time.Unix(ti, 0), nil
}

type keyType string

const (
	keyTypeRSA     keyType = "rsa"
	keyTypeEd25519 keyType = "ed25519"
)

func keyTypeFromString(s string) (keyType, error) {
	switch s {
	case "rsa":
		return keyTypeRSA, nil
	case "ed25519":
		return keyTypeEd25519, nil
	default:
		return "", errUnsupportedKeyType
	}
}

func hashFromString(s string) (crypto.Hash, error) {
	switch s {
	// Note SHA1 is not supported: as per RFC 8301, it must not be used
	// for signing or verifying.
	// https://datatracker.ietf.org/doc/html/rfc8301#section-3.1
	case "sha256":
		return crypto.SHA256, nil
	default:
		return 0, errUnsupportedHash
	}
}

// DKIM Tag=Value lists, as defined in RFC 6376, Section 3.2.
// https://datatracker.ietf.org/doc/html/rfc6376#section-3.2
type tags map[string]string

var errInvalidTag = errors.New("invalid tag")

func parseTags(s string) (tags, error) {
	// First trim space, and trailing semicolon, to simplify parsing below.
	s = strings.TrimSpace(s)
	s = strings.TrimSuffix(s, ";")

	tags := make(tags)
	for _, tv := range strings.Split(s, ";") {
		t, v, found := strings.Cut(tv, "=")
		if !found {
			return nil, fmt.Errorf("%w: missing '='", errInvalidTag)
		}

		// Trim leading and trailing whitespace from tag and value, as per
		// RFC.
		t = strings.TrimSpace(t)
		v = strings.TrimSpace(v)

		if t == "" {
			return nil, fmt.Errorf("%w: missing tag name", errInvalidTag)
		}

		// RFC 6376, Section 3.2: Tags with duplicate names MUST NOT occur
		// within a single tag-list; if a tag name does occur more than once,
		// the entire tag-list is invalid.
		if _, exists := tags[t]; exists {
			return nil, fmt.Errorf("%w: duplicate tag", errInvalidTag)
		}

		tags[t] = v
	}

	return tags, nil
}