git » gofer » main » tree

[main] / ipratelimit / ipratelimit.go

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
// Package ipratelimit implements a per-IP rate limiter.
//
// It implements a Limiter, which is configured with a maximum number of
// requests per given time period to allow per IP address.
//
// The Limiter has a fixed maximum size, to limit memory usage. If the maximum
// is reached, older entries are evicted.
//
// It is safe for concurrent use.
//
// The main use-case is to help prevent abuse, not to perform accurate request
// accounting/throttling. The implementation choices reflect this.
//
// For IPv4 addresses, we use the full address as the limiting key.
//
// For IPv6 addresses, since end users are usually assigned a range of
// /64, /56 or /48, we use the following heuristic: There are 3 rate limiters,
// one for each of the common subnet masks (/48, /56, /64). They operate in
// parallel, and any can deny access.
// By default, the rate for /64 is the one given, the rate for /56 is 4x, and
// the rate for /48 is 8x; these rates can be individually configured if
// needed.
//
// Note that rate-limiting 0.0.0.0 is not supported. It will be automatically
// treated as 0.0.0.1. The same applies to IPv6.
package ipratelimit // blitiri.com.ar/go/gofer/ipratelimit

import (
	"encoding/binary"
	"fmt"
	"math/big"
	"net"
	"sync"
	"time"
)

// For IPv4, we use the IP addresses just as they are, nothing fancy.
//
// For IPv6, the main challenge is that the key space is too large, and that
// users get assigned vast ranges (/48, /56, or /64). If we pick too narrow,
// we allow DoS bypass. If it's too wide, we would over-block.
// We could do fancy heuristics for coalescing entries, but it gets
// computationally expensive.
// So we use use an intermediate solution: we keep 3 rate limiters, one for
// each of the common subnet masks. They have different limits to decrease the
// chances of over-blocking. This is probably okay for coarse abuse
// prevention, but not good for precise rate limiting.

// Useful articles and references on IP/HTTP rate limiting and IPv6
// assignment, for convenience:
//   - https://adam-p.ca/blog/2022/02/ipv6-rate-limiting/
//   - https://caddyserver.com/docs/json/apps/http/servers/routes/handle/rate_limit/
//   - https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-ratelimit-headers
//   - https://dev.to/satrobit/rate-limiting-in-ipv6-era-using-probabilistic-data-structures-15on
//   - https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
//   - https://konghq.com/blog/engineering/how-to-design-a-scalable-rate-limiting-algorithm
//   - https://serverfault.com/questions/863501/blocking-or-rate-limiting-ipv6-what-size-prefix
//   - https://www.nginx.com/blog/rate-limiting-nginx/
//   - https://www.rfc-editor.org/rfc/rfc9110#field.retry-after
//   - https://www.ripe.net/publications/docs/ripe-690

// Possible future changes:
//   - AllowWithInfo returning (bool, uint64, time.Duration) to indicate how
//     many more requests are left within this time window, and when the next
//     window starts.
//   - AllowOrSleep that sleeps until a request is allowed.
//   - AllowN that allows N requests at once (can be used for weighting
//     some requests differently).

// Convert an IPv4 to our uint64 representation.
// ip4 MUST be an ipv4 address. It is not checked for performance reasons
// (this improves ipv4 performance by 1.5% to 3.5%).
func ipv4ToUint64(ip4 net.IP) uint64 {
	return uint64(binary.BigEndian.Uint32(ip4))
}

// Convert an IPv6 to a set of uint64 representations, one for /48, /56, and
// /64, as described above.
func ipv6ExtractMasks(ip net.IP) (ip48, ip56, ip64 uint64) {
	ip64 = binary.BigEndian.Uint64(ip[0:8])
	ip56 = ip64 & 0xffff_ffff_ffff_ff00
	ip48 = ip64 & 0xffff_ffff_ffff_0000
	return
}

type entry struct {
	// Timestamp of the last request allowed.
	// We could use time.Time, but this uses less memory (8 bytes vs 24), and
	// improves performance by ~7-9% on all the "DifferentIPv*" benchmarks,
	// with no negative impact on the rest.
	lastAllowed miniTime

	// Requests left (since lastAllowed).
	requestsLeft uint64

	// Prev and next keys in the LRU list.
	// We use this as a doubly-linked list, to implement an LRU cache and keep
	// the size of the entries map bounded.
	// This could be implemented separately, but keeping it in line helps with
	// performance and memory usage.
	lruPrev, lruNext uint64
}

func (e *entry) reset() {
	e.lastAllowed = 0
	e.requestsLeft = 0
	e.lruPrev = 0
	e.lruNext = 0
}

type limiter struct {
	// Allow this many requests per period.
	Requests uint64

	// Allow requests at most once per this duration.
	Period time.Duration

	// Maximum number of entries to keep track of. This is important to keep
	// the memory usage bounded.
	Size int

	// Pool of free entries, to reduce/avoid allocations.
	// This results in an 15-35% reduction in operation latency on this
	// package's benchmarks.
	entryPool sync.Pool

	// Protects the mutable fields below.
	//
	// This is a single lock for the whole limiter, and the data layout and
	// implementation take significant advantage of this (e.g. the embedded
	// LRU list on each entry).
	//
	// This results in very fast operations when contention is low to
	// moderate, which is what we're optimizing for.
	//
	// Even on parallel benchmarks that have a fair amount of contention, this
	// does alright compared to other libraries that optimize for that.
	// However, it does not scale as well to very high contention scenarios
	// (e.g. parallel benchmarks on >32 cores).
	//
	// To improve high contention performance, we could do fine grained
	// locking in a variety of ways (including sharding the limiters at a high
	// level), although that often causes large performance regressions or
	// when contention is low to moderate, which is our main use case.
	mu sync.Mutex

	// Map of key (IP in uint64 form) to limiter entry.
	m map[uint64]*entry

	// LRU doubly-linked list first and last entry.
	// 0 means "not present", which happens when the list is empty. This works
	// only because we never expect to have 0 (address 0.0.0.0 / 0::0) as a
	// valid key.
	lruFirst, lruLast uint64
}

func newlimiter(req uint64, period time.Duration, size int) *limiter {
	l := &limiter{
		Requests: req,
		Period:   period,
		Size:     size,

		m: make(map[uint64]*entry, size),
	}

	l.entryPool.New = func() any { return &entry{} }
	return l
}

// lruBump moves the key to the top of the LRU list.
func (l *limiter) lruBump(key uint64, e *entry) {
	if l.lruFirst == key {
		return
	}

	// Update the last pointer (if this key is the last one).
	if l.lruLast == key {
		l.lruLast = e.lruPrev
	}

	// Take the key out of the list chain.
	if e.lruPrev != 0 {
		l.m[e.lruPrev].lruNext = e.lruNext
	}
	if e.lruNext != 0 {
		l.m[e.lruNext].lruPrev = e.lruPrev
	}

	// Update the current first element.
	if l.lruFirst != 0 {
		l.m[l.lruFirst].lruPrev = key
	}

	// Adjust the key's entry pointers to be at the beginning.
	e.lruNext = l.lruFirst
	e.lruPrev = 0

	// Set this key as the new first element.
	l.lruFirst = key
}

// lruPrepend adds an element to the top of the list. If the list is full, the
// last element is removed and its entry is returned.
func (l *limiter) lruPrepend(key uint64, e *entry) {
	if l.lruFirst == 0 {
		l.lruFirst = key
		l.lruLast = key
		return
	}

	// Add the new element to the beginning of the list.
	e.lruNext = l.lruFirst
	l.m[l.lruFirst].lruPrev = key
	l.lruFirst = key

	// If we're over capacity, remove the last element.
	if len(l.m) > l.Size {
		lastK := l.lruLast
		lastE := l.m[l.lruLast]
		l.lruLast = lastE.lruPrev
		l.m[l.lruLast].lruNext = 0
		delete(l.m, lastK)
		l.entryPool.Put(lastE)
	}
}

// For testing.
var timeNow = time.Now

func (l *limiter) allow(key uint64) bool {
	now := timeNow()

	if key == 0 {
		// We use 0 as the "null" key, because 0.0.0.0 or ::0 IP addresses are
		// not expected to be rate-limit targets. For IPv4 this is usually an
		// incorrect test and it is harmless. For IPv6 this happens
		// mainly on localhost requests, because ::1 will get masked to ::0
		// after masking. This is a bit unfortunate, but it's not a big deal.
		// We force the key to be 1 in those cases.
		key = 1
	}

	if l.Requests == 0 {
		// Always limiting, no need to compute anything.
		return false
	}

	l.mu.Lock()
	defer l.mu.Unlock()

	e, ok := l.m[key]
	if !ok {
		// It's a new entry.
		e = l.entryPool.Get().(*entry)
		e.reset()

		l.m[key] = e
		l.lruPrepend(key, e)
	} else {
		// Pre-existing entry, just update the LRU.
		l.lruBump(key, e)
	}

	// Decide if we should allow the request.
	if sinceMiniTime(now, e.lastAllowed) >= l.Period {
		e.lastAllowed = makeMiniTime(now)
		e.requestsLeft = l.Requests - 1
		return true
	} else if e.requestsLeft > 0 {
		e.requestsLeft--
		return true
	}

	return false
}

// Limiter is a rate limiter that keeps track of requests per IP address.
type Limiter struct {
	// Individual limiters per IP type.
	ipv4, ip48, ip56, ip64 *limiter
}

// NewLimiter creates a new Limiter.  Per IP address, up to `requests` per
// `period` will be allowed.  Once they're exhausted, requests will be denied
// until `period` has passed from the first approved request. Size is the
// maximum number of IP addresses to keep track of; when that size is reached,
// older entries are removed.  See the package documentation for more details
// on how IPv6 addresses are handled.
func NewLimiter(requests uint64, period time.Duration, size int) *Limiter {
	return &Limiter{
		ipv4: newlimiter(requests, period, size),
		ip64: newlimiter(requests, period, size),
		ip56: newlimiter(requests, period/4, size),
		ip48: newlimiter(requests, period/8, size),
	}
}

// SetIPv6s64Rate sets the rate limit for IPv6 addresses with /64 prefixes.
// It can only be changed before any requests are made.
func (l *Limiter) SetIPv6s64Rate(req uint64, per time.Duration) {
	l.ip64.Requests = req
	l.ip64.Period = per
}

// SetIPv6s56Rate sets the rate limit for IPv6 addresses with /56 prefixes.
// It can only be changed before any requests are made.
func (l *Limiter) SetIPv6s56Rate(req uint64, per time.Duration) {
	l.ip56.Requests = req
	l.ip56.Period = per
}

// SetIPv6s48Rate sets the rate limit for IPv6 addresses with /48 prefixes.
// It can only be changed before any requests are made.
func (l *Limiter) SetIPv6s48Rate(req uint64, per time.Duration) {
	l.ip48.Requests = req
	l.ip48.Period = per
}

// Allow checks if the given IP address is allowed to make a request.
func (l *Limiter) Allow(ip net.IP) bool {
	if ip4 := ip.To4(); ip4 != nil {
		// Convert the IPv4 address to a 64-bit integer, and use that as key.
		return l.ipv4.allow(ipv4ToUint64(ip4))
	}

	// Check if the three masks for ipv6. All must be allowed for the request
	// to be allowed.
	a48, a56, a64 := l.allowV6(ip)
	return a48 && a56 && a64
}

func (l *Limiter) allowV6(ip net.IP) (a48, a56, a64 bool) {
	ip48, ip56, ip64 := ipv6ExtractMasks(ip)
	return l.ip48.allow(ip48), l.ip56.allow(ip56), l.ip64.allow(ip64)
}

// DebugString returns a string with debugging information about the limiter.
// This is useful for debugging, but not for production use. It is not
// guaranteed to be stable.
func (l *Limiter) DebugString() string {
	s := "## IPv4\n\n"
	s += l.ipv4.debugString(kToIPv4)
	s += "\n\n"
	s += "## IPv6\n\n"
	s += "### /48\n\n"
	s += l.ip48.debugString(kToIPv6)
	s += "\n\n"
	s += "### /56\n\n"
	s += l.ip56.debugString(kToIPv6)
	s += "\n\n"
	s += "### /64\n\n"
	s += l.ip64.debugString(kToIPv6)
	s += "\n"
	return s
}

func (l *limiter) debugString(kToIP func(uint64) net.IP) string {
	l.mu.Lock()
	defer l.mu.Unlock()

	s := ""
	s += fmt.Sprintf("Allow: %d / %v\n", l.Requests, l.Period)
	s += fmt.Sprintf("Size: %d / %d\n", len(l.m), l.Size)
	s += "\n"
	k := l.lruFirst
	for k != 0 {
		e := l.m[k]
		ip := kToIP(k)
		last := sinceMiniTime(time.Now(), e.lastAllowed).Round(
			time.Millisecond)
		s += fmt.Sprintf("%-22s %3d requests left, last allowed %10s ago\n",
			ip, e.requestsLeft, last)
		k = e.lruNext
	}
	return s
}

// DebugHTML returns a string with debugging information about the limiter, in
// HTML format (just content starting with `<h2>`, no meta-tags). This is
// useful for debugging, but not for production use. It is not guaranteed to
// be stable.
func (l *Limiter) DebugHTML() string {
	s := "<h2>IPv4</h2>"
	s += l.ipv4.debugHTML(kToIPv4)
	s += "<h2>IPv6</h2>"
	s += "<h3>/48</h3>"
	s += l.ip48.debugHTML(kToIPv6)
	s += "<h3>/56</h3>"
	s += l.ip56.debugHTML(kToIPv6)
	s += "<h3>/64</h3>"
	s += l.ip64.debugHTML(kToIPv6)
	return s
}

func (l *limiter) debugHTML(kToIP func(uint64) net.IP) string {
	l.mu.Lock()
	defer l.mu.Unlock()

	s := fmt.Sprintf("Allow: %d / %v<br>\n", l.Requests, l.Period)
	s += fmt.Sprintf("Size: %d / %d<br>\n", len(l.m), l.Size)
	s += "<p>\n"
	if l.lruFirst == 0 {
		s += "(empty)<br>"
		return s
	}

	s += "<table>\n"
	s += "<tr><th>IP</th><th>Requests left</th><th>Last allowed</th></tr>\n"
	k := l.lruFirst
	for k != 0 {
		e := l.m[k]
		ip := kToIP(k)
		last := sinceMiniTime(time.Now(), e.lastAllowed).Round(
			time.Millisecond)
		s += fmt.Sprintf(`<tr><td class="ip">%v</td>`, ip)
		s += fmt.Sprintf(`<td class="requests">%d</td>`, e.requestsLeft)
		s += fmt.Sprintf(`<td class="last">%s</td></tr>`, last)
		s += "\n"
		k = e.lruNext
	}
	s += "</table>\n"
	return s
}

func kToIPv4(k uint64) net.IP {
	return net.IPv4(byte(k>>24), byte(k>>16), byte(k>>8), byte(k))
}

func kToIPv6(k uint64) net.IP {
	buf := make([]byte, 16)
	b := big.NewInt(0).SetUint64(k)
	b = b.Lsh(b, 64)
	return net.IP(b.FillBytes(buf[:]))
}

// miniTime is a small representation of time, as the number of nanoseconds
// elapsed since January 1, 1970 UTC.
// This is used to reduce memory footprint and improve performance.
type miniTime int64

func makeMiniTime(t time.Time) miniTime {
	return miniTime(t.UnixNano())
}

func sinceMiniTime(now time.Time, old miniTime) time.Duration {
	// time.Duration is an int64 nanosecond count, so we can just subtract.
	return time.Duration(now.UnixNano() - int64(old))
}