git » dnss » commit 12e5df4

Use TLS

author Alberto Bertogli
2015-09-07 01:00:43 UTC
committer Alberto Bertogli
2015-09-07 01:00:43 UTC
parent dab7f50341a51379ef44d906251dc0c2d12b31a6

Use TLS

dnss.go +10 -1
dnstogrpc/dnstogrpc.go +20 -5
grpctodns/grpctodns.go +9 -2

diff --git a/dnss.go b/dnss.go
index b18819a..3a60f49 100644
--- a/dnss.go
+++ b/dnss.go
@@ -18,6 +18,8 @@ var (
 		"address to listen on for DNS")
 	grpcupstream = flag.String("grpcupstream", "localhost:9953",
 		"address of the upstream GRPC server")
+	grpc_client_cafile = flag.String("grpc_client_cafile", "",
+		"CA file to use for the GRPC client")
 
 	enableGRPCtoDNS = flag.Bool("enable_grpc_to_dns", false,
 		"enable GRPC-to-DNS server")
@@ -25,6 +27,11 @@ var (
 		"address to listen on for GRPC")
 	dnsupstream = flag.String("dnsupstream", "8.8.8.8:53",
 		"address of the upstream DNS server")
+
+	grpccert = flag.String("grpccert", "",
+		"certificate file for the GRPC server")
+	grpckey = flag.String("grpckey", "",
+		"key file for the GRPC server")
 )
 
 func main() {
@@ -42,7 +49,7 @@ func main() {
 
 	// DNS to GRPC.
 	if *enableDNStoGRPC {
-		dtg := dnstogrpc.New(*dnsaddr, *grpcupstream)
+		dtg := dnstogrpc.New(*dnsaddr, *grpcupstream, *grpc_client_cafile)
 		wg.Add(1)
 		go func() {
 			defer wg.Done()
@@ -55,6 +62,8 @@ func main() {
 		gtd := &grpctodns.Server{
 			Addr:     *grpcaddr,
 			Upstream: *dnsupstream,
+			CertFile: *grpccert,
+			KeyFile:  *grpckey,
 		}
 		wg.Add(1)
 		go func() {
diff --git a/dnstogrpc/dnstogrpc.go b/dnstogrpc/dnstogrpc.go
index 4df7f26..e2c24d4 100644
--- a/dnstogrpc/dnstogrpc.go
+++ b/dnstogrpc/dnstogrpc.go
@@ -13,16 +13,28 @@ import (
 	"github.com/miekg/dns"
 	"golang.org/x/net/context"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/credentials"
 )
 
 type grpcclient struct {
 	Upstream string
+	CAFile   string
 	client   pb.DNSServiceClient
 }
 
 func (c *grpcclient) Connect() error {
-	// TODO: TLS
-	conn, err := grpc.Dial(c.Upstream, grpc.WithInsecure())
+	var err error
+	var creds credentials.TransportAuthenticator
+	if c.CAFile == "" {
+		creds = credentials.NewClientTLSFromCert(nil, "")
+	} else {
+		creds, err = credentials.NewClientTLSFromFile(c.CAFile, "")
+		if err != nil {
+			return err
+		}
+	}
+
+	conn, err := grpc.Dial(c.Upstream, grpc.WithTransportCredentials(creds))
 	if err != nil {
 		return err
 	}
@@ -58,10 +70,13 @@ type Server struct {
 	client *grpcclient
 }
 
-func New(addr, upstream string) *Server {
+func New(addr, upstream, caFile string) *Server {
 	return &Server{
-		Addr:   addr,
-		client: &grpcclient{Upstream: upstream},
+		Addr: addr,
+		client: &grpcclient{
+			Upstream: upstream,
+			CAFile:   caFile,
+		},
 	}
 }
 
diff --git a/grpctodns/grpctodns.go b/grpctodns/grpctodns.go
index 871f24b..ff95272 100644
--- a/grpctodns/grpctodns.go
+++ b/grpctodns/grpctodns.go
@@ -13,6 +13,7 @@ import (
 	"github.com/miekg/dns"
 	"golang.org/x/net/context"
 	"google.golang.org/grpc"
+	"google.golang.org/grpc/credentials"
 )
 
 func questionsToString(qs []dns.Question) string {
@@ -36,6 +37,8 @@ func rrsToString(rrs []dns.RR) string {
 type Server struct {
 	Addr     string
 	Upstream string
+	CertFile string
+	KeyFile  string
 }
 
 func (s *Server) Query(ctx context.Context, in *pb.RawMsg) (*pb.RawMsg, error) {
@@ -83,9 +86,13 @@ func (s *Server) ListenAndServe() {
 		return
 	}
 
-	// TODO: TLS
+	ta, err := credentials.NewServerTLSFromFile(s.CertFile, s.KeyFile)
+	if err != nil {
+		log.Printf("failed to create TLS transport auth: %v", err)
+		return
+	}
 
-	grpcServer := grpc.NewServer()
+	grpcServer := grpc.NewServer(grpc.Creds(ta))
 	pb.RegisterDNSServiceServer(grpcServer, s)
 
 	log.Printf("GRPC listening on %s\n", s.Addr)