author | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-06-06 15:29:14 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2020-06-06 15:29:14 UTC |
parent | f4f71ed765f7b4a964a3febf0627768face644da |
config/config.go | +2 | -1 |
server/http.go | +40 | -13 |
test/01-be.yaml | +4 | -0 |
test/test.sh | +3 | -0 |
test/util/exp.go | +26 | -2 |
diff --git a/config/config.go b/config/config.go index 3988bb8..7fb2608 100644 --- a/config/config.go +++ b/config/config.go @@ -27,7 +27,8 @@ type HTTP struct { Auth map[string]string - DirOpts map[string]DirOpts + DirOpts map[string]DirOpts + SetHeader map[string]map[string]string } type HTTPS struct { diff --git a/server/http.go b/server/http.go index 1bf6c54..aa51019 100644 --- a/server/http.go +++ b/server/http.go @@ -71,17 +71,29 @@ func httpServer(addr string, conf config.HTTP) *http.Server { "failed to load auth file %q: %v", dbPath, err) } authMux.Handle(path, - WithTrace("http:auth", - &AuthWrapper{ - handler: mux, - users: users, - })) + &AuthWrapper{ + handler: srv.Handler, + users: users, + }) log.Infof("%s auth %q -> %q", srv.Addr, path, dbPath) } srv.Handler = authMux } + // Extra headers. + if len(conf.SetHeader) > 0 { + hdrMux := http.NewServeMux() + hdrMux.Handle("/", srv.Handler) + for path, extraHdrs := range conf.SetHeader { + hdrMux.Handle(path, SetHeader(srv.Handler, extraHdrs)) + log.Infof("%s add headers %q -> %q", srv.Addr, path, extraHdrs) + } + srv.Handler = hdrMux + } + + srv.Handler = WithTrace("http@"+srv.Addr, srv.Handler) + return srv } @@ -230,7 +242,7 @@ func makeDir(from string, to url.URL, conf *config.HTTP) http.Handler { path := pathOrOpaque(to) fs := http.FileServer(NewFS(http.Dir(path), conf.DirOpts[from])) - return WithTrace("http:dir", WithLogging( + return WithLogging( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr, _ := trace.FromContext(r.Context()) tr.Printf("serving dir root %q", path) @@ -242,19 +254,19 @@ func makeDir(from string, to url.URL, conf *config.HTTP) http.Handler { tr.Printf("adjusted path: %q", r.URL.Path) fs.ServeHTTP(w, r) }), - )) + ) } func makeStatic(from string, to url.URL, conf *config.HTTP) http.Handler { path := pathOrOpaque(to) - return WithTrace("http:static", WithLogging( + return WithLogging( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr, _ := trace.FromContext(r.Context()) tr.Printf("statically serving %q", path) http.ServeFile(w, r, path) }), - )) + ) } func makeCGI(from string, to url.URL, conf *config.HTTP) http.Handler { @@ -262,7 +274,7 @@ func makeCGI(from string, to url.URL, conf *config.HTTP) http.Handler { path := pathOrOpaque(to) args := queryToArgs(to.RawQuery) - return WithTrace("http:cgi", WithLogging( + return WithLogging( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr, _ := trace.FromContext(r.Context()) tr.Debugf("exec %q %q", path, args) @@ -275,7 +287,7 @@ func makeCGI(from string, to url.URL, conf *config.HTTP) http.Handler { } h.ServeHTTP(w, r) }), - )) + ) } func queryToArgs(query string) []string { @@ -299,7 +311,7 @@ func queryToArgs(query string) []string { func makeRedirect(from string, to url.URL, conf *config.HTTP) http.Handler { from = stripDomain(from) - return WithTrace("http:redirect", WithLogging( + return WithLogging( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr, _ := trace.FromContext(r.Context()) target := to @@ -309,7 +321,7 @@ func makeRedirect(from string, to url.URL, conf *config.HTTP) http.Handler { http.Redirect(w, r, target.String(), http.StatusTemporaryRedirect) }), - )) + ) } type loggingTransport struct{} @@ -394,6 +406,21 @@ func (w *statusWriter) Write(b []byte) (int, error) { return n, err } +func SetHeader(parent http.Handler, hdrs map[string]string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for k, v := range hdrs { + w.Header().Set(k, v) + } + parent.ServeHTTP(w, r) + + // TODO: better chained contexts. + tr, _ := trace.FromContext(r.Context()) + for k, v := range hdrs { + tr.Printf("added header: %s: %q", k, v) + } + }) +} + func WithTrace(name string, parent http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tr, ok := trace.FromContext(r.Context()) diff --git a/test/01-be.yaml b/test/01-be.yaml index 09d10ef..1a2f16a 100644 --- a/test/01-be.yaml +++ b/test/01-be.yaml @@ -24,3 +24,7 @@ http: "/authdir/ñaca": "testdata/authdb.yaml" "/authdir/withoutindex/": "testdata/authdb.yaml" + setheader: + "/file": + "X-My-Header": "my lovely header" + diff --git a/test/test.sh b/test/test.sh index 4e19450..785d007 100755 --- a/test/test.sh +++ b/test/test.sh @@ -152,6 +152,9 @@ do exp $base/authdir/withoutindex -status 301 exp $base/authdir/withoutindex/ -status 401 exp $base/authdir/withoutindex/chau -status 401 + + # Additional headers. + exp $base/file -hdrre "X-My-Header: my lovely header" done # Good auth. diff --git a/test/util/exp.go b/test/util/exp.go index 37deedc..f6daa1d 100644 --- a/test/util/exp.go +++ b/test/util/exp.go @@ -34,6 +34,8 @@ func main() { "expect this status code") verbose = flag.Bool("v", false, "enable verbose output") + hdrRE = flag.String("hdrre", "", + "expect a header matching these contents (regexp match)") caCert = flag.String("cacert", "", "file to read CA cert from") ) @@ -85,7 +87,7 @@ func main() { if *bodyRE != "" { matched, err := regexp.Match(*bodyRE, rbody) if err != nil { - errorf("regexp error: %q", err) + errorf("regexp error: %q\n", err) } if !matched { errorf("body did not match regexp: %q\n", rbody) @@ -95,7 +97,7 @@ func main() { if *bodyNotRE != "" { matched, err := regexp.Match(*bodyNotRE, rbody) if err != nil { - errorf("regexp error: %q", err) + errorf("regexp error: %q\n", err) } if matched { errorf("body matched regexp: %q\n", rbody) @@ -108,6 +110,28 @@ func main() { } } + if *hdrRE != "" { + match := false + outer: + for k, vs := range resp.Header { + for _, v := range vs { + hdr := fmt.Sprintf("%s: %s", k, v) + matched, err := regexp.MatchString(*hdrRE, hdr) + if err != nil { + errorf("regexp error: %q\n", err) + } + if matched { + match = true + break outer + } + } + } + + if !match { + errorf("header did not match: %v\n", resp.Header) + } + } + os.Exit(exitCode) }