Skip to content

Commit

Permalink
JWT Authentication Rework (#5007)
Browse files Browse the repository at this point in the history
* Implement refresh token reuse detection using a new Session store
* Implement JWT refresh token rotation

Instead of never expiring, 2 changes are made to refresh tokens:

1. A new refresh token is issued with each access token renewal

  This can serve as the basis to implement token "revocation" strategies
  for old refresh tokens, mitigating the impact of a stolen refresh
  token.

2. Refresh tokens have an expiration timestamp

  Paired with 1. above, this means that there is now a concept of
  "inactivity" baked into refresh tokens: if a user doesn't renew their
  access token before the refresh token expires, they will be forced to
  re-authenticate. This defaults to 12 hours of inactivity.

Signed-off-by: Cyril Cressent <[email protected]>
  • Loading branch information
ccressent authored May 17, 2023
1 parent a4f0a0a commit cedb7a3
Show file tree
Hide file tree
Showing 15 changed files with 311 additions and 51 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG-6.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Versioning](http://semver.org/spec/v2.0.0.html).
- Adding a flag at agent level to avoid collecting system.networks property in the agent entity state
- Added silences sorting by expiration to GraphQL service
- Added log-millisecond-timestamps backend configuration flag
- Added a session store, used to detect and prevent refresh token reuse

### Changed
- Users are now automatically logged out after a period of inactivity (12h)

## [6.9.2] - 2023-03-08

### Added
Expand Down
118 changes: 96 additions & 22 deletions backend/api/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,25 @@ import (
"fmt"

corev2 "github.com/sensu/core/v2"

"github.com/sensu/sensu-go/backend/authentication"
"github.com/sensu/sensu-go/backend/authentication/jwt"
"github.com/sensu/sensu-go/backend/authentication/providers/basic"
"github.com/sensu/sensu-go/backend/store"
)

// AuthenticationClient is an API client for authentication.
type AuthenticationClient struct {
auth *authentication.Authenticator
auth *authentication.Authenticator
sessionStore store.SessionStore
}

// NewAuthenticationClient creates a new AuthenticationClient, given a a store
// and an authenticator.
func NewAuthenticationClient(auth *authentication.Authenticator) *AuthenticationClient {
// NewAuthenticationClient creates a new AuthenticationClient, given an
// authenticator and a session store.
func NewAuthenticationClient(auth *authentication.Authenticator, sessionStore store.SessionStore) *AuthenticationClient {
return &AuthenticationClient{
auth: auth,
auth: auth,
sessionStore: sessionStore,
}
}

Expand All @@ -32,6 +36,13 @@ func (a *AuthenticationClient) CreateAccessToken(ctx context.Context, username,
return nil, corev2.ErrUnauthorized
}

// Initialize a new session for this user
sessionID, err := jwt.InitSession(claims.Subject)
if err != nil {
return nil, err
}
claims.SessionID = sessionID

// Add the 'system:users' group to this user
claims.Groups = append(claims.Groups, "system:users")

Expand All @@ -47,10 +58,23 @@ func (a *AuthenticationClient) CreateAccessToken(ctx context.Context, username,
}

// Create a refresh token and its signed version
refreshClaims := &corev2.Claims{StandardClaims: corev2.StandardClaims(claims.Subject)}
_, refreshTokenString, err := jwt.RefreshToken(refreshClaims)
refreshClaims := &corev2.Claims{
StandardClaims: corev2.StandardClaims(claims.Subject),
SessionID: sessionID,
}
refreshToken, refreshTokenString, err := jwt.RefreshToken(refreshClaims)
if err != nil {
return nil, fmt.Errorf("error creating access token: %s", err)
return nil, fmt.Errorf("error creating refresh token: %s", err)
}

refreshTokenClaims, err := jwt.GetClaims(refreshToken)
if err != nil {
return nil, err
}

// Store the refresh token's unique ID as part of this user's session
if err := a.sessionStore.UpdateSession(ctx, refreshTokenClaims.Subject, refreshTokenClaims.SessionID, refreshTokenClaims.Id); err != nil {
return nil, err
}

result := &corev2.Tokens{
Expand All @@ -60,7 +84,6 @@ func (a *AuthenticationClient) CreateAccessToken(ctx context.Context, username,
}

return result, nil

}

// TestCreds detects if the username and password are valid.
Expand All @@ -84,19 +107,37 @@ func (a *AuthenticationClient) TestCreds(ctx context.Context, username, password
//
// corev2.AccessTokenClaims -> *corev2.Claims
// corev2.RefreshTokenClaims -> *corev2.Claims
//
// Given that we use JWTs for authentication, logging out just destroys the
// server side session such that it's not possible to get a new access token
// anymore, regardless of the refresh token presented by the user.
//
// Again, because we use JWTs, even after logging out, the access token bearer
// can still interact with the system until the token expires (up to 5 minutes
// by default).
func (a *AuthenticationClient) Logout(ctx context.Context) error {
return nil
var accessClaims *corev2.Claims

// Retrieve the access token's claims
if value := ctx.Value(corev2.AccessTokenClaims); value != nil {
accessClaims = value.(*corev2.Claims)
} else {
return corev2.ErrInvalidToken
}

return a.sessionStore.DeleteSession(ctx, accessClaims.Subject, accessClaims.SessionID)
}

// RefreshAccessToken refreshes an access token. The context must carry the
// user's access and refresh claims, as well as the previous token value,
// with the following context key-values:
// RefreshAccessToken refreshes an access/refresh token pair. The context must
// carry the user's access and refresh claims, as well as the previous token
// value, with the following context key-values:
//
// corev2.AccessTokenClaims -> *corev2.Claims
// corev2.RefreshTokenClaims -> *corev2.Claims
// corev2.RefreshTokenString -> string
func (a *AuthenticationClient) RefreshAccessToken(ctx context.Context) (*corev2.Tokens, error) {
var accessClaims *corev2.Claims
var refreshClaims *corev2.Claims

// Get the access token claims
if value := ctx.Value(corev2.AccessTokenClaims); value != nil {
Expand All @@ -106,15 +147,25 @@ func (a *AuthenticationClient) RefreshAccessToken(ctx context.Context) (*corev2.
}

// Get the refresh token claims
if value := ctx.Value(corev2.RefreshTokenClaims); value == nil {
if value := ctx.Value(corev2.RefreshTokenClaims); value != nil {
refreshClaims = value.(*corev2.Claims)
} else {
return nil, corev2.ErrInvalidToken
}

// Get the refresh token string
var refreshTokenString string
if value := ctx.Value(corev2.RefreshTokenString); value != nil {
refreshTokenString = value.(string)
} else {
sessionID := accessClaims.SessionID

storedRefreshTokenID, err := a.sessionStore.GetSession(ctx, refreshClaims.Subject, refreshClaims.SessionID)
if err != nil {
return nil, err
}

// If the supplied refresh token's ID doesn't match what the session
// expected it to be. Whatever the reason for that, be it refresh token
// reuse or otherwise, we just tear down that session, forcing the user to
// fully reauthenticate.
if refreshClaims.Id != storedRefreshTokenID {
a.Logout(ctx)
return nil, corev2.ErrInvalidToken
}

Expand All @@ -136,6 +187,9 @@ func (a *AuthenticationClient) RefreshAccessToken(ctx context.Context) (*corev2.
return nil, err
}

// Carry over the session ID
claims.SessionID = sessionID

// Ensure the 'system:users' group is present
claims.Groups = append(claims.Groups, "system:users")

Expand All @@ -145,14 +199,34 @@ func (a *AuthenticationClient) RefreshAccessToken(ctx context.Context) (*corev2.
}

// Issue a new access token
_, accessTokenString, err := jwt.AccessToken(claims)
_, newAccessTokenString, err := jwt.AccessToken(claims)
if err != nil {
return nil, err
}

// Create a new refresh token, carrying over the session ID
newRefreshClaims := &corev2.Claims{
StandardClaims: corev2.StandardClaims(claims.Subject),
SessionID: sessionID,
}
newRefreshToken, newRefreshTokenString, err := jwt.RefreshToken(newRefreshClaims)
if err != nil {
return nil, fmt.Errorf("error creating refresh token: %s", err)
}

newRefreshTokenClaims, err := jwt.GetClaims(newRefreshToken)
if err != nil {
return nil, err
}

// Update the session with the new refresh token's unique ID
if err := a.sessionStore.UpdateSession(ctx, claims.Subject, refreshClaims.SessionID, newRefreshTokenClaims.Id); err != nil {
return nil, err
}

return &corev2.Tokens{
Access: accessTokenString,
Access: newAccessTokenString,
ExpiresAt: claims.ExpiresAt,
Refresh: refreshTokenString,
Refresh: newRefreshTokenString,
}, nil
}
36 changes: 24 additions & 12 deletions backend/api/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ func defaultStore() store.Store {
}

func contextWithClaims(claims *corev2.Claims) context.Context {
refreshClaims := &corev2.Claims{StandardClaims: corev2.StandardClaims(claims.Subject)}
refreshClaims := &corev2.Claims{
StandardClaims: corev2.StandardClaims(claims.Subject),
SessionID: claims.SessionID,
}
ctx := context.Background()
ctx = context.WithValue(ctx, corev2.AccessTokenClaims, claims)
ctx = context.WithValue(ctx, corev2.RefreshTokenClaims, refreshClaims)
Expand Down Expand Up @@ -60,6 +63,7 @@ func TestCreateAccessToken(t *testing.T) {
user := corev2.FixtureUser("foo")
store := &mockstore.MockStore{}
store.On("AuthenticateUser", mock.Anything, "foo", "P@ssw0rd!").Return(user, errors.New("error"))
store.On("UpdateSession", mock.Anything, "foo", mock.Anything, mock.Anything).Return(nil)
return store
},
Authenticator: defaultAuth,
Expand All @@ -76,6 +80,7 @@ func TestCreateAccessToken(t *testing.T) {
store := &mockstore.MockStore{}
user := corev2.FixtureUser("foo")
store.On("AuthenticateUser", mock.Anything, "foo", "P@ssw0rd!").Return(user, nil)
store.On("UpdateSession", mock.Anything, "foo", mock.Anything, mock.Anything).Return(nil)
return store
},
Authenticator: defaultAuth,
Expand All @@ -85,7 +90,7 @@ func TestCreateAccessToken(t *testing.T) {
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
store := test.Store()
authn := NewAuthenticationClient(test.Authenticator(store))
authn := NewAuthenticationClient(test.Authenticator(store), store)
tokens, err := authn.CreateAccessToken(test.Context(), test.Username, test.Password)
if test.WantError && err == nil {
t.Fatal("want error, got nil")
Expand Down Expand Up @@ -158,7 +163,7 @@ func TestTestCreds(t *testing.T) {
for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
store := test.Store()
authn := NewAuthenticationClient(test.Authenticator(store))
authn := NewAuthenticationClient(test.Authenticator(store), store)
err := authn.TestCreds(test.Context(), test.Username, test.Password)

if test.WantError && test.Error != err {
Expand All @@ -175,39 +180,46 @@ func TestTestCreds(t *testing.T) {
func TestRefreshAccessToken(t *testing.T) {
tests := []struct {
Name string
Store func() store.Store
Store func(string) store.Store
Authenticator func(store.Store) *authentication.Authenticator
Context func(*corev2.Claims) context.Context
Context func(*corev2.Claims) (context.Context, string)
WantError bool
Error error
}{
{
Name: "success",
Store: func() store.Store {
Store: func(refreshTokenId string) store.Store {
st := &mockstore.MockStore{}
user := &corev2.User{Username: "foo"}
st.On("GetUser",
mock.AnythingOfType("*context.valueCtx"), mock.AnythingOfType("string"),
).Return(user, nil)
st.On("GetSession",
mock.AnythingOfType("*context.valueCtx"), user.Username, mock.AnythingOfType("string"),
).Return(refreshTokenId, nil)
st.On("UpdateSession",
mock.AnythingOfType("*context.valueCtx"), user.Username, mock.AnythingOfType("string"), mock.AnythingOfType("string"),
).Return(nil)
return st
},
Authenticator: defaultAuth,
Context: func(claims *corev2.Claims) context.Context {
Context: func(claims *corev2.Claims) (context.Context, string) {
ctx := contextWithClaims(claims)
_, refreshTokenString, _ := jwt.RefreshToken(ctx.Value(corev2.RefreshTokenClaims).(*corev2.Claims))
refreshToken, refreshTokenString, _ := jwt.RefreshToken(ctx.Value(corev2.RefreshTokenClaims).(*corev2.Claims))
refreshTokenClaims, _ := jwt.GetClaims(refreshToken)
ctx = context.WithValue(ctx, corev2.RefreshTokenString, refreshTokenString)
return ctx
return ctx, refreshTokenClaims.Id
},
},
}

for _, test := range tests {
t.Run(test.Name, func(t *testing.T) {
claims := corev2.FixtureClaims("foo", nil)
ctx := test.Context(claims)
store := test.Store()
ctx, refreshTokenId := test.Context(claims)
store := test.Store(refreshTokenId)
authenticator := test.Authenticator(store)
auth := NewAuthenticationClient(authenticator)
auth := NewAuthenticationClient(authenticator, store)
_, err := auth.RefreshAccessToken(ctx)
if err == nil && test.WantError {
t.Fatal("got non-nil error")
Expand Down
8 changes: 4 additions & 4 deletions backend/apid/routers/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (a *AuthenticationRouter) login(w http.ResponseWriter, r *http.Request) {
// issuer URL
ctx := context.WithValue(r.Context(), jwt.IssuerURLKey, issuerURL(r))

client := api.NewAuthenticationClient(a.authenticator)
client := api.NewAuthenticationClient(a.authenticator, a.store)
tokens, err := client.CreateAccessToken(ctx, username, password)
if err != nil {
if err == corev2.ErrUnauthorized {
Expand Down Expand Up @@ -76,7 +76,7 @@ func (a *AuthenticationRouter) test(w http.ResponseWriter, r *http.Request) {
return
}

client := api.NewAuthenticationClient(a.authenticator)
client := api.NewAuthenticationClient(a.authenticator, a.store)
err := client.TestCreds(r.Context(), username, password)
if err == nil {
return
Expand All @@ -90,7 +90,7 @@ func (a *AuthenticationRouter) test(w http.ResponseWriter, r *http.Request) {

// logout handles the logout flow
func (a *AuthenticationRouter) logout(w http.ResponseWriter, r *http.Request) {
client := api.NewAuthenticationClient(a.authenticator)
client := api.NewAuthenticationClient(a.authenticator, a.store)
if err := client.Logout(r.Context()); err == nil {
return
}
Expand All @@ -100,7 +100,7 @@ func (a *AuthenticationRouter) logout(w http.ResponseWriter, r *http.Request) {

// token handles logic for issuing new access tokens
func (a *AuthenticationRouter) token(w http.ResponseWriter, r *http.Request) {
client := api.NewAuthenticationClient(a.authenticator)
client := api.NewAuthenticationClient(a.authenticator, a.store)

// Determine the URL that serves this request so it can be later used as the
// issuer URL
Expand Down
3 changes: 3 additions & 0 deletions backend/apid/routers/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func TestLoginSuccessful(t *testing.T) {
store.
On("AuthenticateUser", mock.Anything, "foo", "P@ssw0rd!").
Return(user, nil)
store.
On("UpdateSession", mock.Anything, "foo", mock.Anything, mock.Anything).
Return(nil)

req, _ := http.NewRequest(http.MethodGet, "/auth", nil)
req.SetBasicAuth("foo", "P@ssw0rd!")
Expand Down
2 changes: 1 addition & 1 deletion backend/authentication/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func (a *Authenticator) Authenticate(ctx context.Context, username, password str
// TODO(palourde): The Go runtime randomizes map iteration order so the
// providers resolution order might vary on each authentication, and
// consequently provoke weird behavior if the same username/password
// combinaison exists in multiple providers.
// combination exists in multiple providers.
for _, provider := range a.providers {
claims, err := provider.Authenticate(ctx, username, password)
if err != nil || claims == nil {
Expand Down
Loading

0 comments on commit cedb7a3

Please sign in to comment.