Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: configurable email and sms rate limiting #1800

Merged
merged 3 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.Get("/authorize", api.ExternalProviderRedirect)

sharedLimiter := api.limitEmailOrPhoneSentHandler(api.limiterOpts)
r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) {
r.With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(api.verifyCaptcha).Route("/signup", func(r *router) {
// rate limit per hour
limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns
limitSignups := api.limiterOpts.Signups
Expand All @@ -165,24 +164,20 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
if _, err := api.limitHandler(limitSignups)(w, r); err != nil {
return err
}
// apply shared rate limiting on email / phone
if _, err := sharedLimiter(w, r); err != nil {
return err
}
return api.Signup(w, r)
})
})
r.With(api.limitHandler(api.limiterOpts.Recover)).
With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)
With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(api.limitHandler(api.limiterOpts.Resend)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)
With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(api.limiterOpts.MagicLink)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)
With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(api.limiterOpts.Otp)).
With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)
With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(api.limiterOpts.Token)).
With(api.verifyCaptcha).Post("/token", api.Token)
Expand All @@ -200,8 +195,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.With(api.requireAuthentication).Route("/user", func(r *router) {
r.Get("/", api.UserGet)
r.With(api.limitHandler(api.limiterOpts.User)).
With(sharedLimiter).Put("/", api.UserUpdate)
r.With(api.limitHandler(api.limiterOpts.User)).Put("/", api.UserUpdate)

r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
Expand Down
19 changes: 0 additions & 19 deletions internal/api/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"net/url"

"github.com/didip/tollbooth/v5/limiter"
jwt "github.com/golang-jwt/jwt/v5"
"github.com/supabase/auth/internal/models"
)
Expand Down Expand Up @@ -32,7 +31,6 @@ const (
ssoProviderKey = contextKey("sso_provider")
externalHostKey = contextKey("external_host")
flowStateKey = contextKey("flow_state_id")
sharedLimiterKey = contextKey("shared_limiter")
)

// withToken adds the JWT token to the context.
Expand Down Expand Up @@ -243,20 +241,3 @@ func getExternalHost(ctx context.Context) *url.URL {
}
return obj.(*url.URL)
}

type SharedLimiter struct {
EmailLimiter *limiter.Limiter
PhoneLimiter *limiter.Limiter
}

func withLimiter(ctx context.Context, limiter *SharedLimiter) context.Context {
return context.WithValue(ctx, sharedLimiterKey, limiter)
}

func getLimiter(ctx context.Context) *SharedLimiter {
obj := ctx.Value(sharedLimiterKey)
if obj == nil {
return nil
}
return obj.(*SharedLimiter)
}
17 changes: 7 additions & 10 deletions internal/api/mail.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"strings"
"time"

"github.com/didip/tollbooth/v5"
"github.com/supabase/auth/internal/hooks"
mail "github.com/supabase/auth/internal/mailer"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -578,15 +577,13 @@ func (a *API) sendEmail(r *http.Request, tx *storage.Connection, u *models.User,
externalURL := getExternalHost(ctx)

// apply rate limiting before the email is sent out
if limiter := getLimiter(ctx); limiter != nil {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil {
emailRateLimitCounter.Add(
ctx,
1,
metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))),
)
return EmailRateLimitExceeded
}
if ok := a.limiterOpts.Email.Allow(); !ok {
emailRateLimitCounter.Add(
ctx,
1,
metric.WithAttributeSet(attribute.NewSet(attribute.String("path", r.URL.Path))),
)
return EmailRateLimitExceeded
}

if config.Hook.SendEmail.Enabled {
Expand Down
21 changes: 0 additions & 21 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,27 +77,6 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler {
}
}

func (a *API) limitEmailOrPhoneSentHandler(limiterOptions *LimiterOptions) middlewareHandler {
return func(w http.ResponseWriter, req *http.Request) (context.Context, error) {
c := req.Context()
config := a.config
shouldRateLimitEmail := config.External.Email.Enabled && !config.Mailer.Autoconfirm
shouldRateLimitPhone := config.External.Phone.Enabled && !config.Sms.Autoconfirm

if shouldRateLimitEmail || shouldRateLimitPhone {
if req.Method == "PUT" || req.Method == "POST" {
// store rate limiter in request context
c = withLimiter(c, &SharedLimiter{
EmailLimiter: limiterOptions.Email,
PhoneLimiter: limiterOptions.Phone,
})
}
}

return c, nil
}
}

func (a *API) requireAdminCredentials(w http.ResponseWriter, req *http.Request) (context.Context, error) {
t, err := a.extractBearerToken(req)
if err != nil || t == "" {
Expand Down
174 changes: 0 additions & 174 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,52 +185,6 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() {
}
}

func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() {
// Set up rate limit config for this test
ts.Config.RateLimitEmailSent = 5
ts.Config.RateLimitSmsSent = 5
ts.Config.External.Phone.Enabled = true

cases := []struct {
desc string
expectedErrorMsg string
requestBody map[string]interface{}
}{
{
desc: "Email rate limit exceeded",
expectedErrorMsg: "429: Email rate limit exceeded",
requestBody: map[string]interface{}{
"email": "[email protected]",
},
},
{
desc: "SMS rate limit exceeded",
expectedErrorMsg: "429: SMS rate limit exceeded",
requestBody: map[string]interface{}{
"phone": "+1233456789",
},
},
}

limiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.requestBody))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

ctx, err := limiter(w, req)
require.NoError(ts.T(), err)

// check that shared limiter is set in the request context
sharedLimiter := getLimiter(ctx)
require.NotNil(ts.T(), sharedLimiter)
})
}
}

func (ts *MiddlewareTestSuite) TestIsValidExternalHost() {
cases := []struct {
desc string
Expand Down Expand Up @@ -388,134 +342,6 @@ func (ts *MiddlewareTestSuite) TestLimitHandler() {
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
}

func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
// setup config for shared limiter and ip-based limiter to work
ts.Config.RateLimitHeader = "X-Rate-Limit"
ts.Config.External.Email.Enabled = true
ts.Config.External.Phone.Enabled = true
ts.Config.Mailer.Autoconfirm = false
ts.Config.Sms.Autoconfirm = false

ipBasedLimiter := func(max float64) *limiter.Limiter {
return tollbooth.NewLimiter(max, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
})
}

okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limiter := getLimiter(r.Context())
if limiter != nil {
var requestBody struct {
Email string `json:"email"`
Phone string `json:"phone"`
}
err := retrieveRequestParams(r, &requestBody)
require.NoError(ts.T(), err)

if requestBody.Email != "" {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"email_functions"}); err != nil {
sendJSON(w, http.StatusTooManyRequests, HTTPError{
HTTPStatus: http.StatusTooManyRequests,
ErrorCode: ErrorCodeOverEmailSendRateLimit,
Message: "Email rate limit exceeded",
})
}
}
if requestBody.Phone != "" {
if err := tollbooth.LimitByKeys(limiter.EmailLimiter, []string{"phone_functions"}); err != nil {
sendJSON(w, http.StatusTooManyRequests, HTTPError{
HTTPStatus: http.StatusTooManyRequests,
ErrorCode: ErrorCodeOverSMSSendRateLimit,
Message: "SMS rate limit exceeded",
})
}
}
}
w.WriteHeader(http.StatusOK)
})

cases := []struct {
desc string
sharedLimiterConfig *conf.GlobalConfiguration
ipBasedLimiterConfig float64
body map[string]interface{}
expectedErrorCode string
}{
{
desc: "Exceed ip-based rate limit before shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 10,
RateLimitSmsSent: 10,
},
ipBasedLimiterConfig: 1,
body: map[string]interface{}{
"email": "[email protected]",
},
expectedErrorCode: ErrorCodeOverRequestRateLimit,
},
{
desc: "Exceed email shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 1,
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
body: map[string]interface{}{
"email": "[email protected]",
},
expectedErrorCode: ErrorCodeOverEmailSendRateLimit,
},
{
desc: "Exceed sms shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 1,
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
body: map[string]interface{}{
"phone": "123456789",
},
expectedErrorCode: ErrorCodeOverSMSSendRateLimit,
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent
ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent
lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig))
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config))

// get the minimum amount to reach the threshold just before the rate limit is exceeded
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
for i := 0; i < int(threshold); i++ {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")

w := httptest.NewRecorder()
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
}

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")

// check if the rate limit is exceeded with the expected error code
w := httptest.NewRecorder()
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)

var data map[string]interface{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), c.expectedErrorCode, data["error_code"])
})
}
}

func (ts *MiddlewareTestSuite) TestIsValidAuthorizedEmail() {
ts.API.config.External.Email.AuthorizedAddresses = []string{"[email protected]"}

Expand Down
17 changes: 5 additions & 12 deletions internal/api/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ type Option interface {
}

type LimiterOptions struct {
Email *limiter.Limiter
Phone *limiter.Limiter
Email *RateLimiter
Phone *RateLimiter

Signups *limiter.Limiter
AnonymousSignIns *limiter.Limiter
Recover *limiter.Limiter
Expand All @@ -35,16 +36,8 @@ func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo }
func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions {
o := &LimiterOptions{}

o.Email = tollbooth.NewLimiter(gc.RateLimitEmailSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"})

o.Phone = tollbooth.NewLimiter(gc.RateLimitSmsSent/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(gc.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"})

o.Email = newRateLimiter(gc.RateLimitEmailSent)
o.Phone = newRateLimiter(gc.RateLimitSmsSent)
o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60),
&limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
Expand Down
9 changes: 2 additions & 7 deletions internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"text/template"
"time"

"github.com/didip/tollbooth/v5"
"github.com/supabase/auth/internal/hooks"

"github.com/pkg/errors"
Expand Down Expand Up @@ -45,7 +44,6 @@ func formatPhoneNumber(phone string) string {

// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) {
ctx := r.Context()
config := a.config

var token *string
Expand Down Expand Up @@ -89,11 +87,8 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
// not using test OTPs
if otp == "" {
// apply rate limiting before the sms is sent out
limiter := getLimiter(ctx)
if limiter != nil {
if err := tollbooth.LimitByKeys(limiter.PhoneLimiter, []string{"phone_functions"}); err != nil {
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}
if ok := a.limiterOpts.Phone.Allow(); !ok {
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded")
}
otp, err = crypto.GenerateOtp(config.Sms.OtpLength)
if err != nil {
Expand Down
Loading
Loading