diff --git a/guidance/models/__init__.py b/guidance/models/__init__.py index 48c8bf99f..05837444b 100644 --- a/guidance/models/__init__.py +++ b/guidance/models/__init__.py @@ -3,5 +3,5 @@ from ._openai import OpenAI, OpenAIChat, OpenAIInstruct, OpenAICompletion from .transformers._transformers import Transformers, TransformersChat from ._llama_cpp import LlamaCpp, LlamaCppChat -from ._local_mock import LocalMock, LocalMockChat +from ._mock import Mock, MockChat from . import transformers \ No newline at end of file diff --git a/guidance/models/_llama_cpp.py b/guidance/models/_llama_cpp.py index 1059aee4c..e012625fd 100644 --- a/guidance/models/_llama_cpp.py +++ b/guidance/models/_llama_cpp.py @@ -5,8 +5,7 @@ import numpy as np -from ._model import Chat -from ._local import Local +from ._model import Model, Chat from .._utils import normalize_notebook_stdout_stderr try: @@ -15,7 +14,7 @@ except ImportError: is_llama_cpp = False -class LlamaCpp(Local): +class LlamaCpp(Model): def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperature=0.0, **kwargs): if not is_llama_cpp: diff --git a/guidance/models/_local.py b/guidance/models/_local.py deleted file mode 100644 index 9f3df3fff..000000000 --- a/guidance/models/_local.py +++ /dev/null @@ -1,498 +0,0 @@ -import numpy as np -from .._utils import ByteTrie, log_softmax, softmax -from ._model import Model -from .._parser import EarleyCommitParser -from .._grammar import Terminal - -class Local(Model): - def __init__(self, tokens, bos_token_id, eos_token_id=None, echo=True): - super().__init__(echo) - - assert isinstance(tokens[0], bytes), "The tokens need to be provided as bytes!" - - self.tokens = tokens - self.bos_token_id = bos_token_id - self.bos_token = None if self.bos_token_id is None else self.tokens[self.bos_token_id] - self.eos_token_id = eos_token_id if eos_token_id is not None else bos_token_id - self.eos_token = None if self.eos_token_id is None else self.tokens[self.eos_token_id] - - # build a prefix tree of the tokens - self._token_trie = ByteTrie(tokens, np.arange(len(tokens))) - self._token_trie.match = True - self._token_trie.match_version = 0 - - self._max_token_bytes = max([len(t) for t in self.tokens]) - - def _get_logits(self, token_ids, forced_bytes): - '''A fake method designed to be overriden by subclasses.''' - - # pretend to extend the KV cache and update the log probs - return np.randn(len(self.tokens)) - - def _joint_tokenize(self, token_ids): - # an abstract method. Should return what a full joint tokenizer would give for a given byte string - return token_ids - - def _tokenize_prefix(self, byte_string): - '''This is used to speed up the tokenization of long prompts without using the parser.''' - token_ids = [] - token_byte_positions = [] - - # loop trying to decode a new token at each iteration - pos = 0 - while True: - - # walk down the token trie looking for a unique token match - trie = self._token_trie - valid_pos = -1 - valid_value = -1 - while True: - if pos >= len(byte_string): - if len(trie.children) > 0: - valid_pos = -1 - break - - # check if we can keep going or are at a dead end - if byte_string[pos:pos+1] in trie.children: - trie = trie.children[byte_string[pos:pos+1]] - pos += 1 - - # record the last valid token down this path as we go - if trie.value is not None: - valid_pos = pos - valid_value = trie.value - else: - break # we can't go any farther - - if valid_pos == -1: - break - else: - token_ids.append(valid_value) - token_byte_positions.append(valid_pos) - pos = valid_pos - - return token_ids,token_byte_positions - - def _cleanup_tokens(self, token_ids, token_byte_positions): - - # compute a joint tokenization - joint_token_ids = self._joint_tokenize(token_ids) - - # see if we need to redo the tokenization - redo = False - if len(joint_token_ids) != len(token_ids): - redo = True - else: - for i,id in enumerate(joint_token_ids): - if token_ids[i] != id: - redo = True - break - - if redo: - token_ids = joint_token_ids - last_pos = token_byte_positions[-1] - token_byte_positions = [] - pos = 0 - for i,id in enumerate(joint_token_ids): - pos += len(self.tokens[id]) - token_byte_positions.append(pos) - assert token_byte_positions[-1] == last_pos - - return token_ids, token_byte_positions - - - def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensure_bos_token=True, log_probs=False): - assert n == 1, "Still need to add support for n > 1!" - - # get our current context in bytes - prompt = self._current_prompt() - prompt = bytes(prompt, encoding="utf8") - - # add the beginning of sequence token if needed - if ensure_bos_token and self.bos_token is not None and not prompt.startswith(self.bos_token): - prompt = self.bos_token + prompt - - # run a simple tokenizer (that does not use a grammar) on the prefix for better performance - token_ids,token_byte_positions = self._tokenize_prefix(prompt) - token_ids,token_byte_positions = self._cleanup_tokens(token_ids,token_byte_positions) - if len(token_byte_positions) > 0: - pre_parser_bytes = token_byte_positions[-1] - prompt = prompt[token_byte_positions[-1]:] - else: - pre_parser_bytes = 0 - - # create a parser with a grammar that includes both our context and the passed grammar - parser = EarleyCommitParser(prompt + grammar) - - # loop until we have generated a complete pattern - hidden_count = len(prompt) # we don't emit the prompt - generated_pos = 0 - sampled_token_ind = None - token_count = 0 - last_token_count = 0 - was_forced = False - captured_data = {} - captured_log_prob_data = {} - while True: # each iteration generates one more token (and some of the associated bytes) - - # enforce the token limit - if token_count >= max_tokens: - break - - # note where we are starting for this token - start_pos = parser.pos - - # let the parser know that we have advanced another token (used ofr tracking max token limits) - parser.mark_new_token() - - # walk down the trie as far as possible before computing the logits - retry_token_gen = False - trie = self._token_trie - trie.match_version += 1 # this invalidates all the match caches from the previous token - while True: - next_byte_mask = parser.next_byte_mask() - next_byte_mask_sum = next_byte_mask.sum() - - # see if we reached a dead end of the grammar - if next_byte_mask_sum == 0: - break - - # if there is more than one option we cannot advance without computing the logits - elif next_byte_mask_sum != 1: - break - - # we are not forced if we are at the end of the grammar - elif parser.matched(): - break - - # if there is only one possible next byte we can keep forcing - elif next_byte_mask_sum == 1: - - # look for valid children - next_byte = None - for byte in trie.children: - - # mark this trie node with an up-to-date match flag (may save work later) - node = trie.children[byte] - node.match_version = self._token_trie.match_version - node.match = next_byte_mask[byte[0]] - - # see if we found a match - if node.match: - next_byte = byte - break - - # if we can't extend then this token is forced - if next_byte is None: - break - - # otherwise since there is only one possible next byte we keep going - else: - commit_point = parser.consume_byte(next_byte, log_prob=0.0) - - # if we are at a hidden commit point then we need to hide the bytes that match that node - if commit_point is not None and commit_point.node.hidden: - - # This takes the item and commits to it as part of the parse and then shrinks it to zero width - # in other words this hides the item - parser.commit_and_collapse_item(commit_point) - - # keep the bytes we still need to emit - if start_pos < commit_point.start: - parser.shadow_rewind(start_pos) - - else: - # pop off any tokens that overlap the hidden bytes - i = len(token_byte_positions) - 1 - while i >= 0 and token_byte_positions[i] - pre_parser_bytes > commit_point.start: - token_ids.pop() - token_byte_positions.pop() - token_count -= 1 - i -= 1 - # re-add any bytes we cut too far on - parser.shadow_rewind(token_byte_positions[-1] - pre_parser_bytes) - retry_token_gen = True # this restarts us at the top of the outer token gen loop - break - - trie = trie.children[next_byte] - - forced_pos = parser.pos # record how far the bytes are forced - - if retry_token_gen: - continue - - # back up if we got forced up to a point that is not a valid token - if next_byte_mask_sum <= 1: - while trie.value is None and trie.parent is not None: - trie = trie.parent - forced_pos -= 1 - parser.pos = forced_pos - - # if we walked all the way to a forced token then we advance without computing the logits - # we are forced if there are no more options and we are either in the middle of the grammar or at a trie leaf - is_forced = next_byte_mask_sum <= 1 and (len(trie.children) == 0 if parser.matched() else trie != self._token_trie) - if is_forced: - sampled_token_ind = trie.value - sampled_token = self.tokens[sampled_token_ind] - new_bytes_log_prob = 0.0 - was_forced = True - - # we are at the end of the grammar - elif next_byte_mask_sum == 0: - token_pos = 0 - - # mark the token we "sampled" if we have comsumed some bytes - if trie != self._token_trie: - sampled_token_ind = trie.value - sampled_token = self.tokens[sampled_token_ind] - new_bytes_log_prob = 0.0 - - # otherwise we need to compute the logits and sample a valid token - else: - - # if we were forced we might need to clean up the greedy tokenization to match the global tokenization behavior as seen in training - if was_forced: - token_ids,token_byte_positions = self._cleanup_tokens(token_ids, token_byte_positions) - was_forced = False - logits = self._get_logits(token_ids, parser.bytes[start_pos:forced_pos]) - - # if requested we compute the log probabilities so we can track the probabilities of each node - # TODO: we should lower this step to C++ with pybind11 - if log_probs: - _compute_log_probs(trie, log_softmax(logits, axis=-1)) - - # get the sampling order - grammar_temp = parser.next_byte_temperature() - current_temp = grammar_temp if grammar_temp >= 0 else temperature # we prefer to use the grammar temp when it is specified - if current_temp == 0: - sampling_order = np.argsort(-logits) # we need numpy so the enumerate below does not get really slow... - else: - assert top_p == 1, "Still need to add support for top_p!" - probs = softmax(logits / current_temp, axis=-1) - probs += 1e-10 # ensure we have no zero probs that mess up numpy - probs /= np.sum(probs) - sampling_order = np.random.choice(len(probs), size=len(probs), p=probs, replace=False) # the 1e-10 is ensure we have no zero probs, which numpy does not like - - # loop over the tokens looking for a valid one - for i,sampled_token_ind in enumerate(sampling_order): - sampled_token = self.tokens[sampled_token_ind] - - # make sure the parse is backed up to the position we want to start checking from TODO: make this account for shared prefixes with the last token - parser.pos = forced_pos - new_bytes_log_prob = 0.0 - - # make sure it matches any forced prefix - if start_pos < forced_pos and not sampled_token.startswith(parser.bytes[start_pos:forced_pos]): - continue - offset = forced_pos - start_pos - - # check to see if the sampled token is allowed - token_pos = offset - node = trie # this is the Trie node we were left at when we could force the next byte above - - while token_pos < len(sampled_token): - next_byte = sampled_token[token_pos:token_pos+1] - next_node = node.children[next_byte] - - # if we don't have a cached match flag compute it using the grammar - if next_node.match_version < self._token_trie.match_version: - next_byte_mask = parser.next_byte_mask() - for byte in node.children: # we update all the children since the parser knows the full mask - child = node.children[byte] - child.match_version = self._token_trie.match_version - child.match = next_byte_mask[byte[0]] - - # advance or fail according to the (now up-to-date) match cache - if next_node.match: - - # mark that we accepted this byte - node = next_node - token_pos += 1 - - # get the parser to consume the next byte - log_prob_delta = next_node.log_prob - node.log_prob - new_bytes_log_prob += log_prob_delta - commit_point = parser.consume_byte(next_byte, log_prob=log_prob_delta) - - # if we are at a hidden commit point then we need to hide the bytes that match that node - if commit_point is not None and commit_point.node.hidden: - - # if we are capturing the data from this node we need to do that now since we are about to remove it - # TODO: build a whole parse tree under this commit_point node so we can record child node captures - if commit_point.node.capture_name: - captured_data[commit_point.node.capture_name] = parser.bytes[commit_point.start:] - - # This takes the item and commits to it as part of the parse and then shrinks it to zero width - # in other words this hides the item - parser.commit_and_collapse_item(commit_point) - - # keep the bytes we still need to emit - if forced_pos < commit_point.start: - parser.shadow_rewind(forced_pos) - - else: - # pop off any tokens that overlap the hidden bytes - i = len(token_byte_positions) - 1 - while i >= 0 and token_byte_positions[i] - pre_parser_bytes > commit_point.start: - token_ids.pop() - token_byte_positions.pop() - token_count -= 1 - i -= 1 - # re-add any bytes we cut too far on - parser.shadow_rewind(token_byte_positions[-1] - pre_parser_bytes) - retry_token_gen = True # this restarts us at the top of the outer token gen loop - break - - elif token_pos == len(sampled_token): - break # this token is valid - else: - # partially valid tokens are okay if we are running off the end of a grammar, but not otherwise - if not parser.matched(): - token_pos = -1 - - break # this token is no longer valid - - # see if we are breaking out of the whole loop - if retry_token_gen: - break - - # check if this token is dominated by other longer valid tokens (and hence would never be consistent with greedy tokenization) - # TODO: disabled for now because of sentencepeice non-local issues - # if token_pos == len(sampled_token) and not parser.matched(): # not we don't check if we have matched, because then we can generate anything afterwards - # if _check_dominated(node, parser, self._token_trie.match_version, parser.next_byte_mask()): - # token_pos = -1 - - if token_pos > 0: - break # we found a valid token - - if parser.matched(): - break # if we already have a full match we don't try more tokens we just give up as soon as the model deviates from the grammar - - # if we just collpased a hidden commit point then we start over looking for a new token - if retry_token_gen: - continue - - # emit whatever we know will not be hidden - new_bytes = parser.bytes[generated_pos:parser.earliest_hidden_start()] - - # if we cannot consume any more tokens then we are done - if not is_forced and token_pos < len(sampled_token) and trie == self._token_trie: - assert parser.matched(), "We can't consume any more tokens, but we are not yet done! Perhaps your model's token set is incomplete?" - - # TODO: if we exactly match the end of the pattern then we can commit to this last token - # if m.span()[1] == len(generated_text): - # self._cache_state["new_token_ids"].append(sampled_token_ind) - - # capture the named groups from the parse tree - parse_tree = parser.parse_tree() - _record_captures(parse_tree, captured_data, captured_log_prob_data, parser.bytes) - - # we have no valid log prob data if we didn't compute it - if not log_probs: - captured_log_prob_data = {k: None for k in captured_data} - yield new_bytes[hidden_count:], not is_forced, new_bytes_log_prob, captured_data, captured_log_prob_data, token_count - last_token_count - last_token_count = token_count - break # we are done! - else: - generated_pos += len(new_bytes) - - # yeild the snippet of text created by the next token - out = new_bytes[hidden_count:] - if len(out) > 0: - yield out, not is_forced, new_bytes_log_prob, {}, {}, token_count - last_token_count # note that we don't capture groups until a complete parse right now... - last_token_count = token_count - hidden_count = 0 - token_count += 1 # note we only update this for tokens that emit non-hidden content - else: - hidden_count -= len(new_bytes) - - token_ids.append(sampled_token_ind) - - # track the byte position of each token - if len(token_byte_positions) == 0: - token_byte_positions.append(len(sampled_token)) - else: - token_byte_positions.append(token_byte_positions[-1] + len(sampled_token)) - -def _record_captures(initial_item, data, log_prob_data, byte_data): - stack = [(initial_item, 0)] - used_names = set() # track which capture names have been used so self-recursive children don't overwrite their parents - - while stack: - item, byte_pos = stack.pop() - # terminal nodes - if isinstance(item, Terminal): - - # if we are at a capture group node then we save the matched terminal byte - if item.capture_name is not None: - data[item.capture_name] = item.byte - log_prob_data[item.capture_name] = 0 - - # internal nodes - else: - start_byte_pos = byte_pos - - # recurse for all our non-null children - for child in item.children: - if child is not None: - stack.append((child, byte_pos)) - # _record_captures(child, data, log_prob_data, byte_data, byte_pos) - if isinstance(child, Terminal): - byte_pos += len(child) - else: - byte_pos = child.start # note that "start" means "end" since this is a reversed state set - - # if we are at a capture group node then we save the matched bytes range - # note that we record this after calling our children so that we save the outermost version of self-recursive calls - cname = item.node.capture_name - if cname is not None and cname not in used_names and not item.node.hidden: - - # see if we are doing a list append - if cname.startswith("__LIST_APPEND:"): - cname = cname[14:] # trim off the list append tag - if cname not in data or not isinstance(data[cname], list): - data[cname] = [] - log_prob_data[cname] = [] - data[cname].append(byte_data[start_byte_pos:item.start]) - log_prob_data[cname].append(item.log_prob) - - # or just a regular assignment - else: - data[cname] = byte_data[start_byte_pos:item.start] # note that "start" means "end" since this is a reversed state set - log_prob_data[cname] = item.log_prob - - used_names.add(cname) - -def _compute_log_probs(trie, log_probs): - '''Computes the log probabilities for each internal trie node.''' - if trie.value is not None: - trie.log_prob += log_probs[trie.value] - - if len(trie.children) > 0: - child_log_probs = [] - for b in trie.children: - child = trie.children[b] - _compute_log_probs(child, log_probs) - child_log_probs.append(child.log_prob) - trie.log_prob = np.logaddexp.reduce(child_log_probs) - -def _check_dominated(node, parser, match_version, next_byte_mask): - curr_pos = parser.pos - for byte_num in next_byte_mask.nonzero()[0]: - next_byte = bytes((byte_num,)) - if next_byte not in node.children: - return False # no possible exension this direction, so we are not dominated - child = node.children[next_byte] - if child.match_version < match_version: - child.match_version = match_version - child.match = next_byte_mask[next_byte[0]] - - if not child.match: - return False # this child does not dominate the node, so the node is not dominated - elif child.value is None: # this child might not dominate the node - parser.consume_byte(next_byte, log_prob=0.0) - child_dominate = _check_dominated(child, parser, match_version, parser.next_byte_mask()) - parser.pos = curr_pos - if not child_dominate: - return False - return True diff --git a/guidance/models/_local_mock.py b/guidance/models/_mock.py similarity index 95% rename from guidance/models/_local_mock.py rename to guidance/models/_mock.py index 5d8790e12..b0553794e 100644 --- a/guidance/models/_local_mock.py +++ b/guidance/models/_mock.py @@ -1,10 +1,9 @@ import numpy as np -from ._model import Chat -from ._local import Local +from ._model import Model, Chat -class LocalMock(Local): +class Mock(Model): def __init__(self, byte_patterns=[], echo=True): super().__init__( @@ -67,6 +66,6 @@ def _get_next_tokens(self, byte_string): yield i -class LocalMockChat(LocalMock, Chat): +class MockChat(Mock, Chat): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/guidance/models/_model.py b/guidance/models/_model.py index 87f0e47e2..68eae1552 100644 --- a/guidance/models/_model.py +++ b/guidance/models/_model.py @@ -9,7 +9,9 @@ import copy import time import numpy as np -from .._grammar import StatelessFunction, string, _call_pool, _tag_pattern, Null, replace_model_variables, unreplace_model_variables, select +from .._utils import ByteTrie, log_softmax, softmax +from .._parser import EarleyCommitParser +from .._grammar import StatelessFunction, string, _call_pool, _tag_pattern, Null, replace_model_variables, unreplace_model_variables, select, Terminal # define some constants we will reuse many times _null_grammar = string('') @@ -29,19 +31,32 @@ class Model: _grammar_only = 0 # a flag that tracks when we are forced to be executing only compiled grammars (like when we are inside a select) _throttle_refresh = 0 # a flag that tracks when we can throttle our display since we know future display calls are going to happen - def __init__(self, echo=True): - '''Build a new model state object. + def __init__(self, tokens, bos_token_id=None, eos_token_id=None, echo=True): + '''Build a new model object that represents a model in a given state. Parameters ---------- + tokens : list + This is a list of all the tokens in byte-string form. The index of the token in the list is the token's id. + bos_token_id : int + The index of the special beginning-of-sequence token (if used for this model). + eos_token_id : int + The index of the special end-of-sequence token (if used for this model). echo : bool If true the final result of creating this model state will be displayed (as HTML in a notebook). ''' + assert isinstance(tokens[0], bytes), "The tokens need to be provided as bytes!" + self.echo = echo self.token_count = 0 # tracks how many tokens our byte state represents self.max_display_rate = 0.2 # this controls how frequently we are allowed to redraw the display (in seconds) self.opened_blocks = {} # what context blocks have been opened but not closed + self.tokens = tokens # the token byte strings indexed by their token id + self.bos_token_id = bos_token_id + self.bos_token = None if self.bos_token_id is None else self.tokens[self.bos_token_id] + self.eos_token_id = eos_token_id if eos_token_id is not None else bos_token_id + self.eos_token = None if self.eos_token_id is None else self.tokens[self.eos_token_id] # private attributes self._variables = {} # these are the state variables stored with the model @@ -51,9 +66,10 @@ def __init__(self, echo=True): self._event_parent = None self._last_display = 0 # used to track the last display call to enable throttling - def __call__(self, grammar=None, max_tokens=100, n=1, top_p=1, temperature=0.0, ensure_bos_token=True): - # TODO: turn this into "append" and make the models keep the grammar and parse as current state - pass # meant to be overriden by subclasses + # build a prefix tree of the tokens + self._token_trie = ByteTrie(tokens, np.arange(len(tokens))) + self._token_trie.match = True + self._token_trie.match_version = 0 @property def default_end_patterns(self): @@ -355,7 +371,7 @@ def tool_def(self, functions): return self - def _run_stateless(lm, stateless_function, max_tokens=2000, temperature=0.0, top_p=1.0, n=1): + def _run_stateless(lm, stateless_function, max_tokens=1000, temperature=0.0, top_p=1.0, n=1): assert Model._grammar_only == 0, "We can't run grammar parsing while in context free mode! (for example inside a block closer)" # This needs to be here for streaming @@ -436,6 +452,397 @@ def _run_stateless(lm, stateless_function, max_tokens=2000, temperature=0.0, top unreplace_model_variables(replacements) return lm + + def _get_logits(self, token_ids, forced_bytes): + '''A fake method designed to be overriden by subclasses.''' + + # pretend to extend the KV cache and update the log probs + return np.randn(len(self.tokens)) + + def _joint_tokenize(self, token_ids): + # an abstract method. Should return what a full joint tokenizer would give for a given byte string + return token_ids + + def _tokenize_prefix(self, byte_string): + '''This is used to speed up the tokenization of long prompts without using the parser.''' + token_ids = [] + token_byte_positions = [] + + # loop trying to decode a new token at each iteration + pos = 0 + while True: + + # walk down the token trie looking for a unique token match + trie = self._token_trie + valid_pos = -1 + valid_value = -1 + while True: + if pos >= len(byte_string): + if len(trie.children) > 0: + valid_pos = -1 + break + + # check if we can keep going or are at a dead end + if byte_string[pos:pos+1] in trie.children: + trie = trie.children[byte_string[pos:pos+1]] + pos += 1 + + # record the last valid token down this path as we go + if trie.value is not None: + valid_pos = pos + valid_value = trie.value + else: + break # we can't go any farther + + if valid_pos == -1: + break + else: + token_ids.append(valid_value) + token_byte_positions.append(valid_pos) + pos = valid_pos + + return token_ids,token_byte_positions + + def _cleanup_tokens(self, token_ids, token_byte_positions): + + # compute a joint tokenization + joint_token_ids = self._joint_tokenize(token_ids) + + # see if we need to redo the tokenization + redo = False + if len(joint_token_ids) != len(token_ids): + redo = True + else: + for i,id in enumerate(joint_token_ids): + if token_ids[i] != id: + redo = True + break + + if redo: + token_ids = joint_token_ids + last_pos = token_byte_positions[-1] + token_byte_positions = [] + pos = 0 + for i,id in enumerate(joint_token_ids): + pos += len(self.tokens[id]) + token_byte_positions.append(pos) + assert token_byte_positions[-1] == last_pos + + return token_ids, token_byte_positions + + + def __call__(self, grammar, max_tokens=100, n=1, top_p=1, temperature=0.0, ensure_bos_token=True, log_probs=False): + assert n == 1, "Still need to add support for n > 1!" + + # get our current context in bytes + prompt = self._current_prompt() + prompt = bytes(prompt, encoding="utf8") + + # add the beginning of sequence token if needed + if ensure_bos_token and self.bos_token is not None and not prompt.startswith(self.bos_token): + prompt = self.bos_token + prompt + + # run a simple tokenizer (that does not use a grammar) on the prefix for better performance + token_ids,token_byte_positions = self._tokenize_prefix(prompt) + token_ids,token_byte_positions = self._cleanup_tokens(token_ids,token_byte_positions) + if len(token_byte_positions) > 0: + pre_parser_bytes = token_byte_positions[-1] + prompt = prompt[token_byte_positions[-1]:] + else: + pre_parser_bytes = 0 + + # create a parser with a grammar that includes both our context and the passed grammar + parser = EarleyCommitParser(prompt + grammar) + + # loop until we have generated a complete pattern + hidden_count = len(prompt) # we don't emit the prompt + generated_pos = 0 + sampled_token_ind = None + token_count = 0 + last_token_count = 0 + was_forced = False + captured_data = {} + captured_log_prob_data = {} + while True: # each iteration generates one more token (and some of the associated bytes) + + # enforce the token limit + if token_count >= max_tokens: + break + + # note where we are starting for this token + start_pos = parser.pos + + # let the parser know that we have advanced another token (used ofr tracking max token limits) + parser.mark_new_token() + + # walk down the trie as far as possible before computing the logits + retry_token_gen = False + trie = self._token_trie + trie.match_version += 1 # this invalidates all the match caches from the previous token + while True: + next_byte_mask = parser.next_byte_mask() + next_byte_mask_sum = next_byte_mask.sum() + + # see if we reached a dead end of the grammar + if next_byte_mask_sum == 0: + break + + # if there is more than one option we cannot advance without computing the logits + elif next_byte_mask_sum != 1: + break + + # we are not forced if we are at the end of the grammar + elif parser.matched(): + break + + # if there is only one possible next byte we can keep forcing + elif next_byte_mask_sum == 1: + + # look for valid children + next_byte = None + for byte in trie.children: + + # mark this trie node with an up-to-date match flag (may save work later) + node = trie.children[byte] + node.match_version = self._token_trie.match_version + node.match = next_byte_mask[byte[0]] + + # see if we found a match + if node.match: + next_byte = byte + break + + # if we can't extend then this token is forced + if next_byte is None: + break + + # otherwise since there is only one possible next byte we keep going + else: + commit_point = parser.consume_byte(next_byte, log_prob=0.0) + + # if we are at a hidden commit point then we need to hide the bytes that match that node + if commit_point is not None and commit_point.node.hidden: + + # This takes the item and commits to it as part of the parse and then shrinks it to zero width + # in other words this hides the item + parser.commit_and_collapse_item(commit_point) + + # keep the bytes we still need to emit + if start_pos < commit_point.start: + parser.shadow_rewind(start_pos) + + else: + # pop off any tokens that overlap the hidden bytes + i = len(token_byte_positions) - 1 + while i >= 0 and token_byte_positions[i] - pre_parser_bytes > commit_point.start: + token_ids.pop() + token_byte_positions.pop() + token_count -= 1 + i -= 1 + # re-add any bytes we cut too far on + parser.shadow_rewind(token_byte_positions[-1] - pre_parser_bytes) + retry_token_gen = True # this restarts us at the top of the outer token gen loop + break + + trie = trie.children[next_byte] + + forced_pos = parser.pos # record how far the bytes are forced + + if retry_token_gen: + continue + + # back up if we got forced up to a point that is not a valid token + if next_byte_mask_sum <= 1: + while trie.value is None and trie.parent is not None: + trie = trie.parent + forced_pos -= 1 + parser.pos = forced_pos + + # if we walked all the way to a forced token then we advance without computing the logits + # we are forced if there are no more options and we are either in the middle of the grammar or at a trie leaf + is_forced = next_byte_mask_sum <= 1 and (len(trie.children) == 0 if parser.matched() else trie != self._token_trie) + if is_forced: + sampled_token_ind = trie.value + sampled_token = self.tokens[sampled_token_ind] + new_bytes_log_prob = 0.0 + was_forced = True + + # we are at the end of the grammar + elif next_byte_mask_sum == 0: + token_pos = 0 + + # mark the token we "sampled" if we have comsumed some bytes + if trie != self._token_trie: + sampled_token_ind = trie.value + sampled_token = self.tokens[sampled_token_ind] + new_bytes_log_prob = 0.0 + + # otherwise we need to compute the logits and sample a valid token + else: + + # if we were forced we might need to clean up the greedy tokenization to match the global tokenization behavior as seen in training + if was_forced: + token_ids,token_byte_positions = self._cleanup_tokens(token_ids, token_byte_positions) + was_forced = False + logits = self._get_logits(token_ids, parser.bytes[start_pos:forced_pos]) + + # if requested we compute the log probabilities so we can track the probabilities of each node + # TODO: we should lower this step to C++ with pybind11 + if log_probs: + _compute_log_probs(trie, log_softmax(logits, axis=-1)) + + # get the sampling order + grammar_temp = parser.next_byte_temperature() + current_temp = grammar_temp if grammar_temp >= 0 else temperature # we prefer to use the grammar temp when it is specified + if current_temp == 0: + sampling_order = np.argsort(-logits) # we need numpy so the enumerate below does not get really slow... + else: + assert top_p == 1, "Still need to add support for top_p!" + probs = softmax(logits / current_temp, axis=-1) + probs += 1e-10 # ensure we have no zero probs that mess up numpy + probs /= np.sum(probs) + sampling_order = np.random.choice(len(probs), size=len(probs), p=probs, replace=False) # the 1e-10 is ensure we have no zero probs, which numpy does not like + + # loop over the tokens looking for a valid one + for i,sampled_token_ind in enumerate(sampling_order): + sampled_token = self.tokens[sampled_token_ind] + + # make sure the parse is backed up to the position we want to start checking from TODO: make this account for shared prefixes with the last token + parser.pos = forced_pos + new_bytes_log_prob = 0.0 + + # make sure it matches any forced prefix + if start_pos < forced_pos and not sampled_token.startswith(parser.bytes[start_pos:forced_pos]): + continue + offset = forced_pos - start_pos + + # check to see if the sampled token is allowed + token_pos = offset + node = trie # this is the Trie node we were left at when we could force the next byte above + + while token_pos < len(sampled_token): + next_byte = sampled_token[token_pos:token_pos+1] + next_node = node.children[next_byte] + + # if we don't have a cached match flag compute it using the grammar + if next_node.match_version < self._token_trie.match_version: + next_byte_mask = parser.next_byte_mask() + for byte in node.children: # we update all the children since the parser knows the full mask + child = node.children[byte] + child.match_version = self._token_trie.match_version + child.match = next_byte_mask[byte[0]] + + # advance or fail according to the (now up-to-date) match cache + if next_node.match: + + # mark that we accepted this byte + node = next_node + token_pos += 1 + + # get the parser to consume the next byte + log_prob_delta = next_node.log_prob - node.log_prob + new_bytes_log_prob += log_prob_delta + commit_point = parser.consume_byte(next_byte, log_prob=log_prob_delta) + + # if we are at a hidden commit point then we need to hide the bytes that match that node + if commit_point is not None and commit_point.node.hidden: + + # if we are capturing the data from this node we need to do that now since we are about to remove it + # TODO: build a whole parse tree under this commit_point node so we can record child node captures + if commit_point.node.capture_name: + captured_data[commit_point.node.capture_name] = parser.bytes[commit_point.start:] + + # This takes the item and commits to it as part of the parse and then shrinks it to zero width + # in other words this hides the item + parser.commit_and_collapse_item(commit_point) + + # keep the bytes we still need to emit + if forced_pos < commit_point.start: + parser.shadow_rewind(forced_pos) + + else: + # pop off any tokens that overlap the hidden bytes + i = len(token_byte_positions) - 1 + while i >= 0 and token_byte_positions[i] - pre_parser_bytes > commit_point.start: + token_ids.pop() + token_byte_positions.pop() + token_count -= 1 + i -= 1 + # re-add any bytes we cut too far on + parser.shadow_rewind(token_byte_positions[-1] - pre_parser_bytes) + retry_token_gen = True # this restarts us at the top of the outer token gen loop + break + + elif token_pos == len(sampled_token): + break # this token is valid + else: + # partially valid tokens are okay if we are running off the end of a grammar, but not otherwise + if not parser.matched(): + token_pos = -1 + + break # this token is no longer valid + + # see if we are breaking out of the whole loop + if retry_token_gen: + break + + # check if this token is dominated by other longer valid tokens (and hence would never be consistent with greedy tokenization) + # TODO: disabled for now because of sentencepeice non-local issues + # if token_pos == len(sampled_token) and not parser.matched(): # not we don't check if we have matched, because then we can generate anything afterwards + # if _check_dominated(node, parser, self._token_trie.match_version, parser.next_byte_mask()): + # token_pos = -1 + + if token_pos > 0: + break # we found a valid token + + if parser.matched(): + break # if we already have a full match we don't try more tokens we just give up as soon as the model deviates from the grammar + + # if we just collpased a hidden commit point then we start over looking for a new token + if retry_token_gen: + continue + + # emit whatever we know will not be hidden + new_bytes = parser.bytes[generated_pos:parser.earliest_hidden_start()] + + # if we cannot consume any more tokens then we are done + if not is_forced and token_pos < len(sampled_token) and trie == self._token_trie: + assert parser.matched(), "We can't consume any more tokens, but we are not yet done! Perhaps your model's token set is incomplete?" + + # TODO: if we exactly match the end of the pattern then we can commit to this last token + # if m.span()[1] == len(generated_text): + # self._cache_state["new_token_ids"].append(sampled_token_ind) + + # capture the named groups from the parse tree + parse_tree = parser.parse_tree() + _record_captures(parse_tree, captured_data, captured_log_prob_data, parser.bytes) + + # we have no valid log prob data if we didn't compute it + if not log_probs: + captured_log_prob_data = {k: None for k in captured_data} + yield new_bytes[hidden_count:], not is_forced, new_bytes_log_prob, captured_data, captured_log_prob_data, token_count - last_token_count + last_token_count = token_count + break # we are done! + else: + generated_pos += len(new_bytes) + + # yeild the snippet of text created by the next token + out = new_bytes[hidden_count:] + if len(out) > 0: + yield out, not is_forced, new_bytes_log_prob, {}, {}, token_count - last_token_count # note that we don't capture groups until a complete parse right now... + last_token_count = token_count + hidden_count = 0 + token_count += 1 # note we only update this for tokens that emit non-hidden content + else: + hidden_count -= len(new_bytes) + + token_ids.append(sampled_token_ind) + + # track the byte position of each token + if len(token_byte_positions) == 0: + token_byte_positions.append(len(sampled_token)) + else: + token_byte_positions.append(token_byte_positions[-1] + len(sampled_token)) class Chat(Model): '''The base class for all chat-tuned models.''' @@ -497,4 +904,87 @@ def __exit__(self, exc_type, exc_value, traceback): def throttle_refresh(): '''Returns a context manager that allows the print statement to drop display calls above the throttle rate.''' - return ThrottleRefresh() \ No newline at end of file + return ThrottleRefresh() + +def _record_captures(initial_item, data, log_prob_data, byte_data): + stack = [(initial_item, 0)] + used_names = set() # track which capture names have been used so self-recursive children don't overwrite their parents + + while stack: + item, byte_pos = stack.pop() + # terminal nodes + if isinstance(item, Terminal): + + # if we are at a capture group node then we save the matched terminal byte + if item.capture_name is not None: + data[item.capture_name] = item.byte + log_prob_data[item.capture_name] = 0 + + # internal nodes + else: + start_byte_pos = byte_pos + + # recurse for all our non-null children + for child in item.children: + if child is not None: + stack.append((child, byte_pos)) + # _record_captures(child, data, log_prob_data, byte_data, byte_pos) + if isinstance(child, Terminal): + byte_pos += len(child) + else: + byte_pos = child.start # note that "start" means "end" since this is a reversed state set + + # if we are at a capture group node then we save the matched bytes range + # note that we record this after calling our children so that we save the outermost version of self-recursive calls + cname = item.node.capture_name + if cname is not None and cname not in used_names and not item.node.hidden: + + # see if we are doing a list append + if cname.startswith("__LIST_APPEND:"): + cname = cname[14:] # trim off the list append tag + if cname not in data or not isinstance(data[cname], list): + data[cname] = [] + log_prob_data[cname] = [] + data[cname].append(byte_data[start_byte_pos:item.start]) + log_prob_data[cname].append(item.log_prob) + + # or just a regular assignment + else: + data[cname] = byte_data[start_byte_pos:item.start] # note that "start" means "end" since this is a reversed state set + log_prob_data[cname] = item.log_prob + + used_names.add(cname) + +def _compute_log_probs(trie, log_probs): + '''Computes the log probabilities for each internal trie node.''' + if trie.value is not None: + trie.log_prob += log_probs[trie.value] + + if len(trie.children) > 0: + child_log_probs = [] + for b in trie.children: + child = trie.children[b] + _compute_log_probs(child, log_probs) + child_log_probs.append(child.log_prob) + trie.log_prob = np.logaddexp.reduce(child_log_probs) + +def _check_dominated(node, parser, match_version, next_byte_mask): + curr_pos = parser.pos + for byte_num in next_byte_mask.nonzero()[0]: + next_byte = bytes((byte_num,)) + if next_byte not in node.children: + return False # no possible exension this direction, so we are not dominated + child = node.children[next_byte] + if child.match_version < match_version: + child.match_version = match_version + child.match = next_byte_mask[next_byte[0]] + + if not child.match: + return False # this child does not dominate the node, so the node is not dominated + elif child.value is None: # this child might not dominate the node + parser.consume_byte(next_byte, log_prob=0.0) + child_dominate = _check_dominated(child, parser, match_version, parser.next_byte_mask()) + parser.pos = curr_pos + if not child_dominate: + return False + return True diff --git a/guidance/models/_remote.py b/guidance/models/_remote.py index c2df765c1..ad0b10d04 100644 --- a/guidance/models/_remote.py +++ b/guidance/models/_remote.py @@ -10,8 +10,7 @@ import tiktoken import re -from ._model import Chat, Instruct -from ._local import Local +from ._model import Model, Chat, Instruct # try: @@ -21,7 +20,7 @@ # except ImportError: # is_vertexai = False -class Remote(Local): +class Remote(Model): def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0.0, max_streaming_tokens=500, **kwargs): self.caching = caching self.temperature = temperature diff --git a/guidance/models/transformers/_transformers.py b/guidance/models/transformers/_transformers.py index d66bc6dc6..78b58a24a 100644 --- a/guidance/models/transformers/_transformers.py +++ b/guidance/models/transformers/_transformers.py @@ -9,11 +9,10 @@ except ImportError: pass -from .._model import Chat -from .._local import Local +from .._model import Model, Chat -class Transformers(Local): +class Transformers(Model): def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperature=0.0, device=None, **kwargs): # fill in default model value diff --git a/guidance/models/vertexai/_PaLM2.py b/guidance/models/vertexai/_PaLM2.py index 4eeb1edfb..b6c95e21a 100644 --- a/guidance/models/vertexai/_PaLM2.py +++ b/guidance/models/vertexai/_PaLM2.py @@ -11,15 +11,7 @@ from ._vertexai import VertexAIInstruct, VertexAIChat -# try: -# # TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead? -# import torch -# is_torch = True -# except ImportError: -# is_torch = False - try: - # TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead? from vertexai.preview.language_models import TextGenerationModel from vertexai.language_models import ChatModel, InputOutputTextPair is_vertexai = True diff --git a/guidance/models/vertexai/_vertexai.py b/guidance/models/vertexai/_vertexai.py index 44fb7198c..fc6e75015 100644 --- a/guidance/models/vertexai/_vertexai.py +++ b/guidance/models/vertexai/_vertexai.py @@ -1,5 +1,4 @@ from .._model import Chat, Instruct -# from ._local import Local from .._remote import Remote diff --git a/tests/library/test_any_char.py b/tests/library/test_any_char.py index 33eeb661c..e8912aafb 100644 --- a/tests/library/test_any_char.py +++ b/tests/library/test_any_char.py @@ -1,5 +1,5 @@ from guidance import models, any_char def test_single_char(): - model = models.LocalMock("abc") + model = models.Mock("abc") assert str(model + '' + any_char()) == "a" \ No newline at end of file diff --git a/tests/library/test_any_char_but.py b/tests/library/test_any_char_but.py index fefbbd75a..4607f42a2 100644 --- a/tests/library/test_any_char_but.py +++ b/tests/library/test_any_char_but.py @@ -1,12 +1,12 @@ from guidance import models, any_char_but def test_single_char(): - model = models.LocalMock("abc") + model = models.Mock("abc") assert str(model + '' + any_char_but('a')) != "a" assert str(model + '' + any_char_but('!')) == "a" def test_multi_char(): - model = models.LocalMock(["abc", "bbc"]) + model = models.Mock(["abc", "bbc"]) assert str(model + '' + any_char_but('ab')) not in ("a", "b") assert str(model + '' + any_char_but('a!')) == "b" assert str(model + '' + any_char_but('5b')) == "a" \ No newline at end of file diff --git a/tests/library/test_block.py b/tests/library/test_block.py index 3f8997173..1604b7bd1 100644 --- a/tests/library/test_block.py +++ b/tests/library/test_block.py @@ -1,26 +1,26 @@ from guidance import models, block, any_char def test_text_opener(): - model = models.LocalMock("open texta") + model = models.Mock("open texta") with block(opener="open text"): model += any_char() assert str(model) == "open texta" def test_text_closer(): - model = models.LocalMock("aclose text") + model = models.Mock("aclose text") model += "" with block(closer="close text"): model += any_char() assert str(model) == "aclose text" def test_grammar_opener(): - model = models.LocalMock("open texta") + model = models.Mock("open texta") with block(opener="open tex" + any_char()): model += any_char() assert str(model) == "open texta" def test_grammar_closer(): - model = models.LocalMock(["aclose text", "close text"]) + model = models.Mock(["aclose text", "close text"]) model += "" try: with block(closer=any_char() + "lose text"): diff --git a/tests/library/test_capture.py b/tests/library/test_capture.py index b11d5a0a9..92cd8ad91 100644 --- a/tests/library/test_capture.py +++ b/tests/library/test_capture.py @@ -2,12 +2,12 @@ from ..utils import get_model def test_capture(): - model = models.LocalMock() + model = models.Mock() model += 'This is' + capture(select(options=['bad', 'quite bad']), name="my_var") assert model["my_var"] in ["bad", "quite bad"] def test_capture_star(): - lm = models.LocalMock(b"1234233234") + lm = models.Mock(b"1234233234") grammar = capture(one_or_more(select(['1', '2'])), name='test') lm2 = lm + grammar assert lm2['test'] == '12' \ No newline at end of file diff --git a/tests/library/test_char_set.py b/tests/library/test_char_set.py index 8bfda679d..2bf34484b 100644 --- a/tests/library/test_char_set.py +++ b/tests/library/test_char_set.py @@ -1,14 +1,14 @@ from guidance import models, char_set def test_single_char(): - model = models.LocalMock("abc") + model = models.Mock("abc") assert str(model + '' + char_set("a")) == "a" assert str(model + '' + char_set("ab")) == "a" assert str(model + '' + char_set("ba")) == "a" assert str(model + '' + char_set("b")) == "b" def test_char_range(): - model = models.LocalMock("bac") + model = models.Mock("bac") assert str(model + '' + char_set("a-c")) == "b" assert str(model + '' + char_set("b-z")) == "b" assert str(model + '' + char_set("0-9")) != "b" diff --git a/tests/library/test_commit_point.py b/tests/library/test_commit_point.py index c4f5f3062..29d9a7dc9 100644 --- a/tests/library/test_commit_point.py +++ b/tests/library/test_commit_point.py @@ -2,7 +2,7 @@ from ..utils import get_model def test_hidden(): - model = models.LocalMock() + model = models.Mock() model += " one" + commit_point(" two", hidden=True) + " three" assert str(model) == " one three" diff --git a/tests/library/test_gen.py b/tests/library/test_gen.py index a72823377..3f6386f96 100644 --- a/tests/library/test_gen.py +++ b/tests/library/test_gen.py @@ -4,22 +4,22 @@ import re def test_basic(): - lm = models.LocalMock() + lm = models.Mock() lm += "Write a number: " + gen('text', max_tokens=3) assert len(lm["text"]) > 0 def test_stop_string(): - lm = models.LocalMock(b"Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10") + lm = models.Mock(b"Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10") lm += "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen('text', stop=", 9") assert lm["text"] == "8" def test_stop_char(): - lm = models.LocalMock(b"Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10") + lm = models.Mock(b"Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10") lm += "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen('text', stop=",") assert lm["text"] == "8" def test_save_stop(): - lm = models.LocalMock(b"Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10") + lm = models.Mock(b"Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10") lm += "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen('text', stop=",", save_stop_text='stop_text') assert lm["stop_text"] == "," @@ -44,7 +44,7 @@ def test_unicode2(): assert True def test_gsm8k(): - lm = models.LocalMock() + lm = models.Mock() lm + '''Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? Answer: ''' + gen(max_tokens=30) assert True @@ -63,60 +63,60 @@ def test_non_token_force(): assert len(str(lm)) == 6 def test_pattern_optional(): - lm = models.LocalMock(b"12342333") + lm = models.Mock(b"12342333") pattern = '.?233' lm2 = lm + '123' + gen(name='numbers', regex=pattern, max_tokens=10) assert lm2['numbers'] == '4233' - lm = models.LocalMock(b"1232333") + lm = models.Mock(b"1232333") pattern = '.?233' lm2 = lm + '123' + gen(name='numbers', regex=pattern, max_tokens=10) assert lm2['numbers'] == '233' pattern = r'(Scott is bad)?(\d+)?o' - lm = models.LocalMock(b"John was a little man full of things") + lm = models.Mock(b"John was a little man full of things") lm2 = lm + 'J' + gen(name='test', regex=pattern, max_tokens=30) assert lm2['test'] == 'o' def test_pattern_stops_when_fulfilled(): - lm = models.LocalMock(b"123abc") + lm = models.Mock(b"123abc") lm += gen(regex=r'\d+', max_tokens=10, name='test') assert lm['test'] == '123' def test_pattern_star(): - # lm = models.LocalMock(b"1234233234") # commented out because it is not a valid test + # lm = models.Mock(b"1234233234") # commented out because it is not a valid test # patterns = ['\d+233', '\d*233', '.+233', '.*233'] # for pattern in patterns: # lm2 = lm + '123' + gen(name='numbers', regex=pattern, max_tokens=10) # assert lm2['numbers'] == '4233' - lm = models.LocalMock(b"123233") + lm = models.Mock(b"123233") patterns = [r'\d*233','.*233'] for pattern in patterns: lm2 = lm + '123' + gen(name='numbers', regex=pattern, max_tokens=10) assert lm2['numbers'].startswith('233') pattern = '.*(\n|little)' - lm = models.LocalMock(b"John was a little") + lm = models.Mock(b"John was a little") lm2 = lm + 'J' + gen(name='test', regex=pattern, max_tokens=30) assert lm2['test'].startswith('ohn was a little') - lm = models.LocalMock(b"John was a litt\n") + lm = models.Mock(b"John was a litt\n") lm2 = lm + 'J' + gen(name='test', regex=pattern, max_tokens=30) assert lm2['test'].startswith('ohn was a litt\n') def test_stop_regex(): - lm = models.LocalMock(b"123a3233") + lm = models.Mock(b"123a3233") lm2 = lm + '123' + gen(name='test', stop_regex=r'\d233', max_tokens=10) assert lm2['test'] == 'a' - lm = models.LocalMock(b"123aegalera3233") + lm = models.Mock(b"123aegalera3233") lm2 = lm + '123' + gen(name='test', stop_regex=r'\d', max_tokens=30) assert lm2['test'] == 'aegalera' def test_stop_regex_star(): - lm = models.LocalMock(b"123a3233") + lm = models.Mock(b"123a3233") pattern = r'\d+233' lm2 = lm + '123' + gen(name='test', stop_regex=pattern, max_tokens=10) assert lm2['test'] == 'a' def test_empty_pattern(): pattern = r'(Scott is bad)?(\d+)?' - lm = models.LocalMock(b"J") + lm = models.Mock(b"J") lm2 = lm + 'J' + gen(name='test', regex=pattern, max_tokens=30) assert lm2['test'] == '' @@ -184,7 +184,7 @@ def test_long_prompt(): def test_list_append(): '''This tests is list append works across grammar appends.''' - lm = models.LocalMock(b"bababababa") + lm = models.Mock(b"bababababa") lm += "" for _ in range(3): lm += gen("my_list", list_append=True, stop="a") + "a" @@ -193,19 +193,19 @@ def test_list_append(): def test_list_append_in_grammar(): '''This tests is list append works within the same grammar.''' - lm = models.LocalMock(b"bababababa") + lm = models.Mock(b"bababababa") lm += "" lm += gen("my_list", list_append=True, stop="a") + "a" + gen("my_list", list_append=True, stop="a") + "a" + gen("my_list", list_append=True, stop="a") assert isinstance(lm['my_list'], list) assert len(lm['my_list']) == 3 def test_one_char_suffix_and_regex(): - model = models.LocalMock(b"this is\na test") + model = models.Mock(b"this is\na test") model += gen(regex=".*", suffix="\n", max_tokens=20) assert str(model) == "this is\n" def test_one_char_stop_and_regex(): - model = models.LocalMock(b"this is\na test") + model = models.Mock(b"this is\na test") model += gen(regex=".*", stop="\n", max_tokens=20) assert str(model) == "this is" diff --git a/tests/library/test_one_or_more.py b/tests/library/test_one_or_more.py index 9565ac42b..7ce71a487 100644 --- a/tests/library/test_one_or_more.py +++ b/tests/library/test_one_or_more.py @@ -1,13 +1,13 @@ from guidance import models, one_or_more, char_set def test_string(): - model = models.LocalMock("aaabc") + model = models.Mock("aaabc") assert str(model + '' + one_or_more("a")) == "aaa" def test_grammar(): - model = models.LocalMock("bac") + model = models.Mock("bac") assert str(model + '' + one_or_more(char_set("ab"))) == "ba" def test_at_least_one(): - model = models.LocalMock("cbac") + model = models.Mock("cbac") assert not str(model + '' + one_or_more(char_set("ab"))).startswith("c") \ No newline at end of file diff --git a/tests/library/test_silent.py b/tests/library/test_silent.py index 3a55ea616..a749dc679 100644 --- a/tests/library/test_silent.py +++ b/tests/library/test_silent.py @@ -2,7 +2,7 @@ from ..utils import get_model def test_basic(): - lm = models.LocalMock() + lm = models.Mock() lm += "Start text" with silent(): lm += "silent text" diff --git a/tests/models/test_local.py b/tests/models/test_local.py deleted file mode 100644 index 00b410cc1..000000000 --- a/tests/models/test_local.py +++ /dev/null @@ -1,9 +0,0 @@ -import guidance -from guidance import zero_or_more, byte_range -from ..utils import get_model - -def test_token_healing(): - '''Tests a bug where the space is incorrectly forced as token 220, while it should be not forced it might be extended''' - gpt2 = get_model("transformers:gpt2") - lm = gpt2 + ("This is a story of 10 or 5 or " + zero_or_more(byte_range(b'0', b'9'))) - assert len(lm) > len("This is a story of 10 or 5 or ") \ No newline at end of file diff --git a/tests/models/test_model.py b/tests/models/test_model.py index bab3a918d..41aa7aa6e 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1,5 +1,5 @@ import guidance -from guidance import select, models, gen +from guidance import select, models, gen, zero_or_more, byte_range from ..utils import get_model def test_fstring(): @@ -24,7 +24,7 @@ def test_token_count(): def test_call_embeddings(): '''This tests calls embedded in strings.''' - model = models.LocalMock() + model = models.Mock() @guidance(dedent=False) def bla(lm, bla): @@ -38,4 +38,10 @@ def ble(lm): let's do more stuff!!''' + gen(max_tokens=10) return lm - assert "{{G|" not in str(model + ble()) \ No newline at end of file + assert "{{G|" not in str(model + ble()) + +def test_token_healing(): + '''Tests a bug where the space is incorrectly forced as token 220, while it should be not forced it might be extended''' + gpt2 = get_model("transformers:gpt2") + lm = gpt2 + ("This is a story of 10 or 5 or " + zero_or_more(byte_range(b'0', b'9'))) + assert len(lm) > len("This is a story of 10 or 5 or ") \ No newline at end of file diff --git a/tests/test_grammar.py b/tests/test_grammar.py index 4c351bca3..953d20f92 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -2,7 +2,7 @@ from .utils import get_model def test_select_reset_pos(): - model = models.LocalMock() + model = models.Mock() model += 'This is' + select(options=['bad', 'quite bad']) assert str(model) in ["This isbad", "This isquite bad"] @@ -14,6 +14,6 @@ def test_select_simple(): def test_select_longer(): '''This tests to ensure that the grammar is extended greedily.''' - lm = models.LocalMock(b"Scott is a very nice man.") + lm = models.Mock(b"Scott is a very nice man.") lm += "Scott is a very " + select(name='text', options=['nice', 'nice man.']) assert lm["text"] == 'nice man.'