git » dnss » commit 9c963cb

Proxy through GRPC

author Alberto Bertogli
2015-09-06 23:25:47 UTC
committer Alberto Bertogli
2015-09-06 23:25:47 UTC
parent b79914bdb3a24d2700b37f248ee65cc241801beb

Proxy through GRPC

dnsproxy.go +1 -3
dnss.go +1 -4
dnstogrpc/dnstogrpc.go +53 -19
grpctodns/grpctodns.go +45 -31
proto/dnss.pb.go +31 -68
proto/dnss.proto +3 -3
util/strings.go +19 -0

diff --git a/dnsproxy.go b/dnsproxy.go
index 31a06b6..2fc593b 100644
--- a/dnsproxy.go
+++ b/dnsproxy.go
@@ -1,6 +1,4 @@
-// Generate the protobuf+grpc service.
-//go:generate protoc --go_out=plugins=grpc:. dnss.proto
-
+// dnsproxy is a simple DNS proxy server.
 package main
 
 import (
diff --git a/dnss.go b/dnss.go
index 33a6aa0..311508a 100644
--- a/dnss.go
+++ b/dnss.go
@@ -29,10 +29,7 @@ func main() {
 	var wg sync.WaitGroup
 
 	// DNS to GRPC.
-	dtg := &dnstogrpc.Server{
-		Addr:     *dnsaddr,
-		Upstream: *grpcupstream,
-	}
+	dtg := dnstogrpc.New(*dnsaddr, *grpcupstream)
 	wg.Add(1)
 	go func() {
 		defer wg.Done()
diff --git a/dnstogrpc/dnstogrpc.go b/dnstogrpc/dnstogrpc.go
index e59c59f..39c6a8f 100644
--- a/dnstogrpc/dnstogrpc.go
+++ b/dnstogrpc/dnstogrpc.go
@@ -4,46 +4,73 @@ package dnstogrpc
 
 import (
 	"fmt"
-	"strings"
 	"sync"
 
+	pb "blitiri.com.ar/go/dnss/proto"
+	"blitiri.com.ar/go/dnss/util"
 	"github.com/miekg/dns"
+	"golang.org/x/net/context"
+	"google.golang.org/grpc"
 )
 
-func questionsToString(qs []dns.Question) string {
-	var s []string
-	for _, q := range qs {
-		s = append(s, fmt.Sprintf("(%s %s %s)", q.Name,
-			dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]))
-	}
-	return "Q[" + strings.Join(s, " ") + "]"
+type grpcclient struct {
+	Upstream string
+	client   pb.DNSServiceClient
 }
 
-func rrsToString(rrs []dns.RR) string {
-	var s []string
-	for _, rr := range rrs {
-		s = append(s, fmt.Sprintf("(%s)", rr))
+func (c *grpcclient) Connect() error {
+	// TODO: TLS
+	conn, err := grpc.Dial(c.Upstream, grpc.WithInsecure())
+	if err != nil {
+		return err
 	}
-	return "RR[" + strings.Join(s, " ") + "]"
 
+	c.client = pb.NewDNSServiceClient(conn)
+	return nil
 }
 
-func l(w dns.ResponseWriter, r *dns.Msg) string {
-	return fmt.Sprintf("%v %v", w.RemoteAddr(), r.Id)
+func (c *grpcclient) Query(r *dns.Msg) (*dns.Msg, error) {
+	buf, err := r.Pack()
+	if err != nil {
+		return nil, err
+	}
+
+	g, err := c.client.Query(
+		context.Background(),
+		&pb.RawMsg{Data: buf})
+	if err != nil {
+		return nil, err
+	}
+
+	m := &dns.Msg{}
+	err = m.Unpack(g.Data)
+	return m, err
 }
 
 type Server struct {
-	Addr     string
-	Upstream string
+	Addr string
+
+	client *grpcclient
+}
+
+func New(addr, upstream string) *Server {
+	return &Server{
+		Addr:   addr,
+		client: &grpcclient{Upstream: upstream},
+	}
+}
+
+func l(w dns.ResponseWriter, r *dns.Msg) string {
+	return fmt.Sprintf("%v %v", w.RemoteAddr(), r.Id)
 }
 
 func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
-	fmt.Printf("DNS  %v %v\n", l(w, r), questionsToString(r.Question))
+	fmt.Printf("DNS  %v %v\n", l(w, r), util.QuestionsToString(r.Question))
 
 	// TODO: we should create our own IDs, in case different users pick the
 	// same id and we pass that upstream.
 
-	from_up, err := dns.Exchange(r, s.Upstream)
+	from_up, err := s.client.Query(r)
 	if err != nil {
 		fmt.Printf("DNS  %v  ERR: %v\n", l(w, r), err)
 		fmt.Printf("DNS  %v  UP: %v\n", l(w, r), from_up)
@@ -62,6 +89,13 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
 }
 
 func (s *Server) ListenAndServe() {
+	err := s.client.Connect()
+	if err != nil {
+		// TODO: handle errors and reconnect.
+		fmt.Printf("Error creating GRPC client: %v\n", err)
+		return
+	}
+
 	var wg sync.WaitGroup
 	wg.Add(1)
 	go func() {
diff --git a/grpctodns/grpctodns.go b/grpctodns/grpctodns.go
index 37771fc..cc05023 100644
--- a/grpctodns/grpctodns.go
+++ b/grpctodns/grpctodns.go
@@ -4,10 +4,14 @@ package grpctodns
 
 import (
 	"fmt"
+	"net"
 	"strings"
-	"sync"
 
+	pb "blitiri.com.ar/go/dnss/proto"
+	"blitiri.com.ar/go/dnss/util"
 	"github.com/miekg/dns"
+	"golang.org/x/net/context"
+	"google.golang.org/grpc"
 )
 
 func questionsToString(qs []dns.Question) string {
@@ -37,45 +41,55 @@ type Server struct {
 	Upstream string
 }
 
-func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
-	fmt.Printf("GRPC %v %v\n", l(w, r), questionsToString(r.Question))
+func (s *Server) Query(ctx context.Context, in *pb.RawMsg) (*pb.RawMsg, error) {
+	r := &dns.Msg{}
+	err := r.Unpack(in.Data)
+	if err != nil {
+		return nil, err
+	}
+
+	fmt.Printf("GRPC %v\n", util.QuestionsToString(r.Question))
 
 	// TODO: we should create our own IDs, in case different users pick the
 	// same id and we pass that upstream.
-
 	from_up, err := dns.Exchange(r, s.Upstream)
 	if err != nil {
-		fmt.Printf("GRPC %v  ERR: %v\n", l(w, r), err)
-		fmt.Printf("GRPC %v  UP: %v\n", l(w, r), from_up)
+		fmt.Printf("GRPC   ERR: %v\n", err)
+		fmt.Printf("GRPC   UP: %v\n", from_up)
+		return nil, err
 	}
 
-	if from_up != nil {
-		if from_up.Rcode != dns.RcodeSuccess {
-			rcode := dns.RcodeToString[from_up.Rcode]
-			fmt.Printf("GPRC %v  !->  %v\n", l(w, r), rcode)
-		}
-		for _, rr := range from_up.Answer {
-			fmt.Printf("GRPC %v  ->  %v\n", l(w, r), rr)
-		}
-		w.WriteMsg(from_up)
+	if from_up == nil {
+		return nil, fmt.Errorf("No response from upstream")
 	}
+
+	if from_up.Rcode != dns.RcodeSuccess {
+		rcode := dns.RcodeToString[from_up.Rcode]
+		fmt.Printf("GPRC   !->  %v\n", rcode)
+	}
+	for _, rr := range from_up.Answer {
+		fmt.Printf("GRPC   ->  %v\n", rr)
+	}
+
+	buf, err := from_up.Pack()
+	if err != nil {
+		fmt.Printf("GRPC   ERR: %v\n", err)
+		return nil, err
+	}
+
+	return &pb.RawMsg{Data: buf}, nil
 }
 
 func (s *Server) ListenAndServe() {
-	var wg sync.WaitGroup
-	wg.Add(1)
-	go func() {
-		defer wg.Done()
-		err := dns.ListenAndServe(s.Addr, "udp", dns.HandlerFunc(s.Handler))
-		fmt.Printf("Exiting UDP: %v\n", err)
-	}()
-
-	wg.Add(1)
-	go func() {
-		defer wg.Done()
-		err := dns.ListenAndServe(s.Addr, "tcp", dns.HandlerFunc(s.Handler))
-		fmt.Printf("Exiting TCP: %v\n", err)
-	}()
-
-	wg.Wait()
+	lis, err := net.Listen("tcp", s.Addr)
+	if err != nil {
+		fmt.Printf("failed to listen: %v", err)
+		return
+	}
+
+	// TODO: TLS
+
+	grpcServer := grpc.NewServer()
+	pb.RegisterDNSServiceServer(grpcServer, s)
+	grpcServer.Serve(lis)
 }
diff --git a/proto/dnss.pb.go b/proto/dnss.pb.go
index 63b086a..e1c1f93 100644
--- a/proto/dnss.pb.go
+++ b/proto/dnss.pb.go
@@ -9,41 +9,42 @@ It is generated from these files:
 	dnss.proto
 
 It has these top-level messages:
-	GobMsg
+	RawMsg
 */
 package dnss
 
 import proto "github.com/golang/protobuf/proto"
+import fmt "fmt"
+import math "math"
 
 import (
 	context "golang.org/x/net/context"
 	grpc "google.golang.org/grpc"
 )
 
-// Reference imports to suppress errors if they are not otherwise used.
-var _ context.Context
-var _ grpc.ClientConn
-
 // Reference imports to suppress errors if they are not otherwise used.
 var _ = proto.Marshal
+var _ = fmt.Errorf
+var _ = math.Inf
 
-type GobMsg struct {
-	// gob-encoded message.
+type RawMsg struct {
+	// DNS-encoded message.
 	// A horrible hack, but will do for now.
 	Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"`
 }
 
-func (m *GobMsg) Reset()         { *m = GobMsg{} }
-func (m *GobMsg) String() string { return proto.CompactTextString(m) }
-func (*GobMsg) ProtoMessage()    {}
+func (m *RawMsg) Reset()         { *m = RawMsg{} }
+func (m *RawMsg) String() string { return proto.CompactTextString(m) }
+func (*RawMsg) ProtoMessage()    {}
 
-func init() {
-}
+// Reference imports to suppress errors if they are not otherwise used.
+var _ context.Context
+var _ grpc.ClientConn
 
 // Client API for DNSService service
 
 type DNSServiceClient interface {
-	Query(ctx context.Context, opts ...grpc.CallOption) (DNSService_QueryClient, error)
+	Query(ctx context.Context, in *RawMsg, opts ...grpc.CallOption) (*RawMsg, error)
 }
 
 type dNSServiceClient struct {
@@ -54,83 +55,45 @@ func NewDNSServiceClient(cc *grpc.ClientConn) DNSServiceClient {
 	return &dNSServiceClient{cc}
 }
 
-func (c *dNSServiceClient) Query(ctx context.Context, opts ...grpc.CallOption) (DNSService_QueryClient, error) {
-	stream, err := grpc.NewClientStream(ctx, &_DNSService_serviceDesc.Streams[0], c.cc, "/dnss.DNSService/Query", opts...)
+func (c *dNSServiceClient) Query(ctx context.Context, in *RawMsg, opts ...grpc.CallOption) (*RawMsg, error) {
+	out := new(RawMsg)
+	err := grpc.Invoke(ctx, "/dnss.DNSService/Query", in, out, c.cc, opts...)
 	if err != nil {
 		return nil, err
 	}
-	x := &dNSServiceQueryClient{stream}
-	return x, nil
-}
-
-type DNSService_QueryClient interface {
-	Send(*GobMsg) error
-	Recv() (*GobMsg, error)
-	grpc.ClientStream
-}
-
-type dNSServiceQueryClient struct {
-	grpc.ClientStream
-}
-
-func (x *dNSServiceQueryClient) Send(m *GobMsg) error {
-	return x.ClientStream.SendProto(m)
-}
-
-func (x *dNSServiceQueryClient) Recv() (*GobMsg, error) {
-	m := new(GobMsg)
-	if err := x.ClientStream.RecvProto(m); err != nil {
-		return nil, err
-	}
-	return m, nil
+	return out, nil
 }
 
 // Server API for DNSService service
 
 type DNSServiceServer interface {
-	Query(DNSService_QueryServer) error
+	Query(context.Context, *RawMsg) (*RawMsg, error)
 }
 
 func RegisterDNSServiceServer(s *grpc.Server, srv DNSServiceServer) {
 	s.RegisterService(&_DNSService_serviceDesc, srv)
 }
 
-func _DNSService_Query_Handler(srv interface{}, stream grpc.ServerStream) error {
-	return srv.(DNSServiceServer).Query(&dNSServiceQueryServer{stream})
-}
-
-type DNSService_QueryServer interface {
-	Send(*GobMsg) error
-	Recv() (*GobMsg, error)
-	grpc.ServerStream
-}
-
-type dNSServiceQueryServer struct {
-	grpc.ServerStream
-}
-
-func (x *dNSServiceQueryServer) Send(m *GobMsg) error {
-	return x.ServerStream.SendProto(m)
-}
-
-func (x *dNSServiceQueryServer) Recv() (*GobMsg, error) {
-	m := new(GobMsg)
-	if err := x.ServerStream.RecvProto(m); err != nil {
+func _DNSService_Query_Handler(srv interface{}, ctx context.Context, codec grpc.Codec, buf []byte) (interface{}, error) {
+	in := new(RawMsg)
+	if err := codec.Unmarshal(buf, in); err != nil {
+		return nil, err
+	}
+	out, err := srv.(DNSServiceServer).Query(ctx, in)
+	if err != nil {
 		return nil, err
 	}
-	return m, nil
+	return out, nil
 }
 
 var _DNSService_serviceDesc = grpc.ServiceDesc{
 	ServiceName: "dnss.DNSService",
 	HandlerType: (*DNSServiceServer)(nil),
-	Methods:     []grpc.MethodDesc{},
-	Streams: []grpc.StreamDesc{
+	Methods: []grpc.MethodDesc{
 		{
-			StreamName:    "Query",
-			Handler:       _DNSService_Query_Handler,
-			ServerStreams: true,
-			ClientStreams: true,
+			MethodName: "Query",
+			Handler:    _DNSService_Query_Handler,
 		},
 	},
+	Streams: []grpc.StreamDesc{},
 }
diff --git a/proto/dnss.proto b/proto/dnss.proto
index 2bc0040..8f982ea 100644
--- a/proto/dnss.proto
+++ b/proto/dnss.proto
@@ -3,13 +3,13 @@ syntax = "proto3";
 
 package dnss;
 
-message GobMsg {
-	// gob-encoded message.
+message RawMsg {
+	// DNS-encoded message.
 	// A horrible hack, but will do for now.
 	bytes data = 1;
 }
 
 service DNSService {
-	rpc Query(stream GobMsg) returns (stream GobMsg);
+	rpc Query(RawMsg) returns (RawMsg);
 }
 
diff --git a/util/strings.go b/util/strings.go
new file mode 100644
index 0000000..76ac24f
--- /dev/null
+++ b/util/strings.go
@@ -0,0 +1,19 @@
+package util
+
+// Utility functions for printing DNS messages.
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/miekg/dns"
+)
+
+func QuestionsToString(qs []dns.Question) string {
+	var s []string
+	for _, q := range qs {
+		s = append(s, fmt.Sprintf("(%s %s %s)", q.Name,
+			dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]))
+	}
+	return "Q[" + strings.Join(s, " ") + "]"
+}