From da7c446dbfce03120e44c18d9f949aef68ef7ea0 Mon Sep 17 00:00:00 2001 From: Ben Frengley Date: Thu, 15 Dec 2022 18:05:11 +1300 Subject: [PATCH] feat/azureadv2: retrieve ID token from response if available --- providers/azureadv2/azureadv2.go | 1 + providers/azureadv2/session.go | 4 ++++ providers/azureadv2/session_test.go | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/providers/azureadv2/azureadv2.go b/providers/azureadv2/azureadv2.go index 523f79cce..bd1762656 100644 --- a/providers/azureadv2/azureadv2.go +++ b/providers/azureadv2/azureadv2.go @@ -159,6 +159,7 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { err = userFromReader(response.Body, &user) user.AccessToken = msSession.AccessToken + user.IDToken = msSession.IDToken user.RefreshToken = msSession.RefreshToken user.ExpiresAt = msSession.ExpiresAt return user, err diff --git a/providers/azureadv2/session.go b/providers/azureadv2/session.go index f2f0cd07c..9b3ba8490 100644 --- a/providers/azureadv2/session.go +++ b/providers/azureadv2/session.go @@ -13,6 +13,7 @@ import ( type Session struct { AuthURL string `json:"au"` AccessToken string `json:"at"` + IDToken string `json:"it"` RefreshToken string `json:"rt"` ExpiresAt time.Time `json:"exp"` } @@ -41,6 +42,9 @@ func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, s.AccessToken = token.AccessToken s.RefreshToken = token.RefreshToken s.ExpiresAt = token.Expiry + if idTok, ok := token.Extra("id_token").(string); ok { + s.IDToken = idTok + } return token.AccessToken, err } diff --git a/providers/azureadv2/session_test.go b/providers/azureadv2/session_test.go index 7edfde4e6..11b32b9ea 100644 --- a/providers/azureadv2/session_test.go +++ b/providers/azureadv2/session_test.go @@ -36,7 +36,7 @@ func Test_ToJSON(t *testing.T) { s := &azureadv2.Session{} data := s.Marshal() - a.Equal(`{"au":"","at":"","rt":"","exp":"0001-01-01T00:00:00Z"}`, data) + a.Equal(`{"au":"","at":"","it":"","rt":"","exp":"0001-01-01T00:00:00Z"}`, data) } func Test_String(t *testing.T) {