git » systemd » commit aec3508

Add functions Listen and OneListener

author Alberto Bertogli
2017-08-21 16:45:41 UTC
committer Alberto Bertogli
2017-08-21 16:57:59 UTC
parent cd3225afdf78a8a0b038bf10b81daf35e42d5214

Add functions Listen and OneListener

This patch introduces to new functions for convenience: Listen and
OneListener.

They allow the callers to have simplified listener opening strategies,
in particular Listen can be used in many places as a drop-in replacement
for net.Listen, which makes it very convenient for simple code.

systemd.go +48 -0
systemd_test.go +43 -0

diff --git a/systemd.go b/systemd.go
index b931cf8..b9de0d8 100644
--- a/systemd.go
+++ b/systemd.go
@@ -123,3 +123,51 @@ func Listeners() (map[string][]net.Listener, error) {
 	parse()
 	return listeners, parseError
 }
+
+// OneListener returns a listener for the first systemd socket with the given
+// name. If there are none, the listener and error will both be nil. An error
+// will be returned only if there were issues parsing the file descriptors.
+//
+// This function can be convenient for simple callers where you know there's
+// only one file descriptor being passed with the given name.
+//
+// This is a convenience function built on top of Listeners().
+func OneListener(name string) (net.Listener, error) {
+	parse()
+	if parseError != nil {
+		return nil, parseError
+	}
+
+	lis := listeners[name]
+	if len(lis) < 1 {
+		return nil, nil
+	}
+	return lis[0], nil
+}
+
+// Listen returns a listener for the given address, similar to net.Listen.
+//
+// If the address begins with "&" it is interpreted as a systemd socket being
+// passed.  For example, using "&http" would mean we expect a systemd socket
+// passed to us, named with "FileDescriptorName=http" in its unit.
+//
+// Otherwise, it uses net.Listen to create a new listener with the given net
+// and local address.
+//
+// This function can be convenient for simple callers where you get the
+// address from a user, and want to let them specify either "use systemd" or a
+// normal address without too much additional complexity.
+//
+// This is a convenience function built on top of Listeners().
+func Listen(netw, laddr string) (net.Listener, error) {
+	if strings.HasPrefix(laddr, "&") {
+		name := laddr[1:]
+		lis, err := OneListener(name)
+		if lis == nil && err == nil {
+			err = fmt.Errorf("systemd socket %q not found", name)
+		}
+		return lis, err
+	} else {
+		return net.Listen(netw, laddr)
+	}
+}
diff --git a/systemd_test.go b/systemd_test.go
index 6819ef4..d02915b 100644
--- a/systemd_test.go
+++ b/systemd_test.go
@@ -56,6 +56,9 @@ func TestBadEnvironment(t *testing.T) {
 			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q LISTEN_FDNAMES=%q", c.pid, c.fds, c.names)
 			t.Errorf("Unexpected result: %v // %v", ls, err)
 		}
+		if ls, err := OneListener("name"); err == nil {
+			t.Errorf("Unexpected result: %v // %v", ls, err)
+		}
 	}
 }
 
@@ -70,6 +73,10 @@ func TestWrongPID(t *testing.T) {
 	if _, err := Listeners(); err != ErrPIDMismatch {
 		t.Errorf("Did not fail with PID mismatch: %v", err)
 	}
+
+	if _, err := OneListener("name"); err != ErrPIDMismatch {
+		t.Errorf("Did not fail with PID mismatch: %v", err)
+	}
 }
 
 func TestNoFDs(t *testing.T) {
@@ -126,6 +133,15 @@ func TestOneSocket(t *testing.T) {
 			l.Addr(), ls[0].Addr())
 	}
 
+	oneL, err := OneListener("name")
+	if err != nil {
+		t.Errorf("OneListener error: %v", err)
+	}
+	if !sameAddr(oneL.Addr(), l.Addr()) {
+		t.Errorf("OneListener address mismatch, expected %#v, got %#v",
+			l.Addr(), ls[0].Addr())
+	}
+
 	if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" {
 		t.Errorf("Failed to reset the environment")
 	}
@@ -172,9 +188,36 @@ func TestManySockets(t *testing.T) {
 			l1.Addr(), ls[1].Addr())
 	}
 
+	oneL, _ := OneListener("name1")
+	if !sameAddr(oneL.Addr(), l0.Addr()) {
+		t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
+			oneL.Addr(), ls[0].Addr())
+	}
+
 	if os.Getenv("LISTEN_PID") != "" ||
 		os.Getenv("LISTEN_FDS") != "" ||
 		os.Getenv("LISTEN_FDNAMES") != "" {
 		t.Errorf("Failed to reset the environment")
 	}
 }
+
+func TestListen(t *testing.T) {
+	orig := newListener(t)
+	firstFD = listenerFd(t, orig)
+	setenv(strconv.Itoa(os.Getpid()), "1", "name")
+
+	l, err := Listen("tcp", "&name")
+	if err != nil {
+		t.Errorf("Listen failed: %v", err)
+	}
+	if !sameAddr(l.Addr(), orig.Addr()) {
+		t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
+			l.Addr(), orig.Addr())
+	}
+
+	l, err = Listen("tcp", ":0")
+	if err != nil {
+		t.Errorf("Listen failed: %v", err)
+	}
+	t.Logf("listener created at %v", l.Addr())
+}