From ddc7fd2e5857dbfb6f57f9c6237d99347348a760 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Sun, 12 Jan 2025 16:21:33 -0800 Subject: [PATCH] feat (provider/fireworks): Pass generateImage sample count to backend. --- .changeset/violet-shirts-lick.md | 5 ++++ .../ai-core/src/generate-image/fireworks.ts | 25 +++++++++++-------- .../src/fireworks-image-model.test.ts | 18 +++++++++++++ .../fireworks/src/fireworks-image-model.ts | 3 +++ 4 files changed, 41 insertions(+), 10 deletions(-) create mode 100644 .changeset/violet-shirts-lick.md diff --git a/.changeset/violet-shirts-lick.md b/.changeset/violet-shirts-lick.md new file mode 100644 index 000000000000..a1fdd2c56057 --- /dev/null +++ b/.changeset/violet-shirts-lick.md @@ -0,0 +1,5 @@ +--- +'@ai-sdk/fireworks': patch +--- + +feat (provider/fireworks): Pass generateImage sample count to backend. diff --git a/examples/ai-core/src/generate-image/fireworks.ts b/examples/ai-core/src/generate-image/fireworks.ts index 109bd15426b0..2d9286c5184c 100644 --- a/examples/ai-core/src/generate-image/fireworks.ts +++ b/examples/ai-core/src/generate-image/fireworks.ts @@ -4,23 +4,28 @@ import { experimental_generateImage as generateImage } from 'ai'; import fs from 'fs'; async function main() { - const { image } = await generateImage({ - model: fireworks.image('accounts/fireworks/models/flux-1-dev-fp8'), + const result = await generateImage({ + model: fireworks.image( + 'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0', + ), prompt: 'A burrito launched through a tunnel', - aspectRatio: '4:3', - seed: 0, // 0 is random seed for this model + size: '1024x1024', + seed: 0, + n: 2, providerOptions: { fireworks: { - // https://fireworks.ai/models/fireworks/flux-1-dev-fp8/playground - guidance_scale: 10, - num_inference_steps: 10, + // https://fireworks.ai/models/fireworks/stable-diffusion-xl-1024-v1-0/playground + cfg_scale: 10, + steps: 30, }, }, }); - const filename = `image-${Date.now()}.png`; - fs.writeFileSync(filename, image.uint8Array); - console.log(`Image saved to ${filename}`); + for (const [index, image] of result.images.entries()) { + const filename = `image-${Date.now()}-${index}.png`; + fs.writeFileSync(filename, image.uint8Array); + console.log(`Image saved to ${filename}`); + } } main().catch(console.error); diff --git a/packages/fireworks/src/fireworks-image-model.test.ts b/packages/fireworks/src/fireworks-image-model.test.ts index cce2035d656a..ada5ac2589f4 100644 --- a/packages/fireworks/src/fireworks-image-model.test.ts +++ b/packages/fireworks/src/fireworks-image-model.test.ts @@ -64,6 +64,7 @@ describe('FireworksImageModel', () => { prompt, aspect_ratio: '16:9', seed: 42, + samples: 1, additional_param: 'value', }); }); @@ -182,6 +183,7 @@ describe('FireworksImageModel', () => { width: '1024', height: '768', seed: 42, + samples: 1, }); }); @@ -269,6 +271,22 @@ describe('FireworksImageModel', () => { expect(mockFetch).toHaveBeenCalled(); }); + + it('should pass samples parameter to API', async () => { + const model = createBasicModel(); + + await model.doGenerate({ + prompt, + n: 1, + size: undefined, + aspectRatio: undefined, + seed: undefined, + providerOptions: {}, + }); + + const requestBody = await server.calls[0].requestBody; + expect(requestBody).toHaveProperty('samples', 1); + }); }); describe('constructor', () => { diff --git a/packages/fireworks/src/fireworks-image-model.ts b/packages/fireworks/src/fireworks-image-model.ts index b29b5057168e..f5b5e2e7b151 100644 --- a/packages/fireworks/src/fireworks-image-model.ts +++ b/packages/fireworks/src/fireworks-image-model.ts @@ -141,6 +141,7 @@ interface ImageRequestParams { aspectRatio?: string; size?: string; seed?: number; + samples: number; providerOptions: Record; headers: Record; abortSignal?: AbortSignal; @@ -158,6 +159,7 @@ async function postImageToApi( prompt: params.prompt, aspect_ratio: params.aspectRatio, seed: params.seed, + samples: params.samples, ...(splitSize && { width: splitSize[0], height: splitSize[1] }), ...(params.providerOptions.fireworks ?? {}), }, @@ -225,6 +227,7 @@ export class FireworksImageModel implements ImageModelV1 { size, seed, modelId: this.modelId, + samples: n, providerOptions, headers: combineHeaders(this.config.headers(), headers), abortSignal,