git » systemd » commit 40368eb

Distinguish between SMTP and submission ports

author Alberto Bertogli
2016-09-12 02:47:36 UTC
committer Alberto Bertogli
2016-10-09 23:50:24 UTC
parent 819a282f71a39545e1de582abce04c3adc5048ff

Distinguish between SMTP and submission ports

We want to be able to distinguish between connections for SMTP and connections
for submission, so we can make different policy decisions.

To do that, we first make the configuration aware of the different kinds of
addresses. This is done in this patch in a backwards-incompatible way, but at
this point in time it is ok to do so.

Then, we extend systemd's socket passing library to support socket naming, so
we can tell the different sockets apart. This is done via the
LISTEN_FDNAMES/FileDescriptorName mechanism.

And finally we make the server and connection types aware of the socket mode.

systemd.go +22 -7
systemd_test.go +36 -15

diff --git a/systemd.go b/systemd.go
index 4cbaf5e..75f2686 100644
--- a/systemd.go
+++ b/systemd.go
@@ -7,6 +7,7 @@ import (
 	"net"
 	"os"
 	"strconv"
+	"strings"
 	"syscall"
 )
 
@@ -21,10 +22,12 @@ var (
 
 // Listeners creates a slice net.Listener from the file descriptors passed
 // by systemd, via the LISTEN_FDS environment variable.
-// See sd_listen_fds(3) for more details.
-func Listeners() ([]net.Listener, error) {
+// See sd_listen_fds(3) and sd_listen_fds_with_names(3) for more details.
+func Listeners() (map[string][]net.Listener, error) {
 	pidStr := os.Getenv("LISTEN_PID")
 	nfdsStr := os.Getenv("LISTEN_FDS")
+	fdNamesStr := os.Getenv("LISTEN_FDNAMES")
+	fdNames := strings.Split(fdNamesStr, ":")
 
 	// Nothing to do if the variables are not set.
 	if pidStr == "" || nfdsStr == "" {
@@ -45,26 +48,38 @@ func Listeners() ([]net.Listener, error) {
 			"error reading $LISTEN_FDS=%q: %v", nfdsStr, err)
 	}
 
-	listeners := []net.Listener{}
+	// We should have as many names as we have descriptors.
+	// Note that if we have no descriptors, fdNames will be [""] (due to how
+	// strings.Split works), so we consider that special case.
+	if nfds > 0 && (fdNamesStr == "" || len(fdNames) != nfds) {
+		return nil, fmt.Errorf(
+			"Incorrect LISTEN_FDNAMES, have you set FileDescriptorName?")
+	}
 
-	for fd := firstFD; fd < firstFD+nfds; fd++ {
+	listeners := map[string][]net.Listener{}
+
+	for i := 0; i < nfds; i++ {
+		fd := firstFD + i
 		// We don't want childs to inherit these file descriptors.
 		syscall.CloseOnExec(fd)
 
-		name := fmt.Sprintf("[systemd-fd-%d]", fd)
-		lis, err := net.FileListener(os.NewFile(uintptr(fd), name))
+		name := fdNames[i]
+
+		sysName := fmt.Sprintf("[systemd-fd-%d-%v]", fd, name)
+		lis, err := net.FileListener(os.NewFile(uintptr(fd), sysName))
 		if err != nil {
 			return nil, fmt.Errorf(
 				"Error making listener out of fd %d: %v", fd, err)
 		}
 
-		listeners = append(listeners, lis)
+		listeners[name] = append(listeners[name], lis)
 	}
 
 	// Remove them from the environment, to prevent accidental reuse (by
 	// us or children processes).
 	os.Unsetenv("LISTEN_PID")
 	os.Unsetenv("LISTEN_FDS")
+	os.Unsetenv("LISTEN_FDNAMES")
 
 	return listeners, nil
 }
diff --git a/systemd_test.go b/systemd_test.go
index 34bca4b..226b124 100644
--- a/systemd_test.go
+++ b/systemd_test.go
@@ -5,12 +5,14 @@ import (
 	"net"
 	"os"
 	"strconv"
+	"strings"
 	"testing"
 )
 
-func setenv(pid, fds string) {
+func setenv(pid, fds string, names ...string) {
 	os.Setenv("LISTEN_PID", pid)
 	os.Setenv("LISTEN_FDS", fds)
+	os.Setenv("LISTEN_FDNAMES", strings.Join(names, ":"))
 }
 
 func TestEmptyEnvironment(t *testing.T) {
@@ -30,16 +32,26 @@ func TestEmptyEnvironment(t *testing.T) {
 }
 
 func TestBadEnvironment(t *testing.T) {
+	// Create a listener so we have something to reference.
+	l := newListener(t)
+	firstFD = listenerFd(t, l)
+
 	ourPID := strconv.Itoa(os.Getpid())
-	cases := []struct{ pid, fds string }{
-		{"a", "4"},
-		{ourPID, "a"},
+	cases := []struct {
+		pid, fds string
+		names    []string
+	}{
+		{"a", "1", []string{"name"}},              // Invalid PID.
+		{ourPID, "a", []string{"name"}},           // Invalid number of fds.
+		{"1", "1", []string{"name"}},              // PID != ourselves.
+		{ourPID, "1", []string{"name1", "name2"}}, // Too many names.
+		{ourPID, "1", []string{}},                 // Not enough names.
 	}
 	for _, c := range cases {
-		setenv(c.pid, c.fds)
+		setenv(c.pid, c.fds, c.names...)
 
 		if ls, err := Listeners(); err == nil {
-			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q", c.pid, c.fds)
+			t.Logf("Case: LISTEN_PID=%q  LISTEN_FDS=%q LISTEN_FDNAMES=%q", c.pid, c.fds, c.names)
 			t.Errorf("Unexpected result: %v // %v", ls, err)
 		}
 	}
@@ -98,13 +110,15 @@ func TestOneSocket(t *testing.T) {
 	l := newListener(t)
 	firstFD = listenerFd(t, l)
 
-	setenv(strconv.Itoa(os.Getpid()), "1")
+	setenv(strconv.Itoa(os.Getpid()), "1", "name")
 
-	ls, err := Listeners()
-	if err != nil || len(ls) != 1 {
-		t.Fatalf("Got an invalid result: %v // %v", ls, err)
+	lsMap, err := Listeners()
+	if err != nil || len(lsMap) != 1 {
+		t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
 	}
 
+	ls := lsMap["name"]
+
 	if !sameAddr(ls[0].Addr(), l.Addr()) {
 		t.Errorf("Listener 0 address mismatch, expected %#v, got %#v",
 			l.Addr(), ls[0].Addr())
@@ -134,11 +148,16 @@ func TestManySockets(t *testing.T) {
 
 	firstFD = f0
 
-	setenv(strconv.Itoa(os.Getpid()), "2")
+	setenv(strconv.Itoa(os.Getpid()), "2", "name1", "name2")
 
-	ls, err := Listeners()
-	if err != nil || len(ls) != 2 {
-		t.Fatalf("Got an invalid result: %v // %v", ls, err)
+	lsMap, err := Listeners()
+	if err != nil || len(lsMap) != 2 {
+		t.Fatalf("Got an invalid result: %v // %v", lsMap, err)
+	}
+
+	ls := []net.Listener{
+		lsMap["name1"][0],
+		lsMap["name2"][0],
 	}
 
 	if !sameAddr(ls[0].Addr(), l0.Addr()) {
@@ -151,7 +170,9 @@ func TestManySockets(t *testing.T) {
 			l1.Addr(), ls[1].Addr())
 	}
 
-	if os.Getenv("LISTEN_PID") != "" || os.Getenv("LISTEN_FDS") != "" {
+	if os.Getenv("LISTEN_PID") != "" ||
+		os.Getenv("LISTEN_FDS") != "" ||
+		os.Getenv("LISTEN_FDNAMES") != "" {
 		t.Errorf("Failed to reset the environment")
 	}
 }