Skip to content

Commit

Permalink
Add proxyHandler custom httpendpoint and a config option to use it pe…
Browse files Browse the repository at this point in the history
…r model
  • Loading branch information
justyns committed Jan 23, 2025
1 parent b4256bd commit 1e471f3
Show file tree
Hide file tree
Showing 11 changed files with 7,463 additions and 228 deletions.
7,399 changes: 7,201 additions & 198 deletions silverbullet-ai.plug.js

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions silverbullet-ai.plug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ name: silverbullet-ai
requiredPermissions:
- fetch
functions:
proxyHandler:
path: src/proxyHandler.ts:proxyHandler
events:
- http:request:/ai-proxy/*

aiPromptSlashCommplete:
path: src/prompts.ts:aiPromptSlashComplete
events:
Expand Down
14 changes: 13 additions & 1 deletion src/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,30 @@ function setupAIProvider(model: ModelConfig) {
switch (providerName) {
case Provider.OpenAI:
currentAIProvider = new OpenAIProvider(
model.name,
apiKey,
model.modelName,
model.baseUrl || aiSettings.openAIBaseUrl,
model.requireAuth || aiSettings.requireAuth,
model.proxyOnServer || false,
);
break;
case Provider.Gemini:
currentAIProvider = new GeminiProvider(apiKey, model.modelName);
currentAIProvider = new GeminiProvider(
model.name,
apiKey,
model.modelName,
model.proxyOnServer || false,
);
break;
case Provider.Ollama:
currentAIProvider = new OllamaProvider(
model.name,
apiKey,
model.modelName,
model.baseUrl || "http://localhost:11434/v1",
model.requireAuth,
model.proxyOnServer || false,
);
break;
case Provider.Mock:
Expand All @@ -204,19 +213,22 @@ function setupEmbeddingProvider(model: EmbeddingModelConfig) {
switch (providerName) {
case EmbeddingProvider.OpenAI:
currentEmbeddingProvider = new OpenAIEmbeddingProvider(
model.name,
apiKey,
model.modelName,
model.baseUrl || aiSettings.openAIBaseUrl,
);
break;
case EmbeddingProvider.Gemini:
currentEmbeddingProvider = new GeminiEmbeddingProvider(
model.name,
apiKey,
model.modelName,
);
break;
case EmbeddingProvider.Ollama:
currentEmbeddingProvider = new OllamaEmbeddingProvider(
model.name,
apiKey,
model.modelName,
model.baseUrl || "http://localhost:11434",
Expand Down
24 changes: 20 additions & 4 deletions src/interfaces/EmbeddingProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { EmbeddingGenerationOptions } from "../types.ts";
import * as cache from "../cache.ts";

export interface EmbeddingProviderInterface {
name: string;
fullName: string;
apiKey: string;
baseUrl: string;
modelName: string;
Expand All @@ -19,22 +19,38 @@ export abstract class AbstractEmbeddingProvider
implements EmbeddingProviderInterface {
apiKey: string;
baseUrl: string;
name: string;
fullName: string;
modelName: string;
requireAuth: boolean;
proxyOnServer?: boolean;

constructor(
modelConfigName: string,
apiKey: string,
baseUrl: string,
name: string,
modelName: string,
requireAuth: boolean = true,
proxyOnServer?: boolean
) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.name = name;
this.fullName = modelConfigName;
this.modelName = modelName;
this.requireAuth = requireAuth;
this.proxyOnServer = proxyOnServer;
}

protected getUrl(path: string): string {
// Remove any leading slashes from the path
path = path.replace(/^\/+/, '');

if (this.proxyOnServer) {
// Remove any v1 prefix from the path if it exists
path = path.replace(/^v1\//, '');
return `/_/ai-proxy/${this.fullName}/${path}`;
} else {
return `${this.baseUrl}/${path}`;
}
}

abstract _generateEmbeddings(
Expand Down
4 changes: 2 additions & 2 deletions src/interfaces/ImageProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export interface ImageProviderInterface {
export abstract class AbstractImageProvider implements ImageProviderInterface {
apiKey: string;
baseUrl: string;
name: string;
fullName: string;
modelName: string;
requireAuth: boolean;

Expand All @@ -26,7 +26,7 @@ export abstract class AbstractImageProvider implements ImageProviderInterface {
) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.name = name;
this.fullName = name;
this.modelName = modelName;
this.requireAuth = requireAuth;
}
Expand Down
21 changes: 18 additions & 3 deletions src/interfaces/Provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { ChatMessage, PostProcessorData, StreamChatOptions } from "../types.ts";
import { enrichChatMessages } from "../utils.ts";

export interface ProviderInterface {
name: string;
fullName: string;
apiKey: string;
baseUrl: string;
modelName: string;
Expand All @@ -27,21 +27,36 @@ export interface ProviderInterface {
}

export abstract class AbstractProvider implements ProviderInterface {
name: string;
fullName: string;
apiKey: string;
baseUrl: string;
modelName: string;
proxyOnServer: boolean;

constructor(
name: string,
apiKey: string,
baseUrl: string,
modelName: string,
proxyOnServer: boolean = false,
) {
this.name = name;
this.fullName = name;
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.modelName = modelName;
this.proxyOnServer = proxyOnServer;
}

protected getUrl(path: string): string {
// Remove leading slashes from the path
path = path.replace(/^\/+/, '');

if (this.proxyOnServer) {
console.log("Proxy on server, using proxy URL:", `/_/ai-proxy/${this.fullName}/${path}`);
return `/_/ai-proxy/${this.fullName}/${path}`;
} else {
return `${this.baseUrl}/${path}`;
}
}

abstract chatWithAI(options: StreamChatOptions): Promise<any>;
Expand Down
15 changes: 7 additions & 8 deletions src/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ type GeminiChatContent = {
};

export class GeminiProvider extends AbstractProvider {
override name = "Gemini";

constructor(
modelConfigName: string,
apiKey: string,
modelName: string,
proxyOnServer: boolean,
) {
const baseUrl = "https://generativelanguage.googleapis.com";
super("Gemini", apiKey, baseUrl, modelName);
super(modelConfigName, apiKey, baseUrl, modelName, proxyOnServer);
}

async listModels(): Promise<any> {
const apiUrl = `${this.baseUrl}/v1beta/models?key=${this.apiKey}`;
const apiUrl = this.getUrl(`v1beta/models?key=${this.apiKey}`);
try {
const response = await fetch(apiUrl);
if (!response.ok) {
Expand Down Expand Up @@ -93,8 +93,7 @@ export class GeminiProvider extends AbstractProvider {
const { messages, onDataReceived } = options;

try {
const sseUrl =
`${this.baseUrl}/v1beta/models/${this.modelName}:streamGenerateContent?key=${this.apiKey}&alt=sse`;
const sseUrl = this.getUrl(`v1beta/models/${this.modelName}:streamGenerateContent?key=${this.apiKey}&alt=sse`);

const headers: HttpHeaders = {
"Content-Type": "application/json",
Expand Down Expand Up @@ -166,7 +165,7 @@ export class GeminiProvider extends AbstractProvider {
);

const response = await nativeFetch(
`${this.baseUrl}/v1beta/models/${this.modelName}:generateContent?key=${this.apiKey}`,
this.getUrl(`v1beta/models/${this.modelName}:generateContent?key=${this.apiKey}`),
{
method: "POST",
headers: {
Expand Down Expand Up @@ -214,7 +213,7 @@ export class GeminiEmbeddingProvider extends AbstractEmbeddingProvider {
}

const response = await nativeFetch(
`${this.baseUrl}/v1beta/models/${this.modelName}:embedContent?key=${this.apiKey}`,
this.getUrl(`v1beta/models/${this.modelName}:embedContent?key=${this.apiKey}`),
{
method: "POST",
headers: headers,
Expand Down
15 changes: 10 additions & 5 deletions src/providers/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,27 @@ type HttpHeaders = {

// For now, the Ollama provider is just a wrapper around the openai provider
export class OllamaProvider extends AbstractProvider {
override name = "Ollama";
requireAuth: boolean;
openaiProvider: OpenAIProvider;

constructor(
modelConfigName: string,
apiKey: string,
modelName: string,
baseUrl: string,
requireAuth: boolean,
proxyOnServer: boolean,
) {
super("Ollama", apiKey, baseUrl, modelName);
super(modelConfigName, apiKey, baseUrl, modelName, proxyOnServer);
this.requireAuth = requireAuth;
this.proxyOnServer = proxyOnServer;
this.openaiProvider = new OpenAIProvider(
modelConfigName,
apiKey,
modelName,
baseUrl,
requireAuth,
proxyOnServer,
);
}

Expand All @@ -54,7 +58,7 @@ export class OllamaProvider extends AbstractProvider {

// List models api isn't behind /v1/ like the other endpoints, but we don't want to force the user to change the config yet
const response = await nativeFetch(
`${this.baseUrl.replace(/\/v1\/?/, "")}/api/tags`,
this.getUrl('api/tags').replace(/\/v1\/?/, ''),
{
method: "GET",
headers: headers,
Expand Down Expand Up @@ -82,12 +86,13 @@ export class OllamaProvider extends AbstractProvider {

export class OllamaEmbeddingProvider extends AbstractEmbeddingProvider {
constructor(
modelConfigName: string,
apiKey: string,
modelName: string,
baseUrl: string,
requireAuth: boolean = false,
) {
super(apiKey, baseUrl, "Ollama", modelName, requireAuth);
super(modelConfigName, apiKey, baseUrl, modelName, requireAuth);
}

// Ollama doesn't have an openai compatible api for embeddings yet, so it gets its own provider
Expand All @@ -108,7 +113,7 @@ export class OllamaEmbeddingProvider extends AbstractEmbeddingProvider {
}

const response = await nativeFetch(
`${this.baseUrl}/api/embeddings`,
this.getUrl('api/embeddings'),
{
method: "POST",
headers: headers,
Expand Down
17 changes: 10 additions & 7 deletions src/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ type HttpHeaders = {
};

export class OpenAIProvider extends AbstractProvider {
override name = "OpenAI";
requireAuth: boolean;

constructor(
modelConfigName: string,
apiKey: string,
modelName: string,
baseUrl: string,
requireAuth: boolean,
proxyOnServer: boolean,
) {
super("OpenAI", apiKey, baseUrl, modelName);
super(modelConfigName, apiKey, baseUrl, modelName, proxyOnServer);
this.requireAuth = requireAuth;
this.proxyOnServer = proxyOnServer;
}

async chatWithAI(
Expand All @@ -56,7 +58,7 @@ export class OpenAIProvider extends AbstractProvider {
const { messages, onDataReceived, onResponseComplete } = options;

try {
const sseUrl = `${this.baseUrl}/chat/completions`;
const sseUrl = this.getUrl('chat/completions');

const headers: HttpHeaders = {
"Content-Type": "application/json",
Expand Down Expand Up @@ -133,7 +135,7 @@ export class OpenAIProvider extends AbstractProvider {
}

const response = await nativeFetch(
`${this.baseUrl}/models`,
this.getUrl('models'),
{
method: "GET",
headers: headers,
Expand Down Expand Up @@ -171,7 +173,7 @@ export class OpenAIProvider extends AbstractProvider {
};

const response = await nativeFetch(
this.baseUrl + "/chat/completions",
this.getUrl('chat/completions'),
{
method: "POST",
headers: headers,
Expand Down Expand Up @@ -203,12 +205,13 @@ export class OpenAIProvider extends AbstractProvider {

export class OpenAIEmbeddingProvider extends AbstractEmbeddingProvider {
constructor(
modelConfigName: string,
apiKey: string,
modelName: string,
baseUrl: string,
requireAuth: boolean = true,
) {
super(apiKey, baseUrl, "OpenAI", modelName, requireAuth);
super(modelConfigName, apiKey, baseUrl, modelName, requireAuth);
}

async _generateEmbeddings(
Expand All @@ -229,7 +232,7 @@ export class OpenAIEmbeddingProvider extends AbstractEmbeddingProvider {
}

const response = await nativeFetch(
`${this.baseUrl}/embeddings`,
this.getUrl('embeddings'),
{
method: "POST",
headers: headers,
Expand Down
Loading

0 comments on commit 1e471f3

Please sign in to comment.