Skip to content

Commit

Permalink
Merge Local into Model
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed Nov 29, 2023
1 parent 4ea5d48 commit d1bbce1
Show file tree
Hide file tree
Showing 21 changed files with 556 additions and 580 deletions.
2 changes: 1 addition & 1 deletion guidance/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from ._openai import OpenAI, OpenAIChat, OpenAIInstruct, OpenAICompletion
from .transformers._transformers import Transformers, TransformersChat
from ._llama_cpp import LlamaCpp, LlamaCppChat
from ._local_mock import LocalMock, LocalMockChat
from ._mock import Mock, MockChat
from . import transformers
5 changes: 2 additions & 3 deletions guidance/models/_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

import numpy as np

from ._model import Chat
from ._local import Local
from ._model import Model, Chat
from .._utils import normalize_notebook_stdout_stderr

try:
Expand All @@ -15,7 +14,7 @@
except ImportError:
is_llama_cpp = False

class LlamaCpp(Local):
class LlamaCpp(Model):
def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperature=0.0, **kwargs):

if not is_llama_cpp:
Expand Down
498 changes: 0 additions & 498 deletions guidance/models/_local.py

This file was deleted.

7 changes: 3 additions & 4 deletions guidance/models/_local_mock.py → guidance/models/_mock.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np

from ._model import Chat
from ._local import Local
from ._model import Model, Chat


class LocalMock(Local):
class Mock(Model):
def __init__(self, byte_patterns=[], echo=True):

super().__init__(
Expand Down Expand Up @@ -67,6 +66,6 @@ def _get_next_tokens(self, byte_string):
yield i


class LocalMockChat(LocalMock, Chat):
class MockChat(Mock, Chat):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
506 changes: 498 additions & 8 deletions guidance/models/_model.py

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions guidance/models/_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import tiktoken
import re

from ._model import Chat, Instruct
from ._local import Local
from ._model import Model, Chat, Instruct


# try:
Expand All @@ -21,7 +20,7 @@
# except ImportError:
# is_vertexai = False

class Remote(Local):
class Remote(Model):
def __init__(self, model, tokenizer=None, echo=True, caching=True, temperature=0.0, max_streaming_tokens=500, **kwargs):
self.caching = caching
self.temperature = temperature
Expand Down
5 changes: 2 additions & 3 deletions guidance/models/transformers/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
except ImportError:
pass

from .._model import Chat
from .._local import Local
from .._model import Model, Chat


class Transformers(Local):
class Transformers(Model):
def __init__(self, model=None, tokenizer=None, echo=True, caching=True, temperature=0.0, device=None, **kwargs):

# fill in default model value
Expand Down
8 changes: 0 additions & 8 deletions guidance/models/vertexai/_PaLM2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@

from ._vertexai import VertexAIInstruct, VertexAIChat

# try:
# # TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead?
# import torch
# is_torch = True
# except ImportError:
# is_torch = False

try:
# TODO: can we eliminate the torch requirement for llama.cpp by using numpy in the caller instead?
from vertexai.preview.language_models import TextGenerationModel
from vertexai.language_models import ChatModel, InputOutputTextPair
is_vertexai = True
Expand Down
1 change: 0 additions & 1 deletion guidance/models/vertexai/_vertexai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .._model import Chat, Instruct
# from ._local import Local
from .._remote import Remote


Expand Down
2 changes: 1 addition & 1 deletion tests/library/test_any_char.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from guidance import models, any_char

def test_single_char():
model = models.LocalMock("<s>abc")
model = models.Mock("<s>abc")
assert str(model + '<s>' + any_char()) == "<s>a"
4 changes: 2 additions & 2 deletions tests/library/test_any_char_but.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from guidance import models, any_char_but

def test_single_char():
model = models.LocalMock("<s>abc")
model = models.Mock("<s>abc")
assert str(model + '<s>' + any_char_but('a')) != "<s>a"
assert str(model + '<s>' + any_char_but('!')) == "<s>a"

def test_multi_char():
model = models.LocalMock(["<s>abc", "<s>bbc"])
model = models.Mock(["<s>abc", "<s>bbc"])
assert str(model + '<s>' + any_char_but('ab')) not in ("<s>a", "<s>b")
assert str(model + '<s>' + any_char_but('a!')) == "<s>b"
assert str(model + '<s>' + any_char_but('5b')) == "<s>a"
8 changes: 4 additions & 4 deletions tests/library/test_block.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from guidance import models, block, any_char

def test_text_opener():
model = models.LocalMock("<s>open texta")
model = models.Mock("<s>open texta")
with block(opener="open text"):
model += any_char()
assert str(model) == "open texta"

def test_text_closer():
model = models.LocalMock("<s>aclose text")
model = models.Mock("<s>aclose text")
model += "<s>"
with block(closer="close text"):
model += any_char()
assert str(model) == "<s>aclose text"

def test_grammar_opener():
model = models.LocalMock("<s>open texta")
model = models.Mock("<s>open texta")
with block(opener="open tex" + any_char()):
model += any_char()
assert str(model) == "open texta"

def test_grammar_closer():
model = models.LocalMock(["<s>aclose text", "<s>close text"])
model = models.Mock(["<s>aclose text", "<s>close text"])
model += "<s>"
try:
with block(closer=any_char() + "lose text"):
Expand Down
4 changes: 2 additions & 2 deletions tests/library/test_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from ..utils import get_model

def test_capture():
model = models.LocalMock()
model = models.Mock()
model += 'This is' + capture(select(options=['bad', 'quite bad']), name="my_var")
assert model["my_var"] in ["bad", "quite bad"]

def test_capture_star():
lm = models.LocalMock(b"<s>1234233234<s>")
lm = models.Mock(b"<s>1234233234<s>")
grammar = capture(one_or_more(select(['1', '2'])), name='test')
lm2 = lm + grammar
assert lm2['test'] == '12'
4 changes: 2 additions & 2 deletions tests/library/test_char_set.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from guidance import models, char_set

def test_single_char():
model = models.LocalMock("<s>abc")
model = models.Mock("<s>abc")
assert str(model + '<s>' + char_set("a")) == "<s>a"
assert str(model + '<s>' + char_set("ab")) == "<s>a"
assert str(model + '<s>' + char_set("ba")) == "<s>a"
assert str(model + '<s>' + char_set("b")) == "<s>b"

def test_char_range():
model = models.LocalMock("<s>bac")
model = models.Mock("<s>bac")
assert str(model + '<s>' + char_set("a-c")) == "<s>b"
assert str(model + '<s>' + char_set("b-z")) == "<s>b"
assert str(model + '<s>' + char_set("0-9")) != "<s>b"
Expand Down
2 changes: 1 addition & 1 deletion tests/library/test_commit_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..utils import get_model

def test_hidden():
model = models.LocalMock()
model = models.Mock()
model += " one" + commit_point(" two", hidden=True) + " three"
assert str(model) == " one three"

Expand Down
42 changes: 21 additions & 21 deletions tests/library/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,22 @@
import re

def test_basic():
lm = models.LocalMock()
lm = models.Mock()
lm += "Write a number: " + gen('text', max_tokens=3)
assert len(lm["text"]) > 0

def test_stop_string():
lm = models.LocalMock(b"<s>Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
lm = models.Mock(b"<s>Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
lm += "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen('text', stop=", 9")
assert lm["text"] == "8"

def test_stop_char():
lm = models.LocalMock(b"<s>Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
lm = models.Mock(b"<s>Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
lm += "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen('text', stop=",")
assert lm["text"] == "8"

def test_save_stop():
lm = models.LocalMock(b"<s>Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
lm = models.Mock(b"<s>Count to 10: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
lm += "Count to 10: 1, 2, 3, 4, 5, 6, 7, " + gen('text', stop=",", save_stop_text='stop_text')
assert lm["stop_text"] == ","

Expand All @@ -44,7 +44,7 @@ def test_unicode2():
assert True

def test_gsm8k():
lm = models.LocalMock()
lm = models.Mock()
lm + '''Question: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Answer: ''' + gen(max_tokens=30)
assert True
Expand All @@ -63,60 +63,60 @@ def test_non_token_force():
assert len(str(lm)) == 6

def test_pattern_optional():
lm = models.LocalMock(b"<s>12342333")
lm = models.Mock(b"<s>12342333")
pattern = '.?233'
lm2 = lm + '123' + gen(name='numbers', regex=pattern, max_tokens=10)
assert lm2['numbers'] == '4233'
lm = models.LocalMock(b"<s>1232333")
lm = models.Mock(b"<s>1232333")
pattern = '.?233'
lm2 = lm + '123' + gen(name='numbers', regex=pattern, max_tokens=10)
assert lm2['numbers'] == '233'
pattern = r'(Scott is bad)?(\d+)?o'
lm = models.LocalMock(b"<s>John was a little man full of things")
lm = models.Mock(b"<s>John was a little man full of things")
lm2 = lm + 'J' + gen(name='test', regex=pattern, max_tokens=30)
assert lm2['test'] == 'o'

def test_pattern_stops_when_fulfilled():
lm = models.LocalMock(b"<s>123abc")
lm = models.Mock(b"<s>123abc")
lm += gen(regex=r'\d+', max_tokens=10, name='test')
assert lm['test'] == '123'

def test_pattern_star():
# lm = models.LocalMock(b"<s>1234233234<s>") # commented out because it is not a valid test
# lm = models.Mock(b"<s>1234233234<s>") # commented out because it is not a valid test
# patterns = ['\d+233', '\d*233', '.+233', '.*233']
# for pattern in patterns:
# lm2 = lm + '123' + gen(name='numbers', regex=pattern, max_tokens=10)
# assert lm2['numbers'] == '4233'
lm = models.LocalMock(b"<s>123233")
lm = models.Mock(b"<s>123233")
patterns = [r'\d*233','.*233']
for pattern in patterns:
lm2 = lm + '123' + gen(name='numbers', regex=pattern, max_tokens=10)
assert lm2['numbers'].startswith('233')
pattern = '.*(\n|little)'
lm = models.LocalMock(b"<s>John was a little")
lm = models.Mock(b"<s>John was a little")
lm2 = lm + 'J' + gen(name='test', regex=pattern, max_tokens=30)
assert lm2['test'].startswith('ohn was a little')
lm = models.LocalMock(b"<s>John was a litt\n")
lm = models.Mock(b"<s>John was a litt\n")
lm2 = lm + 'J' + gen(name='test', regex=pattern, max_tokens=30)
assert lm2['test'].startswith('ohn was a litt\n')

def test_stop_regex():
lm = models.LocalMock(b"<s>123a3233")
lm = models.Mock(b"<s>123a3233")
lm2 = lm + '123' + gen(name='test', stop_regex=r'\d233', max_tokens=10)
assert lm2['test'] == 'a'
lm = models.LocalMock(b"<s>123aegalera3233")
lm = models.Mock(b"<s>123aegalera3233")
lm2 = lm + '123' + gen(name='test', stop_regex=r'\d', max_tokens=30)
assert lm2['test'] == 'aegalera'

def test_stop_regex_star():
lm = models.LocalMock(b"<s>123a3233")
lm = models.Mock(b"<s>123a3233")
pattern = r'\d+233'
lm2 = lm + '123' + gen(name='test', stop_regex=pattern, max_tokens=10)
assert lm2['test'] == 'a'

def test_empty_pattern():
pattern = r'(Scott is bad)?(\d+)?'
lm = models.LocalMock(b"<s>J<s>")
lm = models.Mock(b"<s>J<s>")
lm2 = lm + 'J' + gen(name='test', regex=pattern, max_tokens=30)
assert lm2['test'] == ''

Expand Down Expand Up @@ -184,7 +184,7 @@ def test_long_prompt():

def test_list_append():
'''This tests is list append works across grammar appends.'''
lm = models.LocalMock(b"<s>bababababa")
lm = models.Mock(b"<s>bababababa")
lm += "<s>"
for _ in range(3):
lm += gen("my_list", list_append=True, stop="a") + "a"
Expand All @@ -193,19 +193,19 @@ def test_list_append():

def test_list_append_in_grammar():
'''This tests is list append works within the same grammar.'''
lm = models.LocalMock(b"<s>bababababa")
lm = models.Mock(b"<s>bababababa")
lm += "<s>"
lm += gen("my_list", list_append=True, stop="a") + "a" + gen("my_list", list_append=True, stop="a") + "a" + gen("my_list", list_append=True, stop="a")
assert isinstance(lm['my_list'], list)
assert len(lm['my_list']) == 3

def test_one_char_suffix_and_regex():
model = models.LocalMock(b"<s>this is\na test")
model = models.Mock(b"<s>this is\na test")
model += gen(regex=".*", suffix="\n", max_tokens=20)
assert str(model) == "this is\n"

def test_one_char_stop_and_regex():
model = models.LocalMock(b"<s>this is\na test")
model = models.Mock(b"<s>this is\na test")
model += gen(regex=".*", stop="\n", max_tokens=20)
assert str(model) == "this is"

Expand Down
6 changes: 3 additions & 3 deletions tests/library/test_one_or_more.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from guidance import models, one_or_more, char_set

def test_string():
model = models.LocalMock("<s>aaabc")
model = models.Mock("<s>aaabc")
assert str(model + '<s>' + one_or_more("a")) == "<s>aaa"

def test_grammar():
model = models.LocalMock("<s>bac")
model = models.Mock("<s>bac")
assert str(model + '<s>' + one_or_more(char_set("ab"))) == "<s>ba"

def test_at_least_one():
model = models.LocalMock("<s>cbac")
model = models.Mock("<s>cbac")
assert not str(model + '<s>' + one_or_more(char_set("ab"))).startswith("<s>c")
2 changes: 1 addition & 1 deletion tests/library/test_silent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ..utils import get_model

def test_basic():
lm = models.LocalMock()
lm = models.Mock()
lm += "Start text"
with silent():
lm += "silent text"
Expand Down
9 changes: 0 additions & 9 deletions tests/models/test_local.py

This file was deleted.

Loading

0 comments on commit d1bbce1

Please sign in to comment.