Skip to content

Commit

Permalink
Fixed problems with select. #36 #40 #35
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed May 19, 2023
1 parent 3a743bd commit f314f61
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 33 deletions.
2 changes: 1 addition & 1 deletion guidance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.48"
__version__ = "0.0.49"

import types
import sys
Expand Down
101 changes: 70 additions & 31 deletions guidance/library/_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def select(variable_name="selected", options=None, logprobs=None, list_app
for i,option in enumerate(options):
token_map[option] = i

async def recursive_select(current_prefix):
async def recursive_select(current_prefix, allow_token_extension=True):
""" This returns a dictionary of scores for each option (keyed by the option index).
"""

Expand All @@ -65,58 +65,97 @@ async def recursive_select(current_prefix):
for i in range(len(current_prefix), min([len(o[0]) for o in extension_options])):
if len(set([o[0][i] for o in extension_options])) > 1:
break
current_prefix = extension_options[0][0][:i]
if i > 0:
current_prefix += extension_options[0][0][:i]
# extension_options = [(option[i:], index) for option,index in extension_options]

# bias the logits towards valid options
tmp_prefix = (parser_prefix + current_prefix)[-50:] # this is so we get the right tokenization at the boundary
tmp_prefix_token_len = len(parser.program.llm.encode(tmp_prefix))
logit_bias = {}
tmp_prefix = (parser_prefix[-50:] + current_prefix) # this is so we get the right tokenization at the boundary
tmp_prefix_tokens = parser.program.llm.encode(tmp_prefix)
logit_bias1 = {} # token extension biases
logit_bias2 = {} # next token biases
for option,index in extension_options:
option_tokens = parser.program.llm.encode(tmp_prefix+option)
logit_bias[option_tokens[tmp_prefix_token_len-1]] = 50
if len(option_tokens) > tmp_prefix_token_len:
logit_bias[option_tokens[tmp_prefix_token_len]] = 50
option_tokens = parser.program.llm.encode(parser_prefix[-50:] + option)

# if we extended the last token to a longer one
if option_tokens[len(tmp_prefix_tokens)-1] != tmp_prefix_tokens[-1]:
if allow_token_extension: # this is a valid extension only if we are not allowed to extend the token
logit_bias1[option_tokens[len(tmp_prefix_tokens)-1]] = 100

# if we did not extend the last token to a longer one we can bias the next token
else:
if len(option_tokens) > len(tmp_prefix_tokens):
logit_bias2[option_tokens[len(tmp_prefix_tokens)]] = 100

# logit_bias[option_tokens[len(tmp_prefix_tokens)-1]] = tmp_prefix_tokens[-1]
# if len(option_tokens) > len(tmp_prefix_tokens) and :
# logit_bias[option_tokens[len(tmp_prefix_tokens)]] = 100

# extend the prefix by extending the last token
if len(logit_bias1) > 0:
call_prefix = parser.program.llm.decode(tmp_prefix_tokens[:-1])
logit_bias = logit_bias1
last_token_str = parser.program.llm.decode(tmp_prefix_tokens[-1:])

# extend the prefix by adding the next token
else:
call_prefix = tmp_prefix
logit_bias = logit_bias2
last_token_str = ""

# extend the prefix by one token using the model
# generate the token logprobs
gen_obj = await parser.llm_session(
parser_prefix + current_prefix,
call_prefix, # TODO: perhaps we should allow passing of token ids directly?
max_tokens=1,
logit_bias=logit_bias,
logprobs=10,
cache_seed=0
cache_seed=0,
token_healing=False # we manage token boundary healing ourselves for this function
)
logprobs_result = gen_obj["choices"][0]["logprobs"]
top_logprobs = logprobs_result["top_logprobs"][0]

remove_prefix = len(logprobs_result.get("token_healing_prefix", ""))

# for each possible next token, see if it grows the prefix in a valid way
for token_str,logprob in top_logprobs.items():
if len(token_str[remove_prefix:]) > 0:
sub_logprobs = await recursive_select(current_prefix + token_str[remove_prefix:])
for k in sub_logprobs:
logprobs_out[k] = sub_logprobs[k] + logprob

# build our recursive call prefix
rec_prefix = call_prefix + token_str
if len(last_token_str) > 0:
rec_prefix = current_prefix[:-len(last_token_str)]
else:
rec_prefix = current_prefix
rec_prefix += token_str

# if we did not extend the last token then we recurse while ignoring the possibility of extending the last token
if token_str == last_token_str:
sub_logprobs = await recursive_select(rec_prefix, allow_token_extension=False)
else:
sub_logprobs = await recursive_select(rec_prefix)

# we add the logprob of this token to the logprob of the suffix
for k in sub_logprobs:
logprobs_out[k] = sub_logprobs[k] + logprob

# if we did token healing and did not extend past our prefix we need to consider the next token
# TODO: when returning all logprobs we need to consider all the options, which means we should
# force the model to not token heal and see what would have happened then on the next token...
first_token_str = max(top_logprobs, key=top_logprobs.get)
if len(logprobs_result["top_logprobs"]) > 1 and len(first_token_str) == remove_prefix:
top_logprobs = logprobs_result["top_logprobs"][1]
for token_str,logprob in top_logprobs.items():
sub_logprobs = await recursive_select(current_prefix + token_str)
for k in sub_logprobs:

# compute the probability of a logical OR between the new extension and the previous possible ones
p1 = np.exp(logprobs_out[k])
p2 = np.exp(sub_logprobs[k] + logprob)
or_prob = p1 + p2 - p1*p2
logprobs_out[k] = np.log(or_prob)
# first_token_str = max(top_logprobs, key=top_logprobs.get)
# if len(logprobs_result["top_logprobs"]) > 1 and len(first_token_str) == remove_prefix:
# top_logprobs = logprobs_result["top_logprobs"][1]
# for token_str,logprob in top_logprobs.items():
# sub_logprobs = await recursive_select(current_prefix + token_str)
# for k in sub_logprobs:

# # compute the probability of a logical OR between the new extension and the previous possible ones
# p1 = np.exp(logprobs_out[k])
# p2 = np.exp(sub_logprobs[k] + logprob)
# or_prob = p1 + p2 - p1*p2
# logprobs_out[k] = np.log(or_prob)

return logprobs_out

# recursively compute the logprobs for each option
option_logprobs = await recursive_select("")
option_logprobs = await recursive_select("")

selected_option = max(option_logprobs, key=option_logprobs.get)

Expand Down
13 changes: 12 additions & 1 deletion tests/library/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,15 @@ def test_select_list_append():
out = program(options=["Yes", "No"])
assert len(out["name"]) == 2
for v in out["name"]:
assert v in ["Yes", "No"]
assert v in ["Yes", "No"]

def test_select_odd_spacing():
""" Test the behavior of `select` with list_append=True.
"""

llm = get_openai_llm("text-curie-001")
prompt = guidance('''Is the following sentence offensive? Please answer with a single word, either "Yes", "No", or "Maybe".
Sentence: {{example}}
Answer: {{#select "answer" logprobs='logprobs'}} Yes{{or}} Nein{{or}} Maybe{{/select}}''', llm=llm)
prompt = prompt(example='I hate tacos.')
assert prompt["answer"] in [" Yes", " Nein", " Maybe"]

0 comments on commit f314f61

Please sign in to comment.