git » chasquid » commit 701f359

Support getting listeners from systemd

author Alberto Bertogli
2015-10-31 23:05:34 UTC
committer Alberto Bertogli
2015-11-01 02:19:23 UTC
parent a809a3caa9962b38caded37040d97272ea979d92

Support getting listeners from systemd

Add a new module for getting listener sockets via systemd's file descriptor
passing (see sd_listen_fds(3) for more details), and make the main daemon use
it when "systemd" is given an address.

chasquid.go +31 -1
internal/systemd/systemd.go +70 -0
internal/systemd/systemd_test.go +157 -0

diff --git a/chasquid.go b/chasquid.go
index c737d1b..222a84c 100644
--- a/chasquid.go
+++ b/chasquid.go
@@ -17,6 +17,7 @@ import (
 	"time"
 
 	"blitiri.com.ar/go/chasquid/internal/config"
+	"blitiri.com.ar/go/chasquid/internal/systemd"
 
 	_ "net/http/pprof"
 
@@ -69,14 +70,28 @@ func main() {
 	}
 
 	// Load addresses.
+	acount := 0
 	for _, addr := range conf.Address {
+		// The "systemd" address indicates we get listeners via systemd.
 		if addr == "systemd" {
-			// TODO
+			ls, err := systemd.Listeners()
+			if err != nil {
+				glog.Fatalf("Error getting listeners via systemd: %v", err)
+			}
+			s.AddListeners(ls)
+			acount += len(ls)
 		} else {
 			s.AddAddr(addr)
+			acount++
 		}
 	}
 
+	if acount == 0 {
+		glog.Errorf("No addresses/listeners configured")
+		glog.Errorf("If using systemd, check that you started chasquid.socket")
+		glog.Fatalf("Exiting")
+	}
+
 	s.ListenAndServe()
 }
 
@@ -93,6 +108,9 @@ type Server struct {
 	// Addresses.
 	addrs []string
 
+	// Listeners (that came via systemd).
+	listeners []net.Listener
+
 	// TLS config.
 	tlsConfig *tls.Config
 
@@ -119,6 +137,10 @@ func (s *Server) AddAddr(a string) {
 	s.addrs = append(s.addrs, a)
 }
 
+func (s *Server) AddListeners(ls []net.Listener) {
+	s.listeners = append(s.listeners, ls...)
+}
+
 func (s *Server) getTLSConfig() (*tls.Config, error) {
 	var err error
 	conf := &tls.Config{}
@@ -159,6 +181,14 @@ func (s *Server) ListenAndServe() {
 		go s.serve(l)
 	}
 
+	for _, l := range s.listeners {
+		defer l.Close()
+		glog.Infof("Server listening on %s (via systemd)", l.Addr())
+
+		// Serve.
+		go s.serve(l)
+	}
+
 	// Never return. If the serve goroutines have problems, they will abort
 	// execution.
 	for {
diff --git a/internal/systemd/systemd.go b/internal/systemd/systemd.go
new file mode 100644
index 0000000..a6a796c
--- /dev/null
+++ b/internal/systemd/systemd.go
@@ -0,0 +1,70 @@
+// Package systemd implements utility functions to interact with systemd.
+package systemd
+
+import (
+	"errors"
+	"fmt"
+	"net"
+	"os"
+	"strconv"
+	"syscall"
+)
+
+var (
+	// Error to return when $LISTEN_PID does not refer to us.
+	PIDMismatch = errors.New("$LISTEN_PID != our PID")
+
+	// First FD for listeners.
+	// It's 3 by definition, but using a variable simplifies testing.
+	firstFD = 3
+)
+
+// Listeners creates a slice net.Listener from the file descriptors passed
+// by systemd, via the LISTEN_FDS environment variable.
+// See sd_listen_fds(3) for more details.
+func Listeners() ([]net.Listener, error) {
+	pidStr := os.Getenv("LISTEN_PID")
+	nfdsStr := os.Getenv("LISTEN_FDS")
+
+	// Nothing to do if the variables are not set.
+	if pidStr == "" || nfdsStr == "" {
+		return nil, nil
+	}
+
+	pid, err := strconv.Atoi(pidStr)
+	if err != nil {
+		return nil, fmt.Errorf(
+			"error converting $LISTEN_PID=%q: %v", pidStr, err)
+	} else if pid != os.Getpid() {
+		return nil, PIDMismatch
+	}
+
+	nfds, err := strconv.Atoi(os.Getenv("LISTEN_FDS"))
+	if err != nil {
+		return nil, fmt.Errorf(
+			"error reading $LISTEN_FDS=%q: %v", nfdsStr, err)
+	}
+
+	listeners := []net.Listener{}
+
+	for fd := firstFD; fd < firstFD+nfds; fd++ {
+		// We don't want childs to inherit these file descriptors.
+		syscall.CloseOnExec(fd)
+
+		name := fmt.Sprintf("[systemd-fd-%d]", fd)
+		lis, err := net.FileListener(os.NewFile(uintptr(fd), name))
+		if err != nil {
+			return nil, fmt.Errorf(
+				"Error making listener out of fd %d: %v", fd, err)
+		}
+
+		listeners = append(listeners, lis)
+	}
+
+	// Remove them from the environment, to prevent accidental reuse (by
+	// us or children processes).
+	os.Unsetenv("LISTEN_PID")
+	os.Unsetenv("LISTEN_FDS")
+
+	return listeners, nil
+}
diff --git a/internal/systemd/systemd_test.go b/internal/systemd/systemd_test.go
new file mode 100644
index 0000000..28586fe
--- /dev/null
+++ b/internal/systemd/systemd_test.go
@@ -0,0 +1,157 @@
+package systemd
+
+import (
+	"math/rand"
+	"net"
+	"os"
+	"strconv"
+	"testing"
+)
+
+func setenv(pid, fds string) {
+	os.Setenv("LISTEN_PID", pid)
+	os.Setenv("LISTEN_FDS", fds)
+}
+
+func TestEmptyEnvironment(t *testing.T) {
+	cases := []struct{ pid, fds string }{
+		{"", ""},
+		{"123", ""},
+		{"", "4"},
+	}
+	for _, c := range cases {
+		setenv(c.pid, c.fds)
+
+		if ls, err := Listeners(); ls != nil || err != nil {
+			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q", c.pid, c.fds)
+			t.Errorf("Unexpected result: %v // %v", ls, err)
+		}
+	}
+}
+
+func TestBadEnvironment(t *testing.T) {
+	ourPID := strconv.Itoa(os.Getpid())
+	cases := []struct{ pid, fds string }{
+		{"a", "4"},
+		{ourPID, "a"},
+	}
+	for _, c := range cases {
+		setenv(c.pid, c.fds)
+
+		if ls, err := Listeners(); err == nil {
+			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q", c.pid, c.fds)
+			t.Errorf("Unexpected result: %v // %v", ls, err)
+		}
+	}
+}
+
+func TestWrongPID(t *testing.T) {
+	// Find a pid != us. 1 should always work in practice.
+	pid := 1
+	for pid == os.Getpid() {
+		pid = rand.Int()
+	}
+
+	setenv(strconv.Itoa(pid), "4")
+	if _, err := Listeners(); err != PIDMismatch {
+		t.Errorf("Did not fail with PID mismatch: %v", err)
+	}
+}
+
+func TestNoFDs(t *testing.T) {
+	setenv(strconv.Itoa(os.Getpid()), "0")
+	if ls, err := Listeners(); len(ls) != 0 || err != nil {
+		t.Errorf("Got a non-empty result: %v // %v", ls, err)
+	}
+}
+
+// newListener creates a TCP listener.
+func newListener(t *testing.T) *net.TCPListener {
+	addr := &net.TCPAddr{
+		Port: 0,
+	}
+
+	l, err := net.ListenTCP("tcp", addr)
+	if err != nil {
+		t.Fatalf("Could not create TCP listener: %v", err)
+	}
+
+	return l
+}
+
+// listenerFd returns a file descriptor for the listener.
+// Note it is a NEW file descriptor, not the original one.
+func listenerFd(t *testing.T, l *net.TCPListener) int {
+	f, err := l.File()
+	if err != nil {
+		t.Fatalf("Could not get TCP listener file: %v", err)
+	}
+
+	return int(f.Fd())
+}
+
+func sameAddr(a, b net.Addr) bool {
+	return a.Network() == b.Network() && a.String() == b.String()
+}
+
+func TestOneSocket(t *testing.T) {
+	l := newListener(t)
+	firstFD = listenerFd(t, l)
+
+	setenv(strconv.Itoa(os.Getpid()), "1")
+
+	ls, err := Listeners()
+	if err != nil || len(ls) != 1 {
+		t.Fatalf("Got an invalid result: %v // %v", ls, err)
+	}
+
+	if !sameAddr(ls[0].Addr(), l.Addr()) {
+		t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
+			l.Addr(), ls[0].Addr())
+	}
+
+	if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" {
+		t.Errorf("Failed to reset the environment")
+	}
+}
+
+func TestManySockets(t *testing.T) {
+	// Create two contiguous listeners.
+	// The test environment does not guarantee us that they are contiguous, so
+	// keep going until they are.
+	var l0, l1 *net.TCPListener
+	var f0, f1 int = -1, -3
+
+	for f0+1 != f1 {
+		// We have to be careful with the order of these operations, because
+		// listenerFd will create *new* file descriptors.
+		l0 = newListener(t)
+		l1 = newListener(t)
+		f0 = listenerFd(t, l0)
+		f1 = listenerFd(t, l1)
+		t.Logf("Looping for FDs: %d %d", f0, f1)
+	}
+
+	firstFD = f0
+
+	setenv(strconv.Itoa(os.Getpid()), "2")
+
+	ls, err := Listeners()
+	if err != nil || len(ls) != 2 {
+		t.Fatalf("Got an invalid result: %v // %v", ls, err)
+	}
+
+	if !sameAddr(ls[0].Addr(), l0.Addr()) {
+		t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
+			l0.Addr(), ls[0].Addr())
+	}
+
+	if !sameAddr(ls[1].Addr(), l1.Addr()) {
+		t.Errorf("Listener 1 address mismatch, expected %#v, got %#v",
+			l1.Addr(), ls[1].Addr())
+	}
+
+	if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" {
+		t.Errorf("Failed to reset the environment")
+	}
+}