Skip to content

Commit

Permalink
feat(lm): add support for o3-mini and openai reasoning models (#7649)
Browse files Browse the repository at this point in the history
  • Loading branch information
kalanyuz authored Feb 3, 2025
1 parent 5d9a8b4 commit 80cdcbe
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 6 deletions.
20 changes: 14 additions & 6 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import logging
import os
import re
import threading
import uuid
from datetime import datetime
Expand Down Expand Up @@ -74,19 +75,26 @@ def __init__(
self.cache_in_memory = cache_in_memory
self.provider = provider or self.infer_provider()
self.callbacks = callbacks or []
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)
self.history = []
self.callbacks = callbacks or []
self.num_retries = num_retries
self.finetuning_model = finetuning_model
self.launch_kwargs = launch_kwargs

# TODO(bug): Arbitrary model strings could include the substring "o1-".
# We should find a more robust way to check for the "o1-" family models.
if "o1-" in model:
# Handle model-specific configuration for different model families
model_family = model.split("/")[-1].lower() if "/" in model else model.lower()

# Match pattern: o[1,3] at the start, optionally followed by -mini and anything else
model_pattern = re.match(r"^o([13])(?:-mini)?", model_family)

if model_pattern:
# Handle OpenAI reasoning models (o1, o3)
assert (
max_tokens >= 5000 and temperature == 1.0
), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
), "OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)
else:
self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs)

@with_callbacks
def __call__(self, prompt=None, messages=None, **kwargs):
Expand All @@ -101,7 +109,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
if cache_in_memory:
completion = cached_litellm_completion if self.model_type == "chat" else cached_litellm_text_completion

response = completion(
response = completion(
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
)
Expand Down
46 changes: 46 additions & 0 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,49 @@ def test_lm_text_calls_are_retried_for_expected_failures(

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == expected_num_retries + 1 # 1 initial request + 1 retries


def test_reasoning_model_token_parameter():
test_cases = [
("openai/o1", True),
("openai/o1-mini", True),
("openai/o1-2023-01-01", True),
("openai/o3", True),
("openai/o3-mini-2023-01-01", True),
("openai/gpt-4", False),
("anthropic/claude-2", False),
]

for model_name, is_reasoning_model in test_cases:
lm = dspy.LM(
model=model_name,
temperature=1.0 if is_reasoning_model else 0.7,
max_tokens=5000 if is_reasoning_model else 1000,
)
if is_reasoning_model:
assert "max_completion_tokens" in lm.kwargs
assert "max_tokens" not in lm.kwargs
assert lm.kwargs["max_completion_tokens"] == 5000
else:
assert "max_completion_tokens" not in lm.kwargs
assert "max_tokens" in lm.kwargs
assert lm.kwargs["max_tokens"] == 1000


def test_reasoning_model_requirements():
# Should raise assertion error if temperature or max_tokens requirements not met
with pytest.raises(AssertionError) as exc_info:
dspy.LM(
model="openai/o1",
temperature=0.7, # Should be 1.0
max_tokens=1000, # Should be >= 5000
)
assert "reasoning models require passing temperature=1.0 and max_tokens >= 5000" in str(exc_info.value)

# Should pass with correct parameters
lm = dspy.LM(
model="openai/o1",
temperature=1.0,
max_tokens=5000,
)
assert lm.kwargs["max_completion_tokens"] == 5000

0 comments on commit 80cdcbe

Please sign in to comment.