author | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-05-06 00:53:22 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-05-06 00:53:59 UTC |
parent | bc3868c3e8bab0174a053f3191664717f1a451c7 |
proxy/http.go | +24 | -3 |
proxy/proxy_test.go | +28 | -2 |
diff --git a/proxy/http.go b/proxy/http.go index ee0bec0..4721ec4 100644 --- a/proxy/http.go +++ b/proxy/http.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" golog "log" "net/http" "net/http/cgi" @@ -138,9 +139,19 @@ func makeProxy(from string, to url.URL) http.Handler { // joinPath joins to HTTP paths. We can't use path.Join because it strips the // final "/", which may have meaning in URLs. func joinPath(a, b string) string { + if a == "" && b == "" { + return "/" + } + if a == "" || b == "" { + return a + b + } + if strings.HasSuffix(a, "/") && strings.HasPrefix(b, "/") { + return strings.TrimSuffix(a, "/") + b + } if !strings.HasSuffix(a, "/") && !strings.HasPrefix(b, "/") { - a = a + "/" + return a + "/" + b } + return a + b } @@ -157,14 +168,24 @@ func adjustPath(req string, from string, to string) string { // Strip "from" from the request path, so that if we have this config: // // /a/ -> http://dst/b - // www.example.com/p/ -> http://dst/q + // /p/q -> http://dst/r/s + // www.example.com/t/ -> http://dst/u // // then: // /a/x goes to http://dst/b/x (not http://dst/b/a/x) - // www.example.com/p/x goes to http://dst/q/x + // /p/q goes to http://dst/r/s + // www.example.com/t/x goes to http://dst/u/x // // It is expected that `from` already has the domain removed using // stripDomain. + // + // If req doesn't have from as prefix, then we panic. + if !strings.HasPrefix(req, from) { + panic(fmt.Errorf( + "adjustPath(req=%q, from=%q, to=%q): from is not prefix", + req, from, to)) + } + dst := joinPath(to, strings.TrimPrefix(req, from)) if dst == "" || dst[0] != '/' { dst = "/" + dst diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 95ef6ca..e188fd3 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -133,7 +133,7 @@ func TestSimple(t *testing.T) { } func testGet(t *testing.T, url string, expectedStatus int) { - //t.Helper() -- Uncomment once Go 1.9 is commonplace. + t.Helper() t.Logf("URL: %s", url) resp, err := http.Get(url) if err != nil { @@ -170,9 +170,10 @@ func TestJoinPath(t *testing.T) { {"a/", "", "a/"}, {"a/", "b", "a/b"}, {"a/", "b/", "a/b/"}, + {"a/", "/b/", "a/b/"}, {"/", "", "/"}, {"", "", "/"}, - {"/", "/", "//"}, // Not sure if we want this, but ok for now. + {"/", "/", "/"}, } for _, c := range cases { got := joinPath(c.a, c.b) @@ -182,6 +183,31 @@ func TestJoinPath(t *testing.T) { } } +func TestAdjustPath(t *testing.T) { + cases := []struct{ from, to, req, expected string }{ + {"/", "/", "/", "/"}, + {"/", "/", "/a", "/a"}, + {"/", "/", "/a/x", "/a/x"}, + {"/a", "/", "/a", "/"}, + {"/a", "/", "/a/", "/"}, + {"/a", "/", "/a/x", "/x"}, + {"/a/", "/", "/a/", "/"}, + {"/a/", "/", "/a/x", "/x"}, + {"/a/", "/b", "/a/", "/b"}, + {"/a/", "/b", "/a/x", "/b/x"}, + {"/p/q", "/r/s", "/p/q", "/r/s"}, + {"/p/q", "/r/s", "/p/q", "/r/s"}, + {"/p/q", "/r/s", "/p/q/x", "/r/s/x"}, + } + for _, c := range cases { + got := adjustPath(c.req, c.from, c.to) + if got != c.expected { + t.Errorf("adjustPath(%q, %q, %q) = %q, expected %q", + c.req, c.from, c.to, got, c.expected) + } + } +} + func Benchmark(b *testing.B) { makeBench := func(url string) func(b *testing.B) { return func(b *testing.B) {