author | Alberto Bertogli
<albertito@blitiri.com.ar> 2015-10-31 23:05:34 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2015-11-01 02:19:23 UTC |
parent | a809a3caa9962b38caded37040d97272ea979d92 |
chasquid.go | +31 | -1 |
internal/systemd/systemd.go | +70 | -0 |
internal/systemd/systemd_test.go | +157 | -0 |
diff --git a/chasquid.go b/chasquid.go index c737d1b..222a84c 100644 --- a/chasquid.go +++ b/chasquid.go @@ -17,6 +17,7 @@ import ( "time" "blitiri.com.ar/go/chasquid/internal/config" + "blitiri.com.ar/go/chasquid/internal/systemd" _ "net/http/pprof" @@ -69,14 +70,28 @@ func main() { } // Load addresses. + acount := 0 for _, addr := range conf.Address { + // The "systemd" address indicates we get listeners via systemd. if addr == "systemd" { - // TODO + ls, err := systemd.Listeners() + if err != nil { + glog.Fatalf("Error getting listeners via systemd: %v", err) + } + s.AddListeners(ls) + acount += len(ls) } else { s.AddAddr(addr) + acount++ } } + if acount == 0 { + glog.Errorf("No addresses/listeners configured") + glog.Errorf("If using systemd, check that you started chasquid.socket") + glog.Fatalf("Exiting") + } + s.ListenAndServe() } @@ -93,6 +108,9 @@ type Server struct { // Addresses. addrs []string + // Listeners (that came via systemd). + listeners []net.Listener + // TLS config. tlsConfig *tls.Config @@ -119,6 +137,10 @@ func (s *Server) AddAddr(a string) { s.addrs = append(s.addrs, a) } +func (s *Server) AddListeners(ls []net.Listener) { + s.listeners = append(s.listeners, ls...) +} + func (s *Server) getTLSConfig() (*tls.Config, error) { var err error conf := &tls.Config{} @@ -159,6 +181,14 @@ func (s *Server) ListenAndServe() { go s.serve(l) } + for _, l := range s.listeners { + defer l.Close() + glog.Infof("Server listening on %s (via systemd)", l.Addr()) + + // Serve. + go s.serve(l) + } + // Never return. If the serve goroutines have problems, they will abort // execution. for { diff --git a/internal/systemd/systemd.go b/internal/systemd/systemd.go new file mode 100644 index 0000000..a6a796c --- /dev/null +++ b/internal/systemd/systemd.go @@ -0,0 +1,70 @@ +// Package systemd implements utility functions to interact with systemd. +package systemd + +import ( + "errors" + "fmt" + "net" + "os" + "strconv" + "syscall" +) + +var ( + // Error to return when $LISTEN_PID does not refer to us. + PIDMismatch = errors.New("$LISTEN_PID != our PID") + + // First FD for listeners. + // It's 3 by definition, but using a variable simplifies testing. + firstFD = 3 +) + +// Listeners creates a slice net.Listener from the file descriptors passed +// by systemd, via the LISTEN_FDS environment variable. +// See sd_listen_fds(3) for more details. +func Listeners() ([]net.Listener, error) { + pidStr := os.Getenv("LISTEN_PID") + nfdsStr := os.Getenv("LISTEN_FDS") + + // Nothing to do if the variables are not set. + if pidStr == "" || nfdsStr == "" { + return nil, nil + } + + pid, err := strconv.Atoi(pidStr) + if err != nil { + return nil, fmt.Errorf( + "error converting $LISTEN_PID=%q: %v", pidStr, err) + } else if pid != os.Getpid() { + return nil, PIDMismatch + } + + nfds, err := strconv.Atoi(os.Getenv("LISTEN_FDS")) + if err != nil { + return nil, fmt.Errorf( + "error reading $LISTEN_FDS=%q: %v", nfdsStr, err) + } + + listeners := []net.Listener{} + + for fd := firstFD; fd < firstFD+nfds; fd++ { + // We don't want childs to inherit these file descriptors. + syscall.CloseOnExec(fd) + + name := fmt.Sprintf("[systemd-fd-%d]", fd) + lis, err := net.FileListener(os.NewFile(uintptr(fd), name)) + if err != nil { + return nil, fmt.Errorf( + "Error making listener out of fd %d: %v", fd, err) + } + + listeners = append(listeners, lis) + } + + // Remove them from the environment, to prevent accidental reuse (by + // us or children processes). + os.Unsetenv("LISTEN_PID") + os.Unsetenv("LISTEN_FDS") + + return listeners, nil +} diff --git a/internal/systemd/systemd_test.go b/internal/systemd/systemd_test.go new file mode 100644 index 0000000..28586fe --- /dev/null +++ b/internal/systemd/systemd_test.go @@ -0,0 +1,157 @@ +package systemd + +import ( + "math/rand" + "net" + "os" + "strconv" + "testing" +) + +func setenv(pid, fds string) { + os.Setenv("LISTEN_PID", pid) + os.Setenv("LISTEN_FDS", fds) +} + +func TestEmptyEnvironment(t *testing.T) { + cases := []struct{ pid, fds string }{ + {"", ""}, + {"123", ""}, + {"", "4"}, + } + for _, c := range cases { + setenv(c.pid, c.fds) + + if ls, err := Listeners(); ls != nil || err != nil { + t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q", c.pid, c.fds) + t.Errorf("Unexpected result: %v // %v", ls, err) + } + } +} + +func TestBadEnvironment(t *testing.T) { + ourPID := strconv.Itoa(os.Getpid()) + cases := []struct{ pid, fds string }{ + {"a", "4"}, + {ourPID, "a"}, + } + for _, c := range cases { + setenv(c.pid, c.fds) + + if ls, err := Listeners(); err == nil { + t.Logf("Case: LISTEN_PID=%q LISTEN_FDS=%q", c.pid, c.fds) + t.Errorf("Unexpected result: %v // %v", ls, err) + } + } +} + +func TestWrongPID(t *testing.T) { + // Find a pid != us. 1 should always work in practice. + pid := 1 + for pid == os.Getpid() { + pid = rand.Int() + } + + setenv(strconv.Itoa(pid), "4") + if _, err := Listeners(); err != PIDMismatch { + t.Errorf("Did not fail with PID mismatch: %v", err) + } +} + +func TestNoFDs(t *testing.T) { + setenv(strconv.Itoa(os.Getpid()), "0") + if ls, err := Listeners(); len(ls) != 0 || err != nil { + t.Errorf("Got a non-empty result: %v // %v", ls, err) + } +} + +// newListener creates a TCP listener. +func newListener(t *testing.T) *net.TCPListener { + addr := &net.TCPAddr{ + Port: 0, + } + + l, err := net.ListenTCP("tcp", addr) + if err != nil { + t.Fatalf("Could not create TCP listener: %v", err) + } + + return l +} + +// listenerFd returns a file descriptor for the listener. +// Note it is a NEW file descriptor, not the original one. +func listenerFd(t *testing.T, l *net.TCPListener) int { + f, err := l.File() + if err != nil { + t.Fatalf("Could not get TCP listener file: %v", err) + } + + return int(f.Fd()) +} + +func sameAddr(a, b net.Addr) bool { + return a.Network() == b.Network() && a.String() == b.String() +} + +func TestOneSocket(t *testing.T) { + l := newListener(t) + firstFD = listenerFd(t, l) + + setenv(strconv.Itoa(os.Getpid()), "1") + + ls, err := Listeners() + if err != nil || len(ls) != 1 { + t.Fatalf("Got an invalid result: %v // %v", ls, err) + } + + if !sameAddr(ls[0].Addr(), l.Addr()) { + t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", + l.Addr(), ls[0].Addr()) + } + + if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" { + t.Errorf("Failed to reset the environment") + } +} + +func TestManySockets(t *testing.T) { + // Create two contiguous listeners. + // The test environment does not guarantee us that they are contiguous, so + // keep going until they are. + var l0, l1 *net.TCPListener + var f0, f1 int = -1, -3 + + for f0+1 != f1 { + // We have to be careful with the order of these operations, because + // listenerFd will create *new* file descriptors. + l0 = newListener(t) + l1 = newListener(t) + f0 = listenerFd(t, l0) + f1 = listenerFd(t, l1) + t.Logf("Looping for FDs: %d %d", f0, f1) + } + + firstFD = f0 + + setenv(strconv.Itoa(os.Getpid()), "2") + + ls, err := Listeners() + if err != nil || len(ls) != 2 { + t.Fatalf("Got an invalid result: %v // %v", ls, err) + } + + if !sameAddr(ls[0].Addr(), l0.Addr()) { + t.Errorf("Listener 0 address mismatch, expected %#v, got %#v", + l0.Addr(), ls[0].Addr()) + } + + if !sameAddr(ls[1].Addr(), l1.Addr()) { + t.Errorf("Listener 1 address mismatch, expected %#v, got %#v", + l1.Addr(), ls[1].Addr()) + } + + if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" { + t.Errorf("Failed to reset the environment") + } +}