author | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-09-12 02:47:36 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2016-10-09 23:50:24 UTC |
parent | 819a282f71a39545e1de582abce04c3adc5048ff |
systemd.go | +22 | -7 |
systemd_test.go | +36 | -15 |
diff --git a/systemd.go b/systemd.go index 4cbaf5e..75f2686 100644 --- a/systemd.go +++ b/systemd.go @@ -7,6 +7,7 @@ import ( "net" "os" "strconv" + "strings" "syscall" ) @@ -21,10 +22,12 @@ var ( // Listeners creates a slice net.Listener from the file descriptors passed // by systemd, via the LISTEN_FDS environment variable. -// See sd_listen_fds(3) for more details. -func Listeners() ([]net.Listener, error) { +// See sd_listen_fds(3) and sd_listen_fds_with_names(3) for more details. +func Listeners() (map[string][]net.Listener, error) { pidStr := os.Getenv("LISTEN_PID") nfdsStr := os.Getenv("LISTEN_FDS") + fdNamesStr := os.Getenv("LISTEN_FDNAMES") + fdNames := strings.Split(fdNamesStr, ":") // Nothing to do if the variables are not set. if pidStr == "" || nfdsStr == "" { @@ -45,26 +48,38 @@ func Listeners() ([]net.Listener, error) { "error reading $LISTEN_FDS=%q: %v", nfdsStr, err) } - listeners := []net.Listener{} + // We should have as many names as we have descriptors. + // Note that if we have no descriptors, fdNames will be [""] (due to how + // strings.Split works), so we consider that special case. + if nfds > 0 && (fdNamesStr == "" || len(fdNames) != nfds) { + return nil, fmt.Errorf( + "Incorrect LISTEN_FDNAMES, have you set FileDescriptorName?") + } - for fd := firstFD; fd < firstFD+nfds; fd++ { + listeners := map[string][]net.Listener{} + + for i := 0; i < nfds; i++ { + fd := firstFD + i // We don't want childs to inherit these file descriptors. syscall.CloseOnExec(fd) - name := fmt.Sprintf("[systemd-fd-%d]", fd) - lis, err := net.FileListener(os.NewFile(uintptr(fd), name)) + name := fdNames[i] + + sysName := fmt.Sprintf("[systemd-fd-%d-%v]", fd, name) + lis, err := net.FileListener(os.NewFile(uintptr(fd), sysName)) if err != nil { return nil, fmt.Errorf( "Error making listener out of fd %d: %v", fd, err) } - listeners = append(listeners, lis) + listeners[name] = append(listeners[name], lis) } // Remove them from the environment, to prevent accidental reuse (by // us or children processes). os.Unsetenv("LISTEN_PID") os.Unsetenv("LISTEN_FDS") + os.Unsetenv("LISTEN_FDNAMES") return listeners, nil } diff --git a/systemd_test.go b/systemd_test.go index 34bca4b..226b124 100644 --- a/systemd_test.go +++ b/systemd_test.go @@ -5,12 +5,14 @@ import ( "net" "os" "strconv" + "strings" "testing" ) -func setenv(pid, fds string) { +func setenv(pid, fds string, names ...string) { os.Setenv("LISTEN_PID", pid) os.Setenv("LISTEN_FDS", fds) + os.Setenv("LISTEN_FDNAMES", strings.Join(names, ":")) } func TestEmptyEnvironment(t *testing.T) { @@ -30,16 +32,26 @@ func TestEmptyEnvironment(t *testing.T) { } func TestBadEnvironment(t *testing.T) { + // Create a listener so we have something to reference. + l := newListener(t) + firstFD = listenerFd(t, l) + ourPID := strconv.Itoa(os.Getpid()) - cases := []struct{ pid, fds string }{ - {"a", "4"}, - {ourPID, "a"}, + cases := []struct { + pid, fds string + names []string + }{ + {"a", "1", []string{"name"}}, // Invalid PID. + {ourPID, "a", []string{"name"}}, // Invalid number of fds. + {"1", "1", []string{"name"}}, // PID != ourselves. + {ourPID, "1", []string{"name1", "name2"}}, // Too many names. + {ourPID, "1", []string{}}, // Not enough names. } for _, c := range cases { - setenv(c.pid, c.fds) + setenv(c.pid, c.fds, c.names...) if ls, err := Listeners(); err == nil { - t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q", c.pid, c.fds) + t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q LISTEN_FDNAMES=%q", c.pid, c.fds, c.names) t.Errorf("Unexpected result: %v // %v", ls, err) } } @@ -98,13 +110,15 @@ func TestOneSocket(t *testing.T) { l := newListener(t) firstFD = listenerFd(t, l) - setenv(strconv.Itoa(os.Getpid()), "1") + setenv(strconv.Itoa(os.Getpid()), "1", "name") - ls, err := Listeners() - if err != nil || len(ls) != 1 { - t.Fatalf("Got an invalid result: %v // %v", ls, err) + lsMap, err := Listeners() + if err != nil || len(lsMap) != 1 { + t.Fatalf("Got an invalid result: %v // %v", lsMap, err) } + ls := lsMap["name"] + if !sameAddr(ls[0].Addr(), l.Addr()) { t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", l.Addr(), ls[0].Addr()) @@ -134,11 +148,16 @@ func TestManySockets(t *testing.T) { firstFD = f0 - setenv(strconv.Itoa(os.Getpid()), "2") + setenv(strconv.Itoa(os.Getpid()), "2", "name1", "name2") - ls, err := Listeners() - if err != nil || len(ls) != 2 { - t.Fatalf("Got an invalid result: %v // %v", ls, err) + lsMap, err := Listeners() + if err != nil || len(lsMap) != 2 { + t.Fatalf("Got an invalid result: %v // %v", lsMap, err) + } + + ls := []net.Listener{ + lsMap["name1"][0], + lsMap["name2"][0], } if !sameAddr(ls[0].Addr(), l0.Addr()) { @@ -151,7 +170,9 @@ func TestManySockets(t *testing.T) { l1.Addr(), ls[1].Addr()) } - if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" { + if os.Getenv("LISTEN_PID") != "" || + os.Getenv("LISTEN_FDS") != "" || + os.Getenv("LISTEN_FDNAMES") != "" { t.Errorf("Failed to reset the environment") } }