diff --git a/src/config.ts b/src/config.ts index 40aaf0c..e851ae8 100644 --- a/src/config.ts +++ b/src/config.ts @@ -235,6 +235,7 @@ export interface Config extends PromptConfig, ParamConfig { requestTimeout?: number recallTimeout?: number maxConcurrency?: number + globalConcurrency?: number pollInterval?: number trustedWorkers?: boolean workflowText2Image?: string @@ -413,6 +414,7 @@ export const Config = Schema.intersect([ requestTimeout: Schema.number().role('time').description('当请求超过这个时间时会中止并提示超时。').default(Time.minute), recallTimeout: Schema.number().role('time').description('图片发送后自动撤回的时间 (设置为 0 以禁用此功能)。').default(0), maxConcurrency: Schema.number().description('单个频道下的最大并发数量 (设置为 0 以禁用此功能)。').default(0), + globalConcurrency: Schema.number().min(0).description('全局的最大并发数量 (设置为 0 以禁用此功能)。').default(0), }).description('高级设置'), ]) as Schema diff --git a/src/index.ts b/src/index.ts index 101f0ac..74cbacf 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,4 +1,4 @@ -import { Computed, Context, Dict, h, Logger, omit, Quester, Session, SessionError, trimSlash } from 'koishi' +import { Computed, Context, Dict, h, Logger, omit, Quester, Session, SessionError, sleep, trimSlash } from 'koishi' import { Config, modelMap, models, orientMap, parseInput, sampler, upscalers, scheduler } from './config' import { ImageData, StableDiffusionWebUI } from './types' import { closestMultiple, download, forceDataPrefix, getImageSize, login, NetworkError, project, resizeInput, Size } from './utils' @@ -10,6 +10,12 @@ import { readFile } from 'fs/promises' export * from './config' +declare module 'koishi' { + interface Events { + 'novelai/finish'(id: string): void + } +} + export const reactive = true export const name = 'novelai' @@ -45,6 +51,7 @@ export function apply(ctx: Context, config: Config) { const tasks: Dict> = Object.create(null) const globalTasks = new Set() + const globalPending = new Set() let tokenTask: Promise = null const getToken = () => tokenTask ||= login(ctx) @@ -298,10 +305,29 @@ export function apply(ctx: Context, config: Config) { ? session.text('.pending', [globalTasks.size]) : session.text('.waiting')) + if (config.globalConcurrency) { + if (globalTasks.size >= config.globalConcurrency) { + const pendingId = container.pop() + globalPending.add(pendingId) + await new Promise((resolve) => { + const dispose = ctx.on('novelai/finish', (id) => { + if (id !== pendingId) return + resolve() + dispose() + } + })) + } + } + container.forEach((id) => globalTasks.add(id)) const cleanUp = (id: string) => { tasks[session.cid]?.delete(id) globalTasks.delete(id) + if (globalPending.size) { + const id = globalPending.values().next().value + globalPending.delete(id) + ctx.parallel('novelai/finish', id) + } } const path = (() => { @@ -418,8 +444,8 @@ export function apply(ctx: Context, config: Config) { const body = new FormData() const capture = /^data:([\w/.+-]+);base64,(.*)$/.exec(image.dataUrl) const [, mime,] = capture - - let name = Date.now().toString() + + let name = Date.now().toString() const ext = mime === 'image/jpeg' ? 'jpg' : mime === 'image/png' ? 'png' : '' if (ext) name += `.${ext}` const imageFile = new Blob([image.buffer], {type:mime}) @@ -435,7 +461,7 @@ export function apply(ctx: Context, config: Config) { const data = res.data let imagePath = data.name if (data.subfolder) imagePath = data.subfolder + '/' + imagePath - + for (const nodeId in prompt) { if (prompt[nodeId].class_type === 'LoadImage') { prompt[nodeId].inputs.image = imagePath @@ -460,7 +486,7 @@ export function apply(ctx: Context, config: Config) { const negativeeNodeId = prompt[nodeId].inputs.negative[0] const latentImageNodeId = prompt[nodeId].inputs.latent_image[0] prompt[positiveNodeId].inputs.text = parameters.prompt - prompt[negativeeNodeId].inputs.text = parameters.uc + prompt[negativeeNodeId].inputs.text = parameters.uc prompt[latentImageNodeId].inputs.width = parameters.width prompt[latentImageNodeId].inputs.height = parameters.height prompt[latentImageNodeId].inputs.batch_size = parameters.n_samples @@ -522,7 +548,6 @@ export function apply(ctx: Context, config: Config) { const uuid = res.data.id const check = () => ctx.http.get(trimSlash(config.endpoint) + '/api/v2/generate/check/' + uuid).then((res) => res.done) - const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms)) while (await check() === false) { await sleep(config.pollInterval) }