author | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-05-28 08:56:40 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-05-28 08:56:40 UTC |
parent | cdc4fd023aa4baa19acdbb7f2f60b62fce1ce9f7 |
systemd.go | +23 | -8 |
systemd_test.go | +100 | -35 |
diff --git a/systemd.go b/systemd.go index b9de0d8..5c19ad9 100644 --- a/systemd.go +++ b/systemd.go @@ -22,21 +22,23 @@ var ( firstFD = 3 ) -// Keep a single global map of listeners, to avoid repeatedly parsing which -// can be problematic (see parse). +// Keep a single global map of files/listeners, to avoid repeatedly parsing +// which can be problematic (see parse). +var files map[string][]*os.File var listeners map[string][]net.Listener var parseError error var mutex sync.Mutex -// parse listeners, updating the global state. +// parse files, updating the global state. // This function messes with file descriptors and environment, so it is not // idempotent and must be called only once. For the callers' convenience, we -// save the listeners map globally and reuse it on the user-visible functions. +// save the files and listener maps globally, and reuse them on the +// user-visible functions. func parse() { mutex.Lock() defer mutex.Unlock() - if listeners != nil { + if files != nil { return } @@ -76,6 +78,7 @@ func parse() { return } + files = map[string][]*os.File{} listeners = map[string][]net.Listener{} for i := 0; i < nfds; i++ { @@ -86,13 +89,14 @@ func parse() { name := fdNames[i] sysName := fmt.Sprintf("[systemd-fd-%d-%v]", fd, name) - lis, err := net.FileListener(os.NewFile(uintptr(fd), sysName)) + f := os.NewFile(uintptr(fd), sysName) + files[name] = append(files[name], f) + + lis, err := net.FileListener(f) if err != nil { parseError = fmt.Errorf( "Error making listener out of fd %d: %v", fd, err) - return } - listeners[name] = append(listeners[name], lis) } @@ -171,3 +175,14 @@ func Listen(netw, laddr string) (net.Listener, error) { return net.Listen(netw, laddr) } } + +// Files returns a map of files for the file descriptors passed by +// systemd via environment variables. +// +// Normally you would use Listeners instead; however, this is useful if you +// need more fine-grained control over listener creation, for example if you +// need to create packet connections from them. +func Files() (map[string][]*os.File, error) { + parse() + return files, parseError +} diff --git a/systemd_test.go b/systemd_test.go index d02915b..6686795 100644 --- a/systemd_test.go +++ b/systemd_test.go @@ -1,6 +1,7 @@ package systemd import ( + "fmt" "math/rand" "net" "os" @@ -13,6 +14,7 @@ func setenv(pid, fds string, names ...string) { os.Setenv("LISTEN_PID", pid) os.Setenv("LISTEN_FDS", fds) os.Setenv("LISTEN_FDNAMES", strings.Join(names, ":")) + files = nil listeners = nil parseError = nil } @@ -30,6 +32,11 @@ func TestEmptyEnvironment(t *testing.T) { t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q", c.pid, c.fds) t.Errorf("Unexpected result: %v // %v", ls, err) } + + if fs, err := Files(); fs != nil || err != nil { + t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q", c.pid, c.fds) + t.Errorf("Unexpected result: %v // %v", fs, err) + } } } @@ -53,12 +60,18 @@ func TestBadEnvironment(t *testing.T) { setenv(c.pid, c.fds, c.names...) if ls, err := Listeners(); err == nil { - t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q LISTEN_FDNAMES=%q", c.pid, c.fds, c.names) + t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q LISTEN_FDNAMES=%q", + c.pid, c.fds, c.names) t.Errorf("Unexpected result: %v // %v", ls, err) } if ls, err := OneListener("name"); err == nil { t.Errorf("Unexpected result: %v // %v", ls, err) } + if fs, err := Files(); err == nil { + t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q LISTEN_FDNAMES=%q", + c.pid, c.fds, c.names) + t.Errorf("Unexpected result: %v // %v", fs, err) + } } } @@ -77,6 +90,10 @@ func TestWrongPID(t *testing.T) { if _, err := OneListener("name"); err != ErrPIDMismatch { t.Errorf("Did not fail with PID mismatch: %v", err) } + + if _, err := Files(); err != ErrPIDMismatch { + t.Errorf("Did not fail with PID mismatch: %v", err) + } } func TestNoFDs(t *testing.T) { @@ -84,6 +101,10 @@ func TestNoFDs(t *testing.T) { if ls, err := Listeners(); len(ls) != 0 || err != nil { t.Errorf("Got a non-empty result: %v // %v", ls, err) } + + if ls, err := Files(); len(ls) != 0 || err != nil { + t.Errorf("Got a non-empty result: %v // %v", ls, err) + } } // newListener creates a TCP listener. @@ -121,25 +142,43 @@ func TestOneSocket(t *testing.T) { setenv(strconv.Itoa(os.Getpid()), "1", "name") - lsMap, err := Listeners() - if err != nil || len(lsMap) != 1 { - t.Fatalf("Got an invalid result: %v // %v", lsMap, err) - } + { + lsMap, err := Listeners() + if err != nil || len(lsMap) != 1 { + t.Fatalf("Got an invalid result: %v // %v", lsMap, err) + } - ls := lsMap["name"] + ls := lsMap["name"] + if !sameAddr(ls[0].Addr(), l.Addr()) { + t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", + l.Addr(), ls[0].Addr()) + } - if !sameAddr(ls[0].Addr(), l.Addr()) { - t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", - l.Addr(), ls[0].Addr()) + oneL, err := OneListener("name") + if err != nil { + t.Errorf("OneListener error: %v", err) + } + if !sameAddr(oneL.Addr(), l.Addr()) { + t.Errorf("OneListener address mismatch, expected %#v, got %#v", + l.Addr(), ls[0].Addr()) + } } - oneL, err := OneListener("name") - if err != nil { - t.Errorf("OneListener error: %v", err) - } - if !sameAddr(oneL.Addr(), l.Addr()) { - t.Errorf("OneListener address mismatch, expected %#v, got %#v", - l.Addr(), ls[0].Addr()) + { + fsMap, err := Files() + if err != nil || len(fsMap) != 1 || len(fsMap["name"]) != 1 { + t.Fatalf("Got an invalid result: %v // %v", fsMap, err) + } + + f := fsMap["name"][0] + flis, err := net.FileListener(f) + if err != nil { + t.Errorf("File was not a listener: %v", err) + } + if !sameAddr(flis.Addr(), l.Addr()) { + t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", + l.Addr(), flis.Addr()) + } } if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" { @@ -164,34 +203,60 @@ func TestManySockets(t *testing.T) { t.Logf("Looping for FDs: %d %d", f0, f1) } + expected := []*net.TCPListener{l0, l1} + firstFD = f0 setenv(strconv.Itoa(os.Getpid()), "2", "name1", "name2") - lsMap, err := Listeners() - if err != nil || len(lsMap) != 2 { - t.Fatalf("Got an invalid result: %v // %v", lsMap, err) - } + { + lsMap, err := Listeners() + if err != nil || len(lsMap) != 2 { + t.Fatalf("Got an invalid result: %v // %v", lsMap, err) + } - ls := []net.Listener{ - lsMap["name1"][0], - lsMap["name2"][0], - } + ls := []net.Listener{ + lsMap["name1"][0], + lsMap["name2"][0], + } - if !sameAddr(ls[0].Addr(), l0.Addr()) { - t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", - l0.Addr(), ls[0].Addr()) - } + for i := 0; i < 2; i++ { + if !sameAddr(ls[i].Addr(), expected[i].Addr()) { + t.Errorf("Listener %d address mismatch, expected %#v, got %#v", + i, ls[i].Addr(), expected[i].Addr()) + } + } - if !sameAddr(ls[1].Addr(), l1.Addr()) { - t.Errorf("Listener 1 address mismatch, expected %#v, got %#v", - l1.Addr(), ls[1].Addr()) + oneL, _ := OneListener("name1") + if !sameAddr(oneL.Addr(), expected[0].Addr()) { + t.Errorf("OneListener address mismatch, expected %#v, got %#v", + oneL.Addr(), expected[0].Addr()) + } } - oneL, _ := OneListener("name1") - if !sameAddr(oneL.Addr(), l0.Addr()) { - t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", - oneL.Addr(), ls[0].Addr()) + { + fsMap, err := Files() + if err != nil || len(fsMap) != 2 { + t.Fatalf("Got an invalid result: %v // %v", fsMap, err) + } + + for i := 0; i < 2; i++ { + name := fmt.Sprintf("name%d", i+1) + fs := fsMap[name] + if len(fs) != 1 { + t.Errorf("fsMap[%q] = %v had %d entries, expected 1", + name, fs, len(fs)) + } + + flis, err := net.FileListener(fs[0]) + if err != nil { + t.Errorf("File was not a listener: %v", err) + } + if !sameAddr(flis.Addr(), expected[i].Addr()) { + t.Errorf("Listener %d address mismatch, expected %#v, got %#v", + i, flis.Addr(), expected[i].Addr()) + } + } } if os.Getenv("LISTEN_PID") != "" ||