Skip to content

Commit

Permalink
feat: add provider ID to session
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Sep 1, 2023
1 parent f7308ad commit e8be4bd
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 10 deletions.
2 changes: 2 additions & 0 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re
var grantParams models.GrantParams
var err error

grantParams.ProviderID = providerType

if providerType == "twitter" {
// future OAuth1.0 providers will use this method
oAuthResponseData, err := a.oAuth1Callback(ctx, r, providerType)
Expand Down
3 changes: 2 additions & 1 deletion internal/api/samlacs.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {
notAfter := assertion.NotAfter()

var grantParams models.GrantParams
grantParams.ProviderID = "sso:" + ssoProvider.ID.String()

if !notAfter.IsZero() {
grantParams.SessionNotAfter = &notAfter
Expand All @@ -278,7 +279,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error {
var user *models.User

// accounts potentially created via SAML can contain non-unique email addresses in the auth.users table
if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, "sso:"+ssoProvider.ID.String()); terr != nil {
if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, grantParams.ProviderID); terr != nil {
return terr
}
if flowState != nil {
Expand Down
2 changes: 2 additions & 0 deletions internal/api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
return invalidSignupError(config)
}

grantParams.ProviderID = params.Provider

if err != nil && !models.IsNotFoundError(err) {
return internalServerError("Database error finding user").WithInternalError(err)
}
Expand Down
1 change: 1 addition & 0 deletions internal/api/token_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R

var token *AccessTokenResponse
var grantParams models.GrantParams
grantParams.ProviderID = providerType

if err := db.Transaction(func(tx *storage.Connection) error {
var user *models.User
Expand Down
7 changes: 6 additions & 1 deletion internal/models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ func (RefreshToken) TableName() string {
// GrantParams is used to pass session-specific parameters when issuing a new
// refresh token to authenticated users.
type GrantParams struct {
FactorID *uuid.UUID
FactorID *uuid.UUID
ProviderID string

SessionNotAfter *time.Time
}
Expand Down Expand Up @@ -125,6 +126,10 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
session.FactorID = params.FactorID
}

if params.ProviderID != "" {
session.ProviderID = &params.ProviderID
}

if params.SessionNotAfter != nil {
session.NotAfter = params.SessionNotAfter
}
Expand Down
17 changes: 9 additions & 8 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,15 @@ func (s sortAMREntries) Swap(i, j int) {
}

type Session struct {
ID uuid.UUID `json:"-" db:"id"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
NotAfter *time.Time `json:"not_after,omitempty" db:"not_after"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
FactorID *uuid.UUID `json:"factor_id" db:"factor_id"`
AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"`
AAL *string `json:"aal" db:"aal"`
ID uuid.UUID `json:"-" db:"id"`
UserID uuid.UUID `json:"user_id" db:"user_id"`
NotAfter *time.Time `json:"not_after,omitempty" db:"not_after"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
FactorID *uuid.UUID `json:"factor_id" db:"factor_id"`
ProviderID *string `json:"provider_id" db:"provider_id"`
AMRClaims []AMRClaim `json:"amr,omitempty" has_many:"amr_claims"`
AAL *string `json:"aal" db:"aal"`
}

func (Session) TableName() string {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
alter table {{ index .Options "Namespace" }}.sessions add column if not exists provider_id text default null;

0 comments on commit e8be4bd

Please sign in to comment.