Skip to content

Commit

Permalink
Fixes infinite redirect issue on some errors in the OAuth flow (#1089)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihalbhatnagar authored Jan 13, 2025
1 parent 066fb48 commit 0166d14
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 30 deletions.
5 changes: 5 additions & 0 deletions .changeset/fluffy-mails-taste.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@osdk/oauth": patch
---

Fixes an infinite redirect issue on some errors in the OAuth flow
38 changes: 28 additions & 10 deletions packages/oauth/src/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,23 @@ declare const process: {
env: Record<string, string | undefined>;
};

export type LocalStorageState =
export type LocalStorageState = { refresh_token?: string };

export type SessionStorageState =
// when we are going to the login page
| {
refresh_token?: never;
codeVerifier?: never;
state?: never;
oldUrl: string;
}
// when we are redirecting to oauth login
| {
refresh_token?: never;
codeVerifier: string;
state: string;
oldUrl: string;
}
// when we have the refresh token
| {
refresh_token?: string;
codeVerifier?: never;
state?: never;
oldUrl?: never;
}
| {
refresh_token?: never;
codeVerifier?: never;
state?: never;
oldUrl?: never;
Expand Down Expand Up @@ -103,6 +96,31 @@ export function readLocal(client: Client): LocalStorageState {
);
}

export function saveSession(client: Client, x: SessionStorageState) {
// MUST `sessionStorage?` as nodejs does not have sessionStorage
globalThis.sessionStorage?.setItem(
`@osdk/oauth : refresh : ${client.client_id}`,
JSON.stringify(x),
);
}

export function removeSession(client: Client) {
// MUST `sessionStorage?` as nodejs does not have sessionStorage
globalThis.sessionStorage?.removeItem(
`@osdk/oauth : refresh : ${client.client_id}`,
);
}

export function readSession(client: Client): SessionStorageState {
return JSON.parse(
// MUST `sessionStorage?` as nodejs does not have sessionStorage
globalThis.sessionStorage?.getItem(
`@osdk/oauth : refresh : ${client.client_id}`,
)
?? "{}",
);
}

export function common<
R extends undefined | (() => Promise<Token | undefined>),
>(
Expand Down
17 changes: 12 additions & 5 deletions packages/oauth/src/createPublicOauthClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ import {
common,
createAuthorizationServer,
readLocal,
readSession,
removeLocal,
removeSession,
saveLocal,
saveSession,
} from "./common.js";
import type { PublicOauthClient } from "./PublicOauthClient.js";
import { throwIfError } from "./throwIfError.js";
Expand Down Expand Up @@ -201,7 +204,7 @@ export function createPublicOauthClient(
if (
result && window.location.pathname === new URL(redirect_uri).pathname
) {
const { oldUrl } = readLocal(client);
const { oldUrl } = readSession(client);
// don't block on the redirect
void go(oldUrl ?? "/");
}
Expand All @@ -222,7 +225,7 @@ export function createPublicOauthClient(
}

async function maybeHandleAuthReturn() {
const { codeVerifier, state, oldUrl } = readLocal(client);
const { state, oldUrl, codeVerifier } = readSession(client);
if (!codeVerifier) return;

try {
Expand Down Expand Up @@ -262,6 +265,8 @@ export function createPublicOauthClient(
);
}
removeLocal(client);
removeSession(client);
throw e;
}
}

Expand All @@ -272,15 +277,17 @@ export function createPublicOauthClient(
&& window.location.href !== loginPage
&& window.location.pathname !== loginPage
) {
saveLocal(client, { oldUrl: postLoginPage });
saveLocal(client, {});
saveSession(client, { oldUrl: postLoginPage });
await go(loginPage);
return;
}

const state = generateRandomState()!;
const codeVerifier = generateRandomCodeVerifier();
const oldUrl = readLocal(client).oldUrl ?? window.location.toString();
saveLocal(client, { codeVerifier, state, oldUrl });
const oldUrl = readSession(client).oldUrl ?? window.location.toString();
saveLocal(client, {});
saveSession(client, { codeVerifier, state, oldUrl });

window.location.assign(`${authServer
.authorization_endpoint!}?${new URLSearchParams({
Expand Down
53 changes: 38 additions & 15 deletions packages/oauth/src/publicOauth.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import {
} from "vitest";

import * as commonJs from "./common.js";
import { LocalStorageState } from "./common.js";
import { LocalStorageState, SessionStorageState } from "./common.js";
import { createPublicOauthClient } from "./createPublicOauthClient.js";
import { PublicOauthClient } from "./PublicOauthClient.js";

Expand Down Expand Up @@ -59,6 +59,7 @@ vi.mock("./common.js", async (importOriginal) => {
};
}),
readLocal: vi.spyOn(original, "readLocal"),
readSession: vi.spyOn(original, "readSession"),
// vi.fn(original.readLocal),
};
});
Expand Down Expand Up @@ -87,6 +88,12 @@ describe(createPublicOauthClient, () => {
removeItem: vi.fn(),
};

const mockSessionStorage = {
setItem: vi.fn(),
getItem: vi.fn(),
removeItem: vi.fn(),
};

beforeEach((context) => {
vi.restoreAllMocks();
fetchFn.mockRestore();
Expand All @@ -100,6 +107,7 @@ describe(createPublicOauthClient, () => {
beforeAll(() => {
vi.stubGlobal("window", mockWindow);
vi.stubGlobal("localStorage", mockLocalStorage);
vi.stubGlobal("sessionStorage", mockSessionStorage);
});

afterAll(() => {
Expand Down Expand Up @@ -223,22 +231,34 @@ describe(createPublicOauthClient, () => {
);
});

describe.each<LocalStorageState>([
describe.each<
{ localStorage: LocalStorageState; sessionStorage: SessionStorageState }
>([
{
refresh_token: "a-refresh-token",
localStorage: {
refresh_token: "a-refresh-token",
},
sessionStorage: {},
},
{
codeVerifier: "hi",
state: "mom",
oldUrl: "https://someoldurl.local",
localStorage: {},
sessionStorage: {
codeVerifier: "hi",
state: "mom",
oldUrl: "https://someoldurl.local",
},
},
{},
])("Initial Local State: %s", (initialLocalState) => {
{ localStorage: {}, sessionStorage: {} },
])("Initial Local State: %s", (initialState) => {
const ACCESS_TOKEN = (Math.random() + 1).toString(36).substring(7);

beforeEach(() => {
vi.mocked(commonJs.readLocal).mockImplementation(() =>
initialLocalState
initialState.localStorage
);

vi.mocked(commonJs.readSession).mockImplementation(() =>
initialState.sessionStorage
);

hoistedMocks.makeTokenAndSaveRefresh.mockImplementation(
Expand All @@ -250,7 +270,10 @@ describe(createPublicOauthClient, () => {
);
});

if (Object.keys(initialLocalState).length === 0) {
if (
Object.keys(initialState.localStorage).length === 0
&& Object.keys(initialState.sessionStorage).length === 0
) {
if (should.redirectToLoginPage) {
it("redirects to login page", async () => {
const tokenPromise = client!();
Expand All @@ -259,7 +282,7 @@ describe(createPublicOauthClient, () => {
if (should.redirectToLoginPage) {
// expect save local
await expect(tokenPromise).resolves.toBeUndefined();
expect(mockLocalStorage.setItem).toBeCalledWith(
expect(mockSessionStorage.setItem).toBeCalledWith(
`@osdk/oauth : refresh : ${clientArgs.clientId}`,
JSON.stringify({ oldUrl: window.location.toString() }),
);
Expand All @@ -279,7 +302,7 @@ describe(createPublicOauthClient, () => {
if (should.redirectToLoginPage) {
// expect save local
await expect(tokenPromise).resolves.toBeUndefined();
expect(mockLocalStorage.setItem).toBeCalledWith(
expect(mockSessionStorage.setItem).toBeCalledWith(
`@osdk/oauth : refresh : ${clientArgs.clientId}`,
JSON.stringify({ oldUrl: window.location.toString() }),
);
Expand Down Expand Up @@ -319,7 +342,7 @@ describe(createPublicOauthClient, () => {
}
}

if (initialLocalState.codeVerifier) {
if (initialState.sessionStorage.codeVerifier) {
it("tries to auth with return results", async () => {
await expect(client()).resolves.toEqual(ACCESS_TOKEN);
expect(hoistedMocks.makeTokenAndSaveRefresh).toHaveBeenCalledTimes(
Expand All @@ -334,12 +357,12 @@ describe(createPublicOauthClient, () => {
expect(mockWindow.history.replaceState).toBeCalledWith(
expect.anything(),
expect.anything(),
initialLocalState.oldUrl,
initialState.sessionStorage.oldUrl,
);
});
}

if (initialLocalState.refresh_token) {
if (initialState.localStorage.refresh_token) {
it("refreshes", async () => {
await expect(client()).resolves.toEqual(ACCESS_TOKEN);

Expand Down

0 comments on commit 0166d14

Please sign in to comment.