git » debian:golang-blitiri-go-spf » master » tree

[master] / spf_test.go

package spf

import (
	"flag"
	"fmt"
	"net"
	"os"
	"testing"
)

var txtResults = map[string][]string{}
var txtErrors = map[string]error{}

func LookupTXT(domain string) (txts []string, err error) {
	return txtResults[domain], txtErrors[domain]
}

var mxResults = map[string][]*net.MX{}
var mxErrors = map[string]error{}

func LookupMX(domain string) (mxs []*net.MX, err error) {
	return mxResults[domain], mxErrors[domain]
}

var ipResults = map[string][]net.IP{}
var ipErrors = map[string]error{}

func LookupIP(host string) (ips []net.IP, err error) {
	return ipResults[host], ipErrors[host]
}

var addrResults = map[string][]string{}
var addrErrors = map[string]error{}

func LookupAddr(host string) (addrs []string, err error) {
	return addrResults[host], addrErrors[host]
}

func TestMain(m *testing.M) {
	lookupTXT = LookupTXT
	lookupMX = LookupMX
	lookupIP = LookupIP
	lookupAddr = LookupAddr

	flag.Parse()
	os.Exit(m.Run())
}

var ip1110 = net.ParseIP("1.1.1.0")
var ip1111 = net.ParseIP("1.1.1.1")
var ip6666 = net.ParseIP("2001:db8::68")
var ip6660 = net.ParseIP("2001:db8::0")

func TestBasic(t *testing.T) {
	cases := []struct {
		txt string
		res Result
	}{
		{"", None},
		{"blah", None},
		{"v=spf1", Neutral},
		{"v=spf1 ", Neutral},
		{"v=spf1 -", PermError},
		{"v=spf1 all", Pass},
		{"v=spf1  +all", Pass},
		{"v=spf1 -all ", Fail},
		{"v=spf1 ~all", SoftFail},
		{"v=spf1 ?all", Neutral},
		{"v=spf1 a ~all", SoftFail},
		{"v=spf1 a/24", Neutral},
		{"v=spf1 a:d1110/24", Pass},
		{"v=spf1 a:d1110", Neutral},
		{"v=spf1 a:d1111", Pass},
		{"v=spf1 a:nothing/24", Neutral},
		{"v=spf1 mx", Neutral},
		{"v=spf1 mx/24", Neutral},
		{"v=spf1 mx:a/montoto ~all", PermError},
		{"v=spf1 mx:d1110/24 ~all", Pass},
		{"v=spf1 ip4:1.2.3.4 ~all", SoftFail},
		{"v=spf1 ip6:12 ~all", PermError},
		{"v=spf1 ip4:1.1.1.1 -all", Pass},
		{"v=spf1 ptr -all", Pass},
		{"v=spf1 ptr:d1111 -all", Pass},
		{"v=spf1 ptr:lalala -all", Pass},
		{"v=spf1 blah", PermError},
	}

	ipResults["d1111"] = []net.IP{ip1111}
	ipResults["d1110"] = []net.IP{ip1110}
	mxResults["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}}
	addrResults["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."}

	for _, c := range cases {
		txtResults["domain"] = []string{c.txt}
		res, err := CheckHost(ip1111, "domain")
		if (res == TempError || res == PermError) && (err == nil) {
			t.Errorf("%q: expected error, got nil", c.txt)
		}
		if res != c.res {
			t.Errorf("%q: expected %q, got %q", c.txt, c.res, res)
			t.Logf("%q:   error: %v", c.txt, err)
		}
	}
}

func TestIPv6(t *testing.T) {
	cases := []struct {
		txt string
		res Result
	}{
		{"v=spf1 all", Pass},
		{"v=spf1 a ~all", SoftFail},
		{"v=spf1 a/24", Neutral},
		{"v=spf1 a:d6660/24", Pass},
		{"v=spf1 a:d6660", Neutral},
		{"v=spf1 a:d6666", Pass},
		{"v=spf1 a:nothing/24", Neutral},
		{"v=spf1 mx:d6660/24 ~all", Pass},
		{"v=spf1 ip6:2001:db8::68 ~all", Pass},
		{"v=spf1 ip6:2001:db8::1/24 ~all", Pass},
		{"v=spf1 ip6:2001:db8::1/100 ~all", Pass},
		{"v=spf1 ptr -all", Pass},
		{"v=spf1 ptr:d6666 -all", Pass},
		{"v=spf1 ptr:sonlas6 -all", Pass},
	}

	ipResults["d6666"] = []net.IP{ip6666}
	ipResults["d6660"] = []net.IP{ip6660}
	mxResults["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}}
	addrResults["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."}

	for _, c := range cases {
		txtResults["domain"] = []string{c.txt}
		res, err := CheckHost(ip6666, "domain")
		if (res == TempError || res == PermError) && (err == nil) {
			t.Errorf("%q: expected error, got nil", c.txt)
		}
		if res != c.res {
			t.Errorf("%q: expected %q, got %q", c.txt, c.res, res)
			t.Logf("%q:   error: %v", c.txt, err)
		}
	}
}

func TestNotSupported(t *testing.T) {
	cases := []string{
		"v=spf1 exists:blah -all",
		"v=spf1 exp=blah -all",
		"v=spf1 a:%{o} -all",
	}

	for _, txt := range cases {
		txtResults["domain"] = []string{txt}
		res, err := CheckHost(ip1111, "domain")
		if res != Neutral {
			t.Errorf("%q: expected neutral, got %v", txt, res)
			t.Logf("%q:   error: %v", txt, err)
		}
	}
}

func TestRecursion(t *testing.T) {
	txtResults["domain"] = []string{"v=spf1 include:domain ~all"}

	res, err := CheckHost(ip1111, "domain")
	if res != PermError {
		t.Errorf("expected permerror, got %v (%v)", res, err)
	}
}

func TestNoRecord(t *testing.T) {
	txtResults["d1"] = []string{""}
	txtResults["d2"] = []string{"loco", "v=spf2"}
	txtErrors["nospf"] = fmt.Errorf("no such domain")

	for _, domain := range []string{"d1", "d2", "d3", "nospf"} {
		res, err := CheckHost(ip1111, domain)
		if res != None {
			t.Errorf("expected none, got %v (%v)", res, err)
		}
	}
}