git » dnss » commit a440db5

dnstogrpc: Introduce an abstraction for resolvers

author Alberto Bertogli
2015-10-24 10:00:32 UTC
committer Alberto Bertogli
2015-10-24 15:48:40 UTC
parent 7c1bcc804daeb3537416f1842eb549ae329c7273

dnstogrpc: Introduce an abstraction for resolvers

This patch introduces a new Resolver interface, and migrates the GRPC resolver
to it (it's a very similar API so it's mostly just moving code around).

Later patches will introduce a caching resolver using this same interface.

dnss.go +2 -2
dnss_test.go +2 -1
dnstogrpc/dnstogrpc.go +9 -65
dnstogrpc/resolver.go +85 -0

diff --git a/dnss.go b/dnss.go
index 516e588..c6527d9 100644
--- a/dnss.go
+++ b/dnss.go
@@ -80,8 +80,8 @@ func main() {
 
 	// DNS to GRPC.
 	if *enableDNStoGRPC {
-		dtg := dnstogrpc.New(*dnsListenAddr, *grpcUpstream, *grpcClientCAFile,
-			*dnsUnqualifiedUpstream)
+		r := dnstogrpc.NewGRPCResolver(*grpcUpstream, *grpcClientCAFile)
+		dtg := dnstogrpc.New(*dnsListenAddr, r, *dnsUnqualifiedUpstream)
 		wg.Add(1)
 		go func() {
 			defer wg.Done()
diff --git a/dnss_test.go b/dnss_test.go
index 47448ef..f0aab8a 100644
--- a/dnss_test.go
+++ b/dnss_test.go
@@ -230,7 +230,8 @@ func realMain(m *testing.M) int {
 	}
 
 	// DNS to GRPC server.
-	dtg := dnstogrpc.New(dnsToGrpcAddr, grpcToDnsAddr, tmpDir+"/cert.pem", "")
+	r := dnstogrpc.NewGRPCResolver(grpcToDnsAddr, tmpDir+"/cert.pem")
+	dtg := dnstogrpc.New(dnsToGrpcAddr, r, "")
 	go dtg.ListenAndServe()
 
 	// GRPC to DNS server.
diff --git a/dnstogrpc/dnstogrpc.go b/dnstogrpc/dnstogrpc.go
index b569fbd..8e1a490 100644
--- a/dnstogrpc/dnstogrpc.go
+++ b/dnstogrpc/dnstogrpc.go
@@ -8,16 +8,12 @@ import (
 	"fmt"
 	"strings"
 	"sync"
-	"time"
 
-	pb "blitiri.com.ar/go/dnss/internal/proto"
-	"blitiri.com.ar/go/dnss/internal/util"
 	"github.com/golang/glog"
 	"github.com/miekg/dns"
-	"golang.org/x/net/context"
 	"golang.org/x/net/trace"
-	"google.golang.org/grpc"
-	"google.golang.org/grpc/credentials"
+
+	"blitiri.com.ar/go/dnss/internal/util"
 )
 
 // newID is a channel used to generate new request IDs.
@@ -45,68 +41,16 @@ func init() {
 	}()
 }
 
-type grpcclient struct {
-	Upstream string
-	CAFile   string
-	client   pb.DNSServiceClient
-}
-
-func (c *grpcclient) Connect() error {
-	var err error
-	var creds credentials.TransportAuthenticator
-	if c.CAFile == "" {
-		creds = credentials.NewClientTLSFromCert(nil, "")
-	} else {
-		creds, err = credentials.NewClientTLSFromFile(c.CAFile, "")
-		if err != nil {
-			return err
-		}
-	}
-
-	conn, err := grpc.Dial(c.Upstream, grpc.WithTransportCredentials(creds))
-	if err != nil {
-		return err
-	}
-
-	c.client = pb.NewDNSServiceClient(conn)
-	return nil
-}
-
-func (c *grpcclient) Query(r *dns.Msg) (*dns.Msg, error) {
-	buf, err := r.Pack()
-	if err != nil {
-		return nil, err
-	}
-
-	// Give our RPCs 2 second timeouts: DNS usually doesn't wait that long
-	// anyway.
-	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
-	defer cancel()
-
-	g, err := c.client.Query(ctx, &pb.RawMsg{Data: buf})
-	if err != nil {
-		return nil, err
-	}
-
-	m := &dns.Msg{}
-	err = m.Unpack(g.Data)
-	return m, err
-}
-
 type Server struct {
 	Addr        string
 	unqUpstream string
-
-	client *grpcclient
+	resolver    Resolver
 }
 
-func New(addr, upstream, caFile, unqUpstream string) *Server {
+func New(addr string, resolver Resolver, unqUpstream string) *Server {
 	return &Server{
-		Addr: addr,
-		client: &grpcclient{
-			Upstream: upstream,
-			CAFile:   caFile,
-		},
+		Addr:        addr,
+		resolver:    resolver,
 		unqUpstream: unqUpstream,
 	}
 }
@@ -146,7 +90,7 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
 	oldid := r.Id
 	r.Id = <-newId
 
-	from_up, err := s.client.Query(r)
+	from_up, err := s.resolver.Query(r)
 	if err != nil {
 		glog.Infof(err.Error())
 		tr.LazyPrintf(err.Error())
@@ -163,9 +107,9 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
 }
 
 func (s *Server) ListenAndServe() {
-	err := s.client.Connect()
+	err := s.resolver.Init()
 	if err != nil {
-		glog.Errorf("Error creating GRPC client: %v", err)
+		glog.Errorf("Error initializing: %v", err)
 		return
 	}
 
diff --git a/dnstogrpc/resolver.go b/dnstogrpc/resolver.go
new file mode 100644
index 0000000..c7f7d4f
--- /dev/null
+++ b/dnstogrpc/resolver.go
@@ -0,0 +1,85 @@
+package dnstogrpc
+
+import (
+	"time"
+
+	"github.com/miekg/dns"
+	"golang.org/x/net/context"
+	"google.golang.org/grpc"
+	"google.golang.org/grpc/credentials"
+
+	pb "blitiri.com.ar/go/dnss/internal/proto"
+)
+
+// Interface for DNS resolvers that can answer queries.
+type Resolver interface {
+	// Initialize the resolver.
+	Init() error
+
+	// Maintain performs resolver maintenance. It's expected to run
+	// indefinitely, but may return early if appropriate.
+	Maintain()
+
+	// Query responds to a DNS query.
+	Query(r *dns.Msg) (*dns.Msg, error)
+}
+
+// grpcResolver implements the Resolver interface by querying a server via
+// GRPC.
+type grpcResolver struct {
+	Upstream string
+	CAFile   string
+	client   pb.DNSServiceClient
+}
+
+func NewGRPCResolver(upstream, caFile string) *grpcResolver {
+	return &grpcResolver{
+		Upstream: upstream,
+		CAFile:   caFile,
+	}
+}
+
+func (g *grpcResolver) Init() error {
+	var err error
+	var creds credentials.TransportAuthenticator
+	if g.CAFile == "" {
+		creds = credentials.NewClientTLSFromCert(nil, "")
+	} else {
+		creds, err = credentials.NewClientTLSFromFile(g.CAFile, "")
+		if err != nil {
+			return err
+		}
+	}
+
+	conn, err := grpc.Dial(g.Upstream, grpc.WithTransportCredentials(creds))
+	if err != nil {
+		return err
+	}
+
+	g.client = pb.NewDNSServiceClient(conn)
+	return nil
+}
+
+func (g *grpcResolver) Maintain() {
+}
+
+func (g *grpcResolver) Query(r *dns.Msg) (*dns.Msg, error) {
+	buf, err := r.Pack()
+	if err != nil {
+		return nil, err
+	}
+
+	// Give our RPCs 2 second timeouts: DNS usually doesn't wait that long
+	// anyway.
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	reply, err := g.client.Query(ctx, &pb.RawMsg{Data: buf})
+	if err != nil {
+		return nil, err
+	}
+
+	m := &dns.Msg{}
+	err = m.Unpack(reply.Data)
+	return m, err
+}