Skip to content

Commit

Permalink
Added LiteLLM support and fixed remote calling bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Nov 29, 2023
1 parent d1bbce1 commit a5b5db7
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 34 deletions.
17 changes: 8 additions & 9 deletions guidance/library/_gen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import types
import regex as lregex
import uuid

import regex as regex_module
import logging
import guidance
import ast

# from guidance import select, any_char, zero_or_more, commit_point, hide
from ._silent import silent
from .._grammar import select
from ._zero_or_more import zero_or_more
Expand All @@ -18,6 +13,8 @@
from .._grammar import model_variable
from ._tool import Tool

logger = logging.getLogger(__name__)

# TODO: make this stateless!
@guidance(stateless=lambda *args, **kwargs: kwargs.get("tools", None) is None) # TODO: uncomment this once we get temperature stateless
def gen(lm, name=None, *, max_tokens=1000, list_append=False, regex=None,
Expand All @@ -27,6 +24,7 @@ def gen(lm, name=None, *, max_tokens=1000, list_append=False, regex=None,
TODO: document this
tools is a list of guidance.Tool or python functions (which will be converted to guidance.Tool)
"""
logger.debug(f'start gen(name="{name}")')

# set stream if we are interactive
# if stream_tokens is None and not lm.is_silent() and n == 1:
Expand Down Expand Up @@ -124,6 +122,7 @@ def gen(lm, name=None, *, max_tokens=1000, list_append=False, regex=None,
elif n == 1:
lm += with_temperature(pattern + stop_pattern + suffix, temperature)

logger.debug(f'finish gen')
return lm


Expand Down Expand Up @@ -178,9 +177,9 @@ def will_gen(lm, stop=None, stop_regex=None, ignore_spaces=False, max_tokens=30)
stop = []
if not stop_regex:
stop_regex = []
regexes = [lregex.escape(x) for x in stop + stop_regex]
regexes = [regex_module.escape(x) for x in stop + stop_regex]
optional_space = '\\s*' if ignore_spaces else ''
pattern = lregex.compile(f'{optional_space}({"|".join(regexes)})')
pattern = regex_module.compile(f'{optional_space}({"|".join(regexes)})')
lm2 = lm
with silent():
for _ in range(max_tokens):
Expand Down
1 change: 1 addition & 0 deletions guidance/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .transformers._transformers import Transformers, TransformersChat
from ._llama_cpp import LlamaCpp, LlamaCppChat
from ._mock import Mock, MockChat
from ._lite_llm import LiteLLMChat, LiteLLMInstruct, LiteLLMCompletion
from . import transformers
170 changes: 170 additions & 0 deletions guidance/models/_lite_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import os
from pathlib import Path
import multiprocessing
from itertools import takewhile
import operator
import threading
import numpy as np
import queue
import time
import tiktoken
import re

from ._model import Chat, Instruct
from ._remote import Remote

# chat_model_pattern = r'^(ft:)?(gpt-3\.5-turbo|gpt-4)((-\w+)+)?(:[\w-]+(?:[:\w-]+)*)?(::\w+)?$'

class LiteLLM(Remote):
def __init__(self, model, tokenizer=None, echo=True, caching=True, api_base=None, api_key=None, custom_llm_provider=None, temperature=0.0, max_streaming_tokens=1000, **kwargs):
try:
import litellm
except ImportError:
raise Exception("Please install the litellm package version >= 1.7 using `pip install litellm -U` in order to use guidance.models.LiteLLM!")

# if we are called directly (as opposed to through super()) then we convert ourselves to a more specific subclass if possible
if self.__class__ is LiteLLM:
raise Exception("The LightLLM class is not meant to be used directly! Please use LiteLLMChat, LiteLLMInstruct, or LiteLLMCompletion depending on the model you are using.")

# # Configure an AsyncOpenAI Client with user params.
# if api_key is None:
# api_key = os.environ.get("OPENAI_API_KEY")

# if organization is None:
# organization = os.environ.get("OPENAI_ORG_ID")

self.litellm = litellm

# self.client = openai_package.OpenAI(api_key=api_key, organization=organization, base_url=base_url)
self.model_name = model

# self.tokenizer = tiktoken.encoding_for_model(model)
# self.eos_token = b"<|endoftext|>"

super().__init__(
model, tokenizer=tiktoken.encoding_for_model(model), echo=echo,
caching=caching, temperature=temperature,
max_streaming_tokens=max_streaming_tokens, **kwargs
)



class LiteLLMCompletion(LiteLLM, Instruct):

def _generator(self, prompt):
self._shared_state["not_running_stream"].clear() # so we know we are running
self._shared_state["data"] = prompt # we start with this data

try:
generator = self.litellm.completion(
model=self.model_name,
messages=[{"content": prompt.decode("utf8"), "role": "system"}], # note that role=system is just ignored by litellm but used by them to match chat syntax
max_tokens=self.max_streaming_tokens,
n=1,
top_p=1,
temperature=0,
stream=True
)
except Exception as e: # TODO: add retry logic
raise e

for part in generator:
# chunk = part.choices[0].text or ""
# yield chunk.encode("utf8")
chunk = part.choices[0].delta.content or ""
yield chunk.encode("utf8")

class LiteLLMInstruct(LiteLLM, Instruct):

def get_role_start(self, name):
return ""

def get_role_end(self, name):
if name == "instruction":
return "<|endofprompt|>"
else:
raise Exception(f"The LiteLLMInstruct model does not know about the {name} role type!")

def _generator(self, prompt):
# start the new stream
prompt_end = prompt.find(b'<|endofprompt|>')
if prompt_end >= 0:
stripped_prompt = prompt[:prompt_end]
else:
raise Exception("This model cannot handle prompts that don't match the instruct format!")

# make sure you don't try and instruct the same model twice
if b'<|endofprompt|>' in prompt[prompt_end + len(b'<|endofprompt|>'):]:
raise Exception("This model has been given two separate instruct blocks, but this is not allowed!")

self._shared_state["not_running_stream"].clear() # so we know we are running
self._shared_state["data"] = stripped_prompt + b'<|endofprompt|>'# we start with this data

try:
generator = self.litellm.completion(
model=self.model_name,
messages=[{"content": self._shared_state["data"].decode("utf8"), "role": "system"}], # note that role=system is just ignored by litellm but used by them to match chat syntax
prompt=self._shared_state["data"].decode("utf8"),
max_tokens=self.max_streaming_tokens,
n=1,
top_p=1,
temperature=0,
stream=True
)
except Exception as e: # TODO: add retry logic
raise e

for part in generator:
# chunk = part.choices[0].text or ""
# yield chunk.encode("utf8")
chunk = part.choices[0].delta.content or ""
yield chunk.encode("utf8")

class LiteLLMChat(LiteLLM, Chat):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _generator(self, prompt):

# find the system text
pos = 0
role_end = b'<|im_end|>'

# find the user/assistant pairs
messages = []
found = True
while found:

# find the user text
found = False
for role_name,start_bytes in (("system", b'<|im_start|>system\n'), ("user", b'<|im_start|>user\n'), ("assistant", b'<|im_start|>assistant\n')):
if prompt[pos:].startswith(start_bytes):
pos += len(start_bytes)
end_pos = prompt[pos:].find(role_end)
if end_pos < 0:
assert role_name == "assistant", "Bad chat format! Last role before gen needs to be assistant!"
break
btext = prompt[pos:pos+end_pos]
pos += end_pos + len(role_end)
messages.append({"role": role_name, "content": btext.decode("utf8")})
found = True
break

self._shared_state["data"] = prompt[:pos]

try:
generator = self.litellm.completion(
model=self.model_name,
messages=messages,
max_tokens=self.max_streaming_tokens,
n=1,
top_p=1,
temperature=0,
stream=True
)
except Exception as e: # TODO: add retry logic
raise e

for part in generator:
chunk = part.choices[0].delta.content or ""
yield chunk.encode("utf8")
14 changes: 11 additions & 3 deletions guidance/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import copy
import time
import numpy as np
import logging
from .._utils import ByteTrie, log_softmax, softmax
from .._parser import EarleyCommitParser
from .._grammar import StatelessFunction, string, _call_pool, _tag_pattern, Null, replace_model_variables, unreplace_model_variables, select, Terminal

logger = logging.getLogger(__name__)

# define some constants we will reuse many times
_null_grammar = string('')
format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL)
Expand Down Expand Up @@ -80,15 +83,16 @@ def default_end_patterns(self):
what the current open roles are, which is something
'''

# add the eos token
parts = [self.eos_token]

# add any active non-empty role ends. Ignore role ends that are spaces
parts = []
for role_end_str in self.opened_blocks.values():
role_end_str = format_pattern.sub("", role_end_str)
if len(role_end_str) > 0 and not re.fullmatch(r'\s+', role_end_str):
parts.append(role_end_str)

# add the eos token
parts.append(self.eos_token)

return select(parts)

def _html(self):
Expand Down Expand Up @@ -373,6 +377,8 @@ def tool_def(self, functions):

def _run_stateless(lm, stateless_function, max_tokens=1000, temperature=0.0, top_p=1.0, n=1):
assert Model._grammar_only == 0, "We can't run grammar parsing while in context free mode! (for example inside a block closer)"

logger.debug("start Model._run_stateless")

# This needs to be here for streaming
# if name is not None:
Expand Down Expand Up @@ -451,6 +457,8 @@ def _run_stateless(lm, stateless_function, max_tokens=1000, temperature=0.0, top

unreplace_model_variables(replacements)

logger.debug("finish Model._run_stateless")

return lm

def _get_logits(self, token_ids, forced_bytes):
Expand Down
Loading

0 comments on commit a5b5db7

Please sign in to comment.