Skip to content

Commit

Permalink
WIP - refactor LM backend
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacbmiller committed Feb 20, 2024
1 parent 46cc4bc commit 82061c1
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 6 deletions.
11 changes: 5 additions & 6 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from dsp.modules.hf_client import ChatModuleClient
from dsp.modules.hf_client import HFServerTGI, HFClientVLLM, HFClientSGLang
from .signatures import *
import dsp
from dsp.modules.hf_client import HFClientVLLM

from .retrieve import *
from .backends import *
from .predict import *
from .primitives import *
from .retrieve import *
from .signatures import *

# from .evaluation import *


# FIXME:


import dsp

settings = dsp.settings

OpenAI = dsp.GPT3
Expand Down
1 change: 1 addition & 0 deletions dspy/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .base_model import *
74 changes: 74 additions & 0 deletions dspy/backends/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from litellm import text_completion


@dataclass
class BaseModel(ABC):
@abstractmethod
def __call__(self, prompt: str) -> list[str]:
"""Generate completions for the prompt."""

@abstractmethod
def finetune(self, examples: list[tuple[str, str]]) -> "Model":
"""Finetune on examples and return a new model."""


@dataclass
class BaseLM(BaseModel, ABC):
temperature: float
n: int
max_tokens: int


# this kwarg set works for all LiteLLM models except Anyscale, VertexAI, and Petals
# https://docs.litellm.ai/docs/completion/input#translated-openai-params to be implemented
@dataclass
class LiteLLM(BaseLM, ABC):
top_p: float
stream: bool = False


@dataclass
class OpenAILM(LiteLLM):
model: str = "gpt-3.5-turbo"
n: int
presence_penalty: float
frequency_penalty: float

def __init__(
self,
model: str,
temperature: float,
n: int,
max_tokens: int,
top_p: float,
presence_penalty: float,
frequency_penalty: float,
):
super().__init__(temperature, n, max_tokens, top_p)
self.model = model
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty

self.kwargs = {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"n": self.n,
"model": self.model,
"stream": False,
}

def __call__(self, prompt: str) -> list[str]:
# return completion(prompt, temperature=self.temperature, n=self.n, max_tokens=self.max_tokens)
return text_completion(prompt, **self.kwargs)

def finetune(self, examples: list[tuple[str, str]]) -> "OpenAILM":
# Does nothing, just passing pre-commit
examples = list(examples)

return self
Loading

0 comments on commit 82061c1

Please sign in to comment.