git » gofer » main » tree

[main] / server / server_test.go

package server

import (
	"fmt"
	"io/ioutil"
	"net"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"
	"time"

	"blitiri.com.ar/go/gofer/config"
	"blitiri.com.ar/go/log"
)

// WaitForHTTPServer waits 5 seconds for an HTTP server to start, and returns
// an error if it fails to do so.
// It does this by repeatedly querying the server until it either replies or
// times out.
func waitForHTTPServer(addr string) error {
	c := http.Client{
		Timeout: 100 * time.Millisecond,
	}

	deadline := time.Now().Add(5 * time.Second)
	tick := time.Tick(100 * time.Millisecond)

	for (<-tick).Before(deadline) {
		_, err := c.Get("http://" + addr + "/testpoke")
		if err == nil {
			return nil
		}
	}

	return fmt.Errorf("timed out")
}

// Get a free (TCP) port. This is hacky and not race-free, but it works well
// enough for testing purposes.
func getFreePort() string {
	l, _ := net.Listen("tcp", "localhost:0")
	defer l.Close()
	return l.Addr().String()
}

const backendResponse = "backend response\n"

// Addresses of the proxy under test (created by TestMain).
var (
	httpAddr string
	rawAddr  string
)

// startServer for testing. Returns raw addr, http addr, and the backend test
// server (which should be closed afterwards).
// Note it leaks goroutines, we're ok with this for testing.
func TestMain(m *testing.M) {
	backend := httptest.NewServer(
		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			fmt.Fprint(w, backendResponse)
		}))
	defer backend.Close()

	// We have two frontends: one raw and one http.
	rawAddr = getFreePort()
	httpAddr = getFreePort()

	log.Default.Level = log.Error

	pwd, _ := os.Getwd()

	const configTemplate = `
raw:
  "$RAW_ADDR":
    to: "$BACKEND_ADDR"

http:
  "$HTTP_ADDR":
    routes:
      "/be/": { proxy: "$BACKEND_URL" }
      "localhost/xy/": { proxy: "$BACKEND_URL" }
      "/static/hola": { file: "$PWD/testdata/hola" }
      "/dir/": { dir: "$PWD/testdata/" }
      "/redir/": { redirect: "http://$HTTP_ADDR/dir/" }
`
	configStr := strings.NewReplacer(
		"$RAW_ADDR", rawAddr,
		"$HTTP_ADDR", httpAddr,
		"$BACKEND_URL", backend.URL,
		"$BACKEND_ADDR", backend.Listener.Addr().String(),
		"$PWD", pwd,
	).Replace(configTemplate)

	conf, err := config.LoadString(configStr)
	if err != nil {
		log.Fatalf("error loading test config: %v", err)
	}

	go Raw(rawAddr, conf.Raw[rawAddr])
	go HTTP(httpAddr, conf.HTTP[httpAddr])

	waitForHTTPServer(httpAddr)
	waitForHTTPServer(rawAddr)

	os.Exit(m.Run())
}

func TestSimple(t *testing.T) {
	// Test the raw proxy.
	testGet(t, "http://"+rawAddr+"/be", 200)

	// Test the HTTP proxy. Try a combination of URLs and error responses just
	// to exercise a bit more of the path handling and error checking code.
	testGet(t, "http://"+httpAddr+"/be", 200)
	testGet(t, "http://"+httpAddr+"/be/", 200)
	testGet(t, "http://"+httpAddr+"/be/2", 200)
	testGet(t, "http://"+httpAddr+"/be/3", 200)
	testGet(t, "http://"+httpAddr+"/x", 404)
	testGet(t, "http://"+httpAddr+"/xy/1", 404)

	// Test the domain-based routing.
	_, httpPort, _ := net.SplitHostPort(httpAddr)
	testGet(t, "http://localhost:"+httpPort+"/be/", 200)
	testGet(t, "http://localhost:"+httpPort+"/xy/1", 200)

	// Test dir and file schemes.
	testGet(t, "http://"+httpAddr+"/static/hola", 200)
	testGet(t, "http://"+httpAddr+"/dir/hola", 200)
	testGet(t, "http://"+httpAddr+"/redir/hola", 200)
}

func testGet(t *testing.T, url string, expectedStatus int) {
	t.Helper()
	t.Logf("URL: %s", url)
	resp, err := http.Get(url)
	if err != nil {
		t.Fatal(err)
	}
	t.Logf("status %v", resp.Status)

	if resp.StatusCode != expectedStatus {
		t.Errorf("expected status %d, got %v", expectedStatus, resp.Status)
		t.Errorf("response: %#v", resp)
	}

	// We don't care about the body for non-200 responses.
	if resp.StatusCode != http.StatusOK {
		return
	}

	b, err := ioutil.ReadAll(resp.Body)
	if err != nil {
		t.Fatal(err)
	}
	if string(b) != backendResponse {
		t.Errorf("expected body = %q, got %q", backendResponse, string(b))
	}

	t.Logf("response body: %q", b)
}

func TestJoinPath(t *testing.T) {
	cases := []struct{ a, b, expected string }{
		{"/a/", "", "/a/"},
		{"/a/", "b", "/a/b"},
		{"/a/", "b/", "/a/b/"},
		{"a/", "", "a/"},
		{"a/", "b", "a/b"},
		{"a/", "b/", "a/b/"},
		{"a/", "/b/", "a/b/"},
		{"/", "", "/"},
		{"", "", "/"},
		{"/", "/", "/"},
	}
	for _, c := range cases {
		got := joinPath(c.a, c.b)
		if got != c.expected {
			t.Errorf("join %q, %q = %q, expected %q", c.a, c.b, got, c.expected)
		}
	}
}

func TestAdjustPath(t *testing.T) {
	cases := []struct{ from, to, req, expected string }{
		{"/", "/", "/", "/"},
		{"/", "/", "/a", "/a"},
		{"/", "/", "/a/x", "/a/x"},
		{"/a", "/", "/a", "/"},
		{"/a", "/", "/a/", "/"},
		{"/a", "/", "/a/x", "/x"},
		{"/a/", "/", "/a/", "/"},
		{"/a/", "/", "/a/x", "/x"},
		{"/a/", "/b", "/a/", "/b"},
		{"/a/", "/b", "/a/x", "/b/x"},
		{"/p/q", "/r/s", "/p/q", "/r/s"},
		{"/p/q", "/r/s", "/p/q", "/r/s"},
		{"/p/q", "/r/s", "/p/q/x", "/r/s/x"},
	}
	for _, c := range cases {
		got := adjustPath(c.req, c.from, c.to)
		if got != c.expected {
			t.Errorf("adjustPath(%q, %q, %q) = %q, expected %q",
				c.req, c.from, c.to, got, c.expected)
		}
	}
}

func Benchmark(b *testing.B) {
	makeBench := func(url string) func(b *testing.B) {
		return func(b *testing.B) {
			var resp *http.Response
			var err error
			for i := 0; i < b.N; i++ {
				resp, err = http.Get(url)
				if err != nil {
					b.Fatal(err)
				}
				resp.Body.Close()
				if resp.StatusCode != 200 {
					b.Errorf("expected status 200, got %v", resp.Status)
					b.Fatalf("response: %#v", resp)
				}
			}
		}
	}

	b.Run("HTTP", makeBench("http://"+httpAddr+"/be"))
	b.Run("Raw", makeBench("http://"+rawAddr+"/be"))
}

func BenchmarkParallel(b *testing.B) {
	makeP := func(url string) func(pb *testing.PB) {
		return func(pb *testing.PB) {
			var resp *http.Response
			var err error
			for pb.Next() {
				resp, err = http.Get(url)
				if err != nil {
					b.Fatal(err)
				}
				resp.Body.Close()
				if resp.StatusCode != 200 {
					b.Errorf("expected status 200, got %v", resp.Status)
					b.Fatalf("response: %#v", resp)
				}
			}
		}
	}

	b.Run("HTTP", func(b *testing.B) {
		b.RunParallel(makeP("http://" + httpAddr + "/be"))
	})
	b.Run("Raw", func(b *testing.B) {
		b.RunParallel(makeP("http://" + rawAddr + "/be"))
	})
}