forked from confident-ai/deepeval
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Anindyadeep
committed
Nov 28, 2023
1 parent
d35ac7a
commit e999114
Showing
11 changed files
with
107 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,64 @@ | ||
import torch | ||
from typing import Union, List | ||
import torch | ||
from typing import Union, List | ||
from typing import List, Union, get_origin | ||
from deepeval.models.base import DeepEvalBaseModel | ||
from deepeval.models._summac_model import _SummaCZS | ||
|
||
|
||
class SummaCModels(DeepEvalBaseModel): | ||
def __init__(self, model_name: str | None = None, granularity: str | None = None, device: str | None = None, *args, **kwargs): | ||
def __init__( | ||
self, | ||
model_name: str | None = None, | ||
granularity: str | None = None, | ||
device: str | None = None, | ||
*args, | ||
**kwargs | ||
): | ||
model_name = "vitc" if model_name is None else model_name | ||
self.granularity = "sentence" if granularity is None else granularity | ||
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu" | ||
self.device = ( | ||
device | ||
if device is not None | ||
else "cuda" | ||
if torch.cuda.is_available() | ||
else "cpu" | ||
) | ||
super().__init__(model_name, *args, **kwargs) | ||
|
||
def load_model(self, op1: str | None = "max", op2: str | None = "mean", use_ent: bool | None = True, use_con: bool | None = True, image_load_cache: bool | None = True, **kwargs): | ||
|
||
def load_model( | ||
self, | ||
op1: str | None = "max", | ||
op2: str | None = "mean", | ||
use_ent: bool | None = True, | ||
use_con: bool | None = True, | ||
image_load_cache: bool | None = True, | ||
**kwargs | ||
): | ||
return _SummaCZS( | ||
model_name=self.model_name, | ||
granularity=self.granularity, | ||
device=self.device, | ||
op1=op1, op2=op2, use_con=use_con, use_ent=use_ent, | ||
op1=op1, | ||
op2=op2, | ||
use_con=use_con, | ||
use_ent=use_ent, | ||
imager_load_cache=image_load_cache, | ||
**kwargs | ||
) | ||
|
||
def _call(self, predictions: Union[str, List[str]], targets: Union[str, List[str]]) -> Union[float, dict]: | ||
|
||
def _call( | ||
self, predictions: Union[str, List[str]], targets: Union[str, List[str]] | ||
) -> Union[float, dict]: | ||
list_type = List[str] | ||
|
||
if get_origin(predictions) is list_type and get_origin(targets) is list_type: | ||
if ( | ||
get_origin(predictions) is list_type | ||
and get_origin(targets) is list_type | ||
): | ||
return self.model.score(targets, predictions) | ||
elif isinstance(predictions, str) and isinstance(targets, str): | ||
return self.model.score_one(targets, predictions) | ||
else: | ||
raise TypeError('Either both predictions and targets should be List or both should be string') | ||
raise TypeError( | ||
"Either both predictions and targets should be List or both should be string" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,18 @@ | ||
from typing import Optional | ||
from deepeval.models.base import DeepEvalBaseModel | ||
|
||
|
||
class UnBiasedModel(DeepEvalBaseModel): | ||
def __init__(self, model_name: str | None = None, *args, **kwargs): | ||
model_name = "original" if model_name is None else model_name | ||
super().__init__(model_name, *args, **kwargs) | ||
|
||
def load_model(self): | ||
try: | ||
from Dbias.bias_classification import classifier | ||
except ImportError as e: | ||
print("Run `pip install deepeval[bias]`") | ||
return classifier | ||
|
||
def _call(self, text): | ||
return self.model(text) | ||
return self.model(text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters