author | Alberto Bertogli
<albertito@blitiri.com.ar> 2015-10-31 23:05:34 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2015-11-01 02:19:23 UTC |
systemd.go | +70 | -0 |
systemd_test.go | +157 | -0 |
diff --git a/systemd.go b/systemd.go new file mode 100644 index 0000000..a6a796c --- /dev/null +++ b/systemd.go @@ -0,0 +1,70 @@ +// Package systemd implements utility functions to interact with systemd. +package systemd + +import ( + "errors" + "fmt" + "net" + "os" + "strconv" + "syscall" +) + +var ( + // Error to return when $LISTEN_PID does not refer to us. + PIDMismatch = errors.New("$LISTEN_PID != our PID") + + // First FD for listeners. + // It's 3 by definition, but using a variable simplifies testing. + firstFD = 3 +) + +// Listeners creates a slice net.Listener from the file descriptors passed +// by systemd, via the LISTEN_FDS environment variable. +// See sd_listen_fds(3) for more details. +func Listeners() ([]net.Listener, error) { + pidStr := os.Getenv("LISTEN_PID") + nfdsStr := os.Getenv("LISTEN_FDS") + + // Nothing to do if the variables are not set. + if pidStr == "" || nfdsStr == "" { + return nil, nil + } + + pid, err := strconv.Atoi(pidStr) + if err != nil { + return nil, fmt.Errorf( + "error converting $LISTEN_PID=%q: %v", pidStr, err) + } else if pid != os.Getpid() { + return nil, PIDMismatch + } + + nfds, err := strconv.Atoi(os.Getenv("LISTEN_FDS")) + if err != nil { + return nil, fmt.Errorf( + "error reading $LISTEN_FDS=%q: %v", nfdsStr, err) + } + + listeners := []net.Listener{} + + for fd := firstFD; fd < firstFD+nfds; fd++ { + // We don't want childs to inherit these file descriptors. + syscall.CloseOnExec(fd) + + name := fmt.Sprintf("[systemd-fd-%d]", fd) + lis, err := net.FileListener(os.NewFile(uintptr(fd), name)) + if err != nil { + return nil, fmt.Errorf( + "Error making listener out of fd %d: %v", fd, err) + } + + listeners = append(listeners, lis) + } + + // Remove them from the environment, to prevent accidental reuse (by + // us or children processes). + os.Unsetenv("LISTEN_PID") + os.Unsetenv("LISTEN_FDS") + + return listeners, nil +} diff --git a/systemd_test.go b/systemd_test.go new file mode 100644 index 0000000..28586fe --- /dev/null +++ b/systemd_test.go @@ -0,0 +1,157 @@ +package systemd + +import ( + "math/rand" + "net" + "os" + "strconv" + "testing" +) + +func setenv(pid, fds string) { + os.Setenv("LISTEN_PID", pid) + os.Setenv("LISTEN_FDS", fds) +} + +func TestEmptyEnvironment(t *testing.T) { + cases := []struct{ pid, fds string }{ + {"", ""}, + {"123", ""}, + {"", "4"}, + } + for _, c := range cases { + setenv(c.pid, c.fds) + + if ls, err := Listeners(); ls != nil || err != nil { + t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q", c.pid, c.fds) + t.Errorf("Unexpected result: %v // %v", ls, err) + } + } +} + +func TestBadEnvironment(t *testing.T) { + ourPID := strconv.Itoa(os.Getpid()) + cases := []struct{ pid, fds string }{ + {"a", "4"}, + {ourPID, "a"}, + } + for _, c := range cases { + setenv(c.pid, c.fds) + + if ls, err := Listeners(); err == nil { + t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q", c.pid, c.fds) + t.Errorf("Unexpected result: %v // %v", ls, err) + } + } +} + +func TestWrongPID(t *testing.T) { + // Find a pid != us. 1 should always work in practice. + pid := 1 + for pid == os.Getpid() { + pid = rand.Int() + } + + setenv(strconv.Itoa(pid), "4") + if _, err := Listeners(); err != PIDMismatch { + t.Errorf("Did not fail with PID mismatch: %v", err) + } +} + +func TestNoFDs(t *testing.T) { + setenv(strconv.Itoa(os.Getpid()), "0") + if ls, err := Listeners(); len(ls) != 0 || err != nil { + t.Errorf("Got a non-empty result: %v // %v", ls, err) + } +} + +// newListener creates a TCP listener. +func newListener(t *testing.T) *net.TCPListener { + addr := &net.TCPAddr{ + Port: 0, + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + t.Fatalf("Could not create TCP listener: %v", err) + } + + return l +} + +// listenerFd returns a file descriptor for the listener. +// Note it is a NEW file descriptor, not the original one. +func listenerFd(t *testing.T, l *net.TCPListener) int { + f, err := l.File() + if err != nil { + t.Fatalf("Could not get TCP listener file: %v", err) + } + + return int(f.Fd()) +} + +func sameAddr(a, b net.Addr) bool { + return a.Network() == b.Network() && a.String() == b.String() +} + +func TestOneSocket(t *testing.T) { + l := newListener(t) + firstFD = listenerFd(t, l) + + setenv(strconv.Itoa(os.Getpid()), "1") + + ls, err := Listeners() + if err != nil || len(ls) != 1 { + t.Fatalf("Got an invalid result: %v // %v", ls, err) + } + + if !sameAddr(ls[0].Addr(), l.Addr()) { + t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", + l.Addr(), ls[0].Addr()) + } + + if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" { + t.Errorf("Failed to reset the environment") + } +} + +func TestManySockets(t *testing.T) { + // Create two contiguous listeners. + // The test environment does not guarantee us that they are contiguous, so + // keep going until they are. + var l0, l1 *net.TCPListener + var f0, f1 int = -1, -3 + + for f0+1 != f1 { + // We have to be careful with the order of these operations, because + // listenerFd will create *new* file descriptors. + l0 = newListener(t) + l1 = newListener(t) + f0 = listenerFd(t, l0) + f1 = listenerFd(t, l1) + t.Logf("Looping for FDs: %d %d", f0, f1) + } + + firstFD = f0 + + setenv(strconv.Itoa(os.Getpid()), "2") + + ls, err := Listeners() + if err != nil || len(ls) != 2 { + t.Fatalf("Got an invalid result: %v // %v", ls, err) + } + + if !sameAddr(ls[0].Addr(), l0.Addr()) { + t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", + l0.Addr(), ls[0].Addr()) + } + + if !sameAddr(ls[1].Addr(), l1.Addr()) { + t.Errorf("Listener 1 address mismatch, expected %#v, got %#v", + l1.Addr(), ls[1].Addr()) + } + + if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" { + t.Errorf("Failed to reset the environment") + } +}