Skip to content

Commit

Permalink
fix: further mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Szymon Palucha authored and paluchasz committed Sep 25, 2024
1 parent 5f6da5f commit b8018fc
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
11 changes: 6 additions & 5 deletions kazu/annotation/acceptance_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import hydra
from hydra.utils import instantiate
from omegaconf.dictconfig import DictConfig

from kazu.data import Entity, Document, IdsAndSource
from kazu.pipeline import Pipeline
Expand All @@ -30,7 +31,7 @@ def acceptance_criteria() -> AcceptanceCriteria:


@hydra.main(version_base=HYDRA_VERSION_BASE, config_path=".", config_name="config")
def execute_full_pipeline_acceptance_test(cfg):
def execute_full_pipeline_acceptance_test(cfg: DictConfig) -> None:
manager = instantiate(cfg.LabelStudioManager)
pipeline: Pipeline = instantiate(cfg.Pipeline)
analyse_full_pipeline(pipeline, manager.export_from_ls(), acceptance_criteria())
Expand Down Expand Up @@ -75,17 +76,17 @@ def group_mappings_by_source(ents: Iterable[Entity]) -> dict[str, IdsAndSource]:
)
return dict(mappings_by_source)

def calculate_ner_matches(self):
def calculate_ner_matches(self) -> None:
combos = itertools.product(self.gold_ents, self.test_ents)
for (gold_ent, test_ent) in combos:
for gold_ent, test_ent in combos:
if (
gold_ent.spans == test_ent.spans or gold_ent.is_partially_overlapped(test_ent)
) and gold_ent.entity_class == test_ent.entity_class:
self.gold_to_test_ent_soft[gold_ent].add(test_ent)
self.ner_fp_soft.discard(test_ent)
self.ner_fn_soft.discard(gold_ent)

def calculate_linking_matches(self):
def calculate_linking_matches(self) -> None:
for gold_ent, test_ents in self.gold_to_test_ent_soft.items():
gold_mappings_by_source = self.group_mappings_by_source([gold_ent])
test_mappings_by_source = self.group_mappings_by_source(test_ents)
Expand Down Expand Up @@ -305,7 +306,7 @@ def analyse_annotation_consistency(docs: list[Document]) -> None:


@hydra.main(version_base=HYDRA_VERSION_BASE, config_path="../../", config_name="conf")
def check_annotation_consistency(cfg):
def check_annotation_consistency(cfg: DictConfig) -> None:

manager = instantiate(cfg.LabelStudioManager)
docs = manager.export_from_ls()
Expand Down
12 changes: 6 additions & 6 deletions kazu/utils/build_and_test_model_packs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class BuildConfiguration:
#: pack is built. If any exceptions are detected, the build will fail.
sanity_test_strings: list[str] = field(default_factory=list)

def __post_init__(self):
def __post_init__(self) -> None:
if len(self.resources) > 0:
self.requires_resources = True
else:
Expand All @@ -73,7 +73,7 @@ def __init__(
maybe_base_configuration_path: Optional[Path],
skip_tests: bool,
zip_pack: bool,
):
) -> None:
"""A ModelPackBuilder is a helper class to assist in the building of a model
pack.
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init__(
os.environ["KAZU_MODEL_PACK"] = str(self.model_pack_build_path)
self.build_config = self.load_build_configuration()

def __repr__(self):
def __repr__(self) -> str:
"""For nice log messages."""
return f"ModelPackBuilder({self.target_model_pack_path.name})"

Expand Down Expand Up @@ -202,7 +202,7 @@ def load_build_configuration(self) -> BuildConfiguration:
data = json.load(f)
return BuildConfiguration(**data)

def apply_merge_configurations(self):
def apply_merge_configurations(self) -> None:

# copy the target pack to the target build dir
shutil.copytree(
Expand Down Expand Up @@ -231,7 +231,7 @@ def apply_merge_configurations(self):
if self.build_config.requires_resources:
self.copy_resources_to_target()

def copy_resources_to_target(self):
def copy_resources_to_target(self) -> None:

for parent_dir_str, resource_list in self.build_config.resources.items():
parent_dir_path = Path(parent_dir_str)
Expand Down Expand Up @@ -302,7 +302,7 @@ def build_caches_and_run_sanity_checks(self, cfg: DictConfig) -> "Pipeline":
)
return pipeline

def report_tested_dependencies(self):
def report_tested_dependencies(self) -> None:
dependencies = subprocess.check_output("pip freeze --exclude-editable", shell=True).decode(
"utf-8"
)
Expand Down
12 changes: 5 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,11 @@ disallow_untyped_defs = false
# the numbers can be re-calculated with:
# mypy kazu docs conftest.py | cut -f 1,2 -d ':' | sed '$d' | sort | uniq | cut -f 1 -d ':' | uniq -c | sort --reverse | awk '{file_without_dot_py = substr($2, 0, length($2)-3); gsub("/", ".", file_without_dot_py); print " \""file_without_dot_py"\", # "$1}'
module = [
"kazu.distillation.models", # 10
"kazu.utils.build_and_test_model_packs", # 5
"kazu.linking.sapbert.train", # 4
"kazu.annotation.acceptance_test", # 4
"kazu.utils.spacy_pipeline", # 2
"kazu.distillation.tiny_transformers", # 2
"kazu.distillation.metrics", # 1
"kazu.distillation.models", # 10
"kazu.linking.sapbert.train", # 4
"kazu.utils.spacy_pipeline", # 2
"kazu.distillation.tiny_transformers", # 2
"kazu.distillation.metrics", # 1
]
# we had a bunch of these in the codebase before we moved to a 'strict' mypy config, and it was too many
# to fix at that time for the payoff. Having overrides for the modules that would error rather than
Expand Down

0 comments on commit b8018fc

Please sign in to comment.