author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-07-30 13:51:02 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-07-30 16:21:46 UTC |
parent | c8299e88fc8a9d3fffb84ad1d7a8a710cef4c074 |
internal/httpstodns/parser_test.go | +98 | -0 |
internal/httpstodns/server.go | +19 | -6 |
diff --git a/internal/httpstodns/parser_test.go b/internal/httpstodns/parser_test.go new file mode 100644 index 0000000..52e8ce8 --- /dev/null +++ b/internal/httpstodns/parser_test.go @@ -0,0 +1,98 @@ +// Tests for the query parsing. +package httpstodns + +import ( + "net" + "net/url" + "reflect" + "testing" + + "github.com/miekg/dns" +) + +func makeURL(t *testing.T, query string) *url.URL { + u, err := url.Parse("http://site/resolve?" + query) + if err != nil { + t.Fatalf("URL parsing failed: %v", err) + } + + return u +} + +func makeIPNet(s string) *net.IPNet { + _, n, err := net.ParseCIDR(s) + if err != nil { + panic(err) + } + return n +} + +func queryEq(q1, q2 query) bool { + return reflect.DeepEqual(q1, q2) +} + +// A DNS name which is too long (> 253 characters), but otherwise valid. +const longName = "pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito-pablitoclavounclavitoqueclavitoclavopablito" + +func Test(t *testing.T) { + cases := []struct { + rawQ string + q query + }{ + {"name=hola", query{"hola", dns.TypeA, false, nil}}, + {"name=hola&type=a", query{"hola", dns.TypeA, false, nil}}, + {"name=hola&type=A", query{"hola", dns.TypeA, false, nil}}, + {"name=hola&type=1", query{"hola", dns.TypeA, false, nil}}, + {"name=hola&type=MX", query{"hola", dns.TypeMX, false, nil}}, + {"name=hola&type=txt", query{"hola", dns.TypeTXT, false, nil}}, + {"name=x&cd", query{"x", dns.TypeA, true, nil}}, + {"name=x&cd=1", query{"x", dns.TypeA, true, nil}}, + {"name=x&cd=true", query{"x", dns.TypeA, true, nil}}, + {"name=x&cd=0", query{"x", dns.TypeA, false, nil}}, + {"name=x&cd=false", query{"x", dns.TypeA, false, nil}}, + {"name=x&type=mx;cd", query{"x", dns.TypeMX, true, nil}}, + + { + "name=x&edns_client_subnet=1.2.3.0/21", + query{"x", dns.TypeA, false, makeIPNet("1.2.3.0/21")}, + }, + { + "name=x&edns_client_subnet=2001:700:300::/48", + query{"x", dns.TypeA, false, makeIPNet("2001:700:300::/48")}, + }, + { + "name=x&type=mx&cd&edns_client_subnet=2001:700:300::/48", + query{"x", dns.TypeMX, true, makeIPNet("2001:700:300::/48")}, + }, + } + for _, c := range cases { + q, err := parseQuery(makeURL(t, c.rawQ)) + if err != nil { + t.Errorf("query %q: error %v", c.rawQ, err) + } + if !queryEq(q, c.q) { + t.Errorf("query %q: expected %v, got %v", c.rawQ, c.q, q) + } + } + + errCases := []struct { + raw string + err error + }{ + {"", emptyNameErr}, + {"name=" + longName, nameTooLongErr}, + {"name=x;type=0", intOutOfRangeErr}, + {"name=x;type=-1", intOutOfRangeErr}, + {"name=x;type=65536", unknownType}, + {"name=x;type=merienda", unknownType}, + {"name=x;cd=lala", invalidCD}, + {"name=x;edns_client_subnet=lala", invalidSubnetErr}, + {"name=x;edns_client_subnet=1.2.3.4", invalidSubnetErr}, + } + for _, c := range errCases { + _, err := parseQuery(makeURL(t, c.raw)) + if err != c.err { + t.Errorf("query %q: expected error %v, got %v", c.raw, c.err, err) + } + } +} diff --git a/internal/httpstodns/server.go b/internal/httpstodns/server.go index 2864a70..c5d4ca3 100644 --- a/internal/httpstodns/server.go +++ b/internal/httpstodns/server.go @@ -153,6 +153,19 @@ type query struct { clientSubnet *net.IPNet } +func (q query) String() string { + return fmt.Sprintf("{%s %d %v %s}", q.name, q.rrType, q.cd, q.clientSubnet) +} + +var ( + emptyNameErr = fmt.Errorf("empty name") + nameTooLongErr = fmt.Errorf("name too long") + invalidSubnetErr = fmt.Errorf("invalid edns_client_subnet") + intOutOfRangeErr = fmt.Errorf("invalid type (int out of range)") + unknownType = fmt.Errorf("invalid type (unknown string type)") + invalidCD = fmt.Errorf("invalid cd value") +) + func parseQuery(u *url.URL) (query, error) { q := query{ name: "", @@ -174,10 +187,10 @@ func parseQuery(u *url.URL) (query, error) { var err error if q.name, ok = vs["name"]; !ok || q.name == "" { - return q, fmt.Errorf("empty name") + return q, emptyNameErr } if len(q.name) > 253 { - return q, fmt.Errorf("name too long") + return q, nameTooLongErr } if _, ok = vs["type"]; ok { @@ -197,7 +210,7 @@ func parseQuery(u *url.URL) (query, error) { if clientSubnet, ok := vs["edns_client_subnet"]; ok { _, q.clientSubnet, err = net.ParseCIDR(clientSubnet) if err != nil { - return q, fmt.Errorf("invalid edns_client_subnet: %v", err) + return q, invalidSubnetErr } } @@ -213,12 +226,12 @@ func stringToRRType(s string) (uint16, error) { if 1 <= i && i <= 65535 { return uint16(i), nil } - return 0, fmt.Errorf("invalid type (int out of range)") + return 0, intOutOfRangeErr } rrType, ok := dns.StringToType[strings.ToUpper(s)] if !ok { - return 0, fmt.Errorf("invalid type (unknown string type)") + return 0, unknownType } return rrType, nil } @@ -233,5 +246,5 @@ func stringToBool(s string) (bool, error) { return false, nil } - return false, fmt.Errorf("invalid cd value") + return false, invalidCD }