Skip to content

Commit

Permalink
fix _load_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Anindyadeep committed Mar 5, 2024
1 parent 4681243 commit fc26b5d
Showing 1 changed file with 33 additions and 13 deletions.
46 changes: 33 additions & 13 deletions deepeval/experimental/harness/models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from typing import Optional, Union, List
from typing import Optional, Union, List

from deepeval.models.base_model import DeepEvalBaseModel
from deepeval.experimental.harness.config import GeneralConfig, APIEndpointConfig
from deepeval.experimental.harness.config import (
GeneralConfig,
APIEndpointConfig,
)


class DeepEvalHarnessModel(DeepEvalBaseModel):
def __init__(self, model_name_or_path: str, model_backend: str, **kwargs) -> None:
self.model_name_or_path, self.model_backend = model_name_or_path, model_backend
def __init__(
self, model_name_or_path: str, model_backend: str, **kwargs
) -> None:
self.model_name_or_path, self.model_backend = (
model_name_or_path,
model_backend,
)
self.additional_params = kwargs
super().__init__(model_name=model_name_or_path, **self.additional_params)
super().__init__(
model_name=model_name_or_path, **self.additional_params
)

def load_model(self, *args, **kwargs):
try:
from easy_eval import HarnessEvaluator
except ImportError as error:
Expand All @@ -18,17 +29,26 @@ def __init__(self, model_name_or_path: str, model_backend: str, **kwargs) -> Non
"easy_eval is not found."
"You can install it using: pip install easy-evaluator"
)

self.evaluator = HarnessEvaluator(
model_name_or_path=self.model_name_or_path, model_backend=self.model_backend, **self.additional_params
model_name_or_path=self.model_name_or_path,
model_backend=self.model_backend,
**self.additional_params,
)

def load_model(self, *args, **kwargs):
return self.evaluator.llm

def _call(self, tasks: List[str], config: Optional[Union[GeneralConfig, APIEndpointConfig]] = None):
# TODO: Anthropic is not supported in APIEndpointConfig.
return self.evaluator.llm

def _call(
self,
tasks: List[str],
config: Optional[Union[GeneralConfig, APIEndpointConfig]] = None,
):
# TODO: Anthropic is not supported in APIEndpointConfig.
if config is None:
self.config = APIEndpointConfig() if self.model_name_or_path == "openai" else GeneralConfig()
self.config = (
APIEndpointConfig()
if self.model_name_or_path == "openai"
else GeneralConfig()
)
else:
self.config = config
return self.evaluator.evaluate(tasks=tasks, config=self.config)

0 comments on commit fc26b5d

Please sign in to comment.