Skip to content

Commit

Permalink
feat(authentication, types, utils): Add Authentication provider scopes (
Browse files Browse the repository at this point in the history
#6228)

* initial implementation

* add test for invalid scope

* get config from scope not db

* assign config from scope

* fix package.json

* optional providers

* make providers options

* update type
  • Loading branch information
pKorsholm authored Jan 29, 2024
1 parent d1c18a3 commit a41aad4
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,19 @@ describe("AuthenticationModuleService - AuthProvider", () => {
"AuthenticationProvider with for provider: notRegistered wasn't registered in the module. Have you configured your options correctly?"
)
})

it("fails to authenticate using a valid provider with an invalid scope", async () => {
const { success, error } = await service.authenticate(
"usernamePassword",
{
scope: "non-existing",
}
)

expect(success).toBe(false)
expect(error).toEqual(
`Scope "non-existing" is not valid for provider usernamePassword`
)
})
})
})
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ describe("AuthenticationModuleService - AuthProvider", () => {
email: "[email protected]",
password: password,
},
scope: "store",
})

expect(res).toEqual({
Expand All @@ -91,6 +92,7 @@ describe("AuthenticationModuleService - AuthProvider", () => {

const res = await service.authenticate("usernamePassword", {
body: { email: "[email protected]" },
scope: "store",
})

expect(res).toEqual({
Expand All @@ -104,6 +106,7 @@ describe("AuthenticationModuleService - AuthProvider", () => {

const res = await service.authenticate("usernamePassword", {
body: { password: "supersecret" },
scope: "store",
})

expect(res).toEqual({
Expand Down Expand Up @@ -136,6 +139,7 @@ describe("AuthenticationModuleService - AuthProvider", () => {
email: "[email protected]",
password: "password",
},
scope: "store",
})

expect(res).toEqual({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ export function getInitModuleConfig() {
schema: process.env.MEDUSA_AUTHENTICATION_DB_SCHEMA,
},
},
providers: [
{
name: "usernamePassword",
scopes: {
admin: {},
store: {},
},
},
],
}

const injectedDependencies = {}
Expand Down
45 changes: 36 additions & 9 deletions packages/authentication/src/loaders/providers.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
import * as defaultProviders from "@providers"

import {
asClass,
AwilixContainer,
ClassOrFunctionReturning,
Constructor,
Resolver,
asClass,
} from "awilix"
import { LoaderOptions, ModulesSdkTypes } from "@medusajs/types"
import {
AuthModuleProviderConfig,
AuthProviderScope,
LoaderOptions,
ModulesSdkTypes,
} from "@medusajs/types"

type AuthModuleProviders = {
providers: AuthModuleProviderConfig[]
}

export default async ({
container,
options,
}: LoaderOptions<
| ModulesSdkTypes.ModuleServiceInitializeOptions
| ModulesSdkTypes.ModuleServiceInitializeCustomDataLayerOptions
(
| ModulesSdkTypes.ModuleServiceInitializeOptions
| ModulesSdkTypes.ModuleServiceInitializeCustomDataLayerOptions
) &
AuthModuleProviders
>): Promise<void> => {
// if(options.providers?.length) {
const providerMap = new Map(
options?.providers?.map((provider) => [provider.name, provider.scopes]) ??
[]
)
// if(options?.providers?.length) {
// TODO: implement plugin provider registration
// }

Expand All @@ -25,20 +42,30 @@ export default async ({
container.register({
[`auth_provider_${provider.PROVIDER}`]: asClass(
provider as Constructor<any>
).singleton(),
)
.singleton()
.inject(() => ({ scopes: providerMap.get(provider.PROVIDER) ?? {} })),
})
}

container.register({
[`auth_providers`]: asArray(providersToLoad),
[`auth_providers`]: asArray(providersToLoad, providerMap),
})
}

function asArray(
resolvers: (ClassOrFunctionReturning<unknown> | Resolver<unknown>)[]
resolvers: (ClassOrFunctionReturning<unknown> | Resolver<unknown>)[],
providerScopeMap: Map<string, Record<string, AuthProviderScope>>
): { resolve: (container: AwilixContainer) => unknown[] } {
return {
resolve: (container: AwilixContainer) =>
resolvers.map((resolver) => container.build(resolver)),
resolvers.map((resolver) =>
asClass(resolver as Constructor<any>)
.inject(() => ({
// @ts-ignore
scopes: providerScopeMap.get(resolver.PROVIDER) ?? {},
}))
.resolve(container)
),
}
}
52 changes: 29 additions & 23 deletions packages/authentication/src/providers/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import {
} from "@medusajs/utils"
import { AuthProviderService, AuthUserService } from "@services"
import jwt, { JwtPayload } from "jsonwebtoken"

import { AuthProvider } from "@models"
import { AuthenticationResponse } from "@medusajs/types"
import {
AuthenticationInput,
AuthenticationResponse,
AuthProviderScope,
} from "@medusajs/types"
import { AuthorizationCode } from "simple-oauth2"
import url from "url"

Expand All @@ -15,14 +17,6 @@ type InjectedDependencies = {
authProviderService: AuthProviderService
}

type AuthenticationInput = {
connection: { encrypted: boolean }
url: string
headers: { host: string }
query: Record<string, string>
body: Record<string, string>
}

type ProviderConfig = {
clientID: string
clientSecret: string
Expand All @@ -37,7 +31,7 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider {
protected readonly authProviderService_: AuthProviderService

constructor({ authUserService, authProviderService }: InjectedDependencies) {
super()
super(arguments[0])

this.authUserSerivce_ = authUserService
this.authProviderService_ = authProviderService
Expand Down Expand Up @@ -84,11 +78,11 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider {

const code = req.query?.code ?? req.body?.code

return await this.validateCallbackToken(code, config)
return await this.validateCallbackToken(code, req.scope, config)
}

// abstractable
async verify_(refreshToken: string) {
async verify_(refreshToken: string, scope: string) {
const jwtData = (await jwt.decode(refreshToken, {
complete: true,
})) as JwtPayload
Expand All @@ -108,6 +102,7 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider {
entity_id,
provider_id: GoogleProvider.PROVIDER,
user_metadata: jwtData!.payload,
app_metadata: { scope },
},
])
} else {
Expand All @@ -121,6 +116,7 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider {
// abstractable
private async validateCallbackToken(
code: string,
scope: string,
{ clientID, callbackURL, clientSecret }: ProviderConfig
) {
const client = this.getAuthorizationCodeHandler({ clientID, clientSecret })
Expand All @@ -133,24 +129,34 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider {
try {
const accessToken = await client.getToken(tokenParams)

return await this.verify_(accessToken.token.id_token)
return await this.verify_(accessToken.token.id_token, scope)
} catch (error) {
return { success: false, error: error.message }
}
}

private async validateConfig(config: Partial<ProviderConfig>) {
if (!config.clientID) {
private getConfigFromScope(config: AuthProviderScope): ProviderConfig {
const providerConfig: Partial<ProviderConfig> = {}

if (config.clientId) {
providerConfig.clientID = config.clientId
} else {
throw new Error("Google clientID is required")
}

if (!config.clientSecret) {
if (config.clientSecret) {
providerConfig.clientSecret = config.clientSecret
} else {
throw new Error("Google clientSecret is required")
}

if (!config.callbackURL) {
if (config.callbackURL) {
providerConfig.callbackURL = config.callbackUrl
} else {
throw new Error("Google callbackUrl is required")
}

return providerConfig as ProviderConfig
}

private originalURL(req: AuthenticationInput) {
Expand All @@ -165,11 +171,11 @@ class GoogleProvider extends AbstractAuthenticationModuleProvider {
private async getProviderConfig(
req: AuthenticationInput
): Promise<ProviderConfig> {
const { config } = (await this.authProviderService_.retrieve(
GoogleProvider.PROVIDER
)) as AuthProvider & { config: ProviderConfig }
await this.authProviderService_.retrieve(GoogleProvider.PROVIDER)

const scopeConfig = this.scopes_[req.scope]

this.validateConfig(config || {})
const config = this.getConfigFromScope(scopeConfig)

const { callbackURL } = config

Expand Down
12 changes: 6 additions & 6 deletions packages/authentication/src/providers/username-password.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { AbstractAuthenticationModuleProvider, isString } from "@medusajs/utils"

import { AuthUserService } from "@services"
import { AuthenticationResponse } from "@medusajs/types"
import { AuthenticationInput, AuthenticationResponse } from "@medusajs/types"
import Scrypt from "scrypt-kdf"

class UsernamePasswordProvider extends AbstractAuthenticationModuleProvider {
Expand All @@ -10,14 +10,14 @@ class UsernamePasswordProvider extends AbstractAuthenticationModuleProvider {

protected readonly authUserSerivce_: AuthUserService

constructor({ authUserService: AuthUserService }) {
super()
constructor({ authUserService }: { authUserService: AuthUserService }) {
super(arguments[0])

this.authUserSerivce_ = AuthUserService
this.authUserSerivce_ = authUserService
}

async authenticate(
userData: Record<string, any>
userData: AuthenticationInput
): Promise<AuthenticationResponse> {
const { email, password } = userData.body

Expand All @@ -43,7 +43,7 @@ class UsernamePasswordProvider extends AbstractAuthenticationModuleProvider {
const password_hash = authUser.provider_metadata?.password

if (isString(password_hash)) {
const buf = Buffer.from(password_hash, "base64")
const buf = Buffer.from(password_hash as string, "base64")

const success = await Scrypt.verify(buf, password)

Expand Down
22 changes: 15 additions & 7 deletions packages/authentication/src/services/authentication-module.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {
AuthenticationInput,
AuthenticationResponse,
AuthenticationTypes,
Context,
Expand Down Expand Up @@ -353,7 +354,8 @@ export default class AuthenticationModuleService<
}

protected getRegisteredAuthenticationProvider(
provider: string
provider: string,
{ scope }: AuthenticationInput
): AbstractAuthenticationModuleProvider {
let containerProvider: AbstractAuthenticationModuleProvider
try {
Expand All @@ -365,18 +367,22 @@ export default class AuthenticationModuleService<
)
}

containerProvider.validateScope(scope)

return containerProvider
}

async authenticate(
provider: string,
authenticationData: Record<string, unknown>
authenticationData: AuthenticationInput
): Promise<AuthenticationResponse> {
try {
await this.retrieveAuthProvider(provider, {})

const registeredProvider =
this.getRegisteredAuthenticationProvider(provider)
const registeredProvider = this.getRegisteredAuthenticationProvider(
provider,
authenticationData
)

return await registeredProvider.authenticate(authenticationData)
} catch (error) {
Expand All @@ -386,13 +392,15 @@ export default class AuthenticationModuleService<

async validateCallback(
provider: string,
authenticationData: Record<string, unknown>
authenticationData: AuthenticationInput
): Promise<AuthenticationResponse> {
try {
await this.retrieveAuthProvider(provider, {})

const registeredProvider =
this.getRegisteredAuthenticationProvider(provider)
const registeredProvider = this.getRegisteredAuthenticationProvider(
provider,
authenticationData
)

return await registeredProvider.validateCallback(authenticationData)
} catch (error) {
Expand Down
17 changes: 17 additions & 0 deletions packages/types/src/authentication/common/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,21 @@ export type AuthenticationResponse = {
success: boolean
authUser?: any
error?: string
location?: string
}

export type AuthModuleProviderConfig = {
name: string
scopes: Record<string, AuthProviderScope>
}

export type AuthProviderScope = Record<string, string>

export type AuthenticationInput = {
connection: { encrypted: boolean }
url: string
headers: Record<string, string>
query: Record<string, string>
body: Record<string, string>
scope: string
}
Loading

0 comments on commit a41aad4

Please sign in to comment.