package dkim
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/fs"
"net"
"os"
"path/filepath"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestFromFiles(t *testing.T) {
msgfs, err := filepath.Glob("testdata/*.msg")
if err != nil {
t.Fatalf("error finding test files: %v", err)
}
for _, msgf := range msgfs {
base := strings.TrimSuffix(msgf, filepath.Ext(msgf))
t.Run(base, func(t *testing.T) { testOne(t, base) })
}
}
// This is the same as TestFromFiles, but it runs the private test files,
// which are not included in the git repository.
// This is useful for running tests on your own machine, with emails that you
// don't necessarily want to share publicly.
func TestFromPrivateFiles(t *testing.T) {
msgfs, err := filepath.Glob("testdata/private/*/*.msg")
if err != nil {
t.Fatalf("error finding private test files: %v", err)
}
for _, msgf := range msgfs {
base := strings.TrimSuffix(msgf, filepath.Ext(msgf))
t.Run(base, func(t *testing.T) { testOne(t, base) })
}
}
func testOne(t *testing.T, base string) {
ctx := context.Background()
ctx = WithTraceFunc(ctx, t.Logf)
ctx = loadDNS(t, ctx, base+".dns")
msg := toCRLF(mustReadFile(t, base+".msg"))
wantResult := loadResult(t, base+".result")
wantError := loadError(t, base+".error")
t.Logf("Message: %.60q", msg)
t.Logf("Want result: %+v", wantResult)
t.Logf("Want error: %v", wantError)
res, err := VerifyMessage(ctx, msg)
// Write the results out for easy updating.
writeResults(t, base, res, err)
diff := cmp.Diff(wantResult, res, cmp.Comparer(equalErrors))
if diff != "" {
t.Errorf("VerifyMessage result diff (-want +got):\n%s", diff)
}
// We need to compare them by hand because cmp.Diff won't use our comparer
// for top-level errors.
if !equalErrors(wantError, err) {
diff := cmp.Diff(wantError, err)
t.Errorf("VerifyMessage error diff (-want +got):\n%s", diff)
}
}
// Used to make cmp.Diff compare errors by their messages. This is obviously
// not great, but it's good enough for this test.
func equalErrors(a, b error) bool {
if a == nil {
return b == nil
}
if b == nil {
return false
}
return a.Error() == b.Error()
}
func mustReadFile(t *testing.T, path string) string {
t.Helper()
contents, err := os.ReadFile(path)
if errors.Is(err, fs.ErrNotExist) {
return ""
}
if err != nil {
t.Fatalf("error reading %q: %v", path, err)
}
return string(contents)
}
func loadDNS(t *testing.T, ctx context.Context, path string) context.Context {
t.Helper()
results := map[string][]string{}
errors := map[string]error{}
txtFunc := func(ctx context.Context, domain string) ([]string, error) {
return results[domain], errors[domain]
}
ctx = WithLookupTXTFunc(ctx, txtFunc)
c := mustReadFile(t, path)
// Unfold \-terminated lines.
c = strings.ReplaceAll(c, "\\\n", "")
for _, line := range strings.Split(c, "\n") {
if line == "" || strings.HasPrefix(line, "#") {
continue
}
domain, txt, ok := strings.Cut(line, ":")
if !ok {
continue
}
domain = strings.TrimSpace(domain)
switch strings.TrimSpace(txt) {
case "TEMPERROR":
errors[domain] = &net.DNSError{
Err: "temporary error (for testing)",
IsTemporary: true,
}
case "PERMERROR":
errors[domain] = &net.DNSError{
Err: "permanent error (for testing)",
IsTemporary: false,
}
case "NOTFOUND":
errors[domain] = &net.DNSError{
Err: "domain not found (for testing)",
IsNotFound: true,
}
default:
results[domain] = append(results[domain], txt)
}
}
t.Logf("Loaded DNS results: %#v", results)
t.Logf("Loaded DNS errors: %v", errors)
return ctx
}
func loadResult(t *testing.T, path string) *VerifyResult {
t.Helper()
res := &VerifyResult{}
c := mustReadFile(t, path)
if c == "" {
return nil
}
err := json.Unmarshal([]byte(c), res)
if err != nil {
t.Fatalf("error unmarshalling %q: %v", path, err)
}
return res
}
func loadError(t *testing.T, path string) error {
t.Helper()
c := strings.TrimSpace(mustReadFile(t, path))
if c == "" || c == "nil" || c == "<nil>" {
return nil
}
return errors.New(c)
}
func mustWriteFile(t *testing.T, path string, c []byte) {
t.Helper()
err := os.WriteFile(path, c, 0644)
if err != nil {
t.Fatalf("error writing %q: %v", path, err)
}
}
func writeResults(t *testing.T, base string, res *VerifyResult, err error) {
t.Helper()
mustWriteFile(t, base+".error.got", []byte(fmt.Sprintf("%v", err)))
c, err := json.MarshalIndent(res, "", "\t")
if err != nil {
t.Fatalf("error marshalling result: %v", err)
}
mustWriteFile(t, base+".result.got", c)
}
// Custom json marshaller so we can write errors as strings.
func (or *OneResult) MarshalJSON() ([]byte, error) {
// We use an alias to avoid infinite recursion.
type Alias OneResult
aux := &struct {
Error string `json:""`
*Alias
}{
Alias: (*Alias)(or),
}
if or.Error != nil {
aux.Error = or.Error.Error()
}
return json.Marshal(aux)
}
// Custom json unmarshaller so we can read errors as strings.
func (or *OneResult) UnmarshalJSON(b []byte) error {
// We use an alias to avoid infinite recursion.
type Alias OneResult
aux := &struct {
Error string `json:""`
*Alias
}{
Alias: (*Alias)(or),
}
if err := json.Unmarshal(b, aux); err != nil {
return err
}
if aux.Error != "" {
or.Error = errors.New(aux.Error)
}
return nil
}