git » dnss » master » tree

[master] / internal / nettrace / http_test.go

package nettrace

import (
	"fmt"
	"io"
	"net/http"
	"net/http/httptest"
	"net/url"
	"strings"
	"testing"
	"time"
)

func getValues(t *testing.T, vs url.Values, code int) string {
	t.Helper()

	req := httptest.NewRequest("GET", "/debug/traces?"+vs.Encode(), nil)
	w := httptest.NewRecorder()
	RenderTraces(w, req)

	resp := w.Result()
	body, _ := io.ReadAll(resp.Body)

	if resp.StatusCode != code {
		t.Errorf("expected %d, got %v", code, resp)
	}

	return string(body)
}

type v struct {
	fam, b, lat, trace, ref, all string
}

func getCode(t *testing.T, vs v, code int) string {
	t.Helper()

	u := url.Values{}
	if vs.fam != "" {
		u.Add("fam", vs.fam)
	}
	if vs.b != "" {
		u.Add("b", vs.b)
	}
	if vs.lat != "" {
		u.Add("lat", vs.lat)
	}
	if vs.trace != "" {
		u.Add("trace", vs.trace)
	}
	if vs.ref != "" {
		u.Add("ref", vs.ref)
	}
	if vs.all != "" {
		u.Add("all", vs.all)
	}

	return getValues(t, u, code)
}

func get(t *testing.T, fam, b, lat, trace, ref, all string) string {
	t.Helper()
	return getCode(t, v{fam, b, lat, trace, ref, all}, 200)
}

func getErr(t *testing.T, fam, b, lat, trace, ref, all string, code int, err string) string {
	t.Helper()
	body := getCode(t, v{fam, b, lat, trace, ref, all}, code)
	if !strings.Contains(body, err) {
		t.Errorf("Body does not contain error message %q", err)
		t.Logf("Body: %v", body)
	}

	return body
}

func checkContains(t *testing.T, body, s string) {
	t.Helper()
	if !strings.Contains(body, s) {
		t.Errorf("Body does not contain %q", s)
		t.Logf("Body: %v", body)
	}
}

func TestHTTP(t *testing.T) {
	tr := New("TestHTTP", "http")
	tr.Printf("entry #1")
	tr.Finish()

	tr = New("TestHTTP", "http")
	tr.Printf("entry #2")
	tr.Finish()

	tr = New("TestHTTP", "http")
	tr.Errorf("entry #3 (error)")
	tr.Finish()

	tr = New("TestHTTP", "http")
	tr.Printf("hola marola")
	tr.Printf("entry #4")
	// This one is active until the end.
	defer tr.Finish()

	// Get the plain index.
	body := get(t, "", "", "", "", "", "")
	checkContains(t, body, "TestHTTP")

	// Get a specific family, but no bucket.
	body = get(t, "TestHTTP", "", "", "", "", "")
	checkContains(t, body, "TestHTTP")

	// Get a family and active bucket.
	body = get(t, "TestHTTP", "-1", "", "", "", "")
	checkContains(t, body, "hola marola")

	// Get a family and error bucket.
	body = get(t, "TestHTTP", "-2", "", "", "", "")
	checkContains(t, body, "entry #3 (error)")

	// Get a family and first bucket.
	body = get(t, "TestHTTP", "0", "", "", "", "")
	checkContains(t, body, "entry #2")

	// Latency view. There are 3 events because the 4th is active.
	body = get(t, "TestHTTP", "", "lat", "", "", "")
	checkContains(t, body, "Count: 3")

	// Get a specific trace. No family given, since it shouldn't be needed (we
	// take it from the id).
	body = get(t, "", "", "", string(tr.(*trace).ID), "", "")
	checkContains(t, body, "hola marola")

	// Check the "all=true" views.
	body = get(t, "TestHTTP", "0", "", "", "", "true")
	checkContains(t, body, "entry #2")
	checkContains(t, body, "?fam=TestHTTP&b=-2&all=true")

	tr.Finish()
}

func TestHTTPLong(t *testing.T) {
	// Test a long trace.
	tr := New("TestHTTPLong", "verbose")
	for i := 0; i < 1000; i++ {
		tr.Printf("entry #%d", i)
	}
	tr.Finish()
	get(t, "TestHTTPLong", "", "", string(tr.(*trace).ID), "", "")
}

func TestHTTPErrors(t *testing.T) {
	tr := New("TestHTTPErrors", "http")
	tr.Printf("entry #1")
	tr.Finish()

	// Unknown family.
	getErr(t, "unkfamily", "", "", "", "", "",
		404, "Unknown family")

	// Invalid bucket.
	getErr(t, "TestHTTPErrors", "abc", "", "", "", "",
		400, "Invalid bucket")
	getErr(t, "TestHTTPErrors", "-3", "", "", "", "",
		400, "Invalid bucket")
	getErr(t, "TestHTTPErrors", "9", "", "", "", "",
		400, "Invalid bucket")

	// Unknown trace id (malformed).
	getErr(t, "TestHTTPErrors", "", "", "unktrace", "", "",
		404, "Trace not found")

	// Unknown trace id.
	getErr(t, "TestHTTPErrors", "", "", string(tr.(*trace).ID)+"xxx", "", "",
		404, "Trace not found")

	// Check that the trace is actually there.
	get(t, "", "", "", string(tr.(*trace).ID), "", "")
}

func TestHTTPUroboro(t *testing.T) {
	trA := New("TestHTTPUroboro", "trA")
	defer trA.Finish()
	trA.Printf("this is trace A")

	trB := New("TestHTTPUroboro", "trB")
	defer trB.Finish()
	trB.Printf("this is trace B")

	trA.Link(trB, "B is my friend")
	trB.Link(trA, "A is my friend")

	// Check that we handle cross-linked events well.
	get(t, "TestHTTPUroboro", "", "", "", "", "")
	get(t, "TestHTTPUroboro", "-1", "", "", "", "")
	get(t, "", "", "", string(trA.(*trace).ID), "", "")
	get(t, "", "", "", string(trB.(*trace).ID), "", "")
}

func TestHTTPDeep(t *testing.T) {
	tr := New("TestHTTPDeep", "level-0")
	defer tr.Finish()
	ts := []Trace{tr}
	for i := 1; i <= 9; i++ {
		tr = tr.NewChild("TestHTTPDeep", fmt.Sprintf("level-%d", i))
		defer tr.Finish()
		ts = append(ts, tr)
	}

	// Active view.
	body := get(t, "TestHTTPDeep", "-1", "", "", "", "")
	checkContains(t, body, "level-9")

	// Recursive view.
	body = get(t, "TestHTTPDeep", "", "", string(ts[0].(*trace).ID), "", "")
	checkContains(t, body, "level-9")
}

func TestStripZeros(t *testing.T) {
	cases := []struct {
		d   time.Duration
		exp string
	}{
		{0 * time.Second, " .     0"},
		{1 * time.Millisecond, " .  1000"},
		{5 * time.Millisecond, " .  5000"},
		{1 * time.Second, "1.000000"},
		{1*time.Second + 8*time.Millisecond, "1.008000"},
	}
	for _, c := range cases {
		if got := stripZeros(c.d); got != c.exp {
			t.Errorf("stripZeros(%s) got %q, expected %q",
				c.d, got, c.exp)
		}
	}
}

func TestRegisterHandler(t *testing.T) {
	mux := http.NewServeMux()
	RegisterHandler(mux)

	req := httptest.NewRequest("GET", "/debug/traces", nil)
	w := httptest.NewRecorder()
	mux.ServeHTTP(w, req)

	resp := w.Result()

	if resp.StatusCode != 200 {
		t.Errorf("expected 200, got %v", resp)
	}

	body, _ := io.ReadAll(resp.Body)
	if !strings.Contains(string(body), "<h1>Traces</h1>") {
		t.Errorf("unexpected body: %s", body)
	}
}