git » gofer » commit 820ea96

Implement "set header" option

author Alberto Bertogli
2020-06-06 15:29:14 UTC
committer Alberto Bertogli
2020-06-06 15:29:14 UTC
parent f4f71ed765f7b4a964a3febf0627768face644da

Implement "set header" option

config/config.go +2 -1
server/http.go +40 -13
test/01-be.yaml +4 -0
test/test.sh +3 -0
test/util/exp.go +26 -2

diff --git a/config/config.go b/config/config.go
index 3988bb8..7fb2608 100644
--- a/config/config.go
+++ b/config/config.go
@@ -27,7 +27,8 @@ type HTTP struct {
 
 	Auth map[string]string
 
-	DirOpts map[string]DirOpts
+	DirOpts   map[string]DirOpts
+	SetHeader map[string]map[string]string
 }
 
 type HTTPS struct {
diff --git a/server/http.go b/server/http.go
index 1bf6c54..aa51019 100644
--- a/server/http.go
+++ b/server/http.go
@@ -71,17 +71,29 @@ func httpServer(addr string, conf config.HTTP) *http.Server {
 					"failed to load auth file %q: %v", dbPath, err)
 			}
 			authMux.Handle(path,
-				WithTrace("http:auth",
-					&AuthWrapper{
-						handler: mux,
-						users:   users,
-					}))
+				&AuthWrapper{
+					handler: srv.Handler,
+					users:   users,
+				})
 
 			log.Infof("%s auth %q -> %q", srv.Addr, path, dbPath)
 		}
 		srv.Handler = authMux
 	}
 
+	// Extra headers.
+	if len(conf.SetHeader) > 0 {
+		hdrMux := http.NewServeMux()
+		hdrMux.Handle("/", srv.Handler)
+		for path, extraHdrs := range conf.SetHeader {
+			hdrMux.Handle(path, SetHeader(srv.Handler, extraHdrs))
+			log.Infof("%s add headers %q -> %q", srv.Addr, path, extraHdrs)
+		}
+		srv.Handler = hdrMux
+	}
+
+	srv.Handler = WithTrace("http@"+srv.Addr, srv.Handler)
+
 	return srv
 }
 
@@ -230,7 +242,7 @@ func makeDir(from string, to url.URL, conf *config.HTTP) http.Handler {
 	path := pathOrOpaque(to)
 
 	fs := http.FileServer(NewFS(http.Dir(path), conf.DirOpts[from]))
-	return WithTrace("http:dir", WithLogging(
+	return WithLogging(
 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			tr, _ := trace.FromContext(r.Context())
 			tr.Printf("serving dir root %q", path)
@@ -242,19 +254,19 @@ func makeDir(from string, to url.URL, conf *config.HTTP) http.Handler {
 			tr.Printf("adjusted path: %q", r.URL.Path)
 			fs.ServeHTTP(w, r)
 		}),
-	))
+	)
 }
 
 func makeStatic(from string, to url.URL, conf *config.HTTP) http.Handler {
 	path := pathOrOpaque(to)
 
-	return WithTrace("http:static", WithLogging(
+	return WithLogging(
 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			tr, _ := trace.FromContext(r.Context())
 			tr.Printf("statically serving %q", path)
 			http.ServeFile(w, r, path)
 		}),
-	))
+	)
 }
 
 func makeCGI(from string, to url.URL, conf *config.HTTP) http.Handler {
@@ -262,7 +274,7 @@ func makeCGI(from string, to url.URL, conf *config.HTTP) http.Handler {
 	path := pathOrOpaque(to)
 	args := queryToArgs(to.RawQuery)
 
-	return WithTrace("http:cgi", WithLogging(
+	return WithLogging(
 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			tr, _ := trace.FromContext(r.Context())
 			tr.Debugf("exec %q %q", path, args)
@@ -275,7 +287,7 @@ func makeCGI(from string, to url.URL, conf *config.HTTP) http.Handler {
 			}
 			h.ServeHTTP(w, r)
 		}),
-	))
+	)
 }
 
 func queryToArgs(query string) []string {
@@ -299,7 +311,7 @@ func queryToArgs(query string) []string {
 func makeRedirect(from string, to url.URL, conf *config.HTTP) http.Handler {
 	from = stripDomain(from)
 
-	return WithTrace("http:redirect", WithLogging(
+	return WithLogging(
 		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			tr, _ := trace.FromContext(r.Context())
 			target := to
@@ -309,7 +321,7 @@ func makeRedirect(from string, to url.URL, conf *config.HTTP) http.Handler {
 
 			http.Redirect(w, r, target.String(), http.StatusTemporaryRedirect)
 		}),
-	))
+	)
 }
 
 type loggingTransport struct{}
@@ -394,6 +406,21 @@ func (w *statusWriter) Write(b []byte) (int, error) {
 	return n, err
 }
 
+func SetHeader(parent http.Handler, hdrs map[string]string) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		for k, v := range hdrs {
+			w.Header().Set(k, v)
+		}
+		parent.ServeHTTP(w, r)
+
+		// TODO: better chained contexts.
+		tr, _ := trace.FromContext(r.Context())
+		for k, v := range hdrs {
+			tr.Printf("added header: %s: %q", k, v)
+		}
+	})
+}
+
 func WithTrace(name string, parent http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		tr, ok := trace.FromContext(r.Context())
diff --git a/test/01-be.yaml b/test/01-be.yaml
index 09d10ef..1a2f16a 100644
--- a/test/01-be.yaml
+++ b/test/01-be.yaml
@@ -24,3 +24,7 @@ http:
       "/authdir/ñaca": "testdata/authdb.yaml"
       "/authdir/withoutindex/": "testdata/authdb.yaml"
 
+    setheader:
+      "/file":
+        "X-My-Header": "my lovely header"
+
diff --git a/test/test.sh b/test/test.sh
index 4e19450..785d007 100755
--- a/test/test.sh
+++ b/test/test.sh
@@ -152,6 +152,9 @@ do
 	exp $base/authdir/withoutindex -status 301
 	exp $base/authdir/withoutindex/ -status 401
 	exp $base/authdir/withoutindex/chau -status 401
+
+	# Additional headers.
+	exp $base/file -hdrre "X-My-Header: my lovely header"
 done
 
 # Good auth.
diff --git a/test/util/exp.go b/test/util/exp.go
index 37deedc..f6daa1d 100644
--- a/test/util/exp.go
+++ b/test/util/exp.go
@@ -34,6 +34,8 @@ func main() {
 			"expect this status code")
 		verbose = flag.Bool("v", false,
 			"enable verbose output")
+		hdrRE = flag.String("hdrre", "",
+			"expect a header matching these contents (regexp match)")
 		caCert = flag.String("cacert", "",
 			"file to read CA cert from")
 	)
@@ -85,7 +87,7 @@ func main() {
 	if *bodyRE != "" {
 		matched, err := regexp.Match(*bodyRE, rbody)
 		if err != nil {
-			errorf("regexp error: %q", err)
+			errorf("regexp error: %q\n", err)
 		}
 		if !matched {
 			errorf("body did not match regexp: %q\n", rbody)
@@ -95,7 +97,7 @@ func main() {
 	if *bodyNotRE != "" {
 		matched, err := regexp.Match(*bodyNotRE, rbody)
 		if err != nil {
-			errorf("regexp error: %q", err)
+			errorf("regexp error: %q\n", err)
 		}
 		if matched {
 			errorf("body matched regexp: %q\n", rbody)
@@ -108,6 +110,28 @@ func main() {
 		}
 	}
 
+	if *hdrRE != "" {
+		match := false
+	outer:
+		for k, vs := range resp.Header {
+			for _, v := range vs {
+				hdr := fmt.Sprintf("%s: %s", k, v)
+				matched, err := regexp.MatchString(*hdrRE, hdr)
+				if err != nil {
+					errorf("regexp error: %q\n", err)
+				}
+				if matched {
+					match = true
+					break outer
+				}
+			}
+		}
+
+		if !match {
+			errorf("header did not match: %v\n", resp.Header)
+		}
+	}
+
 	os.Exit(exitCode)
 }