git » debian:golang-blitiri-go-systemd » commit 701f817

Add Files function to get the open files directly

author Alberto Bertogli
2020-05-28 08:56:40 UTC
committer Alberto Bertogli
2020-05-28 08:56:40 UTC
parent cdc4fd023aa4baa19acdbb7f2f60b62fce1ce9f7

Add Files function to get the open files directly

For some use cases it can be convenient to get the *os.File objects
directly, instead of net.Listener. For example, if the file descriptor
refers to a packet connection (UDP).

This patch adds a Files function to support that use case.

systemd.go +23 -8
systemd_test.go +100 -35

diff --git a/systemd.go b/systemd.go
index b9de0d8..5c19ad9 100644
--- a/systemd.go
+++ b/systemd.go
@@ -22,21 +22,23 @@ var (
 	firstFD = 3
 )
 
-// Keep a single global map of listeners, to avoid repeatedly parsing which
-// can be problematic (see parse).
+// Keep a single global map of files/listeners, to avoid repeatedly parsing
+// which can be problematic (see parse).
+var files map[string][]*os.File
 var listeners map[string][]net.Listener
 var parseError error
 var mutex sync.Mutex
 
-// parse listeners, updating the global state.
+// parse files, updating the global state.
 // This function messes with file descriptors and environment, so it is not
 // idempotent and must be called only once. For the callers' convenience, we
-// save the listeners map globally and reuse it on the user-visible functions.
+// save the files and listener maps globally, and reuse them on the
+// user-visible functions.
 func parse() {
 	mutex.Lock()
 	defer mutex.Unlock()
 
-	if listeners != nil {
+	if files != nil {
 		return
 	}
 
@@ -76,6 +78,7 @@ func parse() {
 		return
 	}
 
+	files = map[string][]*os.File{}
 	listeners = map[string][]net.Listener{}
 
 	for i := 0; i < nfds; i++ {
@@ -86,13 +89,14 @@ func parse() {
 		name := fdNames[i]
 
 		sysName := fmt.Sprintf("[systemd-fd-%d-%v]", fd, name)
-		lis, err := net.FileListener(os.NewFile(uintptr(fd), sysName))
+		f := os.NewFile(uintptr(fd), sysName)
+		files[name] = append(files[name], f)
+
+		lis, err := net.FileListener(f)
 		if err != nil {
 			parseError = fmt.Errorf(
 				"Error making listener out of fd %d: %v", fd, err)
-			return
 		}
-
 		listeners[name] = append(listeners[name], lis)
 	}
 
@@ -171,3 +175,14 @@ func Listen(netw, laddr string) (net.Listener, error) {
 		return net.Listen(netw, laddr)
 	}
 }
+
+// Files returns a map of files for the file descriptors passed by
+// systemd via environment variables.
+//
+// Normally you would use Listeners instead; however, this is useful if you
+// need more fine-grained control over listener creation, for example if you
+// need to create packet connections from them.
+func Files() (map[string][]*os.File, error) {
+	parse()
+	return files, parseError
+}
diff --git a/systemd_test.go b/systemd_test.go
index d02915b..6686795 100644
--- a/systemd_test.go
+++ b/systemd_test.go
@@ -1,6 +1,7 @@
 package systemd
 
 import (
+	"fmt"
 	"math/rand"
 	"net"
 	"os"
@@ -13,6 +14,7 @@ func setenv(pid, fds string, names ...string) {
 	os.Setenv("LISTEN_PID", pid)
 	os.Setenv("LISTEN_FDS", fds)
 	os.Setenv("LISTEN_FDNAMES", strings.Join(names, ":"))
+	files = nil
 	listeners = nil
 	parseError = nil
 }
@@ -30,6 +32,11 @@ func TestEmptyEnvironment(t *testing.T) {
 			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q", c.pid, c.fds)
 			t.Errorf("Unexpected result: %v // %v", ls, err)
 		}
+
+		if fs, err := Files(); fs != nil || err != nil {
+			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q", c.pid, c.fds)
+			t.Errorf("Unexpected result: %v // %v", fs, err)
+		}
 	}
 }
 
@@ -53,12 +60,18 @@ func TestBadEnvironment(t *testing.T) {
 		setenv(c.pid, c.fds, c.names...)
 
 		if ls, err := Listeners(); err == nil {
-			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q LISTEN_FDNAMES=%q", c.pid, c.fds, c.names)
+			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q LISTEN_FDNAMES=%q",
+				c.pid, c.fds, c.names)
 			t.Errorf("Unexpected result: %v // %v", ls, err)
 		}
 		if ls, err := OneListener("name"); err == nil {
 			t.Errorf("Unexpected result: %v // %v", ls, err)
 		}
+		if fs, err := Files(); err == nil {
+			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q LISTEN_FDNAMES=%q",
+				c.pid, c.fds, c.names)
+			t.Errorf("Unexpected result: %v // %v", fs, err)
+		}
 	}
 }
 
@@ -77,6 +90,10 @@ func TestWrongPID(t *testing.T) {
 	if _, err := OneListener("name"); err != ErrPIDMismatch {
 		t.Errorf("Did not fail with PID mismatch: %v", err)
 	}
+
+	if _, err := Files(); err != ErrPIDMismatch {
+		t.Errorf("Did not fail with PID mismatch: %v", err)
+	}
 }
 
 func TestNoFDs(t *testing.T) {
@@ -84,6 +101,10 @@ func TestNoFDs(t *testing.T) {
 	if ls, err := Listeners(); len(ls) != 0 || err != nil {
 		t.Errorf("Got a non-empty result: %v // %v", ls, err)
 	}
+
+	if ls, err := Files(); len(ls) != 0 || err != nil {
+		t.Errorf("Got a non-empty result: %v // %v", ls, err)
+	}
 }
 
 // newListener creates a TCP listener.
@@ -121,25 +142,43 @@ func TestOneSocket(t *testing.T) {
 
 	setenv(strconv.Itoa(os.Getpid()), "1", "name")
 
-	lsMap, err := Listeners()
-	if err != nil || len(lsMap) != 1 {
-		t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
-	}
+	{
+		lsMap, err := Listeners()
+		if err != nil || len(lsMap) != 1 {
+			t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
+		}
 
-	ls := lsMap["name"]
+		ls := lsMap["name"]
+		if !sameAddr(ls[0].Addr(), l.Addr()) {
+			t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
+				l.Addr(), ls[0].Addr())
+		}
 
-	if !sameAddr(ls[0].Addr(), l.Addr()) {
-		t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
-			l.Addr(), ls[0].Addr())
+		oneL, err := OneListener("name")
+		if err != nil {
+			t.Errorf("OneListener error: %v", err)
+		}
+		if !sameAddr(oneL.Addr(), l.Addr()) {
+			t.Errorf("OneListener address mismatch, expected %#v, got %#v",
+				l.Addr(), ls[0].Addr())
+		}
 	}
 
-	oneL, err := OneListener("name")
-	if err != nil {
-		t.Errorf("OneListener error: %v", err)
-	}
-	if !sameAddr(oneL.Addr(), l.Addr()) {
-		t.Errorf("OneListener address mismatch, expected %#v, got %#v",
-			l.Addr(), ls[0].Addr())
+	{
+		fsMap, err := Files()
+		if err != nil || len(fsMap) != 1 || len(fsMap["name"]) != 1 {
+			t.Fatalf("Got an invalid result: %v // %v", fsMap, err)
+		}
+
+		f := fsMap["name"][0]
+		flis, err := net.FileListener(f)
+		if err != nil {
+			t.Errorf("File was not a listener: %v", err)
+		}
+		if !sameAddr(flis.Addr(), l.Addr()) {
+			t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
+				l.Addr(), flis.Addr())
+		}
 	}
 
 	if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" {
@@ -164,34 +203,60 @@ func TestManySockets(t *testing.T) {
 		t.Logf("Looping for FDs: %d %d", f0, f1)
 	}
 
+	expected := []*net.TCPListener{l0, l1}
+
 	firstFD = f0
 
 	setenv(strconv.Itoa(os.Getpid()), "2", "name1", "name2")
 
-	lsMap, err := Listeners()
-	if err != nil || len(lsMap) != 2 {
-		t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
-	}
+	{
+		lsMap, err := Listeners()
+		if err != nil || len(lsMap) != 2 {
+			t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
+		}
 
-	ls := []net.Listener{
-		lsMap["name1"][0],
-		lsMap["name2"][0],
-	}
+		ls := []net.Listener{
+			lsMap["name1"][0],
+			lsMap["name2"][0],
+		}
 
-	if !sameAddr(ls[0].Addr(), l0.Addr()) {
-		t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
-			l0.Addr(), ls[0].Addr())
-	}
+		for i := 0; i < 2; i++ {
+			if !sameAddr(ls[i].Addr(), expected[i].Addr()) {
+				t.Errorf("Listener %d address mismatch, expected %#v, got %#v",
+					i, ls[i].Addr(), expected[i].Addr())
+			}
+		}
 
-	if !sameAddr(ls[1].Addr(), l1.Addr()) {
-		t.Errorf("Listener 1 address mismatch, expected %#v, got %#v",
-			l1.Addr(), ls[1].Addr())
+		oneL, _ := OneListener("name1")
+		if !sameAddr(oneL.Addr(), expected[0].Addr()) {
+			t.Errorf("OneListener address mismatch, expected %#v, got %#v",
+				oneL.Addr(), expected[0].Addr())
+		}
 	}
 
-	oneL, _ := OneListener("name1")
-	if !sameAddr(oneL.Addr(), l0.Addr()) {
-		t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
-			oneL.Addr(), ls[0].Addr())
+	{
+		fsMap, err := Files()
+		if err != nil || len(fsMap) != 2 {
+			t.Fatalf("Got an invalid result: %v // %v", fsMap, err)
+		}
+
+		for i := 0; i < 2; i++ {
+			name := fmt.Sprintf("name%d", i+1)
+			fs := fsMap[name]
+			if len(fs) != 1 {
+				t.Errorf("fsMap[%q] = %v had %d entries, expected 1",
+					name, fs, len(fs))
+			}
+
+			flis, err := net.FileListener(fs[0])
+			if err != nil {
+				t.Errorf("File was not a listener: %v", err)
+			}
+			if !sameAddr(flis.Addr(), expected[i].Addr()) {
+				t.Errorf("Listener %d address mismatch, expected %#v, got %#v",
+					i, flis.Addr(), expected[i].Addr())
+			}
+		}
 	}
 
 	if os.Getenv("LISTEN_PID") != "" ||