diff --git a/cmd/serve_cmd.go b/cmd/serve_cmd.go index 5a0745a2e..06fa2f532 100644 --- a/cmd/serve_cmd.go +++ b/cmd/serve_cmd.go @@ -48,7 +48,10 @@ func serve(ctx context.Context) { addr := net.JoinHostPort(config.API.Host, config.API.Port) logrus.Infof("GoTrue API started on: %s", addr) - a := api.NewAPIWithVersion(config, db, utilities.Version) + opts := []api.Option{ + api.NewLimiterOptions(config), + } + a := api.NewAPIWithVersion(config, db, utilities.Version, opts...) ah := reloader.NewAtomicHandler(a) baseCtx, baseCancel := context.WithCancel(context.Background()) @@ -74,7 +77,8 @@ func serve(ctx context.Context) { fn := func(latestCfg *conf.GlobalConfiguration) { log.Info("reloading api with new configuration") - latestAPI := api.NewAPIWithVersion(latestCfg, db, utilities.Version) + latestAPI := api.NewAPIWithVersion( + latestCfg, db, utilities.Version, opts...) ah.Store(latestAPI) } diff --git a/internal/api/api.go b/internal/api/api.go index 287ae2995..dba93ae15 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -5,8 +5,6 @@ import ( "regexp" "time" - "github.com/didip/tollbooth/v5" - "github.com/didip/tollbooth/v5/limiter" "github.com/rs/cors" "github.com/sebest/xff" "github.com/sirupsen/logrus" @@ -37,6 +35,8 @@ type API struct { // overrideTime can be used to override the clock used by handlers. Should only be used in tests! overrideTime func() time.Time + + limiterOpts *LimiterOptions } func (a *API) Now() time.Time { @@ -48,8 +48,8 @@ func (a *API) Now() time.Time { } // NewAPI instantiates a new REST API -func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection) *API { - return NewAPIWithVersion(globalConfig, db, defaultVersion) +func NewAPI(globalConfig *conf.GlobalConfiguration, db *storage.Connection, opt ...Option) *API { + return NewAPIWithVersion(globalConfig, db, defaultVersion, opt...) } func (a *API) deprecationNotices() { @@ -67,9 +67,15 @@ func (a *API) deprecationNotices() { } // NewAPIWithVersion creates a new REST API using the specified version -func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string) *API { +func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Connection, version string, opt ...Option) *API { api := &API{config: globalConfig, db: db, version: version} + for _, o := range opt { + o.apply(api) + } + if api.limiterOpts == nil { + api.limiterOpts = NewLimiterOptions(globalConfig) + } if api.config.Password.HIBP.Enabled { httpClient := &http.Client{ // all HIBP API requests should finish quickly to avoid @@ -134,18 +140,12 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Get("/authorize", api.ExternalProviderRedirect) - sharedLimiter := api.limitEmailOrPhoneSentHandler() + sharedLimiter := api.limitEmailOrPhoneSentHandler(api.limiterOpts) r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite) r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) { // rate limit per hour - limitAnonymousSignIns := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"}) - - limitSignups := tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30) - + limitAnonymousSignIns := api.limiterOpts.AnonymousSignIns + limitSignups := api.limiterOpts.Signups r.Post("/", func(w http.ResponseWriter, r *http.Request) error { params := &SignupParams{} if err := retrieveRequestParams(r, params); err != nil { @@ -172,47 +172,22 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne return api.Signup(w, r) }) }) - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes - tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) - - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes - tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend) - - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes - tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) - - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes - tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp) - - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes. - tollbooth.NewLimiter(api.config.RateLimitTokenRefresh/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).With(api.verifyCaptcha).Post("/token", api.Token) - - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes. - tollbooth.NewLimiter(api.config.RateLimitVerify/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).Route("/verify", func(r *router) { + r.With(api.limitHandler(api.limiterOpts.Recover)). + With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover) + + r.With(api.limitHandler(api.limiterOpts.Resend)). + With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend) + + r.With(api.limitHandler(api.limiterOpts.MagicLink)). + With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink) + + r.With(api.limitHandler(api.limiterOpts.Otp)). + With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp) + + r.With(api.limitHandler(api.limiterOpts.Token)). + With(api.verifyCaptcha).Post("/token", api.Token) + + r.With(api.limitHandler(api.limiterOpts.Verify)).Route("/verify", func(r *router) { r.Get("/", api.Verify) r.Post("/", api.Verify) }) @@ -225,12 +200,8 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.With(api.requireAuthentication).Route("/user", func(r *router) { r.Get("/", api.UserGet) - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes - tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).With(sharedLimiter).Put("/", api.UserUpdate) + r.With(api.limitHandler(api.limiterOpts.User)). + With(sharedLimiter).Put("/", api.UserUpdate) r.Route("/identities", func(r *router) { r.Use(api.requireManualLinkingEnabled) @@ -245,14 +216,10 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Route("/{factor_id}", func(r *router) { r.Use(api.loadFactor) - r.With(api.limitHandler( - tollbooth.NewLimiter(api.config.MFA.RateLimitChallengeAndVerify/60, &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Minute, - }).SetBurst(30))).Post("/verify", api.VerifyFactor) - r.With(api.limitHandler( - tollbooth.NewLimiter(api.config.MFA.RateLimitChallengeAndVerify/60, &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Minute, - }).SetBurst(30))).Post("/challenge", api.ChallengeFactor) + r.With(api.limitHandler(api.limiterOpts.FactorVerify)). + Post("/verify", api.VerifyFactor) + r.With(api.limitHandler(api.limiterOpts.FactorChallenge)). + Post("/challenge", api.ChallengeFactor) r.Delete("/", api.UnenrollFactor) }) @@ -260,22 +227,14 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.Route("/sso", func(r *router) { r.Use(api.requireSAMLEnabled) - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes. - tollbooth.NewLimiter(api.config.RateLimitSso/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).With(api.verifyCaptcha).Post("/", api.SingleSignOn) + r.With(api.limitHandler(api.limiterOpts.SSO)). + With(api.verifyCaptcha).Post("/", api.SingleSignOn) r.Route("/saml", func(r *router) { r.Get("/metadata", api.SAMLMetadata) - r.With(api.limitHandler( - // Allow requests at the specified rate per 5 minutes. - tollbooth.NewLimiter(api.config.SAML.RateLimitAssertion/(60*5), &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(30), - )).Post("/acs", api.SamlAcs) + r.With(api.limitHandler(api.limiterOpts.SAMLAssertion)). + Post("/acs", api.SamlAcs) }) }) diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 87639a09c..a472be737 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -45,7 +45,8 @@ func setupAPIForTestWithCallback(cb func(*conf.GlobalConfiguration, *storage.Con cb(nil, conn) } - return NewAPIWithVersion(config, conn, apiTestVersion), config, nil + limiterOpts := NewLimiterOptions(config) + return NewAPIWithVersion(config, conn, apiTestVersion, limiterOpts), config, nil } func TestEmailEnabledByDefault(t *testing.T) { diff --git a/internal/api/middleware.go b/internal/api/middleware.go index e2598b180..88d95e20c 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -77,19 +77,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { } } -func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { - // limit per hour - emailFreq := a.config.RateLimitEmailSent / (60 * 60) - smsFreq := a.config.RateLimitSmsSent / (60 * 60) - - emailLimiter := tollbooth.NewLimiter(emailFreq, &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(int(a.config.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"}) - - phoneLimiter := tollbooth.NewLimiter(smsFreq, &limiter.ExpirableOptions{ - DefaultExpirationTTL: time.Hour, - }).SetBurst(int(a.config.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"}) - +func (a *API) limitEmailOrPhoneSentHandler(limiterOptions *LimiterOptions) middlewareHandler { return func(w http.ResponseWriter, req *http.Request) (context.Context, error) { c := req.Context() config := a.config @@ -100,8 +88,8 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if req.Method == "PUT" || req.Method == "POST" { // store rate limiter in request context c = withLimiter(c, &SharedLimiter{ - EmailLimiter: emailLimiter, - PhoneLimiter: phoneLimiter, + EmailLimiter: limiterOptions.Email, + PhoneLimiter: limiterOptions.Phone, }) } } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 77065e5b3..365abbbdb 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -212,7 +212,7 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { }, } - limiter := ts.API.limitEmailOrPhoneSentHandler() + limiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config)) for _, c := range cases { ts.Run(c.desc, func() { var buffer bytes.Buffer @@ -484,7 +484,7 @@ func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() { ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig)) - sharedLimiter := ts.API.limitEmailOrPhoneSentHandler() + sharedLimiter := ts.API.limitEmailOrPhoneSentHandler(NewLimiterOptions(ts.Config)) // get the minimum amount to reach the threshold just before the rate limit is exceeded threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig) diff --git a/internal/api/options.go b/internal/api/options.go new file mode 100644 index 000000000..345e99d81 --- /dev/null +++ b/internal/api/options.go @@ -0,0 +1,107 @@ +package api + +import ( + "time" + + "github.com/didip/tollbooth/v5" + "github.com/didip/tollbooth/v5/limiter" + "github.com/supabase/auth/internal/conf" +) + +type Option interface { + apply(*API) +} + +type LimiterOptions struct { + Email *limiter.Limiter + Phone *limiter.Limiter + Signups *limiter.Limiter + AnonymousSignIns *limiter.Limiter + Recover *limiter.Limiter + Resend *limiter.Limiter + MagicLink *limiter.Limiter + Otp *limiter.Limiter + Token *limiter.Limiter + Verify *limiter.Limiter + User *limiter.Limiter + FactorVerify *limiter.Limiter + FactorChallenge *limiter.Limiter + SSO *limiter.Limiter + SAMLAssertion *limiter.Limiter +} + +func (lo *LimiterOptions) apply(a *API) { a.limiterOpts = lo } + +func NewLimiterOptions(gc *conf.GlobalConfiguration) *LimiterOptions { + o := &LimiterOptions{} + + o.Email = tollbooth.NewLimiter(gc.RateLimitEmailSent/(60*60), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(int(gc.RateLimitEmailSent)).SetMethods([]string{"PUT", "POST"}) + + o.Phone = tollbooth.NewLimiter(gc.RateLimitSmsSent/(60*60), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(int(gc.RateLimitSmsSent)).SetMethods([]string{"PUT", "POST"}) + + o.AnonymousSignIns = tollbooth.NewLimiter(gc.RateLimitAnonymousUsers/(60*60), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(int(gc.RateLimitAnonymousUsers)).SetMethods([]string{"POST"}) + + o.Token = tollbooth.NewLimiter(gc.RateLimitTokenRefresh/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.Verify = tollbooth.NewLimiter(gc.RateLimitVerify/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.User = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.FactorVerify = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60, + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Minute, + }).SetBurst(30) + + o.FactorChallenge = tollbooth.NewLimiter(gc.MFA.RateLimitChallengeAndVerify/60, + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Minute, + }).SetBurst(30) + + o.SSO = tollbooth.NewLimiter(gc.RateLimitSso/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.SAMLAssertion = tollbooth.NewLimiter(gc.SAML.RateLimitAssertion/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + o.Signups = tollbooth.NewLimiter(gc.RateLimitOtp/(60*5), + &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + + // These all use the OTP limit per 5 min with 1hour ttl and burst of 30. + o.Recover = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.Resend = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.MagicLink = newLimiterPer5mOver1h(gc.RateLimitOtp) + o.Otp = newLimiterPer5mOver1h(gc.RateLimitOtp) + return o +} + +func newLimiterPer5mOver1h(rate float64) *limiter.Limiter { + freq := rate / (60 * 5) + lim := tollbooth.NewLimiter(freq, &limiter.ExpirableOptions{ + DefaultExpirationTTL: time.Hour, + }).SetBurst(30) + return lim +}