From a9c32bb1b4c88ca9fe819d4ae89e52c3fe5c7acd Mon Sep 17 00:00:00 2001 From: zufardhiyaulhaq Date: Thu, 4 Apr 2024 23:51:29 +0700 Subject: [PATCH] feat: implement google group claim in JWT --- connector/google/google.go | 13 ++++++++++++- server/oauth2.go | 2 ++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/connector/google/google.go b/connector/google/google.go index d590867285..2bbd270118 100644 --- a/connector/google/google.go +++ b/connector/google/google.go @@ -58,6 +58,9 @@ type Config struct { // If this field is true, fetch direct group membership and transitive group membership FetchTransitiveGroupMembership bool `json:"fetchTransitiveGroupMembership"` + + // enfore group claim on JWT + EnforceGroupClaim bool } // Open returns a connector which can be used to login users through Google. @@ -128,6 +131,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e domainToAdminEmail: c.DomainToAdminEmail, fetchTransitiveGroupMembership: c.FetchTransitiveGroupMembership, adminSrv: adminSrv, + EnforceGroupClaim: c.EnforceGroupClaim, }, nil } @@ -148,6 +152,7 @@ type googleConnector struct { domainToAdminEmail map[string]string fetchTransitiveGroupMembership bool adminSrv map[string]*admin.Service + EnforceGroupClaim bool } func (c *googleConnector) Close() error { @@ -248,7 +253,13 @@ func (c *googleConnector) createIdentity(ctx context.Context, identity connector } var groups []string - if s.Groups && len(c.adminSrv) > 0 { + + usingGroup := s.Groups + if c.EnforceGroupClaim { + usingGroup = true + } + + if usingGroup && len(c.adminSrv) > 0 { checkedGroups := make(map[string]struct{}) groups, err = c.getGroups(claims.Email, c.fetchTransitiveGroupMembership, checkedGroups) if err != nil { diff --git a/server/oauth2.go b/server/oauth2.go index 3589e493ea..89759ef592 100644 --- a/server/oauth2.go +++ b/server/oauth2.go @@ -440,6 +440,8 @@ func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []str tok.AuthorizingParty = clientID } + tok.Groups = append(tok.Groups, claims.Groups...) + payload, err := json.Marshal(tok) if err != nil { return "", expiry, fmt.Errorf("could not serialize claims: %v", err)