author | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-07-29 21:22:50 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-07-30 12:21:07 UTC |
parent | 764c09e94d67809c780d181b020bea0e28005a3a |
internal/localrpc/client_test.go | +76 | -0 |
internal/localrpc/e2e_test.go | +139 | -0 |
internal/localrpc/localrpc.go | +193 | -0 |
internal/localrpc/server_test.go | +66 | -0 |
diff --git a/internal/localrpc/client_test.go b/internal/localrpc/client_test.go new file mode 100644 index 0000000..055bccb --- /dev/null +++ b/internal/localrpc/client_test.go @@ -0,0 +1,76 @@ +package localrpc + +import ( + "bufio" + "errors" + "io/fs" + "net" + "net/textproto" + "os" + "path/filepath" + "testing" +) + +func NewFakeServer(t *testing.T, path, output string) { + t.Helper() + lis, err := net.Listen("unix", path) + if err != nil { + t.Fatal(err) + } + + for { + conn, err := lis.Accept() + if err != nil { + t.Fatal(err) + } + t.Logf("FakeServer %v: accepted ", conn) + + name, inS, err := readRequest( + textproto.NewReader(bufio.NewReader(conn))) + t.Logf("FakeServer %v: readRequest: %q %q / %v", conn, name, inS, err) + + n, err := conn.Write([]byte(output)) + t.Logf("FakeServer %v: writeMessage(%q): %d %v", + conn, output, n, err) + + t.Logf("FakeServer %v: closing", conn) + conn.Close() + } +} + +func TestBadServer(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "rpc-test-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + socketPath := filepath.Join(tmpDir, "rpc.sock") + + // textproto client expects a numeric code, this should cause ReadCodeLine + // to fail with textproto.ProtocolError. + go NewFakeServer(t, socketPath, "xxx") + waitForServer(t, socketPath) + + client := NewClient(socketPath) + _, err = client.Call("Echo") + if err == nil { + t.Fatal("expected error") + } + var protoErr textproto.ProtocolError + if !errors.As(err, &protoErr) { + t.Errorf("wanted textproto.ProtocolError, got: %v (%T)", err, err) + } +} + +func TestBadSocket(t *testing.T) { + c := NewClient("/does/not/exist") + _, err := c.Call("Echo") + + opErr, ok := err.(*net.OpError) + if !ok { + t.Fatalf("expected net.OpError, got %q (%T)", err, err) + } + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("wanted ErrNotExist, got: %q (%T)", opErr.Err, opErr.Err) + } +} diff --git a/internal/localrpc/e2e_test.go b/internal/localrpc/e2e_test.go new file mode 100644 index 0000000..1366a62 --- /dev/null +++ b/internal/localrpc/e2e_test.go @@ -0,0 +1,139 @@ +package localrpc + +import ( + "errors" + "net" + "net/url" + "os" + "testing" + "time" + + "blitiri.com.ar/go/chasquid/internal/trace" + "github.com/google/go-cmp/cmp" +) + +func Echo(tr *trace.Trace, input url.Values) (url.Values, error) { + return input, nil +} + +func Hola(tr *trace.Trace, input url.Values) (url.Values, error) { + output := url.Values{} + output.Set("greeting", "Hola "+input.Get("name")) + return output, nil +} + +var testErr = errors.New("test error") + +func HolaErr(tr *trace.Trace, input url.Values) (url.Values, error) { + return nil, testErr +} + +type testServer struct { + dir string + sock string + *Server +} + +func newTestServer(t *testing.T) *testServer { + t.Helper() + + tmpDir, err := os.MkdirTemp("", "rpc-test-*") + if err != nil { + t.Fatal(err) + } + + tsrv := &testServer{ + dir: tmpDir, + sock: tmpDir + "/sock", + Server: NewServer(), + } + + tsrv.Register("Echo", Echo) + tsrv.Register("Hola", Hola) + tsrv.Register("HolaErr", HolaErr) + go tsrv.ListenAndServe(tsrv.sock) + + waitForServer(t, tsrv.sock) + return tsrv +} + +func (tsrv *testServer) Cleanup() { + tsrv.Close() + os.RemoveAll(tsrv.dir) +} + +func mkV(args ...string) url.Values { + v := url.Values{} + for i := 0; i < len(args); i += 2 { + v.Set(args[i], args[i+1]) + } + return v +} + +func TestEndToEnd(t *testing.T) { + srv := newTestServer(t) + defer srv.Cleanup() + + // Run the client. + client := NewClient(srv.sock) + + cases := []struct { + method string + input url.Values + output url.Values + err error + }{ + {"Echo", nil, mkV(), nil}, + {"Echo", mkV("msg", "hola"), mkV("msg", "hola"), nil}, + {"Hola", mkV("name", "marola"), mkV("greeting", "Hola marola"), nil}, + {"HolaErr", nil, nil, testErr}, + {"UnknownMethod", nil, nil, errUnknownMethod}, + } + + for _, c := range cases { + t.Run(c.method, func(t *testing.T) { + resp, err := client.CallWithValues(c.method, c.input) + if diff := cmp.Diff(c.err, err, transformErrors); diff != "" { + t.Errorf("error mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(c.output, resp); diff != "" { + t.Errorf("output mismatch (-want +got):\n%s", diff) + } + }) + } + + // Check Call too. + output, err := client.Call("Hola", "name", "marola") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if diff := cmp.Diff(mkV("greeting", "Hola marola"), output); diff != "" { + t.Errorf("output mismatch (-want +got):\n%s", diff) + } +} + +func waitForServer(t *testing.T, path string) { + t.Helper() + for i := 0; i < 100; i++ { + time.Sleep(10 * time.Millisecond) + conn, err := net.Dial("unix", path) + if conn != nil { + conn.Close() + } + if err == nil { + return + } + } + t.Fatal("server didn't start") +} + +// Allow us to compare errors with cmp.Diff by their string content (since the +// instances/types don't carry across RPC boundaries). +var transformErrors = cmp.Transformer( + "error", + func(err error) string { + if err == nil { + return "<nil>" + } + return err.Error() + }) diff --git a/internal/localrpc/localrpc.go b/internal/localrpc/localrpc.go new file mode 100644 index 0000000..68e8ec7 --- /dev/null +++ b/internal/localrpc/localrpc.go @@ -0,0 +1,193 @@ +// Local RPC package. +// +// This is a simple RPC package that uses a line-oriented protocol for +// encoding and decoding, and Unix sockets for transport. It is meant to be +// used for lightweight occassional communication between processes on the +// same machine. +package localrpc + +import ( + "errors" + "net" + "net/textproto" + "net/url" + "os" + "strings" + "time" + + "blitiri.com.ar/go/chasquid/internal/trace" +) + +// Handler is the type of RPC request handlers. +type Handler func(tr *trace.Trace, input url.Values) (url.Values, error) + +// +// Server +// + +// Server represents the RPC server. +type Server struct { + handlers map[string]Handler + lis net.Listener +} + +// NewServer creates a new local RPC server. +func NewServer() *Server { + return &Server{ + handlers: make(map[string]Handler), + } +} + +var errUnknownMethod = errors.New("unknown method") + +// Register a handler for the given name. +func (s *Server) Register(name string, handler Handler) { + s.handlers[name] = handler +} + +// ListenAndServe starts the server. +func (s *Server) ListenAndServe(path string) error { + tr := trace.New("LocalRPC.Server", path) + defer tr.Finish() + + // Previous instances of the server may have shut down uncleanly, leaving + // behind the socket file. Remove it just in case. + os.Remove(path) + + var err error + s.lis, err = net.Listen("unix", path) + if err != nil { + return err + } + + tr.Printf("Listening") + for { + conn, err := s.lis.Accept() + if err != nil { + tr.Errorf("Accept error: %v", err) + return err + } + go s.handleConn(tr, conn) + } +} + +// Close stops the server. +func (s *Server) Close() error { + return s.lis.Close() +} + +func (s *Server) handleConn(tr *trace.Trace, conn net.Conn) { + tr = tr.NewChild("LocalRPC.Handle", conn.RemoteAddr().String()) + defer tr.Finish() + + // Set a generous deadline to prevent client issues from tying up a server + // goroutine indefinitely. + conn.SetDeadline(time.Now().Add(5 * time.Second)) + + tconn := textproto.NewConn(conn) + defer tconn.Close() + + // Read the request. + name, inS, err := readRequest(&tconn.Reader) + if err != nil { + tr.Debugf("error reading request: %v", err) + return + } + tr.Debugf("<- %s %s", name, inS) + + // Find the handler. + handler, ok := s.handlers[name] + if !ok { + writeError(tr, tconn, errUnknownMethod) + return + } + + // Unmarshal the input. + inV, err := url.ParseQuery(inS) + if err != nil { + writeError(tr, tconn, err) + return + } + + // Call the handler. + outV, err := handler(tr, inV) + if err != nil { + writeError(tr, tconn, err) + return + } + + // Send the response. + outS := outV.Encode() + tr.Debugf("-> 200 %s", outS) + tconn.PrintfLine("200 %s", outS) +} + +func readRequest(r *textproto.Reader) (string, string, error) { + line, err := r.ReadLine() + if err != nil { + return "", "", err + } + + sp := strings.SplitN(line, " ", 2) + if len(sp) == 1 { + return sp[0], "", nil + } + return sp[0], sp[1], nil +} + +func writeError(tr *trace.Trace, tconn *textproto.Conn, err error) { + tr.Errorf("-> 500 %s", err.Error()) + tconn.PrintfLine("500 %s", err.Error()) +} + +// Default server. This is a singleton server that can be used for +// convenience. +var DefaultServer = NewServer() + +// +// Client +// + +// Client for the localrpc server. +type Client struct { + path string +} + +// NewClient creates a new client for the given path. +func NewClient(path string) *Client { + return &Client{path: path} +} + +// CallWithValues calls the given method. +func (c *Client) CallWithValues(name string, input url.Values) (url.Values, error) { + conn, err := textproto.Dial("unix", c.path) + if err != nil { + return nil, err + } + defer conn.Close() + + err = conn.PrintfLine("%s %s", name, input.Encode()) + if err != nil { + return nil, err + } + + code, msg, err := conn.ReadCodeLine(0) + if err != nil { + return nil, err + } + if code != 200 { + return nil, errors.New(msg) + } + + return url.ParseQuery(msg) +} + +// Call the given method. The arguments are key-value strings, and must be +// provided in pairs. +func (c *Client) Call(name string, args ...string) (url.Values, error) { + v := url.Values{} + for i := 0; i < len(args); i += 2 { + v.Set(args[i], args[i+1]) + } + return c.CallWithValues(name, v) +} diff --git a/internal/localrpc/server_test.go b/internal/localrpc/server_test.go new file mode 100644 index 0000000..2a9d14a --- /dev/null +++ b/internal/localrpc/server_test.go @@ -0,0 +1,66 @@ +package localrpc + +import ( + "bufio" + "bytes" + "net" + "net/textproto" + "strings" + "testing" + + "blitiri.com.ar/go/chasquid/internal/trace" +) + +func TestListenError(t *testing.T) { + server := NewServer() + err := server.ListenAndServe("/dev/null") + if err == nil { + t.Errorf("ListenAndServe(/dev/null) = nil, want error") + } +} + +// Test that the server can handle a broken client sending a bad request. +func TestServerBadRequest(t *testing.T) { + server := NewServer() + server.Register("Echo", Echo) + + srvConn, cliConn := net.Pipe() + defer srvConn.Close() + defer cliConn.Close() + + // Client sends an invalid request. + go cliConn.Write([]byte("Echo this is an ; invalid ; query\n")) + + // Servers will handle the connection, and should return an error. + tr := trace.New("test", "TestBadRequest") + defer tr.Finish() + go server.handleConn(tr, srvConn) + + // Read the error that the server should have sent. + code, msg, err := textproto.NewConn(cliConn).ReadResponse(0) + if err != nil { + t.Errorf("ReadResponse error: %q", err) + } + if code != 500 { + t.Errorf("ReadResponse code %d, expected 500", code) + } + if !strings.Contains(msg, "invalid semicolon separator") { + t.Errorf("ReadResponse message %q, does not contain 'invalid semicolon separator'", msg) + } +} + +func TestShortReadRequest(t *testing.T) { + // This request is too short, it does not have any arguments. + // This does not happen with the real client, but just in case. + buf := bufio.NewReader(bytes.NewReader([]byte("Method\n"))) + method, args, err := readRequest(textproto.NewReader(buf)) + if err != nil { + t.Errorf("readRequest error: %v", err) + } + if method != "Method" { + t.Errorf("readRequest method %q, expected 'Method'", method) + } + if args != "" { + t.Errorf("readRequest args %q, expected ''", args) + } +}