diff --git a/internal/api/external.go b/internal/api/external.go index af523f999..33503f9c9 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -151,6 +151,8 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re var grantParams models.GrantParams var err error + grantParams.ProviderID = providerType + if providerType == "twitter" { // future OAuth1.0 providers will use this method oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType) diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index f44333901..a4f48b6b5 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -261,6 +261,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { notAfter := assertion.NotAfter() var grantParams models.GrantParams + grantParams.ProviderID = "sso:" + ssoProvider.ID.String() if !notAfter.IsZero() { grantParams.SessionNotAfter = ¬After @@ -278,7 +279,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { var user *models.User // accounts potentially created via SAML can contain non-unique email addresses in the auth.users table - if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, "sso:"+ssoProvider.ID.String()); terr != nil { + if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, grantParams.ProviderID); terr != nil { return terr } if flowState != nil { diff --git a/internal/api/signup.go b/internal/api/signup.go index 9642fad1c..d2a0eae3b 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -133,6 +133,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return invalidSignupError(config) } + grantParams.ProviderID = params.Provider + if err != nil && !models.IsNotFoundError(err) { return internalServerError("Database error finding user").WithInternalError(err) } diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 9a8c57ccb..7354e6bd3 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -190,6 +190,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R var token *AccessTokenResponse var grantParams models.GrantParams + grantParams.ProviderID = providerType if err := db.Transaction(func(tx *storage.Connection) error { var user *models.User diff --git a/internal/models/refresh_token.go b/internal/models/refresh_token.go index 3a15b0b0e..42867d0b9 100644 --- a/internal/models/refresh_token.go +++ b/internal/models/refresh_token.go @@ -38,7 +38,8 @@ func (RefreshToken) TableName() string { // GrantParams is used to pass session-specific parameters when issuing a new // refresh token to authenticated users. type GrantParams struct { - FactorID *uuid.UUID + FactorID *uuid.UUID + ProviderID string SessionNotAfter *time.Time } @@ -125,6 +126,10 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok session.FactorID = params.FactorID } + if params.ProviderID != "" { + session.ProviderID = ¶ms.ProviderID + } + if params.SessionNotAfter != nil { session.NotAfter = params.SessionNotAfter } diff --git a/internal/models/sessions.go b/internal/models/sessions.go index d6ea13b92..fd70fe449 100644 --- a/internal/models/sessions.go +++ b/internal/models/sessions.go @@ -58,14 +58,15 @@ func (s sortAMREntries) Swap(i, j int) { } type Session struct { - ID uuid.UUID `json:"-" db:"id"` - UserID uuid.UUID `json:"user_id" db:"user_id"` - NotAfter *time.Time `json:"not_after,omitempty" db:"not_after"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` - FactorID *uuid.UUID `json:"factor_id" db:"factor_id"` - AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"` - AAL *string `json:"aal" db:"aal"` + ID uuid.UUID `json:"-" db:"id"` + UserID uuid.UUID `json:"user_id" db:"user_id"` + NotAfter *time.Time `json:"not_after,omitempty" db:"not_after"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + FactorID *uuid.UUID `json:"factor_id" db:"factor_id"` + ProviderID *string `json:"provider_id" db:"provider_id"` + AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"` + AAL *string `json:"aal" db:"aal"` } func (Session) TableName() string { diff --git a/migrations/20230901164423_add_provider_id_to_sessions.up.sql b/migrations/20230901164423_add_provider_id_to_sessions.up.sql new file mode 100644 index 000000000..77ef32e33 --- /dev/null +++ b/migrations/20230901164423_add_provider_id_to_sessions.up.sql @@ -0,0 +1 @@ +alter table {{ index .Options "Namespace" }}.sessions add column if not exists provider_id text default null;