Skip to content

Commit

Permalink
Streaming support for Transfomers models
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Apr 18, 2023
1 parent 0607b14 commit fb09220
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 48 deletions.
8 changes: 6 additions & 2 deletions guidance/library/_gen.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import re
import uuid
import logging
from .._grammar import grammar

async def gen(variable_name="generated", partial_output=None, parse=False, stop=None, max_tokens=500, n=1, temperature=0.0, top_p=1.0, logprobs=None, pattern=None, hidden=False, save_prompt=False, parser_prefix=None, parser=None, prefix="", suffix="", next_node=None, prev_node=None, next_next_node=None, **kwargs):
log = logging.getLogger(__name__)

async def gen(variable_name="generated", partial_output=None, parse=False, stop=None, stop_regex=None, max_tokens=500, n=1, temperature=0.0, top_p=1.0, logprobs=None, pattern=None, hidden=False, save_prompt=False, parser_prefix=None, parser=None, prefix="", suffix="", next_node=None, prev_node=None, next_next_node=None, **kwargs):
''' Use the LM to generate a completion string that is stored in the variable `variable_name`.
'''

Expand Down Expand Up @@ -60,7 +63,7 @@ async def gen(variable_name="generated", partial_output=None, parse=False, stop=

# call the LLM
gen_obj = parser.llm_session(
parser_prefix+prefix, stop=stop, max_tokens=max_tokens, n=n, pattern=pattern,
parser_prefix+prefix, stop=stop, stop_regex=stop_regex, max_tokens=max_tokens, n=n, pattern=pattern,
temperature=temperature, top_p=top_p, logprobs=parser.program.logprobs, cache_seed=cache_seed,
echo=parser.program.logprobs is not None, stream=stream_generation
)
Expand All @@ -77,6 +80,7 @@ async def gen(variable_name="generated", partial_output=None, parse=False, stop=
if parser.should_stop:
#log("Stopping generation")
break
# log.debug("resp", resp)
generated_value += resp["choices"][0]["text"]
partial_output(resp["choices"][0]["text"])
if logprobs is not None:
Expand Down
5 changes: 3 additions & 2 deletions guidance/llms/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,12 @@ def role_end(self, role=None):
assert self.chat_mode, "role_end() can only be used in chat mode"
return "<|im_end|>"

def __call__(self, prompt, stop=None, temperature=None, n=1, max_tokens=1000, logprobs=None, top_p=1.0, echo=False, logit_bias=None, pattern=None, stream=False, cache_seed=0):
def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, max_tokens=1000, logprobs=None, top_p=1.0, echo=False, logit_bias=None, pattern=None, stream=False, cache_seed=0):
""" Generate a completion of the given prompt.
"""

assert not pattern, "The OpenAI API does not support Guidance pattern controls! Please either switch to an endpoint that does, or don't user the `pattern` argument to `gen`."
assert not pattern, "The OpenAI API does not support Guidance pattern controls! Please either switch to an endpoint that does, or don't use the `pattern` argument to `gen`."
assert not stop_regex, "The OpenAI API does not support Guidance stop_regex controls! Please either switch to an endpoint that does, or don't use the `stop_regex` argument to `gen`."

if temperature is None:
temperature = self.temperature
Expand Down
159 changes: 116 additions & 43 deletions guidance/llms/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import collections
import regex
import pygtrie
import queue
import threading
import logging
from ._llm import LLM, LLMSession

Expand Down Expand Up @@ -89,13 +91,7 @@ def _model_and_tokenizer(self, model, tokenizer):

def session(self):
return TransformersSession(self)

def stream_then_save(self, gen, key):
list_out = []
for out in gen:
list_out.append(out)
yield out
self.cache[key] = list_out


class TransformersSession(LLMSession):
def __init__(self, llm):
Expand Down Expand Up @@ -173,8 +169,6 @@ def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, ma
if isinstance(stop_regex, str):
stop_regex = [stop_regex]

# assert healing, "Turning off token healing is not yet supported for the Transformers LLM"

# handle caching
key = "_---_".join([str(v) for v in (self.llm.model_name, prompt, stop_regex, temperature, n, max_tokens, logprobs, top_p, echo, logit_bias, token_healing, pattern, cache_seed)])
if key not in self.llm.cache or not self.llm.caching:
Expand Down Expand Up @@ -212,7 +206,6 @@ def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, ma
processors.append(healer)
else:
last_token_str = ""


# make sure we don't run off the end of the model
max_context = (getattr(model_config, "max_sequence_length", None) or getattr(model_config, "n_positions"))
Expand Down Expand Up @@ -241,8 +234,18 @@ def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, ma
if stop_regex is not None:
stoppers.append(RegexStoppingCriteria(stop_regex, self.llm.decode, len(coded_prompt)))

# call the model
generated_sequence = self.llm.model_obj.generate(
# a streamer to handle potentially partial output
streamer = TransformersStreamer(
input_ids=input_ids,
stop_regex=stop_regex,
last_token_str=last_token_str,
coded_prompt=coded_prompt,
llm=self.llm,
max_new_tokens=max_tokens
)

# the args for the transformers generate call
generate_args = dict(
inputs=input_ids,
attention_mask=attention_mask,
# position_ids=position_ids,
Expand All @@ -255,41 +258,28 @@ def __call__(self, prompt, stop=None, stop_regex=None, temperature=None, n=1, ma
past_key_values=self._past_key_values
)

# note what we now have cached and ready for our next call in this session
if self._past_key_values:
self._prefix_cache = generated_sequence[0][:self._past_key_values[0][0].shape[2]] # self._past_key_values is already saved, this just aligns with it

# save the output. note we have to remove the input_ids prefix and the token healing prefix (last token str)
out = {"choices": []}
for i in range(len(input_ids)):
generated_tokens = list(generated_sequence[i][len(input_ids[i]):])
val = self.llm.decode([self.llm._tokenizer.bos_token_id] + generated_tokens)[(len(self.llm._tokenizer.bos_token) + len(last_token_str)):]

# trim off the stop regex matches if needed
stop_pos = len(val) + 1
if stop_regex is not None:
stop_regex_obj = [regex.compile(s) for s in stop_regex]
for s in stop_regex_obj:
m = s.search(val)
if m:
stop_pos = min(m.span()[0], stop_pos)

# record the reason we stopped
if stop_pos <= len(val):
finish_reason = "stop"
elif len(generated_tokens) >= max_tokens:
finish_reason = "length"
elif generated_tokens[-1] == model_config.eos_token_id:
finish_reason = "endoftext"

out["choices"].append({"text": val[:stop_pos], "finish_reason": finish_reason})

# if we are streaming then we need to run the inference process in a separate thread
if stream:
return self.stream_then_save(out, key)
generate_args["streamer"] = streamer
thread = threading.Thread(target=self.llm.model_obj.generate, kwargs=generate_args)
thread.start()
return self._stream_then_save(streamer, key, thread)

# if we are not streaming we still manually use the streamer for consistency
else:
self.llm.cache[key] = out
generated_sequence = self.llm.model_obj.generate(**generate_args)
streamer.put(generated_sequence)
self.llm.cache[key] = streamer.__next__()
return self.llm.cache[key]

def _stream_then_save(self, gen, key, thread):
list_out = []
for out in gen:
list_out.append(out)
yield out
thread.join() # clean up the thread
self.llm.cache[key] = list_out

def __exit__(self, exc_type, exc_value, traceback):
if self.llm.acceleration:
self.llm.model_obj.prepare_inputs_for_generation = self._prev_prepare_method
Expand Down Expand Up @@ -462,4 +452,87 @@ def __call__(self, input_ids, scores, **kwargs):

return all_done

class TransformersStreamer():
def __init__(self, input_ids, stop_regex, last_token_str, coded_prompt, llm, max_new_tokens, timeout=None):
self.timeout = timeout
self.input_ids = input_ids
self.stop_regex = stop_regex
self.last_token_str = last_token_str
# self.coded_prompt = coded_prompt
self.llm = llm
self.max_total_tokens = max_new_tokens + len(input_ids[0])
coded_prompt = coded_prompt[:len(coded_prompt)-len(last_token_str)] # strip off the last token which will be regenerated
self.str_pos = len(coded_prompt) + len(self.last_token_str)
self.out_queue = queue.Queue()
self.sequence_pos = [len(self.input_ids[0]) for i in range(len(self.input_ids))]
self.generated_sequence = [[] for i in range(len(self.input_ids))]
self.generated_string = [coded_prompt for i in range(len(self.input_ids))]

def put(self, new_tokens):

# if we are given a single sequence, then make it a batch of size 1
if len(new_tokens.shape) == 1:
new_tokens = new_tokens.unsqueeze(0)

out = {"choices": [None for i in range(len(self.input_ids))]}
put_data = False
for i in range(len(self.input_ids)):
self.generated_sequence[i].extend(list(new_tokens[i]))

if self.sequence_pos[i] < len(self.generated_sequence[i]):
display_tokens = list(self.generated_sequence[i][self.sequence_pos[i]:])
val = self.llm.decode([self.llm._tokenizer.bos_token_id] + display_tokens)[len(self.llm._tokenizer.bos_token):]
self.generated_string[i] += val
if self.str_pos < len(self.generated_string[i]):
val = self.generated_string[i][self.str_pos:]
finish_reason = None

if len(self.generated_sequence[i]) >= self.max_total_tokens:
finish_reason = "length"
elif self.generated_sequence[i][-1] == self.llm.model_obj.config.eos_token_id:
finish_reason = "endoftext"

# trim off the stop regex matches if needed
stop_pos = len(val) + 1
found_partial = False
if self.stop_regex is not None and finish_reason is None:
stop_regex_obj = [regex.compile(s) for s in self.stop_regex]
for s in stop_regex_obj:
m = s.search(val, partial=True)
if m:
if m.partial: # we might be starting a stop sequence, so we can't emit anything yet
found_partial = True
break
else:
stop_pos = min(m.span()[0], stop_pos)

# record the reason we stopped (if we have stopped)
if stop_pos <= len(val):
finish_reason = "stop"

if not found_partial:
out["choices"][i] = {"text": val[:stop_pos], "finish_reason": finish_reason}
self.str_pos = len(self.generated_string[i])
if not found_partial:
put_data = True
self.sequence_pos[i] = len(self.generated_sequence[i])
if put_data:
self.out_queue.put(out)

def end(self):

# make sure we have flushed all of the data
for i in range(len(self.input_ids)):
assert self.str_pos >= len(self.generated_string[i]), "Not all data was flushed, this means generation stopped for an unknown reason!"

self.out_queue.put(None)

def __iter__(self):
return self

def __next__(self):
value = self.out_queue.get(timeout=self.timeout)
if value is None:
raise StopIteration()
else:
return value
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def find_version(*file_paths):
"pybars3",
"parsimonious",
"pygtrie",
"platformdirs"
"platformdirs",
"tiktoken"
]
)

0 comments on commit fb09220

Please sign in to comment.