git » spf » master » tree

[master] / internal / dnstest / dns.go

// DNS resolver for testing purposes.
//
// In the future, when go fuzz can make use of _test.go files, we can rename
// this file dns_test.go and remove this extra package entirely.
// Until then, unfortunately this is the most reasonable way to share these
// helpers between go and fuzz tests.
package dnstest

import (
	"context"
	"net"
	"strings"
)

// Testing DNS resolver.
//
// Not exported since this is not part of the public API and only used
// internally on tests.
type TestResolver struct {
	Txt    map[string][]string
	Mx     map[string][]*net.MX
	Ip     map[string][]net.IP
	Addr   map[string][]string
	Cname  map[string]string
	Errors map[string]error
}

func NewResolver() *TestResolver {
	return &TestResolver{
		Txt:    map[string][]string{},
		Mx:     map[string][]*net.MX{},
		Ip:     map[string][]net.IP{},
		Addr:   map[string][]string{},
		Cname:  map[string]string{},
		Errors: map[string]error{},
	}
}

var nxDomainErr = &net.DNSError{
	Err:        "domain not found (for testing)",
	IsNotFound: true,
}

func (r *TestResolver) LookupTXT(ctx context.Context, domain string) (txts []string, err error) {
	if ctx.Err() != nil {
		return nil, ctx.Err()
	}
	domain = strings.ToLower(domain)
	domain = strings.TrimRight(domain, ".")
	if cname, ok := r.Cname[domain]; ok {
		return r.LookupTXT(ctx, cname)
	}
	if _, ok := r.Txt[domain]; !ok && r.Errors[domain] == nil {
		return nil, nxDomainErr
	}
	return r.Txt[domain], r.Errors[domain]
}

func (r *TestResolver) LookupMX(ctx context.Context, domain string) (mxs []*net.MX, err error) {
	if ctx.Err() != nil {
		return nil, ctx.Err()
	}
	domain = strings.ToLower(domain)
	domain = strings.TrimRight(domain, ".")
	if cname, ok := r.Cname[domain]; ok {
		return r.LookupMX(ctx, cname)
	}
	if _, ok := r.Mx[domain]; !ok && r.Errors[domain] == nil {
		return nil, nxDomainErr
	}
	return r.Mx[domain], r.Errors[domain]
}

func (r *TestResolver) LookupIPAddr(ctx context.Context, host string) (as []net.IPAddr, err error) {
	if ctx.Err() != nil {
		return nil, ctx.Err()
	}
	host = strings.ToLower(host)
	host = strings.TrimRight(host, ".")
	if cname, ok := r.Cname[host]; ok {
		return r.LookupIPAddr(ctx, cname)
	}
	if _, ok := r.Ip[host]; !ok && r.Errors[host] == nil {
		return nil, nxDomainErr
	}
	return ipsToAddrs(r.Ip[host]), r.Errors[host]
}

func ipsToAddrs(ips []net.IP) []net.IPAddr {
	as := []net.IPAddr{}
	for _, ip := range ips {
		as = append(as, net.IPAddr{IP: ip, Zone: ""})
	}
	return as
}

func (r *TestResolver) LookupAddr(ctx context.Context, host string) (addrs []string, err error) {
	if ctx.Err() != nil {
		return nil, ctx.Err()
	}
	host = strings.ToLower(host)
	host = strings.TrimRight(host, ".")
	if cname, ok := r.Cname[host]; ok {
		return r.LookupAddr(ctx, cname)
	}
	if _, ok := r.Addr[host]; !ok && r.Errors[host] == nil {
		return nil, nxDomainErr
	}
	return r.Addr[host], r.Errors[host]
}