Skip to content

Commit

Permalink
Initial LiteLM implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrusOfEden committed Feb 21, 2024
1 parent 7109f02 commit d10d8d7
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 53 deletions.
9 changes: 5 additions & 4 deletions dsp/primitives/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def _generate(template: Template, **kwargs) -> Callable:
generator = dsp.settings.lm

def do_generate(
example: Example, stage: str, max_depth: int = 2, original_example=None
example: Example,
stage: str,
max_depth: int = 2,
original_example=None,
):
if not dsp.settings.lm:
raise AssertionError("No LM is loaded.")
Expand All @@ -83,9 +86,7 @@ def do_generate(

last_field_idx = 0
for field_idx, key in enumerate(field_names):
completions_ = [
c for c in completions if key in c.keys() and c[key] is not None
]
completions_ = [c for c in completions if c.get(key, None) is not None]

# Filter out completions that are missing fields that are present in at least one completion.
if len(completions_):
Expand Down
Empty file removed dspy/adapters/basic_adapter.py
Empty file.
Empty file removed dspy/adapters/chatml_adapter.py
Empty file.
Empty file removed dspy/adapters/llamachat_adapter.py
Empty file.
Empty file removed dspy/adapters/vicuna_adapter.py
Empty file.
File renamed without changes.
17 changes: 17 additions & 0 deletions dspy/backends/lm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod

from pydantic import BaseModel


class BaseLM(BaseModel, ABC):
@abstractmethod
def __call__(
self,
prompt: str,
temperature: float,
max_tokens: int,
n: int,
**kwargs,
) -> list[dict[str, str]]:
"""Generates `n` predictions for the signature output."""
...
35 changes: 35 additions & 0 deletions dspy/backends/lm/litelm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import typing as t

from litellm import completion
from pydantic import Field


from .base import BaseLM


class LiteLM(BaseLM):
STANDARD_PARAMS = {
"temperature": 0.0,
"max_tokens": 150,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
}

model: str
default_params: dict[str, t.Any] = Field(default_factory=dict)

def __call__(
self,
prompt: str,
**kwargs,
) -> list[dict[str, str]]:
"""Generates `n` predictions for the signature output."""
options = {**self.STANDARD_PARAMS, **self.default_params, **kwargs}
response = completion(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**options,
)
choices = [c for c in response["choices"] if c["finish_reason"] != "length"]
return [c["message"]["content"] for c in choices]
32 changes: 19 additions & 13 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from copy import deepcopy
import dsp
from pydantic import BaseModel, Field, create_model
from typing import Type, Union, Dict, Tuple
from typing import Optional, Union
import re

from dspy.signatures.field import InputField, OutputField, new_to_old_field


def signature_to_template(signature):
"""Convert from new to legacy format"""
def signature_to_template(signature) -> dsp.Template:
"""Convert from new to legacy format."""
return dsp.Template(
signature.instructions,
**{name: new_to_old_field(field) for name, field in signature.fields.items()},
Expand Down Expand Up @@ -60,7 +60,10 @@ def instructions(cls) -> str:

def with_instructions(cls, instructions: str):
return create_model(
cls.__name__, __base__=Signature, __doc__=instructions, **cls.fields
cls.__name__,
__base__=Signature,
__doc__=instructions,
**cls.fields,
)

@property
Expand All @@ -75,17 +78,19 @@ def with_updated_fields(cls, name, **kwargs):
**fields_copy[name].json_schema_extra,
**kwargs,
}
return create_model(cls.__name__, __base__=Signature, __doc__=cls.instructions, **fields_copy)
return create_model(
cls.__name__, __base__=Signature, __doc__=cls.instructions, **fields_copy
)

@property
def input_fields(cls):
def input_fields(cls) -> dict:
return cls._get_fields_with_type("input")

@property
def output_fields(cls):
def output_fields(cls) -> dict:
return cls._get_fields_with_type("output")

def _get_fields_with_type(cls, field_type):
def _get_fields_with_type(cls, field_type) -> dict:
return {
k: v
for k, v in cls.__fields__.items()
Expand All @@ -98,8 +103,8 @@ def prepend(cls, name, field, type_=None):
def append(cls, name, field, type_=None):
return cls.insert(-1, name, field, type_)

def insert(cls, index: int, name: str, field, type_: Type = None):
# It's posisble to set the type as annotation=type in pydantic.Field(...)
def insert(cls, index: int, name: str, field, type_: Optional[type] = None):
# It's possible to set the type as annotation=type in pydantic.Field(...)
# But this may be annoying for users, so we allow them to pass the type
if type_ is not None:
field.annotation = type_
Expand Down Expand Up @@ -127,7 +132,7 @@ def insert(cls, index: int, name: str, field, type_: Type = None):
new_signature.__doc__ = cls.instructions
return new_signature

def _parse_signature(cls, signature: str) -> Tuple[Type, Field]:
def _parse_signature(cls, signature: str) -> tuple[type, Field]:
pattern = r"^\s*[\w\s,]+\s*->\s*[\w\s,]+\s*$"
if not re.match(pattern, signature):
raise ValueError(f"Invalid signature format: '{signature}'")
Expand All @@ -144,7 +149,9 @@ def _parse_signature(cls, signature: str) -> Tuple[Type, Field]:
return fields

def __call__(
cls, signature: Union[str, Dict[str, Field]], instructions: str = None
cls,
signature: Union[str, dict[str, Field]],
instructions: Optional[str] = None,
):
"""
Creates a new Signature type with the given fields and instructions.
Expand Down Expand Up @@ -215,7 +222,6 @@ def ensure_signature(signature):

def infer_prefix(attribute_name: str) -> str:
"""Infers a prefix from an attribute name."""

# Convert camelCase to snake_case, but handle sequences of capital letters properly
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", attribute_name)
intermediate_name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1)
Expand Down
Loading

0 comments on commit d10d8d7

Please sign in to comment.