Skip to content

Commit

Permalink
Add new processors and update type hints for response assessment tasks
Browse files Browse the repository at this point in the history
Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Jan 6, 2025
1 parent b32bb80 commit a23a41e
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 15 deletions.
17 changes: 17 additions & 0 deletions prepare/processors/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
TakeFirstNonEmptyLine,
TakeFirstWord,
TakeLastNonEmptyLine,
TakeUntilPunc,
Title,
ToYesOrNone,
Upper,
YesNoToInt,
YesToOneElseZero,
)
from unitxt.settings_utils import get_constants
from unitxt.string_operators import Strip

constants = get_constants()
logger = get_logger()
Expand Down Expand Up @@ -74,6 +77,12 @@ def add_processor_and_operator_to_catalog(
overwrite=True,
)

add_processor_and_operator_to_catalog(
artifact_name="strip",
operator=Strip(),
overwrite=True,
)

add_processor_and_operator_to_catalog(
artifact_name="take_last_non_empty_line",
operator=TakeLastNonEmptyLine(),
Expand All @@ -91,6 +100,14 @@ def add_processor_and_operator_to_catalog(
artifact_name="lower_case_till_punc", operator=LowerCaseTillPunc(), overwrite=True
)

add_processor_and_operator_to_catalog(
artifact_name="take_until_punc", operator=TakeUntilPunc(), overwrite=True
)

add_processor_and_operator_to_catalog(
artifact_name="title", operator=Title(), overwrite=True
)

add_processor_and_operator_to_catalog(
artifact_name="hate_speech_or_not_hate_speech",
operator=StringEquals(string="hate speech"),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Literal, Tuple

from unitxt.blocks import Task
from unitxt.catalog import add_to_catalog
Expand All @@ -10,9 +10,12 @@
"dialog_b": List[Tuple[str, str]],
},
reference_fields={
"winner": str
}, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"},
"winner": Literal["choice_a", "choice_b", "tie"],
"classes": List[Literal["choice_a", "choice_b", "tie"]],
},
defaults={"classes": ["choice_a", "choice_b", "tie"]},
metrics=["metrics.accuracy", "metrics.f1_micro", "metrics.f1_macro"],
prediction_type=str,
),
"tasks.response_assessment.pairwise_comparison.multi_turn",
overwrite=True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Literal, Tuple

from unitxt.blocks import Task
from unitxt.catalog import add_to_catalog
Expand All @@ -11,9 +11,12 @@
"reference_dialog": List[Tuple[str, str]],
},
reference_fields={
"winner": str
}, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"},
"winner": Literal["choice_a", "choice_b", "tie"],
"classes": List[Literal["choice_a", "choice_b", "tie"]],
},
defaults={"classes": ["choice_a", "choice_b", "tie"]},
metrics=["metrics.accuracy", "metrics.f1_micro", "metrics.f1_macro"],
prediction_type=str,
),
"tasks.response_assessment.pairwise_comparison.multi_turn_with_reference",
overwrite=True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Literal

from unitxt.blocks import Task
from unitxt.catalog import add_to_catalog

Expand All @@ -9,9 +11,12 @@
"answer_b": str,
},
reference_fields={
"winner": str
}, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"
"winner": Literal["choice_a", "choice_b", "tie"],
"classes": List[Literal["choice_a", "choice_b", "tie"]],
},
defaults={"classes": ["choice_a", "choice_b", "tie"]},
metrics=["metrics.accuracy", "metrics.f1_micro", "metrics.f1_macro"],
prediction_type=str,
),
"tasks.response_assessment.pairwise_comparison.single_turn",
overwrite=True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Literal

from unitxt.blocks import Task
from unitxt.catalog import add_to_catalog

Expand All @@ -10,9 +12,12 @@
"reference_answer": str,
},
reference_fields={
"winner": str
}, # TODO: Support and change to "Literal['choice_a', 'choice_b', 'tie']"},
"winner": Literal["choice_a", "choice_b", "tie"],
"classes": List[Literal["choice_a", "choice_b", "tie"]],
},
defaults={"classes": ["choice_a", "choice_b", "tie"]},
metrics=["metrics.accuracy", "metrics.f1_micro", "metrics.f1_macro"],
prediction_type=str,
),
"tasks.response_assessment.pairwise_comparison.single_turn_with_reference",
overwrite=True,
Expand Down
29 changes: 28 additions & 1 deletion src/unitxt/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,27 @@ def process_value(self, text: Any) -> Any:
return str(text).upper()


class Title(FieldOperator):
def process_value(self, text: Any) -> Any:
return str(text).title()


class TakeUntilPunc(FieldOperator):
_requirements_list = ["regex"]

def prepare(self):
super().prepare()
import regex

self.pattern = regex.compile(r"\p{P}+")

def process_value(self, text: Any) -> Any:
match = self.pattern.search(text)
if match:
text = text[: match.start()]
return text


@deprecation("2.0.0", alternative=Lower)
class LowerCase(Lower):
pass
Expand Down Expand Up @@ -294,10 +315,16 @@ def process_value(self, text: Any) -> Any:


class ExtractMtBenchLabelJudgment(FieldOperator):
options = {
"A": "choice_a",
"B": "choice_b",
"C": "tie",
}

def process_value(self, text: Any) -> Any:
match = re.search(r"\[\[([^\]]+)\]\]", text)
try:
return str(match.group(1))
return self.options.get(str(match.group(1)), "None")
except:
return "None"

Expand Down
2 changes: 1 addition & 1 deletion tests/library/test_postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def test_extract_mt_bench_label_judgment(self):
"good",
"bad [[C]]",
]
targets = ["A", "B", "A", "None", "C"]
targets = ["choice_a", "choice_b", "choice_a", "None", "tie"]

check_operator(
operator=postprocessor,
Expand Down
6 changes: 3 additions & 3 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@
"filename": "src/unitxt/inference.py",
"hashed_secret": "aa6cd2a77de22303be80e1f632195d62d211a729",
"is_verified": false,
"line_number": 1250,
"line_number": 1249,
"is_secret": false
},
{
"type": "Secret Keyword",
"filename": "src/unitxt/inference.py",
"hashed_secret": "c8f16a194efc59559549c7bd69f7bea038742e79",
"is_verified": false,
"line_number": 1696,
"line_number": 1663,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-01-05T08:58:59Z"
"generated_at": "2025-01-01T10:22:19Z"
}

0 comments on commit a23a41e

Please sign in to comment.