package spf
import (
"flag"
"fmt"
"io"
"net"
"os"
"strings"
"testing"
"gopkg.in/yaml.v3"
)
var (
ymlSingle = flag.String("yml_single", "",
"run only the test with this name")
ymlSkipMarked = flag.Bool("yml_skip_marked", true,
"skip tests marked with the 'skip' value")
)
//////////////////////////////////////////////////////
// YAML test suite parsing.
//
type Suite struct {
Description string
Tests map[string]Test
ZoneData map[string][]Record `yaml:"zonedata"`
}
type Test struct {
Description string
Comment string
Spec stringSlice
Helo string
Host string
MailFrom string `yaml:"mailfrom"`
Result stringSlice
Explanation string
Skip string
}
// Only one of these will be set.
type Record struct {
A stringSlice `yaml:"A"`
AAAA stringSlice `yaml:"AAAA"`
MX *MX `yaml:"MX"`
SPF stringSlice `yaml:"SPF"`
TXT stringSlice `yaml:"TXT"`
PTR stringSlice `yaml:"PTR"`
CNAME string `yaml:"CNAME"`
// Errors.
TIMEOUT bool `yaml:"TIMEOUT"`
SERVFAIL bool `yaml:"SERVFAIL"`
CNAMELOOP bool `yaml:"CNAMELOOP"`
}
func (r Record) String() string {
if len(r.A) > 0 {
return fmt.Sprintf("A: %v", r.A)
}
if len(r.AAAA) > 0 {
return fmt.Sprintf("AAAA: %v", r.AAAA)
}
if r.MX != nil {
return fmt.Sprintf("MX: %v", *r.MX)
}
if len(r.SPF) > 0 {
return fmt.Sprintf("SPF: %v", r.SPF)
}
if len(r.TXT) > 0 {
return fmt.Sprintf("TXT: %v", r.TXT)
}
if len(r.PTR) > 0 {
return fmt.Sprintf("PTR: %v", r.PTR)
}
if r.CNAME != "" {
return fmt.Sprintf("CNAME: %v", r.CNAME)
}
if r.TIMEOUT {
return "TIMEOUT"
}
if r.SERVFAIL {
return "SERVFAIL"
}
if r.CNAMELOOP {
return "CNAMELOOP"
}
return "<empty>"
}
// String slice with a custom yaml unmarshaller, because the yaml parser can't
// handle single-element entries.
// https://github.com/go-yaml/yaml/issues/100
type stringSlice []string
func (sl *stringSlice) UnmarshalYAML(value *yaml.Node) error {
// Try a slice first, and if it works, return it.
slice := []string{}
if err := value.Decode(&slice); err == nil {
*sl = slice
return nil
}
// Get a single string, and append it.
single := ""
if err := value.Decode(&single); err != nil {
return err
}
*sl = []string{single}
return nil
}
// MX is encoded as:
//
// MX: [0, mail.example.com]
//
// so we have a custom decoder to handle the multi-typed list.
type MX struct {
Prio uint16
Host string
}
func (mx *MX) UnmarshalYAML(value *yaml.Node) error {
seq := []interface{}{}
if err := value.Decode(&seq); err != nil {
return err
}
mx.Prio = uint16(seq[0].(int))
mx.Host = seq[1].(string)
return nil
}
//////////////////////////////////////////////////////
// Test runners.
//
func testRFC(t *testing.T, fname string) {
input, err := os.Open(fname)
if err != nil {
t.Fatal(err)
}
suites := []Suite{}
dec := yaml.NewDecoder(input)
for {
s := Suite{}
err = dec.Decode(&s)
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
suites = append(suites, s)
}
defaultTrace = t.Logf
for _, suite := range suites {
t.Logf("suite: %v", suite.Description)
// Set up zone for the suite based on zonedata.
dns := NewDefaultResolver()
for domain, records := range suite.ZoneData {
t.Logf(" domain %v", domain)
for _, record := range records {
t.Logf(" %v", record)
if record.TIMEOUT {
err := &net.DNSError{
Err: "test timeout error",
IsTimeout: true,
}
dns.Errors[domain] = err
}
if record.SERVFAIL {
err := &net.DNSError{
Err: "test servfail error",
IsTimeout: false,
IsTemporary: false,
}
dns.Errors[domain] = err
}
if record.CNAMELOOP {
dns.Errors[domain] = &net.DNSError{
Err: "CNAME loop detected",
IsTemporary: false,
}
}
for _, s := range record.A {
dns.Ip[domain] = append(dns.Ip[domain], net.ParseIP(s))
}
for _, s := range record.AAAA {
dns.Ip[domain] = append(dns.Ip[domain], net.ParseIP(s))
}
for _, s := range record.TXT {
dns.Txt[domain] = append(dns.Txt[domain], s)
}
if record.MX != nil {
dns.Mx[domain] = append(dns.Mx[domain],
mx(record.MX.Host, record.MX.Prio))
}
for _, s := range record.PTR {
// domain in this case is of the form:
// 4.3.2.1.in-addr.arpa
// 1.0.0.0.0.[...].0.0.E.B.A.B.E.F.A.C.ip6.arpa
// We need to extract the normal string representation for
// them, and add the record to dns.addr[ip.String()].
// Enforce that the record is fully qualified, that's what
// we expect to see in practice.
if !strings.HasSuffix(s, ".") {
s += "."
}
ip := reverseDNS(t, domain).String()
dns.Addr[ip] = append(dns.Addr[ip], s)
}
if record.CNAME != "" {
dns.Cname[domain] = record.CNAME
}
}
// The test suite is not well done: some tests use SPF instead of
// TXT because they are old, and others expect the lookup to try
// TXT first and SPF later, even though that's forbidden by the
// standard.
// To try to minimize changes to the suite, we work around this by
// only adding records from SPF if there is no TXT already.
// We need to do this in a separate step because order of
// appearance is not guaranteed.
if len(dns.Txt[domain]) == 0 {
for _, record := range records {
if len(record.SPF) > 0 {
// The test suite expect a single-line SPF record to be
// concatenated without spaces.
dns.Txt[domain] = append(dns.Txt[domain],
strings.Join(record.SPF, ""))
}
}
}
}
// Run each test.
for name, test := range suite.Tests {
if *ymlSingle != "" && *ymlSingle != name {
continue
}
if test.Skip != "" && *ymlSkipMarked {
continue
}
t.Logf(" test %s", name)
ip := net.ParseIP(test.Host)
t.Logf(" checkhost %v %v", ip, test.MailFrom)
res, err := CheckHostWithSender(
net.ParseIP(test.Host), test.Helo, test.MailFrom)
if !resultIn(res, test.Result) {
t.Errorf(" failed: expected %v, got %v (%v) [%v]",
test.Result, res, err, name)
} else {
t.Logf(" success: %v, %v [%v]", res, err, name)
}
}
}
}
func resultIn(got Result, exp []string) bool {
for _, e := range exp {
if e == string(got) {
return true
}
}
return false
}
// Take a reverse-dns host name of the form:
//
// 4.3.2.1.in-addr.arpa
// 1.0.0.0.0.[...].0.0.E.B.A.B.E.F.A.C.ip6.arpa
//
// and returns the corresponding ip.
func reverseDNS(t *testing.T, r string) net.IP {
s := ""
if strings.HasSuffix(r, ".in-addr.arpa") {
// Strip suffix.
r := r[:len(r)-len(".in-addr.arpa")]
// Break down in pieces, and construct the ipv4 string backwards.
pieces := strings.Split(r, ".")
for i := 0; i < len(pieces); i++ {
s += pieces[len(pieces)-1-i] + "."
}
s = s[:len(s)-1]
} else if strings.HasSuffix(r, ".ip6.arpa") {
// Strip suffix.
r := r[:len(r)-len(".ip6.arpa")]
// Break down in pieces, and construct the ipv6 string backwards.
pieces := strings.Split(r, ".")
for i := 0; i < len(pieces); i++ {
s += pieces[len(pieces)-1-i]
if i%4 == 3 {
s += ":"
}
}
s = s[:len(s)-1]
} else {
t.Fatalf("invalid reverse dns %q: invalid suffix", r)
}
ip := net.ParseIP(s)
if ip == nil {
t.Fatalf("invalid reverse dns %q: bad ip %q", r, s)
}
return ip
}
func TestOurs(t *testing.T) {
testRFC(t, "testdata/blitirispf-tests.yml")
}
func TestRFC4408(t *testing.T) {
testRFC(t, "testdata/rfc4408-tests.yml")
}
func TestRFC7208(t *testing.T) {
testRFC(t, "testdata/rfc7208-tests.yml")
}
func TestPySPF(t *testing.T) {
testRFC(t, "testdata/pyspf-tests.yml")
}