// Fetch an URL, and check if the response matches what we expect.
package main
import (
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"regexp"
"sort"
"strconv"
"strings"
)
var exitCode int = 0
func main() {
// The first arg is the URL, and then we shift.
url := os.Args[1]
os.Args = append([]string{os.Args[0]}, os.Args[2:]...)
var (
body = flag.String("body", "",
"expect body with these exact contents")
bodyRE = flag.String("bodyre", "",
"expect body matching these contents (regexp match)")
bodyNotRE = flag.String("bodynotre", "",
"expect body NOT matching these contents (regexp match)")
redir = flag.String("redir", "",
"expect a redirect to this URL")
status = flag.Int("status", 200,
"expect this status code")
verbose = flag.Bool("v", false,
"enable verbose output")
hdrRE = flag.String("hdrre", "",
"expect a header matching these contents (regexp match)")
clientErrorRE = flag.String("clienterrorre", "",
"expect a client error matching these contents (regexp match)")
caCert = flag.String("cacert", "",
"file to read CA cert from")
forceLocalhost = flag.Bool("forcelocalhost", false,
"force connection to go to localhost")
)
flag.Parse()
client := &http.Client{
CheckRedirect: noRedirect,
Transport: mkTransport(*caCert, *forceLocalhost),
}
resp, err := client.Get(url)
if *clientErrorRE != "" {
if err == nil {
errorf("expected client error, got nil")
}
matched, err := regexp.MatchString(*clientErrorRE, err.Error())
if err != nil {
errorf("regexp error: %q\n", err)
}
if !matched {
errorf("client error did not match regexp: %q\n", err.Error())
}
os.Exit(exitCode)
} else if err != nil {
fatalf("error getting %q: %v\n", url, err)
}
defer resp.Body.Close()
rbody, err := ioutil.ReadAll(resp.Body)
if err != nil {
errorf("error reading body: %v\n", err)
}
if *verbose {
fmt.Printf("Request: %s\n", url)
fmt.Printf("Response:\n")
fmt.Printf(" %v %v\n", resp.Proto, resp.Status)
ks := []string{}
for k, _ := range resp.Header {
ks = append(ks, k)
}
sort.Strings(ks)
for _, k := range ks {
fmt.Printf(" %v: %s\n", k,
strings.Join(resp.Header.Values(k), ", "))
}
fmt.Printf("\n")
}
if resp.StatusCode != *status {
errorf("status is not %d: %q\n", *status, resp.Status)
}
if *body != "" {
// Unescape the body to allow control characters more easily.
*body, _ = strconv.Unquote("\"" + *body + "\"")
if string(rbody) != *body {
errorf("unexpected body: %q\n", rbody)
}
}
if *bodyRE != "" {
matched, err := regexp.Match(*bodyRE, rbody)
if err != nil {
errorf("regexp error: %q\n", err)
}
if !matched {
errorf("body did not match regexp: %q\n", rbody)
}
}
if *bodyNotRE != "" {
matched, err := regexp.Match(*bodyNotRE, rbody)
if err != nil {
errorf("regexp error: %q\n", err)
}
if matched {
errorf("body matched regexp: %q\n", rbody)
}
}
if *redir != "" {
if loc := resp.Header.Get("Location"); loc != *redir {
errorf("unexpected redir location: %q\n", loc)
}
}
if *hdrRE != "" {
match := false
outer:
for k, vs := range resp.Header {
for _, v := range vs {
hdr := fmt.Sprintf("%s: %s", k, v)
matched, err := regexp.MatchString(*hdrRE, hdr)
if err != nil {
errorf("regexp error: %q\n", err)
}
if matched {
match = true
break outer
}
}
}
if !match {
errorf("header did not match: %v\n", resp.Header)
}
}
os.Exit(exitCode)
}
func noRedirect(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
func mkTransport(caCert string, forceLocalhost bool) *http.Transport {
if caCert == "" {
return nil
}
certs, err := ioutil.ReadFile(caCert)
if err != nil {
fatalf("error reading CA file %q: %v", caCert, err)
}
rootCAs := x509.NewCertPool()
if ok := rootCAs.AppendCertsFromPEM(certs); !ok {
fatalf("error adding certs to root")
}
t := &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: rootCAs,
},
}
if forceLocalhost {
t.Dial = func(network, addr string) (net.Conn, error) {
_, port, _ := net.SplitHostPort(addr)
return net.Dial(network, "localhost:"+port)
}
}
return t
}
func fatalf(s string, a ...interface{}) {
fmt.Fprintf(os.Stderr, s, a...)
os.Exit(1)
}
func errorf(s string, a ...interface{}) {
fmt.Fprintf(os.Stderr, s, a...)
exitCode = 1
}