author | Alberto Bertogli
<albertito@blitiri.com.ar> 2015-09-07 01:00:43 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2015-09-07 01:00:43 UTC |
parent | dab7f50341a51379ef44d906251dc0c2d12b31a6 |
dnss.go | +10 | -1 |
dnstogrpc/dnstogrpc.go | +20 | -5 |
grpctodns/grpctodns.go | +9 | -2 |
diff --git a/dnss.go b/dnss.go index b18819a..3a60f49 100644 --- a/dnss.go +++ b/dnss.go @@ -18,6 +18,8 @@ var ( "address to listen on for DNS") grpcupstream = flag.String("grpcupstream", "localhost:9953", "address of the upstream GRPC server") + grpc_client_cafile = flag.String("grpc_client_cafile", "", + "CA file to use for the GRPC client") enableGRPCtoDNS = flag.Bool("enable_grpc_to_dns", false, "enable GRPC-to-DNS server") @@ -25,6 +27,11 @@ var ( "address to listen on for GRPC") dnsupstream = flag.String("dnsupstream", "8.8.8.8:53", "address of the upstream DNS server") + + grpccert = flag.String("grpccert", "", + "certificate file for the GRPC server") + grpckey = flag.String("grpckey", "", + "key file for the GRPC server") ) func main() { @@ -42,7 +49,7 @@ func main() { // DNS to GRPC. if *enableDNStoGRPC { - dtg := dnstogrpc.New(*dnsaddr, *grpcupstream) + dtg := dnstogrpc.New(*dnsaddr, *grpcupstream, *grpc_client_cafile) wg.Add(1) go func() { defer wg.Done() @@ -55,6 +62,8 @@ func main() { gtd := &grpctodns.Server{ Addr: *grpcaddr, Upstream: *dnsupstream, + CertFile: *grpccert, + KeyFile: *grpckey, } wg.Add(1) go func() { diff --git a/dnstogrpc/dnstogrpc.go b/dnstogrpc/dnstogrpc.go index 4df7f26..e2c24d4 100644 --- a/dnstogrpc/dnstogrpc.go +++ b/dnstogrpc/dnstogrpc.go @@ -13,16 +13,28 @@ import ( "github.com/miekg/dns" "golang.org/x/net/context" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) type grpcclient struct { Upstream string + CAFile string client pb.DNSServiceClient } func (c *grpcclient) Connect() error { - // TODO: TLS - conn, err := grpc.Dial(c.Upstream, grpc.WithInsecure()) + var err error + var creds credentials.TransportAuthenticator + if c.CAFile == "" { + creds = credentials.NewClientTLSFromCert(nil, "") + } else { + creds, err = credentials.NewClientTLSFromFile(c.CAFile, "") + if err != nil { + return err + } + } + + conn, err := grpc.Dial(c.Upstream, grpc.WithTransportCredentials(creds)) if err != nil { return err } @@ -58,10 +70,13 @@ type Server struct { client *grpcclient } -func New(addr, upstream string) *Server { +func New(addr, upstream, caFile string) *Server { return &Server{ - Addr: addr, - client: &grpcclient{Upstream: upstream}, + Addr: addr, + client: &grpcclient{ + Upstream: upstream, + CAFile: caFile, + }, } } diff --git a/grpctodns/grpctodns.go b/grpctodns/grpctodns.go index 871f24b..ff95272 100644 --- a/grpctodns/grpctodns.go +++ b/grpctodns/grpctodns.go @@ -13,6 +13,7 @@ import ( "github.com/miekg/dns" "golang.org/x/net/context" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) func questionsToString(qs []dns.Question) string { @@ -36,6 +37,8 @@ func rrsToString(rrs []dns.RR) string { type Server struct { Addr string Upstream string + CertFile string + KeyFile string } func (s *Server) Query(ctx context.Context, in *pb.RawMsg) (*pb.RawMsg, error) { @@ -83,9 +86,13 @@ func (s *Server) ListenAndServe() { return } - // TODO: TLS + ta, err := credentials.NewServerTLSFromFile(s.CertFile, s.KeyFile) + if err != nil { + log.Printf("failed to create TLS transport auth: %v", err) + return + } - grpcServer := grpc.NewServer() + grpcServer := grpc.NewServer(grpc.Creds(ta)) pb.RegisterDNSServiceServer(grpcServer, s) log.Printf("GRPC listening on %s\n", s.Addr)