author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-08-21 16:24:05 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-08-21 16:49:06 UTC |
parent | b7ba9bf95ef6d9b01eec13b6d530041949b1f206 |
systemd.go | +51 | -12 |
systemd_test.go | +2 | -0 |
diff --git a/systemd.go b/systemd.go index a7188b8..b931cf8 100644 --- a/systemd.go +++ b/systemd.go @@ -9,6 +9,7 @@ import ( "os" "strconv" "strings" + "sync" "syscall" ) @@ -21,10 +22,24 @@ var ( firstFD = 3 ) -// Listeners creates a slice net.Listener from the file descriptors passed -// by systemd, via the LISTEN_FDS environment variable. -// See sd_listen_fds(3) and sd_listen_fds_with_names(3) for more details. -func Listeners() (map[string][]net.Listener, error) { +// Keep a single global map of listeners, to avoid repeatedly parsing which +// can be problematic (see parse). +var listeners map[string][]net.Listener +var parseError error +var mutex sync.Mutex + +// parse listeners, updating the global state. +// This function messes with file descriptors and environment, so it is not +// idempotent and must be called only once. For the callers' convenience, we +// save the listeners map globally and reuse it on the user-visible functions. +func parse() { + mutex.Lock() + defer mutex.Unlock() + + if listeners != nil { + return + } + pidStr := os.Getenv("LISTEN_PID") nfdsStr := os.Getenv("LISTEN_FDS") fdNamesStr := os.Getenv("LISTEN_FDNAMES") @@ -32,32 +47,36 @@ func Listeners() (map[string][]net.Listener, error) { // Nothing to do if the variables are not set. if pidStr == "" || nfdsStr == "" { - return nil, nil + return } pid, err := strconv.Atoi(pidStr) if err != nil { - return nil, fmt.Errorf( + parseError = fmt.Errorf( "error converting $LISTEN_PID=%q: %v", pidStr, err) + return } else if pid != os.Getpid() { - return nil, ErrPIDMismatch + parseError = ErrPIDMismatch + return } nfds, err := strconv.Atoi(os.Getenv("LISTEN_FDS")) if err != nil { - return nil, fmt.Errorf( + parseError = fmt.Errorf( "error reading $LISTEN_FDS=%q: %v", nfdsStr, err) + return } // We should have as many names as we have descriptors. // Note that if we have no descriptors, fdNames will be [""] (due to how // strings.Split works), so we consider that special case. if nfds > 0 && (fdNamesStr == "" || len(fdNames) != nfds) { - return nil, fmt.Errorf( + parseError = fmt.Errorf( "Incorrect LISTEN_FDNAMES, have you set FileDescriptorName?") + return } - listeners := map[string][]net.Listener{} + listeners = map[string][]net.Listener{} for i := 0; i < nfds; i++ { fd := firstFD + i @@ -69,8 +88,9 @@ func Listeners() (map[string][]net.Listener, error) { sysName := fmt.Sprintf("[systemd-fd-%d-%v]", fd, name) lis, err := net.FileListener(os.NewFile(uintptr(fd), sysName)) if err != nil { - return nil, fmt.Errorf( + parseError = fmt.Errorf( "Error making listener out of fd %d: %v", fd, err) + return } listeners[name] = append(listeners[name], lis) @@ -81,6 +101,25 @@ func Listeners() (map[string][]net.Listener, error) { os.Unsetenv("LISTEN_PID") os.Unsetenv("LISTEN_FDS") os.Unsetenv("LISTEN_FDNAMES") +} - return listeners, nil +// Listeners returns a map of listeners for the file descriptors passed by +// systemd via environment variables. +// +// It returns a map of the form (file descriptor name -> slice of listeners). +// +// The file descriptor name comes from the "FileDescriptorName=" option in the +// systemd socket unit. Multiple socket units can have the same name, hence +// the slice of listeners for each name. +// +// Ideally you should not need to call this more than once. If you do, the +// same listeners will be returned, as repeated calls to this function will +// return the same results: the parsing is done only once, and the results are +// saved and reused. +// +// See sd_listen_fds(3) and sd_listen_fds_with_names(3) for more details on +// how the passing works. +func Listeners() (map[string][]net.Listener, error) { + parse() + return listeners, parseError } diff --git a/systemd_test.go b/systemd_test.go index 226b124..6819ef4 100644 --- a/systemd_test.go +++ b/systemd_test.go @@ -13,6 +13,8 @@ func setenv(pid, fds string, names ...string) { os.Setenv("LISTEN_PID", pid) os.Setenv("LISTEN_FDS", fds) os.Setenv("LISTEN_FDNAMES", strings.Join(names, ":")) + listeners = nil + parseError = nil } func TestEmptyEnvironment(t *testing.T) {