Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: normalize LLM parameter case and improve type handling #1830

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,28 @@ class Agent(BaseAgent):

@model_validator(mode="after")
def post_init_setup(self):
# Handle case-insensitive LLM parameter
if hasattr(self, 'LLM'):
import warnings
warnings.warn(
"Using 'LLM' parameter is deprecated. Use lowercase 'llm' instead.",
DeprecationWarning,
stacklevel=2
)
# Transfer LLM value to llm
self.llm = getattr(self, 'LLM')
delattr(self, 'LLM')

self._set_knowledge()
self.agent_ops_agent_name = self.role
unaccepted_attributes = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]

# Initialize LLM parameters
llm_params: Dict[str, Any] = {}

# Handle different cases for self.llm
if isinstance(self.llm, str):
Expand Down Expand Up @@ -190,7 +205,71 @@ def post_init_setup(self):
if key not in ["prompt", "key_name", "default"]:
# Only add default if the key is already set in os.environ
if key in os.environ:
llm_params[key] = value
# Convert environment variables to proper types
try:
param_value = None

# Integer parameters
if key in ['timeout', 'max_tokens', 'n', 'max_completion_tokens']:
try:
param_value = int(str(value)) if value else None
except (ValueError, TypeError):
continue

# Float parameters
elif key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']:
try:
param_value = float(str(value)) if value else None
except (ValueError, TypeError):
continue

# Boolean parameters
elif key == 'logprobs':
if isinstance(value, bool):
param_value = value
elif isinstance(value, str):
param_value = value.lower() == 'true'

# Dict parameters
elif key == 'logit_bias' and value:
try:
if isinstance(value, dict):
param_value = {int(k): float(v) for k, v in value.items()}
elif isinstance(value, str):
import json
bias_dict = json.loads(value)
param_value = {int(k): float(v) for k, v in bias_dict.items()}
except (ValueError, TypeError, json.JSONDecodeError):
continue

elif key == 'response_format' and value:
try:
if isinstance(value, dict):
param_value = value
elif isinstance(value, str):
import json
param_value = json.loads(value)
except (ValueError, json.JSONDecodeError):
continue

# List parameters
elif key == 'callbacks':
if isinstance(value, (list, tuple)):
param_value = list(value)
elif isinstance(value, str):
param_value = [cb.strip() for cb in value.split(',') if cb.strip()]
else:
param_value = []

# String and other parameters
else:
param_value = value

if param_value is not None:
llm_params[key] = param_value
except Exception:
# Skip any invalid values
continue

self.llm = LLM(**llm_params)
else:
Expand Down
65 changes: 65 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from unittest import mock

import pytest

from crewai.agent import Agent
from crewai.llm import LLM


def test_agent_with_custom_llm():
"""Test creating an agent with a custom LLM."""
custom_llm = LLM(model="gpt-4")
agent = Agent()
agent.role = "test"
agent.goal = "test"
agent.backstory = "test"
agent.llm = custom_llm
agent.allow_delegation = False
agent.post_init_setup()

assert isinstance(agent.llm, LLM)
assert agent.llm.model == "gpt-4"

def test_agent_with_uppercase_llm_param():
"""Test creating an agent with uppercase 'LLM' parameter."""
custom_llm = LLM(model="gpt-4")
with pytest.warns(DeprecationWarning):
agent = Agent()
agent.role = "test"
agent.goal = "test"
agent.backstory = "test"
setattr(agent, 'LLM', custom_llm) # Using uppercase LLM
agent.allow_delegation = False
agent.post_init_setup()

assert isinstance(agent.llm, LLM)
assert agent.llm.model == "gpt-4"
assert not hasattr(agent, 'LLM')

def test_agent_llm_parameter_types():
"""Test LLM parameter type handling."""
env_vars = {
"temperature": "0.7",
"max_tokens": "100",
"presence_penalty": "0.5",
"logprobs": "true",
"logit_bias": '{"50256": -100}',
"callbacks": "callback1,callback2",
}
with mock.patch.dict(os.environ, env_vars):
agent = Agent()
agent.role = "test"
agent.goal = "test"
agent.backstory = "test"
agent.llm = "gpt-4"
agent.allow_delegation = False
agent.post_init_setup()

assert isinstance(agent.llm, LLM)
assert agent.llm.temperature == 0.7
assert agent.llm.max_tokens == 100
assert agent.llm.presence_penalty == 0.5
assert agent.llm.logprobs is True
assert agent.llm.logit_bias == {50256: -100.0}
assert agent.llm.callbacks == ["callback1", "callback2"]
Loading