diff --git a/package.json b/package.json index 957db874..c6b0e8fd 100644 --- a/package.json +++ b/package.json @@ -11,8 +11,8 @@ "format": "biome format src", "format:fix": "biome format src --write", "start": "node .next/standalone/server.js", - "test:dev": "jest --config jest.pages.config.ts && jest --config jest.api.config.ts", - "test": "jest --config jest.pages.config.ts && jest --config jest.api.config.ts --ci --coverage", + "test:dev": "jest --detectOpenHandles --verbose --config jest.pages.config.ts && jest --config jest.api.config.ts", + "test": "jest --config jest.pages.config.ts && jest --verbose --config jest.api.config.ts --ci --coverage", "studio": "prisma studio" }, "dependencies": { @@ -109,4 +109,4 @@ "prisma": { "seed": "ts-node --compiler-options {\"module\":\"CommonJS\"} prisma/seed.ts" } -} +} \ No newline at end of file diff --git a/src/pages/api/__tests__/v1/application/statistic.test.ts b/src/pages/api/__tests__/v1/application/statistic.test.ts index b6d1b71f..e580c41b 100644 --- a/src/pages/api/__tests__/v1/application/statistic.test.ts +++ b/src/pages/api/__tests__/v1/application/statistic.test.ts @@ -4,12 +4,20 @@ import { NextApiRequest, NextApiResponse } from "next"; describe("/api/stats", () => { it("should allow only GET method", async () => { const methods = ["DELETE", "POST", "PUT", "PATCH", "OPTIONS", "HEAD"]; - const req = {} as NextApiRequest; + const req = { + method: "GET", + headers: { + "x-ztnet-auth": "validApiKey", + }, + query: {}, + body: {}, + } as unknown as NextApiRequest; + const res = { status: jest.fn().mockReturnThis(), end: jest.fn(), json: jest.fn().mockReturnThis(), - setHeader: jest.fn(), // Mock `setHeader` rate limiter uses it + setHeader: jest.fn(), } as unknown as NextApiResponse; for (const method of methods) { @@ -29,7 +37,10 @@ describe("/api/stats", () => { const req = { method: "GET", headers: { "x-ztnet-auth": "invalidApiKey" }, + query: {}, + body: {}, } as unknown as NextApiRequest; + const res = { status: jest.fn().mockReturnThis(), end: jest.fn(), diff --git a/src/pages/api/__tests__/v1/network/network.test.ts b/src/pages/api/__tests__/v1/network/network.test.ts index e0cda60c..d08caae5 100644 --- a/src/pages/api/__tests__/v1/network/network.test.ts +++ b/src/pages/api/__tests__/v1/network/network.test.ts @@ -3,7 +3,7 @@ import { NextApiRequest, NextApiResponse } from "next"; describe("/api/createNetwork", () => { it("should respond 405 to unsupported methods", async () => { - const req = { method: "PUT" } as NextApiRequest; + const req = { method: "PUT", query: {} } as NextApiRequest; const res = { status: jest.fn().mockReturnThis(), end: jest.fn(), @@ -20,12 +20,13 @@ describe("/api/createNetwork", () => { const req = { method: "POST", headers: { "x-ztnet-auth": "invalidApiKey" }, + query: {}, } as unknown as NextApiRequest; const res = { status: jest.fn().mockReturnThis(), end: jest.fn(), json: jest.fn().mockReturnThis(), - setHeader: jest.fn(), // Mock `setHeader` rate limiter uses it + setHeader: jest.fn(), } as unknown as NextApiResponse; await apiNetworkHandler(req, res); @@ -37,12 +38,13 @@ describe("/api/createNetwork", () => { const req = { method: "GET", headers: { "x-ztnet-auth": "invalidApiKey" }, + query: {}, } as unknown as NextApiRequest; const res = { status: jest.fn().mockReturnThis(), end: jest.fn(), json: jest.fn().mockReturnThis(), - setHeader: jest.fn(), // Mock `setHeader` rate limiter uses it + setHeader: jest.fn(), } as unknown as NextApiResponse; await apiNetworkHandler(req, res); diff --git a/src/pages/api/__tests__/v1/networkMembers/updateMember.test.ts b/src/pages/api/__tests__/v1/networkMembers/updateMember.test.ts index 04d40008..3ce9b882 100644 --- a/src/pages/api/__tests__/v1/networkMembers/updateMember.test.ts +++ b/src/pages/api/__tests__/v1/networkMembers/updateMember.test.ts @@ -77,7 +77,7 @@ describe("Update Network Members", () => { method: "POST", headers: { "x-ztnet-auth": "validApiKey" }, query: { id: "networkId", memberId: "memberId" }, - body: { name: "New Name", authorized: "true" }, + body: { name: "New Name", authorized: true }, } as unknown as NextApiRequest; // Mock the database to return a network @@ -114,7 +114,7 @@ describe("Update Network Members", () => { method: "POST", headers: { "x-ztnet-auth": "validApiKey" }, query: { id: "networkId", memberId: "memberId" }, - body: { name: "New Name", authorized: "true" }, + body: { name: "New Name", authorized: true }, } as unknown as NextApiRequest; const res = createMockRes(); @@ -160,7 +160,7 @@ describe("Update Network Members", () => { method: "POST", headers: { "x-ztnet-auth": "invalidApiKey" }, query: { id: "networkId", memberId: "memberId" }, - body: { name: "New Name", authorized: "true" }, + body: { name: "New Name", authorized: true }, } as unknown as NextApiRequest; const res = createMockRes(); diff --git a/src/pages/api/__tests__/v1/org/org.test.ts b/src/pages/api/__tests__/v1/org/org.test.ts index b2efbb77..c0792a2b 100644 --- a/src/pages/api/__tests__/v1/org/org.test.ts +++ b/src/pages/api/__tests__/v1/org/org.test.ts @@ -68,6 +68,7 @@ describe("organization api validation", () => { .mockResolvedValue({ id: "newUserId", name: "Ztnet", email: "post@ztnet.network" }); mockRequest.headers["x-ztnet-auth"] = "not valid token"; + mockRequest.query = {}; await GET_userOrganization( mockRequest as NextApiRequest, diff --git a/src/pages/api/__tests__/v1/org/orgid.test.ts b/src/pages/api/__tests__/v1/org/orgid.test.ts index d2d21f5a..cfc3fce5 100644 --- a/src/pages/api/__tests__/v1/org/orgid.test.ts +++ b/src/pages/api/__tests__/v1/org/orgid.test.ts @@ -156,8 +156,8 @@ describe("organization api validation", () => { const validToken = encrypt(validTokenData, generateInstanceSecret(API_TOKEN_SECRET)); mockRequest.headers["x-ztnet-auth"] = validToken; - // add organizationId to the request - mockRequest.query = undefined; + // add empty query + mockRequest.query = {}; await apiNetworkHandler( mockRequest as NextApiRequest, mockResponse as NextApiResponse, diff --git a/src/pages/api/__tests__/v1/user/user.test.ts b/src/pages/api/__tests__/v1/user/user.test.ts index d1582ba5..d8e2e587 100644 --- a/src/pages/api/__tests__/v1/user/user.test.ts +++ b/src/pages/api/__tests__/v1/user/user.test.ts @@ -1,5 +1,5 @@ import { NextApiRequest, NextApiResponse } from "next"; -import createUserHandler, { POST_createUser } from "~/pages/api/v1/user"; +import createUserHandler from "~/pages/api/v1/user"; import { prisma } from "~/server/db"; import { appRouter } from "~/server/api/root"; import { API_TOKEN_SECRET, encrypt, generateInstanceSecret } from "~/utils/encryption"; @@ -18,7 +18,12 @@ jest.mock("~/server/api/root", () => ({ })), }, })); - +jest.mock("~/utils/rateLimit", () => ({ + __esModule: true, + default: () => ({ + check: jest.fn().mockResolvedValue(true), + }), +})); jest.mock("~/server/api/trpc"); jest.mock("~/server/db", () => ({ @@ -126,9 +131,19 @@ describe("createUserHandler", () => { }), })); + mockRequest.method = "POST"; mockRequest.headers["x-ztnet-auth"] = "not defined"; + mockRequest.body = { + email: "ztnet@example.com", + password: "password123", + name: "Ztnet", + }; + + await createUserHandler( + mockRequest as NextApiRequest, + mockResponse as NextApiResponse, + ); - await POST_createUser(mockRequest as NextApiRequest, mockResponse as NextApiResponse); expect(mockResponse.status).toHaveBeenCalledWith(200); // Check if the response is as expected @@ -166,6 +181,7 @@ describe("createUserHandler", () => { method: "POST", headers: { "x-ztnet-auth": tokenWithIdHash }, body: { email: "test@example.com", password: "password123", name: "Test User" }, + query: {}, } as unknown as NextApiRequest; const res = { @@ -208,7 +224,9 @@ describe("createUserHandler", () => { it("should allow only POST method", async () => { const methods = ["GET", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]; - const req = {} as NextApiRequest; + const req = { + query: {}, + } as NextApiRequest; const res = createMockRes(); for (const method of methods) { diff --git a/src/pages/api/v1/network/[id]/member/[memberId]/_schema.ts b/src/pages/api/v1/network/[id]/member/[memberId]/_schema.ts new file mode 100644 index 00000000..9d1fbaea --- /dev/null +++ b/src/pages/api/v1/network/[id]/member/[memberId]/_schema.ts @@ -0,0 +1,40 @@ +import { z } from "zod"; + +// Schema for updateable fields metadata +export const updateableFieldsMetaSchema = z + .object({ + name: z.string().optional(), + authorized: z.boolean().optional(), + }) + .strict(); + +// Schema for the context passed to the handler +export const handlerContextSchema = z.object({ + body: z.record(z.unknown()), + userId: z.string(), + networkId: z.string(), + memberId: z.string(), + ctx: z.object({ + prisma: z.any(), + session: z.object({ + user: z.object({ + id: z.string(), + }), + }), + }), +}); + +// Schema for the context passed to the DELETE handler +export const deleteHandlerContextSchema = z.object({ + userId: z.string(), + networkId: z.string(), + memberId: z.string(), + ctx: z.object({ + prisma: z.any(), + session: z.object({ + user: z.object({ + id: z.string(), + }), + }), + }), +}); diff --git a/src/pages/api/v1/network/[id]/member/[memberId]/index.ts b/src/pages/api/v1/network/[id]/member/[memberId]/index.ts index 3c0ede6a..dc24695d 100644 --- a/src/pages/api/v1/network/[id]/member/[memberId]/index.ts +++ b/src/pages/api/v1/network/[id]/member/[memberId]/index.ts @@ -6,6 +6,11 @@ import { SecuredPrivateApiRoute } from "~/utils/apiRouteAuth"; import { handleApiErrors } from "~/utils/errors"; import rateLimit from "~/utils/rateLimit"; import * as ztController from "~/utils/ztApi"; +import { + deleteHandlerContextSchema, + handlerContextSchema, + updateableFieldsMetaSchema, +} from "./_schema"; // Number of allowed requests per minute const limiter = rateLimit({ @@ -15,20 +20,6 @@ const limiter = rateLimit({ const REQUEST_PR_MINUTE = 50; -// Function to parse and validate fields based on the expected type -// biome-ignore lint/suspicious/noExplicitAny: -const parseField = (key: string, value: any, expectedType: string) => { - if (expectedType === "string") { - return value; // Assume all strings are valid - } - if (expectedType === "boolean") { - if (value === "true" || value === "false") { - return value === "true"; - } - throw new Error(`Field '${key}' expected to be boolean, got: ${value}`); - } -}; - export default async function apiNetworkUpdateMembersHandler( req: NextApiRequest, res: NextApiResponse, @@ -66,39 +57,39 @@ const POST_updateNetworkMember = SecuredPrivateApiRoute( requireNetworkId: true, requireMemberId: true, }, - async (_req, res, { body, userId, networkId, memberId, ctx }) => { - if (Object.keys(body).length === 0) { - return res.status(400).json({ error: "No data provided for update" }); - } + async (_req, res, context) => { + const validatedContext = handlerContextSchema.parse(context); + const { body, userId, networkId, memberId, ctx } = validatedContext; + + // Validate the input data + const validatedInput = updateableFieldsMetaSchema.parse(body); - // structure of the updateableFields object: const updateableFields = { - name: { type: "string", destinations: ["database"] }, + name: { type: "string", destinations: ["controller", "database"] }, authorized: { type: "boolean", destinations: ["controller"] }, }; + if (Object.keys(body).length === 0) { + return res.status(400).json({ error: "No data provided for update" }); + } + const databasePayload: Partial = {}; const controllerPayload: Partial = {}; // Iterate over keys in the request body - for (const key in body) { - // Check if the key is not in updateableFields - if (!(key in updateableFields)) { - return res.status(400).json({ error: `Invalid field: ${key}` }); - } - + for (const [key, value] of Object.entries(validatedInput)) { try { - const parsedValue = parseField(key, body[key], updateableFields[key].type); if (updateableFields[key].destinations.includes("database")) { - databasePayload[key] = parsedValue; + databasePayload[key] = value; } if (updateableFields[key].destinations.includes("controller")) { - controllerPayload[key] = parsedValue; + controllerPayload[key] = value; } } catch (error) { return res.status(400).json({ error: error.message }); } } + try { // make sure the member is valid const network = await prisma.network.findUnique({ @@ -184,7 +175,10 @@ const DELETE_deleteNetworkMember = SecuredPrivateApiRoute( requireNetworkId: true, requireMemberId: true, }, - async (_req, res, { userId, networkId, memberId, ctx }) => { + async (_req, res, context) => { + const validatedContext = deleteHandlerContextSchema.parse(context); + const { userId, networkId, memberId, ctx } = validatedContext; + try { // make sure the member is valid const network = await prisma.network.findUnique({ diff --git a/src/pages/api/v1/network/_schema.ts b/src/pages/api/v1/network/_schema.ts new file mode 100644 index 00000000..5973cdab --- /dev/null +++ b/src/pages/api/v1/network/_schema.ts @@ -0,0 +1,21 @@ +import { z } from "zod"; + +// Schema for the request body when creating a new network +export const createNetworkBodySchema = z + .object({ + name: z.string().optional(), + }) + .strict(); + +// Schema for the context passed to the handler +export const createNetworkContextSchema = z.object({ + body: createNetworkBodySchema, + ctx: z.object({ + prisma: z.any(), + session: z.object({ + user: z.object({ + id: z.string(), + }), + }), + }), +}); diff --git a/src/pages/api/v1/network/index.ts b/src/pages/api/v1/network/index.ts index db7fa8a6..6ed177e6 100644 --- a/src/pages/api/v1/network/index.ts +++ b/src/pages/api/v1/network/index.ts @@ -5,6 +5,7 @@ import { SecuredPrivateApiRoute } from "~/utils/apiRouteAuth"; import { handleApiErrors } from "~/utils/errors"; import rateLimit from "~/utils/rateLimit"; import * as ztController from "~/utils/ztApi"; +import { createNetworkContextSchema } from "./_schema"; // Number of allowed requests per minute const limiter = rateLimit({ @@ -42,9 +43,13 @@ const POST_createNewNetwork = SecuredPrivateApiRoute( { requireNetworkId: false, }, - async (_req, res, { body, ctx }) => { - // If there are users, verify the API key + async (_req, res, context) => { try { + // Validate the context (which includes the body) + const validatedContext = createNetworkContextSchema.parse(context); + const { body, ctx } = validatedContext; + + // If there are users, verify the API key const { name } = body; const newNetworkId = await networkProvisioningFactory({ diff --git a/src/pages/api/v1/org/[orgid]/network/[nwid]/_schema.ts b/src/pages/api/v1/org/[orgid]/network/[nwid]/_schema.ts new file mode 100644 index 00000000..7b994403 --- /dev/null +++ b/src/pages/api/v1/org/[orgid]/network/[nwid]/_schema.ts @@ -0,0 +1,39 @@ +import { z } from "zod"; + +// Schema for updateable fields +export const NetworkUpdateSchema = z + .object({ + name: z.string().optional(), + description: z.string().optional(), + flowRule: z.string().optional(), + mtu: z.string().optional(), + private: z.boolean().optional(), + dns: z + .object({ + domain: z.string(), + servers: z.array(z.string()), + }) + .optional(), + ipAssignmentPools: z.array(z.unknown()).optional(), + routes: z.array(z.unknown()).optional(), + v4AssignMode: z.record(z.unknown()).optional(), + v6AssignMode: z.record(z.unknown()).optional(), + }) + .strict(); + +// Schema for POST request body +const PostBodySchema = z.record(z.unknown()); + +// Schema for the context passed to the handler +export const HandlerContextSchema = z.object({ + networkId: z.string(), + ctx: z.object({ + prisma: z.any(), + session: z.object({ + user: z.object({ + id: z.string(), + }), + }), + }), + body: PostBodySchema, +}); diff --git a/src/pages/api/v1/org/[orgid]/network/[nwid]/index.ts b/src/pages/api/v1/org/[orgid]/network/[nwid]/index.ts index 7fdc5691..85771caa 100644 --- a/src/pages/api/v1/org/[orgid]/network/[nwid]/index.ts +++ b/src/pages/api/v1/org/[orgid]/network/[nwid]/index.ts @@ -6,6 +6,7 @@ import { SecuredOrganizationApiRoute } from "~/utils/apiRouteAuth"; import { handleApiErrors } from "~/utils/errors"; import rateLimit from "~/utils/rateLimit"; import * as ztController from "~/utils/ztApi"; +import { HandlerContextSchema, NetworkUpdateSchema } from "./_schema"; // Number of allowed requests per minute const limiter = rateLimit({ @@ -13,26 +14,6 @@ const limiter = rateLimit({ uniqueTokenPerInterval: 500, // Max 500 users per second }); -// Function to parse and validate fields based on the expected type -// biome-ignore lint/suspicious/noExplicitAny: -const parseField = (key: string, value: any, expectedType: string) => { - if (expectedType === "object") { - return value; - } - if (expectedType === "array") { - return value; - } - if (expectedType === "string") { - return value; - } - if (expectedType === "boolean") { - if (value === "true" || value === "false") { - return value === "true"; - } - throw new Error(`Field '${key}' expected to be boolean, got: ${value}`); - } -}; - export const REQUEST_PR_MINUTE = 50; export default async function apiNetworkByIdHandler( @@ -60,21 +41,23 @@ export default async function apiNetworkByIdHandler( export const POST_network = SecuredOrganizationApiRoute( { requiredRole: Role.READ_ONLY, requireNetworkId: true }, - async (_req, res, { networkId, ctx, body }) => { + async (_req, res, context) => { try { - // structure of the updateableFields object: + const validatedContext = HandlerContextSchema.parse(context); + const { networkId, ctx, body } = validatedContext; + + // Validate the body against the NetworkUpdateSchema + const validatedBody = NetworkUpdateSchema.parse(body); + const updateableFields = { name: { type: "string", destinations: ["controller", "database"] }, description: { type: "string", destinations: ["database"] }, flowRule: { type: "string", destinations: ["custom"] }, mtu: { type: "string", destinations: ["controller"] }, private: { type: "boolean", destinations: ["controller"] }, - // capabilities: { type: "array", destinations: ["controller"] }, dns: { type: "array", destinations: ["controller"] }, ipAssignmentPools: { type: "array", destinations: ["controller"] }, routes: { type: "array", destinations: ["controller"] }, - // rules: { type: "array", destinations: ["controller"] }, - // tags: { type: "array", destinations: ["controller"] }, v4AssignMode: { type: "object", destinations: ["controller"] }, v6AssignMode: { type: "object", destinations: ["controller"] }, }; @@ -86,14 +69,13 @@ export const POST_network = SecuredOrganizationApiRoute( const caller = appRouter.createCaller(ctx); // Iterate over keys in the request body - for (const key in body) { + for (const [key, value] of Object.entries(validatedBody)) { // Check if the key is not in updateableFields if (!(key in updateableFields)) { return res.status(400).json({ error: `Invalid field: ${key}` }); } try { - const parsedValue = parseField(key, body[key], updateableFields[key].type); // if custom and flowRule call the caller.setFlowRule if (key === "flowRule") { // @ts-expect-error @@ -101,15 +83,15 @@ export const POST_network = SecuredOrganizationApiRoute( await caller.network.setFlowRule({ nwid: networkId, updateParams: { - flowRoute: parsedValue, + flowRoute: value as string, }, }); } if (updateableFields[key].destinations.includes("database")) { - databasePayload[key] = parsedValue; + databasePayload[key] = value; } if (updateableFields[key].destinations.includes("controller")) { - controllerPayload[key] = parsedValue; + controllerPayload[key] = value; } } catch (error) { return res.status(400).json({ error: error.message }); diff --git a/src/pages/api/v1/org/[orgid]/network/[nwid]/member/[memberId]/_schema.ts b/src/pages/api/v1/org/[orgid]/network/[nwid]/member/[memberId]/_schema.ts new file mode 100644 index 00000000..d6761991 --- /dev/null +++ b/src/pages/api/v1/org/[orgid]/network/[nwid]/member/[memberId]/_schema.ts @@ -0,0 +1,26 @@ +import { z } from "zod"; + +// Schema for POST request body +export const PostBodySchema = z + .object({ + name: z.string().optional(), + authorized: z.boolean().optional(), + }) + .strict(); + +// Schema for the context passed to the handler +export const HandlerContextSchema = z.object({ + networkId: z.string(), + orgId: z.string(), + memberId: z.string(), + userId: z.string(), + body: z.record(z.unknown()), + ctx: z.object({ + prisma: z.any(), + session: z.object({ + user: z.object({ + id: z.string(), + }), + }), + }), +}); diff --git a/src/pages/api/v1/org/[orgid]/network/[nwid]/member/[memberId]/index.ts b/src/pages/api/v1/org/[orgid]/network/[nwid]/member/[memberId]/index.ts index 98a1b507..288e574b 100644 --- a/src/pages/api/v1/org/[orgid]/network/[nwid]/member/[memberId]/index.ts +++ b/src/pages/api/v1/org/[orgid]/network/[nwid]/member/[memberId]/index.ts @@ -7,6 +7,7 @@ import { handleApiErrors } from "~/utils/errors"; import rateLimit from "~/utils/rateLimit"; import { checkUserOrganizationRole } from "~/utils/role"; import * as ztController from "~/utils/ztApi"; +import { HandlerContextSchema, PostBodySchema } from "./_schema"; // Number of allowed requests per minute const limiter = rateLimit({ @@ -16,20 +17,6 @@ const limiter = rateLimit({ export const REQUEST_PR_MINUTE = 50; -// Function to parse and validate fields based on the expected type -// biome-ignore lint/suspicious/noExplicitAny: -const parseField = (key: string, value: any, expectedType: string) => { - if (expectedType === "string") { - return value; // Assume all strings are valid - } - if (expectedType === "boolean") { - if (value === "true" || value === "false") { - return value === "true"; - } - throw new Error(`Field '${key}' expected to be boolean, got: ${value}`); - } -}; - export default async function apiNetworkUpdateMembersHandler( req: NextApiRequest, res: NextApiResponse, @@ -67,9 +54,14 @@ export default async function apiNetworkUpdateMembersHandler( */ export const POST_orgUpdateNetworkMember = SecuredOrganizationApiRoute( { requiredRole: Role.USER, requireNetworkId: true }, - async (_req, res, { networkId, orgId, body, userId, memberId }) => { + async (_req, res, context) => { try { + const validatedContext = HandlerContextSchema.parse(context); + const { networkId, orgId, body, userId, memberId } = validatedContext; + + const validatedBody = PostBodySchema.parse(body); // structure of the updateableFields object: + const updateableFields = { name: { type: "string", destinations: ["database"] }, authorized: { type: "boolean", destinations: ["controller"] }, @@ -79,19 +71,18 @@ export const POST_orgUpdateNetworkMember = SecuredOrganizationApiRoute( const controllerPayload: Partial = {}; // Iterate over keys in the request body - for (const key in body) { + for (const [key, value] of Object.entries(validatedBody)) { // Check if the key is not in updateableFields if (!(key in updateableFields)) { return res.status(400).json({ error: `Invalid field: ${key}` }); } try { - const parsedValue = parseField(key, body[key], updateableFields[key].type); if (updateableFields[key].destinations.includes("database")) { - databasePayload[key] = parsedValue; + databasePayload[key] = value; } if (updateableFields[key].destinations.includes("controller")) { - controllerPayload[key] = parsedValue; + controllerPayload[key] = value; } } catch (error) { return res.status(400).json({ error: error.message }); @@ -197,8 +188,11 @@ export const POST_orgUpdateNetworkMember = SecuredOrganizationApiRoute( */ export const DELETE_orgStashNetworkMember = SecuredOrganizationApiRoute( { requiredRole: Role.USER, requireNetworkId: true }, - async (_req, res, { networkId, orgId, memberId, ctx }) => { + async (_req, res, context) => { try { + const validatedContext = HandlerContextSchema.parse(context); + const { networkId, orgId, memberId, ctx } = validatedContext; + // @ts-expect-error const caller = appRouter.createCaller(ctx); const networkAndMembers = await caller.networkMember.stash({ @@ -226,8 +220,11 @@ export const DELETE_orgStashNetworkMember = SecuredOrganizationApiRoute( */ export const GET_orgNetworkMemberById = SecuredOrganizationApiRoute( { requiredRole: Role.USER, requireNetworkId: true }, - async (_req, res, { networkId, memberId, ctx }) => { + async (_req, res, context) => { try { + const validatedContext = HandlerContextSchema.parse(context); + const { networkId, memberId, ctx } = validatedContext; + // @ts-expect-error const caller = appRouter.createCaller(ctx); const networkAndMembers = await caller.networkMember.getMemberById({ diff --git a/src/pages/api/v1/org/[orgid]/network/_schema.ts b/src/pages/api/v1/org/[orgid]/network/_schema.ts new file mode 100644 index 00000000..4c63c37e --- /dev/null +++ b/src/pages/api/v1/org/[orgid]/network/_schema.ts @@ -0,0 +1,22 @@ +import { z } from "zod"; + +// Schema for the request body when creating a new network +export const createNetworkBodySchema = z + .object({ + name: z.string().optional(), + }) + .strict(); + +// Schema for the context passed to the handler +export const createNetworkContextSchema = z.object({ + body: createNetworkBodySchema, + orgId: z.string(), + ctx: z.object({ + prisma: z.any(), + session: z.object({ + user: z.object({ + id: z.string(), + }), + }), + }), +}); diff --git a/src/pages/api/v1/org/[orgid]/network/index.ts b/src/pages/api/v1/org/[orgid]/network/index.ts index d2360489..12cab7ca 100644 --- a/src/pages/api/v1/org/[orgid]/network/index.ts +++ b/src/pages/api/v1/org/[orgid]/network/index.ts @@ -5,6 +5,7 @@ import { SecuredOrganizationApiRoute } from "~/utils/apiRouteAuth"; import { handleApiErrors } from "~/utils/errors"; import rateLimit from "~/utils/rateLimit"; import * as ztController from "~/utils/ztApi"; +import { createNetworkContextSchema } from "./_schema"; // Number of allowed requests per minute const limiter = rateLimit({ @@ -40,8 +41,12 @@ export default async function apiNetworkHandler( export const POST_orgCreateNewNetwork = SecuredOrganizationApiRoute( { requiredRole: Role.USER }, - async (_req, res, { body, orgId, ctx }) => { + async (_req, res, context) => { try { + // Validate the context (which includes the body) + const validatedContext = createNetworkContextSchema.parse(context); + const { body, orgId, ctx } = validatedContext; + // organization name const { name } = body; diff --git a/src/pages/api/v1/user/_schema.ts b/src/pages/api/v1/user/_schema.ts new file mode 100644 index 00000000..a96bf335 --- /dev/null +++ b/src/pages/api/v1/user/_schema.ts @@ -0,0 +1,14 @@ +import { z } from "zod"; +import { passwordSchema } from "~/server/api/routers/_schema"; + +// Input validation schema +export const createUserSchema = z.object({ + email: z + .string() + .email() + .transform((val) => val.trim()), + password: passwordSchema("password does not meet the requirements!"), + name: z.string().min(3, "Name must contain at least 3 character(s)").max(40), + expiresAt: z.string().datetime().optional(), + generateApiToken: z.boolean().optional(), +}); diff --git a/src/pages/api/v1/user/index.ts b/src/pages/api/v1/user/index.ts index 7efb2ff5..945e835b 100644 --- a/src/pages/api/v1/user/index.ts +++ b/src/pages/api/v1/user/index.ts @@ -7,6 +7,7 @@ import { AuthorizationType } from "~/types/apiTypes"; import { decryptAndVerifyToken } from "~/utils/encryption"; import { handleApiErrors } from "~/utils/errors"; import rateLimit from "~/utils/rateLimit"; +import { createUserSchema } from "./_schema"; // Number of allowed requests per minute const limiter = rateLimit({ @@ -61,8 +62,10 @@ export const POST_createUser = async (req: NextApiRequest, res: NextApiResponse) }); } + // Input validation + const validatedInput = createUserSchema.parse(req.body); // get data from the post request - const { email, password, name, expiresAt, generateApiToken } = req.body; + const { email, password, name, expiresAt, generateApiToken } = validatedInput; if (userCount === 0 && expiresAt !== undefined) { return res.status(400).json({ message: "Cannot add expiresAt for Admin user!" }); @@ -77,7 +80,6 @@ export const POST_createUser = async (req: NextApiRequest, res: NextApiResponse) return res.status(400).json({ message: "Invalid expiresAt date" }); } } - /** * * Create a transaction to make sure the user and API token are created together @@ -119,9 +121,6 @@ export const POST_createUser = async (req: NextApiRequest, res: NextApiResponse) let apiToken: string; if (generateApiToken !== undefined) { - if (typeof generateApiToken !== "boolean") { - throw new Error("generateApiToken must be a boolean"); - } if (generateApiToken) { const tokenResponse = await transactionCallerWithUserCtx.auth.addApiToken({ name: "Generated Token via API", diff --git a/src/server/api/routers/_schema.ts b/src/server/api/routers/_schema.ts new file mode 100644 index 00000000..0596aae9 --- /dev/null +++ b/src/server/api/routers/_schema.ts @@ -0,0 +1,21 @@ +import { z } from "zod"; + +// This regular expression (regex) is used to validate a password based on the following criteria: +// - The password must be at least 6 characters long. +// - The password must contain at least two of the following three character types: +// - Lowercase letters (a-z) +// - Uppercase letters (A-Z) +// - Digits (0-9) +export const mediumPassword = new RegExp( + "^(((?=.*[a-z])(?=.*[A-Z]))|((?=.*[a-z])(?=.*[0-9]))|((?=.*[A-Z])(?=.*[0-9])))(?=.{6,})", +); + +// create a zod password schema +export const passwordSchema = (errorMessage: string) => + z + .string() + .max(40, { message: "Password must not exceed 40 characters" }) + .refine((val) => mediumPassword.test(val), { + message: errorMessage, + }) + .optional(); diff --git a/src/server/api/routers/authRouter.ts b/src/server/api/routers/authRouter.ts index 3d2e2cda..71aee569 100644 --- a/src/server/api/routers/authRouter.ts +++ b/src/server/api/routers/authRouter.ts @@ -24,16 +24,7 @@ import { validateOrganizationToken } from "../services/organizationAuthService"; import rateLimit from "~/utils/rateLimit"; import { ErrorCode } from "~/utils/errorCode"; import { MailTemplateKey } from "~/utils/enums"; - -// This regular expression (regex) is used to validate a password based on the following criteria: -// - The password must be at least 6 characters long. -// - The password must contain at least two of the following three character types: -// - Lowercase letters (a-z) -// - Uppercase letters (A-Z) -// - Digits (0-9) -const mediumPassword = new RegExp( - "^(((?=.*[a-z])(?=.*[A-Z]))|((?=.*[a-z])(?=.*[0-9]))|((?=.*[A-Z])(?=.*[0-9])))(?=.{6,})", -); +import { mediumPassword, passwordSchema } from "./_schema"; // allow 15 requests per 10 minutes const limiter = rateLimit({ @@ -44,19 +35,6 @@ const limiter = rateLimit({ const GENERAL_REQUEST_LIMIT = 60; const SHORT_REQUEST_LIMIT = 5; -// create a zod password schema -const passwordSchema = (errorMessage: string) => - z - .string() - .max(40) - .refine((val) => { - if (!mediumPassword.test(val)) { - throw new Error(errorMessage); - } - return true; - }) - .optional(); - export const authRouter = createTRPCRouter({ register: publicProcedure .input( diff --git a/src/utils/apiRouteAuth.ts b/src/utils/apiRouteAuth.ts index 7a5fb5e9..5ecbc62e 100644 --- a/src/utils/apiRouteAuth.ts +++ b/src/utils/apiRouteAuth.ts @@ -5,6 +5,21 @@ import { Role } from "@prisma/client"; import { prisma } from "~/server/db"; import { decryptAndVerifyToken } from "./encryption"; import { AuthorizationType } from "~/types/apiTypes"; +import { z } from "zod"; + +// Schema for API request headers and query parameters +const ApiRequestSchema = z.object({ + headers: z.object({ + "x-ztnet-auth": z.string(), + }), + query: z.object({ + orgid: z.string().optional(), + nwid: z.string().optional(), + memberId: z.string().optional(), + id: z.string().optional(), + }), + body: z.any(), +}); /** * Organization API handler wrapper for apir routes that require authentication @@ -30,6 +45,11 @@ type OrgApiHandler = ( }, ) => Promise; +/** + * Wrapper for organization API routes + * @param options - Options for the API route + * @param handler - The API route handler + */ export const SecuredOrganizationApiRoute = ( options: { requiredRole: Role; @@ -39,19 +59,21 @@ export const SecuredOrganizationApiRoute = ( handler: OrgApiHandler, ) => { return async (req: NextApiRequest, res: NextApiResponse) => { - const apiKey = req.headers["x-ztnet-auth"] as string; - const orgId = req.query?.orgid as string; - const networkId = req.query?.nwid as string; - const memberId = req.query?.memberId as string; - const body = req.body; - - const mergedOptions = { - // Set orgid as required by default - requireOrgId: true, - ...options, - }; - try { + const validatedRequest = ApiRequestSchema.parse(req); + + const apiKey = validatedRequest.headers["x-ztnet-auth"] as string; + const orgId = validatedRequest.query?.orgid as string; + const networkId = validatedRequest.query?.nwid as string; + const memberId = validatedRequest.query?.memberId as string; + const body = validatedRequest.body; + + const mergedOptions = { + // Set orgid as required by default + requireOrgId: true, + ...options, + }; + if (!apiKey) { return res.status(400).json({ error: "API Key is required" }); } @@ -127,10 +149,12 @@ export const SecuredPrivateApiRoute = ( handler: UserApiHandler, ) => { return async (req: NextApiRequest, res: NextApiResponse) => { - const apiKey = req.headers["x-ztnet-auth"] as string; - const networkId = req.query?.id as string; - const memberId = req.query?.memberId as string; - const body = req.body; + const validatedRequest = ApiRequestSchema.parse(req); + + const apiKey = validatedRequest.headers["x-ztnet-auth"] as string; + const networkId = validatedRequest.query?.id as string; + const memberId = validatedRequest.query?.memberId as string; + const body = validatedRequest.body; const mergedOptions = { // Set networkId as required by default diff --git a/src/utils/errors.tsx b/src/utils/errors.tsx index f4cbb238..4c3515b2 100644 --- a/src/utils/errors.tsx +++ b/src/utils/errors.tsx @@ -2,8 +2,31 @@ import { TRPCError } from "@trpc/server"; import { getHTTPStatusCodeFromError } from "@trpc/server/http"; import { NextApiResponse } from "next"; import toast from "react-hot-toast"; +import { ZodError } from "zod"; import { ErrorData } from "~/types/errorHandling"; +interface FieldErrors { + [field: string]: string[]; +} + +function handleZodError(error: ZodError, res: NextApiResponse) { + const fieldErrors = error.issues.reduce((acc, issue) => { + const path = issue.path.join("."); + if (!acc[path]) { + acc[path] = []; + } + acc[path].push(issue.message); + return acc; + }, {}); + + return res.status(400).json({ + error: { + message: "Validation error", + fieldErrors, + }, + }); +} + // biome-ignore lint/suspicious/noExplicitAny: export const handleErrors = (error: any) => { if ((error.data as ErrorData)?.zodError) { @@ -20,6 +43,11 @@ export const handleErrors = (error: any) => { }; export const handleApiErrors = (cause, res: NextApiResponse) => { + // check if cause is an zod error + if (cause instanceof ZodError) { + return handleZodError(cause, res); + } + if (cause instanceof TRPCError) { const httpCode = getHTTPStatusCodeFromError(cause); try { @@ -29,6 +57,7 @@ export const handleApiErrors = (cause, res: NextApiResponse) => { return res.status(httpCode).json({ error: cause.message }); } } + // Check if the error is an instance of Error and has a message indicating an invalid token if (cause instanceof Error && cause.message === "Invalid token") { return res.status(401).json({ error: "Invalid token" });