git » chasquid » commit 360ac13

localrpc: Add a package for local RPC over UNIX sockets

author Alberto Bertogli
2023-07-29 21:22:50 UTC
committer Alberto Bertogli
2023-07-30 12:21:07 UTC
parent 764c09e94d67809c780d181b020bea0e28005a3a

localrpc: Add a package for local RPC over UNIX sockets

This patch adds a new package for doing local lightweight RPC calls over UNIX
sockets. This will be used in later patches for communication between chasquid
and chasquid-util.

internal/localrpc/client_test.go +76 -0
internal/localrpc/e2e_test.go +139 -0
internal/localrpc/localrpc.go +193 -0
internal/localrpc/server_test.go +66 -0

diff --git a/internal/localrpc/client_test.go b/internal/localrpc/client_test.go
new file mode 100644
index 0000000..055bccb
--- /dev/null
+++ b/internal/localrpc/client_test.go
@@ -0,0 +1,76 @@
+package localrpc
+
+import (
+	"bufio"
+	"errors"
+	"io/fs"
+	"net"
+	"net/textproto"
+	"os"
+	"path/filepath"
+	"testing"
+)
+
+func NewFakeServer(t *testing.T, path, output string) {
+	t.Helper()
+	lis, err := net.Listen("unix", path)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	for {
+		conn, err := lis.Accept()
+		if err != nil {
+			t.Fatal(err)
+		}
+		t.Logf("FakeServer %v: accepted ", conn)
+
+		name, inS, err := readRequest(
+			textproto.NewReader(bufio.NewReader(conn)))
+		t.Logf("FakeServer %v: readRequest: %q %q / %v", conn, name, inS, err)
+
+		n, err := conn.Write([]byte(output))
+		t.Logf("FakeServer %v: writeMessage(%q): %d %v",
+			conn, output, n, err)
+
+		t.Logf("FakeServer %v: closing", conn)
+		conn.Close()
+	}
+}
+
+func TestBadServer(t *testing.T) {
+	tmpDir, err := os.MkdirTemp("", "rpc-test-*")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer os.RemoveAll(tmpDir)
+	socketPath := filepath.Join(tmpDir, "rpc.sock")
+
+	// textproto client expects a numeric code, this should cause ReadCodeLine
+	// to fail with textproto.ProtocolError.
+	go NewFakeServer(t, socketPath, "xxx")
+	waitForServer(t, socketPath)
+
+	client := NewClient(socketPath)
+	_, err = client.Call("Echo")
+	if err == nil {
+		t.Fatal("expected error")
+	}
+	var protoErr textproto.ProtocolError
+	if !errors.As(err, &protoErr) {
+		t.Errorf("wanted textproto.ProtocolError, got: %v (%T)", err, err)
+	}
+}
+
+func TestBadSocket(t *testing.T) {
+	c := NewClient("/does/not/exist")
+	_, err := c.Call("Echo")
+
+	opErr, ok := err.(*net.OpError)
+	if !ok {
+		t.Fatalf("expected net.OpError, got %q (%T)", err, err)
+	}
+	if !errors.Is(err, fs.ErrNotExist) {
+		t.Errorf("wanted ErrNotExist, got: %q (%T)", opErr.Err, opErr.Err)
+	}
+}
diff --git a/internal/localrpc/e2e_test.go b/internal/localrpc/e2e_test.go
new file mode 100644
index 0000000..1366a62
--- /dev/null
+++ b/internal/localrpc/e2e_test.go
@@ -0,0 +1,139 @@
+package localrpc
+
+import (
+	"errors"
+	"net"
+	"net/url"
+	"os"
+	"testing"
+	"time"
+
+	"blitiri.com.ar/go/chasquid/internal/trace"
+	"github.com/google/go-cmp/cmp"
+)
+
+func Echo(tr *trace.Trace, input url.Values) (url.Values, error) {
+	return input, nil
+}
+
+func Hola(tr *trace.Trace, input url.Values) (url.Values, error) {
+	output := url.Values{}
+	output.Set("greeting", "Hola "+input.Get("name"))
+	return output, nil
+}
+
+var testErr = errors.New("test error")
+
+func HolaErr(tr *trace.Trace, input url.Values) (url.Values, error) {
+	return nil, testErr
+}
+
+type testServer struct {
+	dir  string
+	sock string
+	*Server
+}
+
+func newTestServer(t *testing.T) *testServer {
+	t.Helper()
+
+	tmpDir, err := os.MkdirTemp("", "rpc-test-*")
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	tsrv := &testServer{
+		dir:    tmpDir,
+		sock:   tmpDir + "/sock",
+		Server: NewServer(),
+	}
+
+	tsrv.Register("Echo", Echo)
+	tsrv.Register("Hola", Hola)
+	tsrv.Register("HolaErr", HolaErr)
+	go tsrv.ListenAndServe(tsrv.sock)
+
+	waitForServer(t, tsrv.sock)
+	return tsrv
+}
+
+func (tsrv *testServer) Cleanup() {
+	tsrv.Close()
+	os.RemoveAll(tsrv.dir)
+}
+
+func mkV(args ...string) url.Values {
+	v := url.Values{}
+	for i := 0; i < len(args); i += 2 {
+		v.Set(args[i], args[i+1])
+	}
+	return v
+}
+
+func TestEndToEnd(t *testing.T) {
+	srv := newTestServer(t)
+	defer srv.Cleanup()
+
+	// Run the client.
+	client := NewClient(srv.sock)
+
+	cases := []struct {
+		method string
+		input  url.Values
+		output url.Values
+		err    error
+	}{
+		{"Echo", nil, mkV(), nil},
+		{"Echo", mkV("msg", "hola"), mkV("msg", "hola"), nil},
+		{"Hola", mkV("name", "marola"), mkV("greeting", "Hola marola"), nil},
+		{"HolaErr", nil, nil, testErr},
+		{"UnknownMethod", nil, nil, errUnknownMethod},
+	}
+
+	for _, c := range cases {
+		t.Run(c.method, func(t *testing.T) {
+			resp, err := client.CallWithValues(c.method, c.input)
+			if diff := cmp.Diff(c.err, err, transformErrors); diff != "" {
+				t.Errorf("error mismatch (-want +got):\n%s", diff)
+			}
+			if diff := cmp.Diff(c.output, resp); diff != "" {
+				t.Errorf("output mismatch (-want +got):\n%s", diff)
+			}
+		})
+	}
+
+	// Check Call too.
+	output, err := client.Call("Hola", "name", "marola")
+	if err != nil {
+		t.Errorf("unexpected error: %v", err)
+	}
+	if diff := cmp.Diff(mkV("greeting", "Hola marola"), output); diff != "" {
+		t.Errorf("output mismatch (-want +got):\n%s", diff)
+	}
+}
+
+func waitForServer(t *testing.T, path string) {
+	t.Helper()
+	for i := 0; i < 100; i++ {
+		time.Sleep(10 * time.Millisecond)
+		conn, err := net.Dial("unix", path)
+		if conn != nil {
+			conn.Close()
+		}
+		if err == nil {
+			return
+		}
+	}
+	t.Fatal("server didn't start")
+}
+
+// Allow us to compare errors with cmp.Diff by their string content (since the
+// instances/types don't carry across RPC boundaries).
+var transformErrors = cmp.Transformer(
+	"error",
+	func(err error) string {
+		if err == nil {
+			return "<nil>"
+		}
+		return err.Error()
+	})
diff --git a/internal/localrpc/localrpc.go b/internal/localrpc/localrpc.go
new file mode 100644
index 0000000..68e8ec7
--- /dev/null
+++ b/internal/localrpc/localrpc.go
@@ -0,0 +1,193 @@
+// Local RPC package.
+//
+// This is a simple RPC package that uses a line-oriented protocol for
+// encoding and decoding, and Unix sockets for transport. It is meant to be
+// used for lightweight occassional communication between processes on the
+// same machine.
+package localrpc
+
+import (
+	"errors"
+	"net"
+	"net/textproto"
+	"net/url"
+	"os"
+	"strings"
+	"time"
+
+	"blitiri.com.ar/go/chasquid/internal/trace"
+)
+
+// Handler is the type of RPC request handlers.
+type Handler func(tr *trace.Trace, input url.Values) (url.Values, error)
+
+//
+// Server
+//
+
+// Server represents the RPC server.
+type Server struct {
+	handlers map[string]Handler
+	lis      net.Listener
+}
+
+// NewServer creates a new local RPC server.
+func NewServer() *Server {
+	return &Server{
+		handlers: make(map[string]Handler),
+	}
+}
+
+var errUnknownMethod = errors.New("unknown method")
+
+// Register a handler for the given name.
+func (s *Server) Register(name string, handler Handler) {
+	s.handlers[name] = handler
+}
+
+// ListenAndServe starts the server.
+func (s *Server) ListenAndServe(path string) error {
+	tr := trace.New("LocalRPC.Server", path)
+	defer tr.Finish()
+
+	// Previous instances of the server may have shut down uncleanly, leaving
+	// behind the socket file. Remove it just in case.
+	os.Remove(path)
+
+	var err error
+	s.lis, err = net.Listen("unix", path)
+	if err != nil {
+		return err
+	}
+
+	tr.Printf("Listening")
+	for {
+		conn, err := s.lis.Accept()
+		if err != nil {
+			tr.Errorf("Accept error: %v", err)
+			return err
+		}
+		go s.handleConn(tr, conn)
+	}
+}
+
+// Close stops the server.
+func (s *Server) Close() error {
+	return s.lis.Close()
+}
+
+func (s *Server) handleConn(tr *trace.Trace, conn net.Conn) {
+	tr = tr.NewChild("LocalRPC.Handle", conn.RemoteAddr().String())
+	defer tr.Finish()
+
+	// Set a generous deadline to prevent client issues from tying up a server
+	// goroutine indefinitely.
+	conn.SetDeadline(time.Now().Add(5 * time.Second))
+
+	tconn := textproto.NewConn(conn)
+	defer tconn.Close()
+
+	// Read the request.
+	name, inS, err := readRequest(&tconn.Reader)
+	if err != nil {
+		tr.Debugf("error reading request: %v", err)
+		return
+	}
+	tr.Debugf("<- %s %s", name, inS)
+
+	// Find the handler.
+	handler, ok := s.handlers[name]
+	if !ok {
+		writeError(tr, tconn, errUnknownMethod)
+		return
+	}
+
+	// Unmarshal the input.
+	inV, err := url.ParseQuery(inS)
+	if err != nil {
+		writeError(tr, tconn, err)
+		return
+	}
+
+	// Call the handler.
+	outV, err := handler(tr, inV)
+	if err != nil {
+		writeError(tr, tconn, err)
+		return
+	}
+
+	// Send the response.
+	outS := outV.Encode()
+	tr.Debugf("-> 200 %s", outS)
+	tconn.PrintfLine("200 %s", outS)
+}
+
+func readRequest(r *textproto.Reader) (string, string, error) {
+	line, err := r.ReadLine()
+	if err != nil {
+		return "", "", err
+	}
+
+	sp := strings.SplitN(line, " ", 2)
+	if len(sp) == 1 {
+		return sp[0], "", nil
+	}
+	return sp[0], sp[1], nil
+}
+
+func writeError(tr *trace.Trace, tconn *textproto.Conn, err error) {
+	tr.Errorf("-> 500 %s", err.Error())
+	tconn.PrintfLine("500 %s", err.Error())
+}
+
+// Default server. This is a singleton server that can be used for
+// convenience.
+var DefaultServer = NewServer()
+
+//
+// Client
+//
+
+// Client for the localrpc server.
+type Client struct {
+	path string
+}
+
+// NewClient creates a new client for the given path.
+func NewClient(path string) *Client {
+	return &Client{path: path}
+}
+
+// CallWithValues calls the given method.
+func (c *Client) CallWithValues(name string, input url.Values) (url.Values, error) {
+	conn, err := textproto.Dial("unix", c.path)
+	if err != nil {
+		return nil, err
+	}
+	defer conn.Close()
+
+	err = conn.PrintfLine("%s %s", name, input.Encode())
+	if err != nil {
+		return nil, err
+	}
+
+	code, msg, err := conn.ReadCodeLine(0)
+	if err != nil {
+		return nil, err
+	}
+	if code != 200 {
+		return nil, errors.New(msg)
+	}
+
+	return url.ParseQuery(msg)
+}
+
+// Call the given method. The arguments are key-value strings, and must be
+// provided in pairs.
+func (c *Client) Call(name string, args ...string) (url.Values, error) {
+	v := url.Values{}
+	for i := 0; i < len(args); i += 2 {
+		v.Set(args[i], args[i+1])
+	}
+	return c.CallWithValues(name, v)
+}
diff --git a/internal/localrpc/server_test.go b/internal/localrpc/server_test.go
new file mode 100644
index 0000000..2a9d14a
--- /dev/null
+++ b/internal/localrpc/server_test.go
@@ -0,0 +1,66 @@
+package localrpc
+
+import (
+	"bufio"
+	"bytes"
+	"net"
+	"net/textproto"
+	"strings"
+	"testing"
+
+	"blitiri.com.ar/go/chasquid/internal/trace"
+)
+
+func TestListenError(t *testing.T) {
+	server := NewServer()
+	err := server.ListenAndServe("/dev/null")
+	if err == nil {
+		t.Errorf("ListenAndServe(/dev/null) = nil, want error")
+	}
+}
+
+// Test that the server can handle a broken client sending a bad request.
+func TestServerBadRequest(t *testing.T) {
+	server := NewServer()
+	server.Register("Echo", Echo)
+
+	srvConn, cliConn := net.Pipe()
+	defer srvConn.Close()
+	defer cliConn.Close()
+
+	// Client sends an invalid request.
+	go cliConn.Write([]byte("Echo this is an ; invalid ; query\n"))
+
+	// Servers will handle the connection, and should return an error.
+	tr := trace.New("test", "TestBadRequest")
+	defer tr.Finish()
+	go server.handleConn(tr, srvConn)
+
+	// Read the error that the server should have sent.
+	code, msg, err := textproto.NewConn(cliConn).ReadResponse(0)
+	if err != nil {
+		t.Errorf("ReadResponse error: %q", err)
+	}
+	if code != 500 {
+		t.Errorf("ReadResponse code %d, expected 500", code)
+	}
+	if !strings.Contains(msg, "invalid semicolon separator") {
+		t.Errorf("ReadResponse message %q, does not contain 'invalid semicolon separator'", msg)
+	}
+}
+
+func TestShortReadRequest(t *testing.T) {
+	// This request is too short, it does not have any arguments.
+	// This does not happen with the real client, but just in case.
+	buf := bufio.NewReader(bytes.NewReader([]byte("Method\n")))
+	method, args, err := readRequest(textproto.NewReader(buf))
+	if err != nil {
+		t.Errorf("readRequest error: %v", err)
+	}
+	if method != "Method" {
+		t.Errorf("readRequest method %q, expected 'Method'", method)
+	}
+	if args != "" {
+		t.Errorf("readRequest args %q, expected ''", args)
+	}
+}