From 92eab02551e1273dcca5ce26bf30dc1b3e670bec Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Fri, 1 Nov 2024 12:06:24 +0100 Subject: [PATCH] fix: don't return on logout, make it idempotent --- internal/api/auth.go | 10 +++++++--- internal/api/auth_test.go | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/api/auth.go b/internal/api/auth.go index b03767f02..6062238c2 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -6,6 +6,7 @@ import ( "net/http" "strings" + "github.com/go-chi/chi/v5" "github.com/gofrs/uuid" jwt "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/conf" @@ -25,7 +26,10 @@ func (a *API) requireAuthentication(w http.ResponseWriter, r *http.Request) (con return ctx, err } - ctx, err = a.maybeLoadUserOrSession(ctx) + routeContext := chi.RouteContext(ctx) + skipSessionMissingError := routeContext != nil && routeContext.RouteMethod == http.MethodPost && routeContext.RoutePath == "/logout" + + ctx, err = a.maybeLoadUserOrSession(ctx, skipSessionMissingError) if err != nil { return ctx, err } @@ -94,7 +98,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e return withToken(ctx, token), nil } -func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, error) { +func (a *API) maybeLoadUserOrSession(ctx context.Context, skipSessionMissingError bool) (context.Context, error) { db := a.db.WithContext(ctx) claims := getClaims(ctx) @@ -130,7 +134,7 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro } session, err = models.FindSessionByID(db, sessionId, false) if err != nil { - if models.IsNotFoundError(err) { + if models.IsNotFoundError(err) && !skipSessionMissingError { return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist").WithInternalError(err).WithInternalMessage(fmt.Sprintf("session id (%s) doesn't exist", sessionId)) } return ctx, err diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 71afe6638..9e507bdfe 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -271,7 +271,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { ctx, err := ts.API.parseJWTClaims(userJwt, req) require.NoError(ts.T(), err) - ctx, err = ts.API.maybeLoadUserOrSession(ctx) + ctx, err = ts.API.maybeLoadUserOrSession(ctx, true) if c.ExpectedError != nil { require.Equal(ts.T(), c.ExpectedError.Error(), err.Error()) } else {