git » debian:golang-blitiri-go-systemd » commit 7cdf949

Support activation of sockets without names

author Alberto Bertogli
2020-05-28 21:14:21 UTC
committer Alberto Bertogli
2020-05-28 23:27:27 UTC
parent 69739305e20c05fe85f223d927d801914435c4fa

Support activation of sockets without names

systemd allows socket activation for sockets without names, which can be
convenient for simple services.

This patch makes the package support that use case, by mapping them to
the "" name.

systemd.go +32 -10
systemd_test.go +43 -2

diff --git a/systemd.go b/systemd.go
index 4ba1ed4..159b528 100644
--- a/systemd.go
+++ b/systemd.go
@@ -27,6 +27,7 @@ var (
 var files map[string][]*os.File
 var listeners map[string][]net.Listener
 var parseError error
+var listenError error
 var mutex sync.Mutex
 
 // parse files, updating the global state.
@@ -69,13 +70,19 @@ func parse() {
 		return
 	}
 
-	// We should have as many names as we have descriptors.
-	// Note that if we have no descriptors, fdNames will be [""] (due to how
-	// strings.Split works), so we consider that special case.
-	if nfds > 0 && (fdNamesStr == "" || len(fdNames) != nfds) {
-		parseError = fmt.Errorf(
-			"Incorrect LISTEN_FDNAMES, have you set FileDescriptorName?")
-		return
+	// If LISTEN_FDNAMES is set at all, it should have as many names as we
+	// have descriptors. If it isn't set, then we map them all to "".
+	if fdNamesStr == "" {
+		fdNames = []string{}
+		for i := 0; i < nfds; i++ {
+			fdNames = append(fdNames, "")
+		}
+	} else {
+		if nfds > 0 && len(fdNames) != nfds {
+			parseError = fmt.Errorf(
+				"Incorrect LISTEN_FDNAMES, have you set FileDescriptorName?")
+			return
+		}
 	}
 
 	files = map[string][]*os.File{}
@@ -92,12 +99,15 @@ func parse() {
 		f := os.NewFile(uintptr(fd), sysName)
 		files[name] = append(files[name], f)
 
+		// Note this can fail for non-TCP listeners, so we put the error in a
+		// separate variable.
 		lis, err := net.FileListener(f)
 		if err != nil {
-			parseError = fmt.Errorf(
+			listenError = fmt.Errorf(
 				"Error making listener out of fd %d: %v", fd, err)
+		} else {
+			listeners[name] = append(listeners[name], lis)
 		}
-		listeners[name] = append(listeners[name], lis)
 	}
 
 	// Remove them from the environment, to prevent accidental reuse (by
@@ -116,6 +126,9 @@ func parse() {
 // systemd socket unit. Multiple socket units can have the same name, hence
 // the slice of listeners for each name.
 //
+// If the "FileDescriptorName=" option is not used, then all file descriptors
+// are mapped to the "" name.
+//
 // Ideally you should not need to call this more than once. If you do, the
 // same listeners will be returned, as repeated calls to this function will
 // return the same results: the parsing is done only once, and the results are
@@ -125,7 +138,10 @@ func parse() {
 // how the passing works.
 func Listeners() (map[string][]net.Listener, error) {
 	parse()
-	return listeners, parseError
+	if parseError != nil {
+		return listeners, parseError
+	}
+	return listeners, listenError
 }
 
 // OneListener returns a net.Listener for the first systemd socket with the
@@ -142,6 +158,9 @@ func OneListener(name string) (net.Listener, error) {
 	if parseError != nil {
 		return nil, parseError
 	}
+	if listenError != nil {
+		return nil, listenError
+	}
 
 	lis := listeners[name]
 	if len(lis) < 1 {
@@ -185,6 +204,9 @@ func Listen(netw, laddr string) (net.Listener, error) {
 // systemd socket unit. Multiple socket units can have the same name, hence
 // the slice of listeners for each name.
 //
+// If the "FileDescriptorName=" option is not used, then all file descriptors
+// are mapped to the "" name.
+//
 // Ideally you should not need to call this more than once. If you do, the
 // same files will be returned, as repeated calls to this function will return
 // the same results: the parsing is done only once, and the results are saved
diff --git a/systemd_test.go b/systemd_test.go
index e005b67..5e7686d 100644
--- a/systemd_test.go
+++ b/systemd_test.go
@@ -17,6 +17,7 @@ func setenv(pid, fds string, names ...string) {
 	files = nil
 	listeners = nil
 	parseError = nil
+	listenError = nil
 }
 
 func TestEmptyEnvironment(t *testing.T) {
@@ -54,7 +55,6 @@ func TestBadEnvironment(t *testing.T) {
 		{ourPID, "a", []string{"name"}},           // Invalid number of fds.
 		{"1", "1", []string{"name"}},              // PID != ourselves.
 		{ourPID, "1", []string{"name1", "name2"}}, // Too many names.
-		{ourPID, "1", []string{}},                 // Not enough names.
 	}
 	for _, c := range cases {
 		setenv(c.pid, c.fds, c.names...)
@@ -122,9 +122,27 @@ func TestBadFDs(t *testing.T) {
 
 	setenv(strconv.Itoa(os.Getpid()), "1")
 	firstFD = int(f.Fd())
-	if ls, err := Listeners(); len(ls) != 1 || err == nil {
+
+	if ls, err := Listeners(); len(ls) != 0 || err == nil {
 		t.Errorf("Got a non-empty result: %v // %v", ls, err)
 	}
+
+	if l, err := OneListener(""); l != nil || err == nil {
+		t.Errorf("Got a non-empty result: %v // %v", l, err)
+	}
+
+	// It's not a bad FD as far as Files() is concerned.
+	fs, err := Files()
+	if err != nil {
+		t.Errorf("Unexpected error: %v", err)
+	}
+	if len(fs) != 1 || len(fs[""]) != 1 {
+		t.Errorf("Unexpected result: %v", fs)
+	}
+	if got := fs[""][0]; got.Fd() != f.Fd() {
+		t.Errorf("File descriptor %d != expected %d (%v)",
+			got.Fd(), f.Fd(), got)
+	}
 }
 
 // newListener creates a TCP listener.
@@ -284,6 +302,29 @@ func TestManySockets(t *testing.T) {
 		os.Getenv("LISTEN_FDNAMES") != "" {
 		t.Errorf("Failed to reset the environment")
 	}
+
+	// Test that things also work with LISTEN_FDNAMES unset.
+	setenv(strconv.Itoa(os.Getpid()), "2")
+	os.Unsetenv("LISTEN_FDNAMES")
+	{
+		lsMap, err := Listeners()
+		if err != nil || len(lsMap) != 1 || len(lsMap[""]) != 2 {
+			t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
+		}
+
+		ls := []net.Listener{
+			lsMap[""][0],
+			lsMap[""][1],
+		}
+
+		for i := 0; i < 2; i++ {
+			if !sameAddr(ls[i].Addr(), expected[i].Addr()) {
+				t.Errorf("Listener %d address mismatch, expected %#v, got %#v",
+					i, ls[i].Addr(), expected[i].Addr())
+			}
+		}
+	}
+
 }
 
 func TestListen(t *testing.T) {