Skip to content

Commit 8a17b6c

Browse files
committed
feat: Treat rate limit header value as comma-separated list
This commit updates performRateLimiting to treat the rate limit header value as a comma-separated list and enforce rate limiting based on the first value in that list. Certain HTTP headers, such as X-Forwarded-For and other headers that are combined according to RFC 7230, can be represented as a comma-separated list of values. Intermediate proxies may add their own values to these headers, modifying the resulting value. For example, an end user with a single IP address proxied through a fleet of load balancers using the X-Forwarded-For header may be associated with multiple X-Forwarded-For header values, e.g., "2.2.2.2,100.100.100.100" and "2.2.2.2,300.300.300.300". The current implementation of performRateLimiting treats each of these as separate rate limiting keys. To address this issue, this commit splits the rate limit header by commas and takes the first value (with whitespace removed) to use as the rate limiting key. Note that this logic is superficially similar to the utilities.GetIPAddress function with two key differences. In performRateLimiting, there is no set format for a given rate limiting key, nor is there a fallback value after the first value in the list that the API should use.
1 parent d8d59c9 commit 8a17b6c

File tree

2 files changed

+137
-2
lines changed

2 files changed

+137
-2
lines changed

internal/api/middleware.go

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,43 @@ var emailRateLimitCounter = observability.ObtainMetricCounter("gotrue_email_rate
6464
func (a *API) performRateLimiting(lmt *limiter.Limiter, req *http.Request) error {
6565
limitHeader := a.config.RateLimitHeader
6666

67+
// If no rate limit header was set, ignore rate limiting
6768
if limitHeader == "" {
6869
return nil
6970
}
7071

71-
key := req.Header.Get(limitHeader)
72+
valuesStr := req.Header.Get(limitHeader)
7273

73-
if key == "" {
74+
// If a rate limit header was set, but has no value, ignore rate limiting but warn with an error
75+
if valuesStr == "" {
7476
log := observability.GetLogEntry(req).Entry
7577
log.WithField("header", limitHeader).Warn("request does not have a value for the rate limiting header, rate limiting is not applied")
7678

7779
return nil
7880
}
7981

82+
// According to RFC 7230 section 3.2.2, multiple headers with the same name are equivalent
83+
// to a single header with that name where each value is separated by a comma and whitespace.
84+
//
85+
// Note that there is some ambiguity in RFC 7230 where section 3.2.4 states that
86+
// header field values (which can contain commas) are processed independently of the header
87+
// field name, and thus it is not always clear if a comma is a list delimiter or simply par
88+
// of a single value.
89+
//
90+
// Given that this function is primarily for use with headers like X-Forwarded-For which
91+
// vendors generally combine into comma-separated lists, we opt for the simpler approach
92+
// here and split the header value by commas before taking the first value.
93+
values := strings.SplitN(valuesStr, ",", 2)
94+
95+
// We will always get at least one value back, so this operation is safe
96+
key := strings.TrimSpace(values[0])
97+
98+
// If the rate limit header has at least one value, but the first value is all whitespace, return an error
99+
if key == "" {
100+
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "Invalid rate limit header value")
101+
}
102+
103+
// Otherwise, apply rate limiting based on the first rate limit header value
80104
if err := tollbooth.LimitByKeys(lmt, []string{key}); err != nil {
81105
return apierrors.NewTooManyRequestsError(apierrors.ErrorCodeOverRequestRateLimit, "Request rate limit reached")
82106
}

internal/api/middleware_test.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,117 @@ func TestTimeoutResponseWriter(t *testing.T) {
415415
require.Equal(t, w1.Result(), w2.Result())
416416
}
417417

418+
func (ts *MiddlewareTestSuite) TestPerformRateLimiting() {
419+
ts.Config.RateLimitHeader = "X-Test-Perform-Rate-Limiting"
420+
421+
tests := []struct {
422+
name string
423+
headerValues []string
424+
expError error
425+
}{
426+
{
427+
name: "no value",
428+
headerValues: []string{
429+
"",
430+
"",
431+
},
432+
expError: nil,
433+
},
434+
{
435+
name: "single end user value",
436+
headerValues: []string{
437+
"192.168.1.100",
438+
"192.168.1.100",
439+
},
440+
expError: apierrors.NewTooManyRequestsError(
441+
apierrors.ErrorCodeOverRequestRateLimit,
442+
"Request rate limit reached",
443+
),
444+
},
445+
{
446+
name: "same end user value, multiple proxies",
447+
headerValues: []string{
448+
"2600:cafe:beef::1,192.168.1.100",
449+
"2600:cafe:beef::1,192.168.1.200",
450+
},
451+
expError: apierrors.NewTooManyRequestsError(
452+
apierrors.ErrorCodeOverRequestRateLimit,
453+
"Request rate limit reached",
454+
),
455+
},
456+
{
457+
name: "multiple end user values, single proxy",
458+
headerValues: []string{
459+
"2600:cafe:beef::1,192.168.1.100",
460+
"3700:dead:abcd::2,192.168.1.100",
461+
},
462+
expError: nil,
463+
},
464+
{
465+
name: "same end user value, multiple proxies, with whitespace",
466+
headerValues: []string{
467+
"2600:cafe:beef::1 ,192.168.1.100",
468+
"2600:cafe:beef::1 , 192.168.1.200",
469+
},
470+
expError: apierrors.NewTooManyRequestsError(
471+
apierrors.ErrorCodeOverRequestRateLimit,
472+
"Request rate limit reached",
473+
),
474+
},
475+
{
476+
name: "malformed header, all whitespace",
477+
headerValues: []string{
478+
" ",
479+
},
480+
expError: apierrors.NewBadRequestError(
481+
apierrors.ErrorCodeOverRequestRateLimit,
482+
"Invalid rate limit header value",
483+
),
484+
},
485+
{
486+
name: "malformed header, no whitespace",
487+
headerValues: []string{
488+
",192.168.1.100",
489+
},
490+
expError: apierrors.NewBadRequestError(
491+
apierrors.ErrorCodeOverRequestRateLimit,
492+
"Invalid rate limit header value",
493+
),
494+
},
495+
{
496+
name: "malformed header, with whitespace",
497+
headerValues: []string{
498+
" ,192.168.1.100",
499+
},
500+
expError: apierrors.NewBadRequestError(
501+
apierrors.ErrorCodeOverRequestRateLimit,
502+
"Invalid rate limit header value",
503+
),
504+
},
505+
}
506+
507+
for _, tt := range tests {
508+
// Trigger a rate limiting error if we see the same end-user key twice in the same
509+
// test case
510+
lmt := tollbooth.NewLimiter(
511+
1,
512+
&limiter.ExpirableOptions{
513+
DefaultExpirationTTL: time.Hour,
514+
},
515+
)
516+
517+
var obsError error
518+
519+
for _, h := range tt.headerValues {
520+
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
521+
req.Header.Add(ts.Config.RateLimitHeader, h)
522+
obsError = ts.API.performRateLimiting(lmt, req)
523+
}
524+
525+
require.ErrorIs(ts.T(), obsError, tt.expError, "error for test '%s'", tt.name)
526+
}
527+
}
528+
418529
func (ts *MiddlewareTestSuite) TestLimitHandler() {
419530
ts.Config.RateLimitHeader = "X-Rate-Limit"
420531
lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{

0 commit comments

Comments
 (0)