From f314f61f0aae4a549898bfc0e64fb8102d1c95d0 Mon Sep 17 00:00:00 2001 From: Scott Lundberg Date: Fri, 19 May 2023 12:45:58 -0700 Subject: [PATCH] Fixed problems with select. #36 #40 #35 --- guidance/__init__.py | 2 +- guidance/library/_select.py | 101 ++++++++++++++++++++++++----------- tests/library/test_select.py | 13 ++++- 3 files changed, 83 insertions(+), 33 deletions(-) diff --git a/guidance/__init__.py b/guidance/__init__.py index c157f2847..3f5e26612 100644 --- a/guidance/__init__.py +++ b/guidance/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.48" +__version__ = "0.0.49" import types import sys diff --git a/guidance/library/_select.py b/guidance/library/_select.py index f8baf83c9..ee718a493 100644 --- a/guidance/library/_select.py +++ b/guidance/library/_select.py @@ -42,7 +42,7 @@ async def select(variable_name="selected", options=None, logprobs=None, list_app for i,option in enumerate(options): token_map[option] = i - async def recursive_select(current_prefix): + async def recursive_select(current_prefix, allow_token_extension=True): """ This returns a dictionary of scores for each option (keyed by the option index). """ @@ -65,58 +65,97 @@ async def recursive_select(current_prefix): for i in range(len(current_prefix), min([len(o[0]) for o in extension_options])): if len(set([o[0][i] for o in extension_options])) > 1: break - current_prefix = extension_options[0][0][:i] + if i > 0: + current_prefix += extension_options[0][0][:i] + # extension_options = [(option[i:], index) for option,index in extension_options] # bias the logits towards valid options - tmp_prefix = (parser_prefix + current_prefix)[-50:] # this is so we get the right tokenization at the boundary - tmp_prefix_token_len = len(parser.program.llm.encode(tmp_prefix)) - logit_bias = {} + tmp_prefix = (parser_prefix[-50:] + current_prefix) # this is so we get the right tokenization at the boundary + tmp_prefix_tokens = parser.program.llm.encode(tmp_prefix) + logit_bias1 = {} # token extension biases + logit_bias2 = {} # next token biases for option,index in extension_options: - option_tokens = parser.program.llm.encode(tmp_prefix+option) - logit_bias[option_tokens[tmp_prefix_token_len-1]] = 50 - if len(option_tokens) > tmp_prefix_token_len: - logit_bias[option_tokens[tmp_prefix_token_len]] = 50 + option_tokens = parser.program.llm.encode(parser_prefix[-50:] + option) + + # if we extended the last token to a longer one + if option_tokens[len(tmp_prefix_tokens)-1] != tmp_prefix_tokens[-1]: + if allow_token_extension: # this is a valid extension only if we are not allowed to extend the token + logit_bias1[option_tokens[len(tmp_prefix_tokens)-1]] = 100 + + # if we did not extend the last token to a longer one we can bias the next token + else: + if len(option_tokens) > len(tmp_prefix_tokens): + logit_bias2[option_tokens[len(tmp_prefix_tokens)]] = 100 + + # logit_bias[option_tokens[len(tmp_prefix_tokens)-1]] = tmp_prefix_tokens[-1] + # if len(option_tokens) > len(tmp_prefix_tokens) and : + # logit_bias[option_tokens[len(tmp_prefix_tokens)]] = 100 + + # extend the prefix by extending the last token + if len(logit_bias1) > 0: + call_prefix = parser.program.llm.decode(tmp_prefix_tokens[:-1]) + logit_bias = logit_bias1 + last_token_str = parser.program.llm.decode(tmp_prefix_tokens[-1:]) + + # extend the prefix by adding the next token + else: + call_prefix = tmp_prefix + logit_bias = logit_bias2 + last_token_str = "" - # extend the prefix by one token using the model + # generate the token logprobs gen_obj = await parser.llm_session( - parser_prefix + current_prefix, + call_prefix, # TODO: perhaps we should allow passing of token ids directly? max_tokens=1, logit_bias=logit_bias, logprobs=10, - cache_seed=0 + cache_seed=0, + token_healing=False # we manage token boundary healing ourselves for this function ) logprobs_result = gen_obj["choices"][0]["logprobs"] top_logprobs = logprobs_result["top_logprobs"][0] - remove_prefix = len(logprobs_result.get("token_healing_prefix", "")) - # for each possible next token, see if it grows the prefix in a valid way for token_str,logprob in top_logprobs.items(): - if len(token_str[remove_prefix:]) > 0: - sub_logprobs = await recursive_select(current_prefix + token_str[remove_prefix:]) - for k in sub_logprobs: - logprobs_out[k] = sub_logprobs[k] + logprob + + # build our recursive call prefix + rec_prefix = call_prefix + token_str + if len(last_token_str) > 0: + rec_prefix = current_prefix[:-len(last_token_str)] + else: + rec_prefix = current_prefix + rec_prefix += token_str + + # if we did not extend the last token then we recurse while ignoring the possibility of extending the last token + if token_str == last_token_str: + sub_logprobs = await recursive_select(rec_prefix, allow_token_extension=False) + else: + sub_logprobs = await recursive_select(rec_prefix) + + # we add the logprob of this token to the logprob of the suffix + for k in sub_logprobs: + logprobs_out[k] = sub_logprobs[k] + logprob # if we did token healing and did not extend past our prefix we need to consider the next token # TODO: when returning all logprobs we need to consider all the options, which means we should # force the model to not token heal and see what would have happened then on the next token... - first_token_str = max(top_logprobs, key=top_logprobs.get) - if len(logprobs_result["top_logprobs"]) > 1 and len(first_token_str) == remove_prefix: - top_logprobs = logprobs_result["top_logprobs"][1] - for token_str,logprob in top_logprobs.items(): - sub_logprobs = await recursive_select(current_prefix + token_str) - for k in sub_logprobs: - - # compute the probability of a logical OR between the new extension and the previous possible ones - p1 = np.exp(logprobs_out[k]) - p2 = np.exp(sub_logprobs[k] + logprob) - or_prob = p1 + p2 - p1*p2 - logprobs_out[k] = np.log(or_prob) + # first_token_str = max(top_logprobs, key=top_logprobs.get) + # if len(logprobs_result["top_logprobs"]) > 1 and len(first_token_str) == remove_prefix: + # top_logprobs = logprobs_result["top_logprobs"][1] + # for token_str,logprob in top_logprobs.items(): + # sub_logprobs = await recursive_select(current_prefix + token_str) + # for k in sub_logprobs: + + # # compute the probability of a logical OR between the new extension and the previous possible ones + # p1 = np.exp(logprobs_out[k]) + # p2 = np.exp(sub_logprobs[k] + logprob) + # or_prob = p1 + p2 - p1*p2 + # logprobs_out[k] = np.log(or_prob) return logprobs_out # recursively compute the logprobs for each option - option_logprobs = await recursive_select("") + option_logprobs = await recursive_select("") selected_option = max(option_logprobs, key=option_logprobs.get) diff --git a/tests/library/test_select.py b/tests/library/test_select.py index a1a3fc55c..c027ffb6b 100644 --- a/tests/library/test_select.py +++ b/tests/library/test_select.py @@ -48,4 +48,15 @@ def test_select_list_append(): out = program(options=["Yes", "No"]) assert len(out["name"]) == 2 for v in out["name"]: - assert v in ["Yes", "No"] \ No newline at end of file + assert v in ["Yes", "No"] + +def test_select_odd_spacing(): + """ Test the behavior of `select` with list_append=True. + """ + + llm = get_openai_llm("text-curie-001") + prompt = guidance('''Is the following sentence offensive? Please answer with a single word, either "Yes", "No", or "Maybe". + Sentence: {{example}} + Answer: {{#select "answer" logprobs='logprobs'}} Yes{{or}} Nein{{or}} Maybe{{/select}}''', llm=llm) + prompt = prompt(example='I hate tacos.') + assert prompt["answer"] in [" Yes", " Nein", " Maybe"] \ No newline at end of file