Skip to content

Commit

Permalink
Fix package extras for watsonx support (#2426)
Browse files Browse the repository at this point in the history
* Update pyproject.toml with watsonx package extra

Signed-off-by: kiersten-stokes <[email protected]>

* Remove unused function

Signed-off-by: kiersten-stokes <[email protected]>

---------

Signed-off-by: kiersten-stokes <[email protected]>
  • Loading branch information
kiersten-stokes authored Oct 25, 2024
1 parent 1185e89 commit 7882043
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
8 changes: 1 addition & 7 deletions lm_eval/models/ibm_watsonx_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(
project_id = watsonx_credentials.get("project_id", None)
deployment_id = watsonx_credentials.get("deployment_id", None)
client.set.default_project(project_id)
self.generate_params = generate_params
self.generate_params = generate_params or {}
self.model = ModelInference(
model_id=model_id,
deployment_id=deployment_id,
Expand All @@ -167,12 +167,6 @@ def __init__(
)
self._model_id = model_id

def dump_parameters(self):
"""
Dumps the model's parameters into a serializable format.
"""
return self._parameters.model_dump()

@staticmethod
def _has_stop_token(response_tokens: List[str], context_tokens: List[str]) -> bool:
"""
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dev = ["pytest", "pytest-cov", "pytest-xdist", "pre-commit", "mypy"]
deepsparse = ["deepsparse-nightly[llm]>=1.8.0.20240404"]
gptq = ["auto-gptq[triton]>=0.6.0"]
hf_transfer = ["hf_transfer"]
ibm_watsonx_ai = ["ibm_watsonx_ai"]
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
neuronx = ["optimum[neuronx]"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2"]
Expand All @@ -81,6 +82,7 @@ all = [
"lm_eval[deepsparse]",
"lm_eval[gptq]",
"lm_eval[hf_transfer]",
"lm_eval[ibm_watsonx_ai]",
"lm_eval[ifeval]",
"lm_eval[mamba]",
"lm_eval[math]",
Expand All @@ -93,7 +95,6 @@ all = [
"lm_eval[vllm]",
"lm_eval[zeno]",
"lm_eval[wandb]",
"lm_eval[ibm_watsonx_ai]"
]

[tool.ruff.lint]
Expand Down

0 comments on commit 7882043

Please sign in to comment.