git » chasquid » smarthost » tree

[smarthost] / internal / haproxy / haproxy_test.go

package haproxy

import (
	"bufio"
	"io"
	"net"
	"strings"
	"testing"
)

func TestNoNewline(t *testing.T) {
	r := bufio.NewReader(strings.NewReader("PROXY "))
	_, _, err := Handshake(r)
	if err != io.EOF {
		t.Errorf("expected EOF, got %v", err)
	}
}

func TestBasic(t *testing.T) {
	var (
		src4, _ = net.ResolveTCPAddr("tcp", "1.1.1.1:3333")
		dst4, _ = net.ResolveTCPAddr("tcp", "2.2.2.2:4444")
		src6, _ = net.ResolveTCPAddr("tcp", "[5::5]:7777")
		dst6, _ = net.ResolveTCPAddr("tcp", "[6::6]:8888")
	)

	cases := []struct {
		str      string
		src, dst net.Addr
		err      error
	}{
		// Early line errors.
		{"", nil, nil, errInvalidProtoID},
		{"lalala", nil, nil, errInvalidProtoID},
		{"PROXY", nil, nil, errInvalidProtoID},
		{"PROXY lalala", nil, nil, errUnkProtocol},
		{"PROXY UNKNOWN", nil, nil, errUnkProtocol},

		// Number of field errors.
		{"PROXY TCP4", nil, nil, errInvalidFields},
		{"PROXY TCP4 a", nil, nil, errInvalidFields},
		{"PROXY TCP4 a b", nil, nil, errInvalidFields},
		{"PROXY TCP4 a b c", nil, nil, errInvalidFields},

		// Parsing of ipv4 addresses.
		{"PROXY TCP4 a b c d", nil, nil, errInvalidSrcIP},
		{"PROXY TCP4 1.1.1.1 b c d",
			nil, nil, errInvalidDstIP},
		{"PROXY TCP4 1.1.1.1 2.2.2.2 c d",
			nil, nil, errInvalidSrcPort},
		{"PROXY TCP4 1.1.1.1 2.2.2.2 3333 d",
			nil, nil, errInvalidDstPort},
		{"PROXY TCP4 1.1.1.1 2.2.2.2 3333 4444",
			src4, dst4, nil},

		// Parsing of ipv6 addresses.
		{"PROXY TCP6 a b c d", nil, nil, errInvalidSrcIP},
		{"PROXY TCP6 5::5 b c d",
			nil, nil, errInvalidDstIP},
		{"PROXY TCP6 5::5 6::6 c d",
			nil, nil, errInvalidSrcPort},
		{"PROXY TCP6 5::5 6::6 7777 d",
			nil, nil, errInvalidDstPort},
		{"PROXY TCP6 5::5 6::6 7777 8888",
			src6, dst6, nil},
	}

	for i, c := range cases {
		t.Logf("testing %d: %v", i, c.str)

		src, dst, err := Handshake(newR(c.str))

		if !addrEq(src, c.src) {
			t.Errorf("%d: got src %v, expected %v", i, src, c.src)
		}
		if !addrEq(dst, c.dst) {
			t.Errorf("%d: got dst %v, expected %v", i, dst, c.dst)
		}
		if err != c.err {
			t.Errorf("%d: got error %v, expected %v", i, err, c.err)
		}
	}
}

func newR(s string) *bufio.Reader {
	return bufio.NewReader(strings.NewReader(s + "\r\n"))
}

func addrEq(a, b net.Addr) bool {
	if a == nil || b == nil {
		return a == nil && b == nil
	}

	ta := a.(*net.TCPAddr)
	tb := b.(*net.TCPAddr)
	return ta.IP.Equal(tb.IP) && ta.Port == tb.Port
}