author | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-07-21 11:41:51 UTC |
committer | Alberto Bertogli
<albertito@blitiri.com.ar> 2023-08-02 15:45:30 UTC |
parent | 136b59b9ff9a73b2c8fbba8760a214a699b959ac |
config/config.go | +82 | -11 |
config/config_test.go | +78 | -0 |
config/gofer.schema.cue | +13 | -0 |
config/gofer.yaml | +33 | -0 |
debug/debug.go | +3 | -0 |
gofer.go | +5 | -0 |
ipratelimit/ipratelimit.go | +448 | -0 |
ipratelimit/ipratelimit_test.go | +456 | -0 |
ratelimit/ratelimit.go | +106 | -0 |
server/http.go | +48 | -0 |
server/raw.go | +29 | -2 |
server/raw_test.go | +34 | -0 |
test/01-fe.yaml | +19 | -0 |
test/test.sh | +41 | -0 |
test/util/exp/exp.go | +16 | -1 |
diff --git a/config/config.go b/config/config.go index be481c0..bdda912 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,9 @@ import ( "io/ioutil" "net/url" "regexp" + "strconv" + "strings" + "time" "gopkg.in/yaml.v3" ) @@ -19,6 +22,8 @@ type Config struct { Raw map[string]Raw `yaml:",omitempty"` ReqLog map[string]ReqLog `yaml:",omitempty"` + + RateLimit map[string]RateLimit `yaml:",omitempty"` } type HTTP struct { @@ -29,6 +34,8 @@ type HTTP struct { SetHeader map[string]map[string]string `yaml:",omitempty"` ReqLog map[string]string `yaml:",omitempty"` + + RateLimit map[string]string `yaml:",omitempty"` } type HTTPS struct { @@ -60,10 +67,11 @@ type DirOpts struct { } type Raw struct { - Certs string `yaml:",omitempty"` - To string `yaml:",omitempty"` - ToTLS bool `yaml:"to_tls,omitempty"` - ReqLog string `yaml:",omitempty"` + Certs string `yaml:",omitempty"` + To string `yaml:",omitempty"` + ToTLS bool `yaml:"to_tls,omitempty"` + ReqLog string `yaml:",omitempty"` + RateLimit string `yaml:",omitempty"` } type ReqLog struct { @@ -72,6 +80,15 @@ type ReqLog struct { Format string `yaml:",omitempty"` } +type RateLimit struct { + Rate Rate `yaml:",omitempty"` + Size int `yaml:",omitempty"` + + Rate64 Rate `yaml:",omitempty"` + Rate56 Rate `yaml:",omitempty"` + Rate48 Rate `yaml:",omitempty"` +} + func (c Config) String() string { d, err := yaml.Marshal(&c) if err != nil { @@ -102,7 +119,12 @@ func (c Config) Check() []error { errs = append(errs, fmt.Errorf("%q: unknown reqlog %q", addr, r.ReqLog)) } + if _, ok := c.RateLimit[r.RateLimit]; r.RateLimit != "" && !ok { + errs = append(errs, + fmt.Errorf("%q: unknown ratelimit %q", addr, r.RateLimit)) + } } + return errs } @@ -142,10 +164,27 @@ func (h HTTP) Check(c Config, addr string) []error { fmt.Errorf("%q: %q: unknown reqlog %q", addr, path, name)) } } + for path, name := range h.RateLimit { + if _, ok := c.RateLimit[name]; !ok { + errs = append(errs, + fmt.Errorf("%q: %q: unknown ratelimit %q", addr, path, name)) + } + } return errs } +// Count how many true values are in a series of bools. +func nTrue(bs ...bool) int { + n := 0 + for _, b := range bs { + if b { + n++ + } + } + return n +} + func Load(filename string) (*Config, error) { contents, err := ioutil.ReadFile(filename) if err != nil { @@ -220,12 +259,44 @@ func (u URL) String() string { return p.String() } -func nTrue(bs ...bool) int { - n := 0 - for _, b := range bs { - if b { - n++ - } +// Rate type to simplify rate limits in configuration. +// Format is "requests/period", e.g. "10/1s". +type Rate struct { + Requests uint64 + Period time.Duration +} + +func (r *Rate) UnmarshalYAML(unmarshal func(interface{}) error) error { + var s string + if err := unmarshal(&s); err != nil { + return err } - return n + + sp := strings.SplitN(s, "/", 2) + if len(sp) != 2 { + return fmt.Errorf("invalid rate format %q (needs a single '/')", s) + } + reqS, periodS := strings.TrimSpace(sp[0]), strings.TrimSpace(sp[1]) + + req, err := strconv.ParseUint(reqS, 10, 64) + if err != nil { + return fmt.Errorf("invalid requests in %q: %v", s, err) + } + + period, err := time.ParseDuration(periodS) + if err != nil { + return fmt.Errorf("invalid period in %q: %v", s, err) + } + if period == 0 { + return fmt.Errorf("period must be >0 in %q", s) + } + + r.Requests = req + r.Period = period + + return nil +} + +func (r Rate) MarshalYAML() (interface{}, error) { + return fmt.Sprintf("%d/%s", r.Requests, r.Period), nil } diff --git a/config/config_test.go b/config/config_test.go index a77bcac..c4f5cbc 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -7,6 +7,7 @@ import ( "regexp" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "gopkg.in/yaml.v3" @@ -188,6 +189,29 @@ raw: ` expectErrs(t, `":1234": unknown reqlog "lalala"`, loadAndCheck(t, contents)) + + // ratelimit reference (http). + contents = ` +https: + ":https": + certs: "/dev/null" + routes: + "/": + file: "/dev/null" + ratelimit: + "/": "lalala" +` + expectErrs(t, `":https": "/": unknown ratelimit "lalala"`, + loadAndCheck(t, contents)) + + // ratelimit reference (raw). + contents = ` +raw: + ":1234": + ratelimit: "lalala" +` + expectErrs(t, `":1234": unknown ratelimit "lalala"`, + loadAndCheck(t, contents)) } func loadAndCheck(t *testing.T, contents string) []error { @@ -245,6 +269,12 @@ func TestRegexp(t *testing.T) { if err != unmarshalErr { t.Errorf("expected unmarshalErr, got %v", err) } + + // Test marshalling. + s, err := Regexp{orig: "ab.d"}.MarshalYAML() + if !(s == "ab.d" && err == nil) { + t.Errorf(`expected "ab.d" / nil, got %q / %v`, s, err) + } } func TestURL(t *testing.T) { @@ -278,4 +308,52 @@ func TestURL(t *testing.T) { } } +func TestRate(t *testing.T) { + r := Rate{} + err := yaml.Unmarshal([]byte(`"1000/2s"`), &r) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + expected := Rate{1000, 2 * time.Second} + if diff := cmp.Diff(expected, r); diff != "" { + t.Errorf("unexpected unmarshalling result (-want +got):\n%s", diff) + } + + // Error: missing '/'. + err = yaml.Unmarshal([]byte(`"1000"`), &r) + if !strings.Contains(err.Error(), "needs a single '/'") { + t.Errorf("expected error about missing '/', got %v", err) + } + + // Error: invalid requests. + err = yaml.Unmarshal([]byte(`"-1234/5s"`), &r) + if !strings.Contains(err.Error(), "invalid requests") { + t.Errorf("expected error about invalid requests, got %v", err) + } + + // Error: invalid period. + err = yaml.Unmarshal([]byte(`"1234/5"`), &r) + if !strings.Contains(err.Error(), "missing unit in duration") { + t.Errorf("expected error about invalid period, got %v", err) + } + + // Error: period is 0. + err = yaml.Unmarshal([]byte(`"1234/0s"`), &r) + if !strings.Contains(err.Error(), "period must be >0") { + t.Errorf("expected error about period being 0, got %v", err) + } + + // Test marshalling. + s, err := Rate{1000, 2 * time.Second}.MarshalYAML() + if !(s == "1000/2s" && err == nil) { + t.Errorf(`expected "1000/2s" / nil, got %q / %v`, s, err) + } + + // Test handling unmarshal error. + err = r.UnmarshalYAML(func(interface{}) error { return unmarshalErr }) + if err != unmarshalErr { + t.Errorf("expected unmarshalErr, got %v", err) + } +} + var unmarshalErr = fmt.Errorf("error unmarshalling for testing") diff --git a/config/gofer.schema.cue b/config/gofer.schema.cue index ecf783a..f4aed6a 100644 --- a/config/gofer.schema.cue +++ b/config/gofer.schema.cue @@ -14,6 +14,16 @@ reqlog?: format?: string }) +ratelimit?: + [string]: close({ + rate: string + size?: number + + rate64?: string + rate56?: string + rate48?: string + }) + http?: [string]: close(#http) @@ -57,6 +67,8 @@ https?: reqlog?: [string]: string + ratelimit?: [string]: string + ... } @@ -66,4 +78,5 @@ raw?: to: string to_tls?: bool reqlog?: string + ratelimit?: string }) diff --git a/config/gofer.yaml b/config/gofer.yaml index 1d27d9b..0e5a10a 100644 --- a/config/gofer.yaml +++ b/config/gofer.yaml @@ -22,6 +22,33 @@ reqlog: #format: "<gofer>" +# IP rate limiting. +ratelimit: + # Name of the IP rate limit arena; just an id used to refer to it on the + # server entries below. + "rl-arena1": + # Rate to enforce. + # Format is "requests/period". For example, "50/1s" will allow 50 + # requests every second. + rate: "50/1s" + + # How many IPs to hold in memory, to keep the memory usage bounded. + # Setting this to 1000 will increase memory usage by ~128 KiB. + # Default: 1000. + #size: 1000 + + # By default, IPv6 addresses get limited at /64, /56 and /48 + # simultaneously; and the rate for /64 is the one given above, for /56 is + # 4x the /64 rate, and /48 is 8x the /64 rate. + # This is an imperfect heuristic to account for the fact that IPv6 is + # allocated to end users in different block sizes, and it is not possible + # to tell them apart. + # You can configure custom rates for each one as follows: + #rate64: "50/1s" + #rate56: "200/500ms" + #rate48: "400/250ms" + + # HTTP servers. # Map of address: configuration. http: @@ -77,6 +104,12 @@ http: # "/": # "My-Header": "my header value" + # Enable IP rate limiting. The target is a rate limit arena name, which + # should match an entry in the top-level ratelimit configuration (see + # above). + ratelimit: + "/": "rl-arena1" + # Enable request logging. The target is a log name, which should match an # entry in the top-level reqlog configuration (see above). reqlog: diff --git a/debug/debug.go b/debug/debug.go index 546ca9f..45044a4 100644 --- a/debug/debug.go +++ b/debug/debug.go @@ -15,6 +15,7 @@ import ( "blitiri.com.ar/go/gofer/config" "blitiri.com.ar/go/gofer/nettrace" + "blitiri.com.ar/go/gofer/ratelimit" "blitiri.com.ar/go/log" ) @@ -95,6 +96,7 @@ func ServeDebugging(addr string, conf *config.Config) error { } http.HandleFunc("/debug/config", DumpConfigFunc(conf)) + http.HandleFunc("/debug/ratelimit", ratelimit.DebugHandler) nettrace.RegisterHandler(http.DefaultServeMux) http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/" { @@ -166,6 +168,7 @@ var htmlIndex = template.Must( <ul> <li><a href="/debug/config">configuration</a> <li><a href="/debug/traces">traces</a> + <li><a href="/debug/ratelimit">ratelimit</a> <li><a href="/debug/pprof">pprof</a> <small><a href="https://golang.org/pkg/net/http/pprof/"> (ref)</a></small> diff --git a/gofer.go b/gofer.go index 6a2ccbc..ec74846 100644 --- a/gofer.go +++ b/gofer.go @@ -10,6 +10,7 @@ import ( "blitiri.com.ar/go/gofer/config" "blitiri.com.ar/go/gofer/debug" + "blitiri.com.ar/go/gofer/ratelimit" "blitiri.com.ar/go/gofer/reqlog" "blitiri.com.ar/go/gofer/server" "blitiri.com.ar/go/log" @@ -58,6 +59,10 @@ func main() { } } + for name, rl := range conf.RateLimit { + ratelimit.FromConfig(name, rl) + } + servers := []runnerFunc{} for addr, https := range conf.HTTPS { diff --git a/ipratelimit/ipratelimit.go b/ipratelimit/ipratelimit.go new file mode 100644 index 0000000..cf867bd --- /dev/null +++ b/ipratelimit/ipratelimit.go @@ -0,0 +1,448 @@ +// Package ipratelimit implements a per-IP rate limiter. +// +// It implements a Limiter, which is configured with a maximum number of +// requests per given time period to allow per IP address. +// +// The Limiter has a fixed maximum size, to limit memory usage. If the maximum +// is reached, older entries are evicted. +// +// It is safe for concurrent use. +// +// The main use-case is to help prevent abuse, not to perform accurate request +// accounting/throttling. The implementation choices reflect this. +// +// For IPv4 addresses, we use the full address as the limiting key. +// +// For IPv6 addresses, since end users are usually assigned a range of +// /64, /56 or /48, we use the following heuristic: There are 3 rate limiters, +// one for each of the common subnet masks (/48, /56, /64). They operate in +// parallel, and any can deny access. +// By default, the rate for /64 is the one given, the rate for /56 is 4x, and +// the rate for /48 is 8x; these rates can be individually configured if +// needed. +// +// Note that rate-limiting 0.0.0.0 is not supported. It will be automatically +// treated as 0.0.0.1. The same applies to IPv6. +package ipratelimit // blitiri.com.ar/go/gofer/ipratelimit + +import ( + "encoding/binary" + "fmt" + "math/big" + "net" + "sync" + "time" +) + +// For IPv4, we use the IP addresses just as they are, nothing fancy. +// +// For IPv6, the main challenge is that the key space is too large, and that +// users get assigned vast ranges (/48, /56, or /64). If we pick too narrow, +// we allow DoS bypass. If it's too wide, we would over-block. +// We could do fancy heuristics for coalescing entries, but it gets +// computationally expensive. +// So we use use an intermediate solution: we keep 3 rate limiters, one for +// each of the common subnet masks. They have different limits to decrease the +// chances of over-blocking. This is probably okay for coarse abuse +// prevention, but not good for precise rate limiting. + +// Useful articles and references on IP/HTTP rate limiting and IPv6 +// assignment, for convenience: +// - https://adam-p.ca/blog/2022/02/ipv6-rate-limiting/ +// - https://caddyserver.com/docs/json/apps/http/servers/routes/handle/rate_limit/ +// - https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers +// - https://dev.to/satrobit/rate-limiting-in-ipv6-era-using-probabilistic-data-structures-15on +// - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After +// - https://konghq.com/blog/engineering/how-to-design-a-scalable-rate-limiting-algorithm +// - https://serverfault.com/questions/863501/blocking-or-rate-limiting-ipv6-what-size-prefix +// - https://www.nginx.com/blog/rate-limiting-nginx/ +// - https://www.rfc-editor.org/rfc/rfc9110#field.retry-after +// - https://www.ripe.net/publications/docs/ripe-690 + +// Possible future changes: +// - AllowWithInfo returning (bool, uint64, time.Duration) to indicate how +// many more requests are left within this time window, and when the next +// window starts. +// - AllowOrSleep that sleeps until a request is allowed. +// - AllowN that allows N requests at once (can be used for weighting +// some requests differently). + +// Convert an IPv4 to our uint64 representation. +// ip4 MUST be an ipv4 address. It is not checked for performance reasons +// (this improves ipv4 performance by 1.5% to 3.5%). +func ipv4ToUint64(ip4 net.IP) uint64 { + return uint64(binary.BigEndian.Uint32(ip4)) +} + +// Convert an IPv6 to a set of uint64 representations, one for /48, /56, and +// /64, as described above. +func ipv6ExtractMasks(ip net.IP) (ip48, ip56, ip64 uint64) { + ip64 = binary.BigEndian.Uint64(ip[0:8]) + ip56 = ip64 & 0xffff_ffff_ffff_ff00 + ip48 = ip64 & 0xffff_ffff_ffff_0000 + return +} + +type entry struct { + // Timestamp of the last request allowed. + // We could use time.Time, but this uses less memory (8 bytes vs 24), and + // improves performance by ~7-9% on all the "DifferentIPv*" benchmarks, + // with no negative impact on the rest. + lastAllowed miniTime + + // Requests left (since lastAllowed). + requestsLeft uint64 + + // Prev and next keys in the LRU list. + // We use this as a doubly-linked list, to implement an LRU cache and keep + // the size of the entries map bounded. + // This could be implemented separately, but keeping it in line helps with + // performance and memory usage. + lruPrev, lruNext uint64 +} + +func (e *entry) reset() { + e.lastAllowed = 0 + e.requestsLeft = 0 + e.lruPrev = 0 + e.lruNext = 0 +} + +type limiter struct { + // Allow this many requests per period. + Requests uint64 + + // Allow requests at most once per this duration. + Period time.Duration + + // Maximum number of entries to keep track of. This is important to keep + // the memory usage bounded. + Size int + + // Pool of free entries, to reduce/avoid allocations. + // This results in an 15-35% reduction in operation latency on this + // package's benchmarks. + entryPool sync.Pool + + // Protects the mutable fields below. + // + // This is a single lock for the whole limiter, and the data layout and + // implementation take significant advantage of this (e.g. the embedded + // LRU list on each entry). + // + // This results in very fast operations when contention is low to + // moderate, which is what we're optimizing for. + // + // Even on parallel benchmarks that have a fair amount of contention, this + // does alright compared to other libraries that optimize for that. + // However, it does not scale as well to very high contention scenarios + // (e.g. parallel benchmarks on >32 cores). + // + // To improve high contention performance, we could do fine grained + // locking in a variety of ways (including sharding the limiters at a high + // level), although that often causes large performance regressions or + // when contention is low to moderate, which is our main use case. + mu sync.Mutex + + // Map of key (IP in uint64 form) to limiter entry. + m map[uint64]*entry + + // LRU doubly-linked list first and last entry. + // 0 means "not present", which happens when the list is empty. This works + // only because we never expect to have 0 (address 0.0.0.0 / 0::0) as a + // valid key. + lruFirst, lruLast uint64 +} + +func newlimiter(req uint64, period time.Duration, size int) *limiter { + l := &limiter{ + Requests: req, + Period: period, + Size: size, + + m: make(map[uint64]*entry, size), + } + + l.entryPool.New = func() any { return &entry{} } + return l +} + +// lruBump moves the key to the top of the LRU list. +func (l *limiter) lruBump(key uint64, e *entry) { + if l.lruFirst == key { + return + } + + // Update the last pointer (if this key is the last one). + if l.lruLast == key { + l.lruLast = e.lruPrev + } + + // Take the key out of the list chain. + if e.lruPrev != 0 { + l.m[e.lruPrev].lruNext = e.lruNext + } + if e.lruNext != 0 { + l.m[e.lruNext].lruPrev = e.lruPrev + } + + // Update the current first element. + if l.lruFirst != 0 { + l.m[l.lruFirst].lruPrev = key + } + + // Adjust the key's entry pointers to be at the beginning. + e.lruNext = l.lruFirst + e.lruPrev = 0 + + // Set this key as the new first element. + l.lruFirst = key +} + +// lruPrepend adds an element to the top of the list. If the list is full, the +// last element is removed and its entry is returned. +func (l *limiter) lruPrepend(key uint64, e *entry) { + if l.lruFirst == 0 { + l.lruFirst = key + l.lruLast = key + return + } + + // Add the new element to the beginning of the list. + e.lruNext = l.lruFirst + l.m[l.lruFirst].lruPrev = key + l.lruFirst = key + + // If we're over capacity, remove the last element. + if len(l.m) > l.Size { + lastK := l.lruLast + lastE := l.m[l.lruLast] + l.lruLast = lastE.lruPrev + l.m[l.lruLast].lruNext = 0 + delete(l.m, lastK) + l.entryPool.Put(lastE) + } +} + +// For testing. +var timeNow = time.Now + +func (l *limiter) allow(key uint64) bool { + now := timeNow() + + if key == 0 { + // We use 0 as the "null" key, because 0.0.0.0 or ::0 IP addresses are + // not expected to be rate-limit targets. For IPv4 this is usually an + // incorrect test and it is harmless. For IPv6 this happens + // mainly on localhost requests, because ::1 will get masked to ::0 + // after masking. This is a bit unfortunate, but it's not a big deal. + // We force the key to be 1 in those cases. + key = 1 + } + + if l.Requests == 0 { + // Always limiting, no need to compute anything. + return false + } + + l.mu.Lock() + defer l.mu.Unlock() + + e, ok := l.m[key] + if !ok { + // It's a new entry. + e = l.entryPool.Get().(*entry) + e.reset() + + l.m[key] = e + l.lruPrepend(key, e) + } else { + // Pre-existing entry, just update the LRU. + l.lruBump(key, e) + } + + // Decide if we should allow the request. + if sinceMiniTime(now, e.lastAllowed) >= l.Period { + e.lastAllowed = makeMiniTime(now) + e.requestsLeft = l.Requests - 1 + return true + } else if e.requestsLeft > 0 { + e.requestsLeft-- + return true + } + + return false +} + +// Limiter is a rate limiter that keeps track of requests per IP address. +type Limiter struct { + // Individual limiters per IP type. + ipv4, ip48, ip56, ip64 *limiter +} + +// NewLimiter creates a new Limiter. Per IP address, up to `requests` per +// `period` will be allowed. Once they're exhausted, requests will be denied +// until `period` has passed from the first approved request. Size is the +// maximum number of IP addresses to keep track of; when that size is reached, +// older entries are removed. See the package documentation for more details +// on how IPv6 addresses are handled. +func NewLimiter(requests uint64, period time.Duration, size int) *Limiter { + return &Limiter{ + ipv4: newlimiter(requests, period, size), + ip64: newlimiter(requests, period, size), + ip56: newlimiter(requests, period/4, size), + ip48: newlimiter(requests, period/8, size), + } +} + +// SetIPv6s64Rate sets the rate limit for IPv6 addresses with /64 prefixes. +// It can only be changed before any requests are made. +func (l *Limiter) SetIPv6s64Rate(req uint64, per time.Duration) { + l.ip64.Requests = req + l.ip64.Period = per +} + +// SetIPv6s56Rate sets the rate limit for IPv6 addresses with /56 prefixes. +// It can only be changed before any requests are made. +func (l *Limiter) SetIPv6s56Rate(req uint64, per time.Duration) { + l.ip56.Requests = req + l.ip56.Period = per +} + +// SetIPv6s48Rate sets the rate limit for IPv6 addresses with /48 prefixes. +// It can only be changed before any requests are made. +func (l *Limiter) SetIPv6s48Rate(req uint64, per time.Duration) { + l.ip48.Requests = req + l.ip48.Period = per +} + +// Allow checks if the given IP address is allowed to make a request. +func (l *Limiter) Allow(ip net.IP) bool { + if ip4 := ip.To4(); ip4 != nil { + // Convert the IPv4 address to a 64-bit integer, and use that as key. + return l.ipv4.allow(ipv4ToUint64(ip4)) + } + + // Check if the three masks for ipv6. All must be allowed for the request + // to be allowed. + a48, a56, a64 := l.allowV6(ip) + return a48 && a56 && a64 +} + +func (l *Limiter) allowV6(ip net.IP) (a48, a56, a64 bool) { + ip48, ip56, ip64 := ipv6ExtractMasks(ip) + return l.ip48.allow(ip48), l.ip56.allow(ip56), l.ip64.allow(ip64) +} + +// DebugString returns a string with debugging information about the limiter. +// This is useful for debugging, but not for production use. It is not +// guaranteed to be stable. +func (l *Limiter) DebugString() string { + s := "## IPv4\n\n" + s += l.ipv4.debugString(kToIPv4) + s += "\n\n" + s += "## IPv6\n\n" + s += "### /48\n\n" + s += l.ip48.debugString(kToIPv6) + s += "\n\n" + s += "### /56\n\n" + s += l.ip56.debugString(kToIPv6) + s += "\n\n" + s += "### /64\n\n" + s += l.ip64.debugString(kToIPv6) + s += "\n" + return s +} + +func (l *limiter) debugString(kToIP func(uint64) net.IP) string { + l.mu.Lock() + defer l.mu.Unlock() + + s := "" + s += fmt.Sprintf("Allow: %d / %v\n", l.Requests, l.Period) + s += fmt.Sprintf("Size: %d / %d\n", len(l.m), l.Size) + s += "\n" + k := l.lruFirst + for k != 0 { + e := l.m[k] + ip := kToIP(k) + last := sinceMiniTime(time.Now(), e.lastAllowed).Round( + time.Millisecond) + s += fmt.Sprintf("%-22s %3d requests left, last allowed %10s ago\n", + ip, e.requestsLeft, last) + k = e.lruNext + } + return s +} + +// DebugHTML returns a string with debugging information about the limiter, in +// HTML format (just content starting with `<h2>`, no meta-tags). This is +// useful for debugging, but not for production use. It is not guaranteed to +// be stable. +func (l *Limiter) DebugHTML() string { + s := "<h2>IPv4</h2>" + s += l.ipv4.debugHTML(kToIPv4) + s += "<h2>IPv6</h2>" + s += "<h3>/48</h3>" + s += l.ip48.debugHTML(kToIPv6) + s += "<h3>/56</h3>" + s += l.ip56.debugHTML(kToIPv6) + s += "<h3>/64</h3>" + s += l.ip64.debugHTML(kToIPv6) + return s +} + +func (l *limiter) debugHTML(kToIP func(uint64) net.IP) string { + l.mu.Lock() + defer l.mu.Unlock() + + s := fmt.Sprintf("Allow: %d / %v<br>\n", l.Requests, l.Period) + s += fmt.Sprintf("Size: %d / %d<br>\n", len(l.m), l.Size) + s += "<p>\n" + if l.lruFirst == 0 { + s += "(empty)<br>" + return s + } + + s += "<table>\n" + s += "<tr><th>IP</th><th>Requests left</th><th>Last allowed</th></tr>\n" + k := l.lruFirst + for k != 0 { + e := l.m[k] + ip := kToIP(k) + last := sinceMiniTime(time.Now(), e.lastAllowed).Round( + time.Millisecond) + s += fmt.Sprintf(`<tr><td class="ip">%v</td>`, ip) + s += fmt.Sprintf(`<td class="requests">%d</td>`, e.requestsLeft) + s += fmt.Sprintf(`<td class="last">%s</td></tr>`, last) + s += "\n" + k = e.lruNext + } + s += "</table>\n" + return s +} + +func kToIPv4(k uint64) net.IP { + return net.IPv4(byte(k>>24), byte(k>>16), byte(k>>8), byte(k)) +} + +func kToIPv6(k uint64) net.IP { + buf := make([]byte, 16) + b := big.NewInt(0).SetUint64(k) + b = b.Lsh(b, 64) + return net.IP(b.FillBytes(buf[:])) +} + +// miniTime is a small representation of time, as the number of nanoseconds +// elapsed since January 1, 1970 UTC. +// This is used to reduce memory footprint and improve performance. +type miniTime int64 + +func makeMiniTime(t time.Time) miniTime { + return miniTime(t.UnixNano()) +} + +func sinceMiniTime(now time.Time, old miniTime) time.Duration { + // time.Duration is an int64 nanosecond count, so we can just substract. + return time.Duration(now.UnixNano() - int64(old)) +} diff --git a/ipratelimit/ipratelimit_test.go b/ipratelimit/ipratelimit_test.go new file mode 100644 index 0000000..3429211 --- /dev/null +++ b/ipratelimit/ipratelimit_test.go @@ -0,0 +1,456 @@ +package ipratelimit + +import ( + "net" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestIPv4ToUint32(t *testing.T) { + cases := []struct { + ip net.IP + expected uint64 + }{ + {net.IPv4(0, 0, 0, 0), 0}, + {net.IPv4(1, 2, 3, 4), 0x01020304}, + } + for _, c := range cases { + v := ipv4ToUint64(c.ip.To4()) + if v != c.expected { + t.Errorf("IPv4ToUint32(%v) -> %v, expected %v", + c.ip, v, c.expected) + } + } +} + +func TestIPv6ExtractMasks(t *testing.T) { + cases := []struct { + ip string + eip48, eip56, eip64 uint64 + }{ + { + "0::1", 0, 0, 0, + }, + { + "1111:2222:3333:4444:5555:6666:7777:8888", + 0x1111222233330000, + 0x1111222233334400, + 0x1111222233334444, + }, + } + + for _, c := range cases { + ip := net.ParseIP(c.ip) + ip48, ip56, ip64 := ipv6ExtractMasks(ip) + if ip48 != c.eip48 || ip56 != c.eip56 || ip64 != c.eip64 { + t.Errorf("IP %q (%v)", c.ip, ip) + t.Errorf(" expected (%.16x, %.16x, %.16x)", + c.eip48, c.eip56, c.eip64) + t.Errorf(" got (%.16x, %.16x, %.16x)", + ip48, ip56, ip64) + } + } +} + +func as(yes, no int) []bool { + r := []bool{} + for i := 0; i < yes; i++ { + r = append(r, true) + } + for i := 0; i < no; i++ { + r = append(r, false) + } + return r +} + +func TestBasic(t *testing.T) { + cases := []struct { + reqs uint64 + period time.Duration + pkts uint64 + allowed []bool + }{ + {0, time.Second, 3, as(0, 3)}, + {1, time.Second, 1, as(1, 0)}, + {1, time.Second, 2, as(1, 1)}, + {2, time.Second, 2, as(2, 0)}, + {2, time.Second, 3, as(2, 1)}, + {10, time.Second, 20, as(10, 10)}, + } + for _, c := range cases { + l := NewLimiter(c.reqs, c.period, 256) + ip := net.IPv4(1, 2, 3, 4) + as := []bool{} + for i := uint64(0); i < c.pkts; i++ { + as = append(as, l.Allow(ip)) + } + if diff := cmp.Diff(c.allowed, as); diff != "" { + t.Errorf( + "[rate=%d/%v, pkts=%d, allowed=%v]"+ + " mismatch (-want +got):\n%s", + c.reqs, c.period, c.pkts, c.allowed, diff) + } + } +} + +func TestBasicIPv6(t *testing.T) { + operations := []struct { + ip string + allowed bool + }{ + {"1111:2222:3333:4444::a", true}, + {"1111:2222:3333:4444::b", false}, + {"1111:2222:3333:5555::c", false}, + {"1111:2222:3333:5500::d", false}, + {"1111:2222:3333::e", false}, + } + + l := NewLimiter(1, time.Second, 256) + for i, op := range operations { + ip := net.ParseIP(op.ip) + allowed := l.Allow(ip) + if allowed != op.allowed { + t.Errorf("operation %d: Allow(%v) -> %v, expected %v", + i, ip, allowed, op.allowed) + } + } +} + +func TestIPv6Subnetting(t *testing.T) { + // These two are equal in the first 64 bits, and differ at the end. + // So they should be counted as the same at all levels. + ip64a := net.ParseIP("1111:1111:1111:1111:aaaa::a") + ip64b := net.ParseIP("1111:1111:1111:1111:bbbb::b") + + // These two are equal in the first 56 bits, and differ at the end. + // So they should be counted as the same for /48 and /56, but not /64. + ip56a := net.ParseIP("2222:2222:2222:22aa::a") + ip56b := net.ParseIP("2222:2222:2222:22bb::b") + + // These two are equal in the first 48 bits, and differ at the end. + // So they should be counted as the same for /48, but not /56 or /64. + ip48a := net.ParseIP("3333:3333:3333:aaaa::a") + ip48b := net.ParseIP("3333:3333:3333:bbbb::b") + + operations := []struct { + ip net.IP + a48, a56, a64 bool + }{ + {ip64a, true, true, true}, + {ip64b, false, false, false}, + + {ip56a, true, true, true}, + {ip56b, false, false, true}, + + {ip48a, true, true, true}, + {ip48b, false, true, true}, + } + + l := NewLimiter(1, time.Second, 256) + for i, op := range operations { + a48, a56, a64 := l.allowV6(op.ip) + diff := cmp.Diff( + []bool{op.a48, op.a56, op.a64}, + []bool{a48, a56, a64}, + ) + if diff != "" { + t.Errorf("operation %d: Allow(%v) mismatch (-want +got):\n%s", + i, op.ip, diff) + } + } +} + +func TestSize(t *testing.T) { + sizes := []int{ + 1, 2, 3, 5, 8, 10, 100, 256, 10000, + } + for _, size := range sizes { + l := newlimiter(1, 0, size) + + // Note we avoid i=0 because we never expect the zero IP to be + // allowed. + i := 1 + + // First, run up to size to fill in the map. + for ; i < size+1; i++ { + ip := net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)) + if !l.allow(ipv4ToUint64(ip.To4())) { + t.Errorf("size %d, IP %v, i %d: not allowed", size, ip, i) + } + + if len(l.m) != i { + t.Errorf("size %d, IP %v, i %d: len %d != i %d", + size, ip, i, len(l.m), i) + } + } + + // Now do another size iterations, checking that the size of the maps + // stays constant. + for ; i < (size+1)*2; i++ { + ip := net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)) + if !l.allow(ipv4ToUint64(ip.To4())) { + t.Errorf("size %d, IP %v, i %d: not allowed", size, ip, i) + } + + if len(l.m) != size { + t.Errorf("size %d, IP %v, i %d: len %d != size %d", + size, ip, i, len(l.m), size) + } + } + } +} + +func TestLRU(t *testing.T) { + ip1 := net.IPv4(1, 1, 1, 1) + ip2 := net.IPv4(2, 2, 2, 2) + ip3 := net.IPv4(3, 3, 3, 3) + ip4 := net.IPv4(4, 4, 4, 4) + + // We're going to do a sequence of allow() calls, and check that the LRU + // list is as we expect after each one. + operations := []struct { + ip net.IP + lru []net.IP + }{ + {ip1, []net.IP{ip1}}, + + // Bump ip1 (it is a special case when there's only one element). + {ip1, []net.IP{ip1}}, + + // Add ip2 and ip3, all straightforward. + {ip2, []net.IP{ip2, ip1}}, + {ip3, []net.IP{ip3, ip2, ip1}}, + + // Add ip4, evict ip1 which is the oldest. + {ip4, []net.IP{ip4, ip3, ip2}}, + + // Add ip1, evict ip2 which is the oldest. + {ip1, []net.IP{ip1, ip4, ip3}}, + + // Bump ip3 (last one), twice in a row. + {ip3, []net.IP{ip3, ip1, ip4}}, + {ip3, []net.IP{ip3, ip1, ip4}}, + + // Bump ip1 (middle one). + {ip1, []net.IP{ip1, ip3, ip4}}, + } + + l := newlimiter(1, 0, 3) + for i, op := range operations { + l.allow(ipv4ToUint64(op.ip.To4())) + lru := getLRU(l) + + if diff := cmp.Diff(op.lru, lru); diff != "" { + t.Errorf("operation %d: allow(%v)", i, op.ip) + t.Errorf(" expected LRU %v, got %v", op.lru, lru) + t.Errorf(" diff (-want +got):\n%s", diff) + } + } +} + +func getLRU(l *limiter) []net.IP { + r := []net.IP{} + k := l.lruFirst + for k != 0 { + ip := net.IPv4(byte(k>>24), byte(k>>16), byte(k>>8), byte(k)) + r = append(r, ip) + k = l.m[k].lruNext + } + return r +} + +func TestZeroKey(t *testing.T) { + l := newlimiter(1, 0, 3) + if !l.allow(0) { + t.Errorf("allow(0) = false, want true") + } + lru := getLRU(l) + expected := []net.IP{net.IPv4(0, 0, 0, 1)} + if diff := cmp.Diff(expected, lru); diff != "" { + t.Errorf("allow(0):") + t.Errorf(" expected LRU %v, got %v", expected, lru) + t.Errorf(" diff (-want +got):\n%s", diff) + } +} + +var ( + nowSec = int64(0) + nowNsec = int64(0) +) + +func fakeTimeNow() time.Time { + return time.Unix(nowSec, nowNsec) +} + +func TestTime(t *testing.T) { + // Override the time function so we can control the time. + timeNow = fakeTimeNow + defer func() { timeNow = time.Now }() + + l := newlimiter(2, time.Second, 3) + check := func(want bool) { + t.Helper() + if got := l.allow(22); got != want { + t.Errorf("@%s: allow(22) = %v, want %v", timeNow(), got, want) + } + } + + nowSec = 500 + nowNsec = 1000 + check(true) // Request 1. + nowNsec = 1001 + check(true) // Request 2, last one allowed. + nowNsec = 1002 + check(false) // Request 3, limit exhausted. + + nowSec, nowNsec = 501, 999 + check(false) // Not yet 1s. + + nowNsec = 1000 + check(true) // Exactly 1s since last allowed. + nowNsec = 1001 + check(true) // Request 2, last one allowed. + nowNsec = 1003 + check(false) // Request 3, limit exhausted. +} + +func TestSetIPv6Rates(t *testing.T) { + check := func(l *limiter, req uint64, period time.Duration) { + t.Helper() + if l.Requests != req || l.Period != period { + t.Errorf("Requests / Period = %d / %v ; expect %d / %v", + l.Requests, l.Period, req, period) + } + } + l := NewLimiter(1, time.Second, 3) + check(l.ipv4, 1, time.Second) + check(l.ip64, 1, time.Second) + check(l.ip56, 1, time.Second/4) + check(l.ip48, 1, time.Second/8) + + l.SetIPv6s64Rate(64, time.Second/64) + check(l.ipv4, 1, time.Second) + check(l.ip64, 64, time.Second/64) + check(l.ip56, 1, time.Second/4) + check(l.ip48, 1, time.Second/8) + + l.SetIPv6s56Rate(56, time.Second/56) + check(l.ipv4, 1, time.Second) + check(l.ip64, 64, time.Second/64) + check(l.ip56, 56, time.Second/56) + check(l.ip48, 1, time.Second/8) + + l.SetIPv6s48Rate(48, time.Second/48) + check(l.ipv4, 1, time.Second) + check(l.ip64, 64, time.Second/64) + check(l.ip56, 56, time.Second/56) + check(l.ip48, 48, time.Second/48) +} + +func TestDebugString(t *testing.T) { + l := NewLimiter(1, time.Second, 3) + l.Allow(net.IPv4(1, 1, 1, 1)) + l.Allow(net.ParseIP("1111:2222:3333:4444:5555:6666:7777:8888")) + t.Logf(l.DebugString()) +} + +func TestDebugHTML(t *testing.T) { + l := NewLimiter(1, time.Second, 3) + l.Allow(net.IPv4(1, 1, 1, 1)) + l.Allow(net.ParseIP("1111:2222:3333:4444:5555:6666:7777:8888")) + t.Logf(l.DebugHTML()) +} + +func BenchmarkDifferentIPv4_256(b *testing.B) { + l := NewLimiter(1, time.Second, 256) + for i := 0; i < b.N; i++ { + l.Allow(net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i))) + } +} + +func BenchmarkDifferentIPv4_10000(b *testing.B) { + l := NewLimiter(1, time.Second, 10000) + for i := 0; i < b.N; i++ { + l.Allow(net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i))) + } +} + +func BenchmarkSameIPv4_Strict(b *testing.B) { + l := NewLimiter(1, time.Second, 256) + ip := net.IPv4(1, 2, 3, 4) + for i := 0; i < b.N; i++ { + l.Allow(ip) + } +} + +func BenchmarkSameIPv4_Bursty(b *testing.B) { + l := NewLimiter(100, time.Second, 256) + ip := net.IPv4(1, 2, 3, 4) + for i := 0; i < b.N; i++ { + l.Allow(ip) + } +} + +func BenchmarkSameIPv4_Strict_Parallel(b *testing.B) { + l := NewLimiter(1, time.Second, 256) + ip := net.IPv4(1, 2, 3, 4) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + l.Allow(ip) + } + }) +} + +func BenchmarkDifferentIPv6_256(b *testing.B) { + l := NewLimiter(1, time.Second, 256) + ip := net.ParseIP("1111:2222:3333:4444:5555:6666:7777:8888") + for i := 0; i < b.N; i++ { + // Change the IP at different levels, so they don't all fall under the + // same /64, /56, /48. + ip[7] = byte(i) + ip[6] = byte(i >> 8) + ip[5] = byte(i) + ip[4] = byte(i >> 8) + l.Allow(ip) + } +} + +func BenchmarkDifferentIPv6_10000(b *testing.B) { + l := NewLimiter(1, time.Second, 10000) + ip := net.ParseIP("1111:2222:3333:4444:5555:6666:7777:8888") + for i := 0; i < b.N; i++ { + ip[7] = byte(i) + ip[6] = byte(i >> 8) + ip[5] = byte(i) + ip[4] = byte(i >> 8) + l.Allow(ip) + } +} + +func BenchmarkSameIPv6_Strict(b *testing.B) { + l := NewLimiter(1, time.Second, 256) + ip := net.ParseIP("1111:2222:3333:4444:5555:6666:7777:8888") + for i := 0; i < b.N; i++ { + l.Allow(ip) + } +} + +func BenchmarkSameIPv6_Bursty(b *testing.B) { + l := NewLimiter(100, time.Second, 256) + ip := net.ParseIP("1111:2222:3333:4444:5555:6666:7777:8888") + for i := 0; i < b.N; i++ { + l.Allow(ip) + } +} + +func BenchmarkSameIPv6_Strict_Parallel(b *testing.B) { + l := NewLimiter(100, time.Second, 256) + ip := net.ParseIP("1111:2222:3333:4444:5555:6666:7777:8888") + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + l.Allow(ip) + } + }) +} diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go new file mode 100644 index 0000000..35d7723 --- /dev/null +++ b/ratelimit/ratelimit.go @@ -0,0 +1,106 @@ +package ratelimit + +import ( + "fmt" + "net/http" + "sort" + + "blitiri.com.ar/go/gofer/config" + "blitiri.com.ar/go/gofer/ipratelimit" + "blitiri.com.ar/go/gofer/trace" + "blitiri.com.ar/go/log" +) + +// Global registry for convenience. +// This is not pretty but it simplifies a lot of the handling for now. +var registry = map[string]*ipratelimit.Limiter{} + +var traces = map[*ipratelimit.Limiter]*trace.Trace{} + +func FromConfig(name string, conf config.RateLimit) { + if conf.Size == 0 { + conf.Size = 1000 + } + + rl := ipratelimit.NewLimiter( + conf.Rate.Requests, conf.Rate.Period, conf.Size) + + // If config has custom IPv6 rates, use them. + if conf.Rate64.Period > 0 { + rl.SetIPv6s64Rate(conf.Rate64.Requests, conf.Rate64.Period) + } + if conf.Rate56.Period > 0 { + rl.SetIPv6s56Rate(conf.Rate56.Requests, conf.Rate56.Period) + } + if conf.Rate48.Period > 0 { + rl.SetIPv6s48Rate(conf.Rate48.Requests, conf.Rate48.Period) + } + + registry[name] = rl + traces[rl] = trace.New("ratelimit", name) + traces[rl].SetMaxEvents(1000) + + log.Infof("ratelimit %q: %d/%s, size %d", + name, conf.Rate.Requests, conf.Rate.Period, conf.Size) + return +} + +func FromName(name string) *ipratelimit.Limiter { + return registry[name] +} + +func Trace(rl *ipratelimit.Limiter) *trace.Trace { + return traces[rl] +} + +func DebugHandler(w http.ResponseWriter, r *http.Request) { + names := []string{} + for name := range registry { + names = append(names, name) + } + sort.Strings(names) + + fmt.Fprintf(w, `<!DOCTYPE html> +<html> + +<head> +<meta name="viewport" content="width=device-width, initial-scale=1"> +<title>ratelimit</title> +<style type="text/css"> + body { + font-family: sans-serif; + } + @media (prefers-color-scheme: dark) { + body { + background: #121212; + color: #c9d1d9; + } + a { color: #44b4ec; } + } + table { + text-align: right; + } + th { + text-align: center; + } + td, th { + padding: 0.15em 0.5em; + } + td.ip { + min-width: 10em; + text-align: left; + font-family: monospace; + } +</style> +</head> + +<body> +`) + + for _, name := range names { + fmt.Fprintf(w, "<h1>%s</h1>\n\n%s\n\n", + name, registry[name].DebugHTML()) + } + + fmt.Fprintf(w, "</body>\n</html>\n") +} diff --git a/server/http.go b/server/http.go index 384e9fd..2a6a853 100644 --- a/server/http.go +++ b/server/http.go @@ -7,6 +7,7 @@ import ( "fmt" "io" golog "log" + "net" "net/http" "net/http/cgi" "net/http/httputil" @@ -15,6 +16,8 @@ import ( "time" "blitiri.com.ar/go/gofer/config" + "blitiri.com.ar/go/gofer/ipratelimit" + "blitiri.com.ar/go/gofer/ratelimit" "blitiri.com.ar/go/gofer/reqlog" "blitiri.com.ar/go/gofer/trace" "blitiri.com.ar/go/gofer/util" @@ -124,6 +127,21 @@ func httpServer(addr string, conf config.HTTP) (*http.Server, error) { // Tracing for all entries. srv.Handler = WithTrace("http@"+srv.Addr, srv.Handler) + // Rate limiting goes outside of tracing, to avoid polluting per-protocol + // traces with rate-limited events (we trace those separately). + if len(conf.RateLimit) > 0 { + rlMux := http.NewServeMux() + for path, rlName := range conf.RateLimit { + l := ratelimit.FromName(rlName) + rlMux.Handle(path, WithRateLimit(srv.Handler, l)) + log.Infof("%s ratelimit %q to %q", srv.Addr, path, rlName) + } + if _, ok := conf.RateLimit["/"]; !ok { + rlMux.Handle("/", srv.Handler) + } + srv.Handler = rlMux + } + return srv, nil } @@ -491,3 +509,33 @@ func reqLog(r *http.Request, status int, length int64, latency time.Duration) { Latency: latency, }) } + +func WithRateLimit(parent http.Handler, rl *ipratelimit.Limiter) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var ip net.IP + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + ratelimit.Trace(rl).Errorf( + "[http] failed to split remote address %q: %v", + r.RemoteAddr, err) + goto allow + } + + ip = net.ParseIP(host) + if ip == nil { + ratelimit.Trace(rl).Errorf( + "[http] failed to parse IP address: %q", r.RemoteAddr) + goto allow + } + + if !rl.Allow(ip) { + ratelimit.Trace(rl).Printf( + "[http] rate limit exceeded for %q", ip) + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + return + } + + allow: + parent.ServeHTTP(w, r) + }) +} diff --git a/server/raw.go b/server/raw.go index b5bcd15..937f536 100644 --- a/server/raw.go +++ b/server/raw.go @@ -7,6 +7,8 @@ import ( "time" "blitiri.com.ar/go/gofer/config" + "blitiri.com.ar/go/gofer/ipratelimit" + "blitiri.com.ar/go/gofer/ratelimit" "blitiri.com.ar/go/gofer/reqlog" "blitiri.com.ar/go/gofer/trace" "blitiri.com.ar/go/gofer/util" @@ -37,6 +39,7 @@ func Raw(addr string, conf config.Raw) error { } rlog := reqlog.FromName(conf.ReqLog) + lim := ratelimit.FromName(conf.RateLimit) log.Infof("%s raw proxy starting on %q", addr, lis.Addr()) for { @@ -45,14 +48,38 @@ func Raw(addr string, conf config.Raw) error { return log.Errorf("%s error accepting: %v", addr, err) } - go forward(conn, conf.To, conf.ToTLS, rlog) + go forward(conn, conf.To, conf.ToTLS, rlog, lim) } } -func forward(src net.Conn, dstAddr string, dstTLS bool, rlog *reqlog.Log) { +func allowed(addr net.Addr, lim *ipratelimit.Limiter) bool { + // We only support raw proxying over TCP, so we can assume the address is + // a TCP address. If not, fail-open just to be safe. + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + ratelimit.Trace(lim).Errorf( + "[raw] non-TCP address %q", addr) + return true + } + + if !lim.Allow(tcpAddr.IP) { + ratelimit.Trace(lim).Printf( + "[raw] rate limit exceeded for %q", tcpAddr.IP) + return false + } + + return true +} + +func forward(src net.Conn, dstAddr string, dstTLS bool, + rlog *reqlog.Log, lim *ipratelimit.Limiter) { defer src.Close() start := time.Now() + if lim != nil && !allowed(src.RemoteAddr(), lim) { + return + } + tr := trace.New("raw", fmt.Sprintf("%s -> %s", src.LocalAddr(), dstAddr)) defer tr.Finish() diff --git a/server/raw_test.go b/server/raw_test.go new file mode 100644 index 0000000..3bb6cb6 --- /dev/null +++ b/server/raw_test.go @@ -0,0 +1,34 @@ +package server + +import ( + "net" + "testing" + "time" + + "blitiri.com.ar/go/gofer/config" + "blitiri.com.ar/go/gofer/ratelimit" +) + +func TestAllowedOnNonTCP(t *testing.T) { + // Use a rate limit with 0 requests per second to disable ratelimiting. + ratelimit.FromConfig("test-rl", config.RateLimit{ + Rate: config.Rate{Requests: 0, Period: time.Second}}) + rl := ratelimit.FromName("test-rl") + + tcp := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234} + if allowed(tcp, rl) { + t.Errorf("allowed(tcp %v) = true, expected false", tcp) + } + + // Try a few different non-TCP addresses, to make sure we fail-open on + // them. + addrs := []net.Addr{ + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}, + &net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}, + } + for _, addr := range addrs { + if !allowed(addr, rl) { + t.Errorf("allowed(%v) = false, expected true", addr) + } + } +} diff --git a/test/01-fe.yaml b/test/01-fe.yaml index b7f4234..992688e 100644 --- a/test/01-fe.yaml +++ b/test/01-fe.yaml @@ -20,16 +20,30 @@ _routes: &routes proxy: "http://localhost:8450/dir/" "/gogo/": redirect: "https://google.com" + "/rlme/": + proxy: "http://localhost:8450/cgi/" + reqlog: "requests": file: ".01-fe.requests.log" +ratelimit: + "rl": + rate: 1/1s + rate64: 1/1s + rate56: 1/500ms + rate48: 1/250ms + "raw-rl": + rate: 1/1s + http: ":8441": routes: *routes reqlog: "/": "requests" + ratelimit: + "/rlme/": "rl" https: ":8442": @@ -68,3 +82,8 @@ raw: to: "localhost:8442" to_tls: true reqlog: "requests" + + ":8449": + to: "localhost:8450" + reqlog: "requests" + ratelimit: "raw-rl" diff --git a/test/test.sh b/test/test.sh index aa540f1..8362ee2 100755 --- a/test/test.sh +++ b/test/test.sh @@ -259,6 +259,8 @@ exp "http://127.0.0.1:8459/" -bodyre "gofer @" # Check that the debug / handler only serves /. exp "http://127.0.0.1:8459/notexists" -status 404 +# Rate-limiting debug handler. +exp "http://127.0.0.1:8440/debug/ratelimit" -bodyre "Allow: 1 / 1s" echo "### Raw proxying" exp http://localhost:8445/file -body "ñaca\n" @@ -271,8 +273,47 @@ if ! waitgrep -q ":8447 = 500" .01-fe.requests.log; then exit 1 fi + +echo "### Rate limiting (http)" +# First request must be allowed. +exp http://localhost:8441/rlme/0 -status 200 + +# Somewhere in these, we should start to get rejected (likely from the +# beginning, but there could be timing issues). +for i in `seq 1 3`; do + exp http://localhost:8441/rlme/$i -statuslist 200,429 +done + +# By this stage, they should all be rejected. +for i in `seq 4 6`; do + exp http://localhost:8441/rlme/$i -status 429 +done + + +echo "### Rate limiting (raw)" +# Because these are raw proxies, we don't get nice HTTP status on rejections, +# so we count errors instead. +# We give it a rate of 1/1s, and perform 6 requests in quick succession. +# Expect at least 1 success and 3 errors. +NSUCCESS=0 +NERR=0 +for i in `seq 1 6`; do + if exp http://localhost:8449/file >> .exp-raw-rl.log 2>&1; then + NSUCCESS=$(( NSUCCESS + 1 )) + else + NERR=$(( NERR + 1 )) + fi +done +if [ $NSUCCESS -lt 1 ] || [ $NERR -lt 3 ]; then + echo "expected >=1 successes and >=3 errors, but" \ + "got $NSUCCESS successes and $NERR errors" + exit 1 +fi + + echo "### Checking examples from doc/examples.md" ./util/check-examples.sh + echo "## Success" snoop diff --git a/test/util/exp/exp.go b/test/util/exp/exp.go index 06be6ab..5f876a4 100644 --- a/test/util/exp/exp.go +++ b/test/util/exp/exp.go @@ -34,6 +34,8 @@ func main() { "expect a redirect to this URL") status = flag.Int("status", 200, "expect this status code") + statusList = flag.String("statuslist", "", + "expect this comma-separated list of status codes") verbose = flag.Bool("v", false, "enable verbose output") hdrRE = flag.String("hdrre", "", @@ -91,7 +93,20 @@ func main() { fmt.Printf("\n") } - if resp.StatusCode != *status { + if *statusList != "" { + statuses := strings.Split(*statusList, ",") + found := false + for _, s := range statuses { + si, _ := strconv.Atoi(s) + if resp.StatusCode == si { + found = true + break + } + } + if !found { + errorf("status %d not in list: %v\n", resp.StatusCode, statuses) + } + } else if resp.StatusCode != *status { errorf("status is not %d: %q\n", *status, resp.Status) }