Skip to content

Commit

Permalink
feat: make auth delegation optional (#1083)
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyrag authored Dec 27, 2023
1 parent 804743c commit 5191305
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 56 deletions.
44 changes: 22 additions & 22 deletions ee/auth/cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,6 @@ func newServeCommand() *cobra.Command {
return errors.New("base url must be defined")
}

delegatedClientID := viper.GetString(delegatedClientIDFlag)
if delegatedClientID == "" {
return errors.New("delegated client id must be defined")
}

delegatedClientSecret := viper.GetString(delegatedClientSecretFlag)
if delegatedClientSecret == "" {
return errors.New("delegated client secret must be defined")
}

delegatedIssuer := viper.GetString(delegatedIssuerFlag)
if delegatedIssuer == "" {
return errors.New("delegated issuer must be defined")
}

signingKey := viper.GetString(signingKeyFlag)
if signingKey == "" {
return errors.New("signing key must be defined")
Expand Down Expand Up @@ -145,26 +130,41 @@ func newServeCommand() *cobra.Command {
options := []fx.Option{
otlpHttpClientModule(viper.GetBool(service.DebugFlag)),
fx.Supply(fx.Annotate(cmd.Context(), fx.As(new(context.Context)))),
fx.Supply(delegatedauth.Config{
Issuer: delegatedIssuer,
ClientID: delegatedClientID,
ClientSecret: delegatedClientSecret,
RedirectURL: fmt.Sprintf("%s/authorize/callback", viper.GetString(baseUrlFlag)),
}),
sqlstorage.Module(sqlstorage.KindPostgres, viper.GetString(postgresUriFlag), key, o.Clients...),
api.Module(viper.GetString(listenFlag), viper.GetString(baseUrlFlag), sharedapi.ServiceInfo{
Version: Version,
}),
oidc.Module(key, viper.GetString(baseUrlFlag), o.Clients...),
authorization.Module(),
delegatedauth.Module(),
fx.Decorate(func(logger logging.Logger) *gorm.Config {
return &gorm.Config{
Logger: sqlstorage.NewLogger(logger),
}
}),
}

if delegatedIssuer := viper.GetString(delegatedIssuerFlag); delegatedIssuer != "" {
delegatedClientID := viper.GetString(delegatedClientIDFlag)
if delegatedClientID == "" {
return errors.New("delegated client id must be defined")
}

delegatedClientSecret := viper.GetString(delegatedClientSecretFlag)
if delegatedClientSecret == "" {
return errors.New("delegated client secret must be defined")
}

options = append(options,
fx.Supply(delegatedauth.Config{
Issuer: delegatedIssuer,
ClientID: delegatedClientID,
ClientSecret: delegatedClientSecret,
RedirectURL: fmt.Sprintf("%s/authorize/callback", viper.GetString(baseUrlFlag)),
}),
delegatedauth.Module(),
)
}

options = append(options, otlptraces.CLITracesModule(viper.GetViper()))

return service.New(cmd.OutOrStdout(), options...).Run(cmd.Context())
Expand Down
2 changes: 1 addition & 1 deletion ee/auth/pkg/api/authorization/accesstoken_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestVerifyAccessToken(t *testing.T) {
})
require.NoError(t, err)

provider, err := authoidc.NewOpenIDProvider(storageFacade, serverURL, mockOIDC.Issuer(), *keySet)
provider, err := authoidc.NewOpenIDProvider(storageFacade, serverURL, mockOIDC.Issuer(), keySet)
require.NoError(t, err)

ar := &oidc.AuthRequest{
Expand Down
27 changes: 18 additions & 9 deletions ee/auth/pkg/oidc/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"crypto/rsa"
"net/http"

"gopkg.in/square/go-jose.v2"

auth "github.com/formancehq/auth/pkg"
"github.com/formancehq/auth/pkg/delegatedauth"
"github.com/gorilla/mux"
Expand All @@ -15,19 +17,26 @@ import (

func Module(privateKey *rsa.PrivateKey, issuer string, staticClients ...auth.StaticClient) fx.Option {
return fx.Options(
fx.Invoke(func(router *mux.Router, provider op.OpenIDProvider, storage Storage, relyingParty rp.RelyingParty) {
fx.Invoke(fx.Annotate(func(router *mux.Router, provider op.OpenIDProvider,
storage Storage, relyingParty rp.RelyingParty) {
AddRoutes(router, provider, storage, relyingParty)
}),
}, fx.ParamTags(``, ``, ``, `optional:"true"`))),
fx.Provide(fx.Annotate(func(storage Storage, relyingParty rp.RelyingParty) *storageFacade {
return NewStorageFacade(storage, relyingParty, privateKey, staticClients...)
}, fx.As(new(op.Storage)))),
fx.Provide(func(httpClient *http.Client, storage op.Storage, configuration delegatedauth.Config) (op.OpenIDProvider, error) {
keySet, err := ReadKeySet(httpClient, context.TODO(), configuration)
if err != nil {
return nil, err
}, fx.As(new(op.Storage)), fx.ParamTags(``, `optional:"true"`))),
fx.Provide(fx.Annotate(func(httpClient *http.Client, storage op.Storage, configuration delegatedauth.Config) (op.OpenIDProvider, error) {
var (
keySet *jose.JSONWebKeySet
err error
)
if configuration.Issuer != "" {
keySet, err = ReadKeySet(httpClient, context.TODO(), configuration)
if err != nil {
return nil, err
}
}

return NewOpenIDProvider(storage, issuer, configuration.Issuer, *keySet)
}),
return NewOpenIDProvider(storage, issuer, configuration.Issuer, keySet)
}, fx.ParamTags(``, ``, `optional:"true"`))),
)
}
2 changes: 1 addition & 1 deletion ee/auth/pkg/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func withServer(t *testing.T, fn func(m *mockoidc.MockOIDC, storage *sqlstorage.
require.NoError(t, err)

// Construct our oidc provider
provider, err := oidc.NewOpenIDProvider(storageFacade, serverUrl, mockOIDC.Issuer(), *keySet)
provider, err := oidc.NewOpenIDProvider(storageFacade, serverUrl, mockOIDC.Issuer(), keySet)
require.NoError(t, err)

// Create the router
Expand Down
45 changes: 26 additions & 19 deletions ee/auth/pkg/oidc/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,32 @@ func (p provider) JWTProfileVerifier() JWTProfileVerifier {

var _ JWTAuthorizationGrantExchanger = (*provider)(nil)

func NewOpenIDProvider(storage op.Storage, issuer, delegatedIssuer string, delegatedIssuerJsonWebKeySet jose.JSONWebKeySet) (op.OpenIDProvider, error) {
func NewOpenIDProvider(storage op.Storage, issuer, delegatedIssuer string, delegatedIssuerJsonWebKeySet *jose.JSONWebKeySet) (op.OpenIDProvider, error) {
var p op.OpenIDProvider

interceptors := make([]op.Option, 0)
if delegatedIssuer != "" {
interceptors = append(interceptors, op.WithHttpInterceptors(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Intercept token requests with grant_type of type bearer assertion
// as the library does not implement what we needs
if r.URL.Path == op.DefaultEndpoints.Token.Relative() &&
r.FormValue("grant_type") == string(oidc.GrantTypeBearer) {
grantTypeBearer(issuer, &provider{
issuer: issuer,
OpenIDProvider: p,
delegatedIssuerJsonWebKeySet: *delegatedIssuerJsonWebKeySet,
delegatedIssuer: delegatedIssuer,
}).ServeHTTP(w, r)
return
}
handler.ServeHTTP(w, r)
})

}))
}
interceptors = append(interceptors, op.WithAllowInsecure())

p, err := op.NewOpenIDProvider(issuer, &op.Config{
CryptoKey: sha256.Sum256([]byte("test")),
DefaultLogoutRedirectURI: pathLoggedOut,
Expand All @@ -73,23 +97,6 @@ func NewOpenIDProvider(storage op.Storage, issuer, delegatedIssuer string, deleg
GrantTypeRefreshToken: true,
RequestObjectSupported: true,
SupportedUILocales: []language.Tag{language.English},
}, storage, op.WithHttpInterceptors(func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Intercept token requests with grant_type of type bearer assertion
// as the library does not implement what we needs
if r.URL.Path == op.DefaultEndpoints.Token.Relative() &&
r.FormValue("grant_type") == string(oidc.GrantTypeBearer) {
grantTypeBearer(issuer, &provider{
issuer: issuer,
OpenIDProvider: p,
delegatedIssuerJsonWebKeySet: delegatedIssuerJsonWebKeySet,
delegatedIssuer: delegatedIssuer,
}).ServeHTTP(w, r)
return
}
handler.ServeHTTP(w, r)
})

}), op.WithAllowInsecure())
}, storage, interceptors...)
return p, err
}
10 changes: 6 additions & 4 deletions ee/auth/pkg/oidc/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ const AuthorizeCallbackPath = "/authorize/callback"

func AddRoutes(router *mux.Router, provider op.OpenIDProvider, storage Storage, relyingParty rp.RelyingParty) {
authorizationRouter := router.NewRoute().Subrouter()
authorizationRouter.NewRoute().Path(AuthorizeCallbackPath).Queries("code", "{code}").
Handler(authorizeCallbackHandler(provider, storage, relyingParty))
authorizationRouter.NewRoute().Path(AuthorizeCallbackPath).Queries("error", "{error}").
Handler(authorizeErrorHandler())
if relyingParty != nil {
authorizationRouter.NewRoute().Path(AuthorizeCallbackPath).Queries("code", "{code}").
Handler(authorizeCallbackHandler(provider, storage, relyingParty))
authorizationRouter.NewRoute().Path(AuthorizeCallbackPath).Queries("error", "{error}").
Handler(authorizeErrorHandler())
}

oidcLibRouter := router.PathPrefix("/").Subrouter()
oidcLibRouter.Use(func(handler http.Handler) http.Handler {
Expand Down

0 comments on commit 5191305

Please sign in to comment.