Skip to content

Commit

Permalink
feat (provider/fireworks): Pass generateImage sample count to backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
shaper committed Jan 13, 2025
1 parent 9495438 commit ddc7fd2
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 10 deletions.
5 changes: 5 additions & 0 deletions .changeset/violet-shirts-lick.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/fireworks': patch
---

feat (provider/fireworks): Pass generateImage sample count to backend.
25 changes: 15 additions & 10 deletions examples/ai-core/src/generate-image/fireworks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@ import { experimental_generateImage as generateImage } from 'ai';
import fs from 'fs';

async function main() {
const { image } = await generateImage({
model: fireworks.image('accounts/fireworks/models/flux-1-dev-fp8'),
const result = await generateImage({
model: fireworks.image(
'accounts/fireworks/models/stable-diffusion-xl-1024-v1-0',
),
prompt: 'A burrito launched through a tunnel',
aspectRatio: '4:3',
seed: 0, // 0 is random seed for this model
size: '1024x1024',
seed: 0,
n: 2,
providerOptions: {
fireworks: {
// https://fireworks.ai/models/fireworks/flux-1-dev-fp8/playground
guidance_scale: 10,
num_inference_steps: 10,
// https://fireworks.ai/models/fireworks/stable-diffusion-xl-1024-v1-0/playground
cfg_scale: 10,
steps: 30,
},
},
});

const filename = `image-${Date.now()}.png`;
fs.writeFileSync(filename, image.uint8Array);
console.log(`Image saved to ${filename}`);
for (const [index, image] of result.images.entries()) {
const filename = `image-${Date.now()}-${index}.png`;
fs.writeFileSync(filename, image.uint8Array);
console.log(`Image saved to ${filename}`);
}
}

main().catch(console.error);
18 changes: 18 additions & 0 deletions packages/fireworks/src/fireworks-image-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ describe('FireworksImageModel', () => {
prompt,
aspect_ratio: '16:9',
seed: 42,
samples: 1,
additional_param: 'value',
});
});
Expand Down Expand Up @@ -182,6 +183,7 @@ describe('FireworksImageModel', () => {
width: '1024',
height: '768',
seed: 42,
samples: 1,
});
});

Expand Down Expand Up @@ -269,6 +271,22 @@ describe('FireworksImageModel', () => {

expect(mockFetch).toHaveBeenCalled();
});

it('should pass samples parameter to API', async () => {
const model = createBasicModel();

await model.doGenerate({
prompt,
n: 1,
size: undefined,
aspectRatio: undefined,
seed: undefined,
providerOptions: {},
});

const requestBody = await server.calls[0].requestBody;
expect(requestBody).toHaveProperty('samples', 1);
});
});

describe('constructor', () => {
Expand Down
3 changes: 3 additions & 0 deletions packages/fireworks/src/fireworks-image-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ interface ImageRequestParams {
aspectRatio?: string;
size?: string;
seed?: number;
samples: number;
providerOptions: Record<string, unknown>;
headers: Record<string, string | undefined>;
abortSignal?: AbortSignal;
Expand All @@ -158,6 +159,7 @@ async function postImageToApi(
prompt: params.prompt,
aspect_ratio: params.aspectRatio,
seed: params.seed,
samples: params.samples,
...(splitSize && { width: splitSize[0], height: splitSize[1] }),
...(params.providerOptions.fireworks ?? {}),
},
Expand Down Expand Up @@ -225,6 +227,7 @@ export class FireworksImageModel implements ImageModelV1 {
size,
seed,
modelId: this.modelId,
samples: n,
providerOptions,
headers: combineHeaders(this.config.headers(), headers),
abortSignal,
Expand Down

0 comments on commit ddc7fd2

Please sign in to comment.