author | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-08-21 16:45:41 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2017-08-21 16:57:59 UTC |
parent | cd3225afdf78a8a0b038bf10b81daf35e42d5214 |
systemd.go | +48 | -0 |
systemd_test.go | +43 | -0 |
diff --git a/systemd.go b/systemd.go index b931cf8..b9de0d8 100644 --- a/systemd.go +++ b/systemd.go @@ -123,3 +123,51 @@ func Listeners() (map[string][]net.Listener, error) { parse() return listeners, parseError } + +// OneListener returns a listener for the first systemd socket with the given +// name. If there are none, the listener and error will both be nil. An error +// will be returned only if there were issues parsing the file descriptors. +// +// This function can be convenient for simple callers where you know there's +// only one file descriptor being passed with the given name. +// +// This is a convenience function built on top of Listeners(). +func OneListener(name string) (net.Listener, error) { + parse() + if parseError != nil { + return nil, parseError + } + + lis := listeners[name] + if len(lis) < 1 { + return nil, nil + } + return lis[0], nil +} + +// Listen returns a listener for the given address, similar to net.Listen. +// +// If the address begins with "&" it is interpreted as a systemd socket being +// passed. For example, using "&http" would mean we expect a systemd socket +// passed to us, named with "FileDescriptorName=http" in its unit. +// +// Otherwise, it uses net.Listen to create a new listener with the given net +// and local address. +// +// This function can be convenient for simple callers where you get the +// address from a user, and want to let them specify either "use systemd" or a +// normal address without too much additional complexity. +// +// This is a convenience function built on top of Listeners(). +func Listen(netw, laddr string) (net.Listener, error) { + if strings.HasPrefix(laddr, "&") { + name := laddr[1:] + lis, err := OneListener(name) + if lis == nil && err == nil { + err = fmt.Errorf("systemd socket %q not found", name) + } + return lis, err + } else { + return net.Listen(netw, laddr) + } +} diff --git a/systemd_test.go b/systemd_test.go index 6819ef4..d02915b 100644 --- a/systemd_test.go +++ b/systemd_test.go @@ -56,6 +56,9 @@ func TestBadEnvironment(t *testing.T) { t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q LISTEN_FDNAMES=%q", c.pid, c.fds, c.names) t.Errorf("Unexpected result: %v // %v", ls, err) } + if ls, err := OneListener("name"); err == nil { + t.Errorf("Unexpected result: %v // %v", ls, err) + } } } @@ -70,6 +73,10 @@ func TestWrongPID(t *testing.T) { if _, err := Listeners(); err != ErrPIDMismatch { t.Errorf("Did not fail with PID mismatch: %v", err) } + + if _, err := OneListener("name"); err != ErrPIDMismatch { + t.Errorf("Did not fail with PID mismatch: %v", err) + } } func TestNoFDs(t *testing.T) { @@ -126,6 +133,15 @@ func TestOneSocket(t *testing.T) { l.Addr(), ls[0].Addr()) } + oneL, err := OneListener("name") + if err != nil { + t.Errorf("OneListener error: %v", err) + } + if !sameAddr(oneL.Addr(), l.Addr()) { + t.Errorf("OneListener address mismatch, expected %#v, got %#v", + l.Addr(), ls[0].Addr()) + } + if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" { t.Errorf("Failed to reset the environment") } @@ -172,9 +188,36 @@ func TestManySockets(t *testing.T) { l1.Addr(), ls[1].Addr()) } + oneL, _ := OneListener("name1") + if !sameAddr(oneL.Addr(), l0.Addr()) { + t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", + oneL.Addr(), ls[0].Addr()) + } + if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" || os.Getenv("LISTEN_FDNAMES") != "" { t.Errorf("Failed to reset the environment") } } + +func TestListen(t *testing.T) { + orig := newListener(t) + firstFD = listenerFd(t, orig) + setenv(strconv.Itoa(os.Getpid()), "1", "name") + + l, err := Listen("tcp", "&name") + if err != nil { + t.Errorf("Listen failed: %v", err) + } + if !sameAddr(l.Addr(), orig.Addr()) { + t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", + l.Addr(), orig.Addr()) + } + + l, err = Listen("tcp", ":0") + if err != nil { + t.Errorf("Listen failed: %v", err) + } + t.Logf("listener created at %v", l.Addr()) +}