git » go-net » unix-separados » tree

[unix-separados] / unix / cred_test.go

// Copyright 2016 The Go Authors.  All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package unix

import (
	"io/ioutil"
	"net"
	"os"
	"runtime"
	"testing"
)

// tempPath uses ioutil.TempFile to get a name that is unique.
// It also uses /tmp directory in case it is prohibited to create UNIX
// sockets in TMPDIR.
func tempPath() string {
	f, err := ioutil.TempFile("", "go-nettest")
	if err != nil {
		panic(err)
	}
	addr := f.Name()
	f.Close()
	os.Remove(addr)
	return addr
}

func TestGetPeerCredentials(t *testing.T) {
	if runtime.GOOS == "plan9" || runtime.GOOS == "windows" {
		t.Skip("GetPeerCredentials not work on this OS")
	}

	path := tempPath()
	ln, err := net.ListenUnix("unix", &net.UnixAddr{Name: path, Net: "unix"})
	if err != nil {
		t.Fatalf("failed to listen on %q: %v", path, err)
	}

	go func() {
		_, err := net.Dial("unix", path)
		if err != nil {
			t.Fatalf("failed to dial %q: %v", path, err)
		}
	}()

	c, err := ln.AcceptUnix()
	if err != nil {
		t.Fatal(err)
	}

	creds, err := GetPeerCredentials(c)
	if err != nil {
		t.Fatal(err)
	}

	if creds.Uid != os.Getuid() {
		t.Errorf("uid mismatch, got %d, want %d", creds.Uid, os.Getuid())
	}
	if creds.Gid != os.Getgid() {
		t.Errorf("gid mismatch, got %d, want %d", creds.Gid, os.Getgid())
	}

	if runtime.GOOS == "linux" || runtime.GOOS == "openbsd" {
		if creds.Pid != os.Getpid() {
			t.Errorf("pid mismatch, got %d, want %d", creds.Pid, os.Getpid())
		}
	}
}