git » dnss » commit 1c82f29

httpstodns: Add tests for the query parser

author Alberto Bertogli
2017-07-30 13:51:02 UTC
committer Alberto Bertogli
2017-07-30 16:21:46 UTC
parent c8299e88fc8a9d3fffb84ad1d7a8a710cef4c074

httpstodns: Add tests for the query parser

internal/httpstodns/parser_test.go +98 -0
internal/httpstodns/server.go +19 -6

diff --git a/internal/httpstodns/parser_test.go b/internal/httpstodns/parser_test.go
new file mode 100644
index 0000000..52e8ce8
--- /dev/null
+++ b/internal/httpstodns/parser_test.go
@@ -0,0 +1,98 @@
+// Tests for the query parsing.
+package httpstodns
+
+import (
+	"net"
+	"net/url"
+	"reflect"
+	"testing"
+
+	"github.com/miekg/dns"
+)
+
+func makeURL(t *testing.T, query string) *url.URL {
+	u, err := url.Parse("http://site/resolve?" + query)
+	if err != nil {
+		t.Fatalf("URL parsing failed: %v", err)
+	}
+
+	return u
+}
+
+func makeIPNet(s string) *net.IPNet {
+	_, n, err := net.ParseCIDR(s)
+	if err != nil {
+		panic(err)
+	}
+	return n
+}
+
+func queryEq(q1, q2 query) bool {
+	return reflect.DeepEqual(q1, q2)
+}
+
+// A DNS name which is too long (> 253 characters), but otherwise valid.
+const longName = "pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito"
+
+func Test(t *testing.T) {
+	cases := []struct {
+		rawQ string
+		q    query
+	}{
+		{"name=hola", query{"hola", dns.TypeA, false, nil}},
+		{"name=hola&type=a", query{"hola", dns.TypeA, false, nil}},
+		{"name=hola&type=A", query{"hola", dns.TypeA, false, nil}},
+		{"name=hola&type=1", query{"hola", dns.TypeA, false, nil}},
+		{"name=hola&type=MX", query{"hola", dns.TypeMX, false, nil}},
+		{"name=hola&type=txt", query{"hola", dns.TypeTXT, false, nil}},
+		{"name=x&cd", query{"x", dns.TypeA, true, nil}},
+		{"name=x&cd=1", query{"x", dns.TypeA, true, nil}},
+		{"name=x&cd=true", query{"x", dns.TypeA, true, nil}},
+		{"name=x&cd=0", query{"x", dns.TypeA, false, nil}},
+		{"name=x&cd=false", query{"x", dns.TypeA, false, nil}},
+		{"name=x&type=mx;cd", query{"x", dns.TypeMX, true, nil}},
+
+		{
+			"name=x&edns_client_subnet=1.2.3.0/21",
+			query{"x", dns.TypeA, false, makeIPNet("1.2.3.0/21")},
+		},
+		{
+			"name=x&edns_client_subnet=2001:700:300::/48",
+			query{"x", dns.TypeA, false, makeIPNet("2001:700:300::/48")},
+		},
+		{
+			"name=x&type=mx&cd&edns_client_subnet=2001:700:300::/48",
+			query{"x", dns.TypeMX, true, makeIPNet("2001:700:300::/48")},
+		},
+	}
+	for _, c := range cases {
+		q, err := parseQuery(makeURL(t, c.rawQ))
+		if err != nil {
+			t.Errorf("query %q: error %v", c.rawQ, err)
+		}
+		if !queryEq(q, c.q) {
+			t.Errorf("query %q: expected %v, got %v", c.rawQ, c.q, q)
+		}
+	}
+
+	errCases := []struct {
+		raw string
+		err error
+	}{
+		{"", emptyNameErr},
+		{"name=" + longName, nameTooLongErr},
+		{"name=x;type=0", intOutOfRangeErr},
+		{"name=x;type=-1", intOutOfRangeErr},
+		{"name=x;type=65536", unknownType},
+		{"name=x;type=merienda", unknownType},
+		{"name=x;cd=lala", invalidCD},
+		{"name=x;edns_client_subnet=lala", invalidSubnetErr},
+		{"name=x;edns_client_subnet=1.2.3.4", invalidSubnetErr},
+	}
+	for _, c := range errCases {
+		_, err := parseQuery(makeURL(t, c.raw))
+		if err != c.err {
+			t.Errorf("query %q: expected error %v, got %v", c.raw, c.err, err)
+		}
+	}
+}
diff --git a/internal/httpstodns/server.go b/internal/httpstodns/server.go
index 2864a70..c5d4ca3 100644
--- a/internal/httpstodns/server.go
+++ b/internal/httpstodns/server.go
@@ -153,6 +153,19 @@ type query struct {
 	clientSubnet *net.IPNet
 }
 
+func (q query) String() string {
+	return fmt.Sprintf("{%s %d %v %s}", q.name, q.rrType, q.cd, q.clientSubnet)
+}
+
+var (
+	emptyNameErr     = fmt.Errorf("empty name")
+	nameTooLongErr   = fmt.Errorf("name too long")
+	invalidSubnetErr = fmt.Errorf("invalid edns_client_subnet")
+	intOutOfRangeErr = fmt.Errorf("invalid type (int out of range)")
+	unknownType      = fmt.Errorf("invalid type (unknown string type)")
+	invalidCD        = fmt.Errorf("invalid cd value")
+)
+
 func parseQuery(u *url.URL) (query, error) {
 	q := query{
 		name:         "",
@@ -174,10 +187,10 @@ func parseQuery(u *url.URL) (query, error) {
 	var err error
 
 	if q.name, ok = vs["name"]; !ok || q.name == "" {
-		return q, fmt.Errorf("empty name")
+		return q, emptyNameErr
 	}
 	if len(q.name) > 253 {
-		return q, fmt.Errorf("name too long")
+		return q, nameTooLongErr
 	}
 
 	if _, ok = vs["type"]; ok {
@@ -197,7 +210,7 @@ func parseQuery(u *url.URL) (query, error) {
 	if clientSubnet, ok := vs["edns_client_subnet"]; ok {
 		_, q.clientSubnet, err = net.ParseCIDR(clientSubnet)
 		if err != nil {
-			return q, fmt.Errorf("invalid edns_client_subnet: %v", err)
+			return q, invalidSubnetErr
 		}
 	}
 
@@ -213,12 +226,12 @@ func stringToRRType(s string) (uint16, error) {
 		if 1 <= i && i <= 65535 {
 			return uint16(i), nil
 		}
-		return 0, fmt.Errorf("invalid type (int out of range)")
+		return 0, intOutOfRangeErr
 	}
 
 	rrType, ok := dns.StringToType[strings.ToUpper(s)]
 	if !ok {
-		return 0, fmt.Errorf("invalid type (unknown string type)")
+		return 0, unknownType
 	}
 	return rrType, nil
 }
@@ -233,5 +246,5 @@ func stringToBool(s string) (bool, error) {
 		return false, nil
 	}
 
-	return false, fmt.Errorf("invalid cd value")
+	return false, invalidCD
 }