git » go-net » commit c98446e

unix: implement unix.GetPeerCredentials

author Alberto Bertogli
2016-09-16 19:32:00 UTC
committer Alberto Bertogli
2016-09-16 21:03:05 UTC
parent d06f1c2971a83c319103b39e99b217c11d4280c4

unix: implement unix.GetPeerCredentials

This commit implements a new unix.GetPeerCredentials function that can be used
to get the peer credentials of a net.UnixConn.

Change-Id: I7c04ded40b957da01c104857e5c00d525fec9a05

unix/cred.go +35 -0
unix/cred_linux.go +14 -0
unix/cred_stub.go +12 -0
unix/cred_test.go +69 -0
unix/defs_linux.go +24 -0
unix/sys_linux.go +18 -0
unix/zdefs_linux.go +17 -0

diff --git a/unix/cred.go b/unix/cred.go
new file mode 100644
index 0000000..7e8b315
--- /dev/null
+++ b/unix/cred.go
@@ -0,0 +1,35 @@
+// Copyright 2016 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package unix contains utilities for UNIX sockets.
+package unix
+
+import (
+	"errors"
+	"net"
+)
+
+var (
+	errOpNoSupport = errors.New("operation not supported")
+)
+
+// Credentials of a peer connected to an UNIX socket.
+// Not all fields will be set in all OSs, some only provide Uid and Gid.
+// Values will be set to -1 if they're not supported.
+type UnixPeerCreds struct {
+	Uid int
+	Gid int
+	Pid int
+}
+
+// GetPeerCredentials obtains the credentials of the peer c is connected to.
+func GetPeerCredentials(c *net.UnixConn) (*UnixPeerCreds, error) {
+	f, err := c.File()
+	if err != nil {
+		return nil, err
+	}
+	defer f.Close()
+
+	return getPeerCredentials(f.Fd())
+}
diff --git a/unix/cred_linux.go b/unix/cred_linux.go
new file mode 100644
index 0000000..58e034b
--- /dev/null
+++ b/unix/cred_linux.go
@@ -0,0 +1,14 @@
+// Copyright 2016 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unix
+
+func getPeerCredentials(fd uintptr) (*UnixPeerCreds, error) {
+	cred, err := getSockoptUcred(fd, sysSOL_SOCKET, sysSO_PEERCRED)
+	if err != nil {
+		return nil, err
+	}
+
+	return &UnixPeerCreds{int(cred.Uid), int(cred.Gid), int(cred.Pid)}, nil
+}
diff --git a/unix/cred_stub.go b/unix/cred_stub.go
new file mode 100644
index 0000000..1faebd8
--- /dev/null
+++ b/unix/cred_stub.go
@@ -0,0 +1,12 @@
+// Copyright 2016 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Systems that have no way of getting peer credentials.
+// +build plan9 windows
+
+package unix
+
+func getPeerCredentials(fd uintptr) (creds *UnixPeerCreds, err error) {
+	return nil, errOpNoSupport
+}
diff --git a/unix/cred_test.go b/unix/cred_test.go
new file mode 100644
index 0000000..056aef5
--- /dev/null
+++ b/unix/cred_test.go
@@ -0,0 +1,69 @@
+// Copyright 2016 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unix
+
+import (
+	"io/ioutil"
+	"net"
+	"os"
+	"runtime"
+	"testing"
+)
+
+// tempPath uses ioutil.TempFile to get a name that is unique.
+// It also uses /tmp directory in case it is prohibited to create UNIX
+// sockets in TMPDIR.
+func tempPath() string {
+	f, err := ioutil.TempFile("", "go-nettest")
+	if err != nil {
+		panic(err)
+	}
+	addr := f.Name()
+	f.Close()
+	os.Remove(addr)
+	return addr
+}
+
+func TestGetPeerCredentials(t *testing.T) {
+	if runtime.GOOS == "plan9" || runtime.GOOS == "windows" {
+		t.Skip("GetPeerCredentials not work on this OS")
+	}
+
+	path := tempPath()
+	ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"})
+	if err != nil {
+		t.Fatalf("failed to listen on %q: %v", path, err)
+	}
+
+	go func() {
+		_, err := net.Dial("unix", path)
+		if err != nil {
+			t.Fatalf("failed to dial %q: %v", path, err)
+		}
+	}()
+
+	c, err := ln.AcceptUnix()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	creds, err := GetPeerCredentials(c)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if creds.Uid != os.Getuid() {
+		t.Errorf("uid mismatch, got %d, want %d", creds.Uid, os.Getuid())
+	}
+	if creds.Gid != os.Getgid() {
+		t.Errorf("gid mismatch, got %d, want %d", creds.Gid, os.Getgid())
+	}
+
+	if runtime.GOOS == "linux" || runtime.GOOS == "openbsd" {
+		if creds.Pid != os.Getpid() {
+			t.Errorf("pid mismatch, got %d, want %d", creds.Pid, os.Getpid())
+		}
+	}
+}
diff --git a/unix/defs_linux.go b/unix/defs_linux.go
new file mode 100644
index 0000000..14b4d3d
--- /dev/null
+++ b/unix/defs_linux.go
@@ -0,0 +1,24 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build ignore
+
+package unix
+
+/*
+#define _GNU_SOURCE
+
+#include <sys/socket.h>
+#include <sys/un.h>
+*/
+import "C"
+
+type ucred C.struct_ucred
+
+const (
+	sysSO_PEERCRED = C.SO_PEERCRED
+	sysSOL_SOCKET  = C.SOL_SOCKET
+
+	sizeofUcred = C.sizeof_struct_ucred
+)
diff --git a/unix/sys_linux.go b/unix/sys_linux.go
new file mode 100644
index 0000000..b90ee55
--- /dev/null
+++ b/unix/sys_linux.go
@@ -0,0 +1,18 @@
+// Copyright 2016 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unix
+
+import (
+	"unsafe"
+
+	"golang.org/x/net/internal/netsyscall"
+)
+
+func getSockoptUcred(fd uintptr, level, opt int) (*ucred, error) {
+	var value ucred
+	vallen := uint32(sizeofUcred)
+	err := netsyscall.Getsockopt(fd, level, opt, unsafe.Pointer(&value), &vallen)
+	return &value, err
+}
diff --git a/unix/zdefs_linux.go b/unix/zdefs_linux.go
new file mode 100644
index 0000000..2754665
--- /dev/null
+++ b/unix/zdefs_linux.go
@@ -0,0 +1,17 @@
+// Created by cgo -godefs - DO NOT EDIT
+// cgo -godefs defs_linux.go
+
+package unix
+
+type ucred struct {
+	Pid int32
+	Uid uint32
+	Gid uint32
+}
+
+const (
+	sysSO_PEERCRED = 0x11
+	sysSOL_SOCKET  = 0x1
+
+	sizeofUcred = 0xc
+)