Skip to content

Commit

Permalink
Add another test for device authentication as next auth mode
Browse files Browse the repository at this point in the history
  • Loading branch information
adombeck committed Feb 5, 2025
1 parent 1fa0e5f commit 6cd19a6
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 10 deletions.
2 changes: 0 additions & 2 deletions internal/broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,14 +625,12 @@ func (b *Broker) handleIsAuthenticated(ctx context.Context, session *session, au
// Refresh the token if we're online even if the token has not expired
if !session.isOffline {
authInfo, err = b.refreshToken(ctx, session.oauth2Config, authInfo)

var retrieveErr *oauth2.RetrieveError
if errors.As(err, &retrieveErr) && b.provider.IsTokenExpiredError(*retrieveErr) {
// The refresh token is expired, so the user needs to authenticate via OIDC again.
session.nextAuthModes = []string{authmodes.Device, authmodes.DeviceQr}
return AuthNext, nil
}

if err != nil {
log.Error(context.Background(), err.Error())
return AuthDenied, errorMessage{Message: "could not refresh token"}
Expand Down
13 changes: 13 additions & 0 deletions internal/broker/broker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ func TestIsAuthenticated(t *testing.T) {
dontWaitForFirstCall bool
readOnlyDataDir bool
wantGroups []info.Group
wantNextAuthModes []string
}{
"Successfully_authenticate_user_with_device_auth_and_newpassword": {firstSecret: "-", wantSecondCall: true},
"Successfully_authenticate_user_with_password": {firstMode: authmodes.Password, token: &tokenOptions{}},
Expand Down Expand Up @@ -463,6 +464,13 @@ func TestIsAuthenticated(t *testing.T) {
token: &tokenOptions{noUserInfo: true},
getUserInfoFails: true,
},
"Authenticating_with_password_when_refresh_token_is_expired_results_in_device_auth_as_next_mode": {
firstMode: authmodes.Password,
token: &tokenOptions{refreshTokenExpired: true},
wantNextAuthModes: []string{authmodes.Device, authmodes.DeviceQr},
wantSecondCall: true,
secondMode: authmodes.DeviceQr,
},

"Error_when_authentication_data_is_invalid": {invalidAuthData: true},
"Error_when_secret_can_not_be_decrypted": {firstMode: authmodes.Password, badFirstKey: true},
Expand Down Expand Up @@ -640,6 +648,11 @@ func TestIsAuthenticated(t *testing.T) {
err = os.WriteFile(filepath.Join(outDir, "first_call"), out, 0600)
require.NoError(t, err, "Failed to write first response")

if tc.wantNextAuthModes != nil {
nextAuthModes := b.GetNextAuthModes(sessionID)
require.ElementsMatch(t, tc.wantNextAuthModes, nextAuthModes, "Next auth modes should match")
}

if tc.wantGroups != nil {
type userInfoMsgType struct {
UserInfo info.User `json:"userinfo"`
Expand Down
12 changes: 12 additions & 0 deletions internal/broker/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ func (b *Broker) DataDir() string {
return b.cfg.DataDir
}

// GetNextAuthModes returns the next auth mode of the specified session.
func (b *Broker) GetNextAuthModes(sessionID string) []string {
b.currentSessionsMu.Lock()
defer b.currentSessionsMu.Unlock()

session, ok := b.currentSessions[sessionID]
if !ok {
return nil
}
return session.nextAuthModes
}

// SetNextAuthModes sets the next auth mode of the specified session.
func (b *Broker) SetNextAuthModes(sessionID string, authModes []string) {
b.currentSessionsMu.Lock()
Expand Down
16 changes: 10 additions & 6 deletions internal/broker/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,13 @@ type tokenOptions struct {
issuer string
groups []info.Group

expired bool
noRefreshToken bool
noIDToken bool
invalid bool
invalidClaims bool
noUserInfo bool
expired bool
noRefreshToken bool
refreshTokenExpired bool
noIDToken bool
invalid bool
invalidClaims bool
noUserInfo bool
}

func generateCachedInfo(t *testing.T, options tokenOptions) *token.AuthCachedInfo {
Expand Down Expand Up @@ -226,6 +227,9 @@ func generateCachedInfo(t *testing.T, options tokenOptions) *token.AuthCachedInf
if options.noRefreshToken {
tok.Token.RefreshToken = ""
}
if options.refreshTokenExpired {
tok.Token.RefreshToken = testutils.ExpiredRefreshToken
}

if !options.noUserInfo {
tok.UserInfo = info.User{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Definitely a hashed password
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Definitely an encrypted token
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
access: next
data: '{}'
err: <nil>
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
access: next
data: '{}'
err: <nil>
6 changes: 4 additions & 2 deletions internal/providers/noprovider/noprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package noprovider
import (
"context"
"fmt"
"strings"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/ubuntu/authd-oidc-brokers/internal/providers/info"
Expand Down Expand Up @@ -103,6 +104,7 @@ func (p NoProvider) getGroups(_ *oauth2.Token) ([]info.Group, error) {

// IsTokenExpiredError returns true if the reason for the error is that the refresh token is expired.
func (p NoProvider) IsTokenExpiredError(err oauth2.RetrieveError) bool {
// There is no generic error for this, so we return false.
return false
// TODO: This is an msentraid specific error code and description.
// Change it to the ones from Google once we know them.
return err.ErrorCode == "invalid_grant" && strings.HasPrefix(err.ErrorDescription, "AADSTS50173:")
}
23 changes: 23 additions & 0 deletions internal/testutils/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"slices"
"strings"
"sync"
Expand All @@ -23,9 +24,15 @@ import (
"github.com/ubuntu/authd-oidc-brokers/internal/consts"
"github.com/ubuntu/authd-oidc-brokers/internal/providers/info"
"github.com/ubuntu/authd-oidc-brokers/internal/providers/noprovider"
"github.com/ubuntu/authd/log"
"golang.org/x/oauth2"
)

const (
// ExpiredRefreshToken is used to test the expired refresh token error.
ExpiredRefreshToken = "expired-refresh-token"
)

// MockKey is the RSA key used to sign the JWTs for the mock provider.
var MockKey *rsa.PrivateKey

Expand Down Expand Up @@ -197,6 +204,22 @@ func TokenHandler(serverURL string, opts *TokenHandlerOptions) EndpointHandler {
}

return func(w http.ResponseWriter, r *http.Request) {
s, err := httputil.DumpRequest(r, true)
if err != nil {
log.Errorf(context.Background(), "could not dump request: %v", err)
}
log.Debugf(context.Background(), "/token endpoint request:\n%s", s)

// Handle expired refresh token
refreshToken := r.FormValue("refresh_token")
if refreshToken == ExpiredRefreshToken {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
// This is an msentraid specific error code and description.
_, _ = w.Write([]byte(`{"error": "invalid_grant", "error_description": "AADSTS50173: The refresh token has expired."}`))
return
}

// Mimics user going through auth process
time.Sleep(2 * time.Second)

Expand Down

0 comments on commit 6cd19a6

Please sign in to comment.