diff --git a/dbt_common/clients/agate_helper.py b/dbt_common/clients/agate_helper.py index 3aade66d..934b0b11 100644 --- a/dbt_common/clients/agate_helper.py +++ b/dbt_common/clients/agate_helper.py @@ -149,7 +149,7 @@ def as_matrix(table): return [r.values() for r in table.rows.values()] -def from_csv(abspath, text_columns, delimiter=","): +def from_csv(abspath, text_columns, delimiter=",") -> agate.Table: type_tester = build_type_tester(text_columns=text_columns) with open(abspath, encoding="utf-8") as fp: if fp.read(1) != BOM: diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 44d3eade..653231c1 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -9,12 +9,12 @@ from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type from typing_extensions import Protocol -import jinja2 # type: ignore -import jinja2.ext # type: ignore -import jinja2.nativetypes # type: ignore -import jinja2.nodes # type: ignore -import jinja2.parser # type: ignore -import jinja2.sandbox # type: ignore +import jinja2 +import jinja2.ext +import jinja2.nativetypes +import jinja2.nodes +import jinja2.parser +import jinja2.sandbox from dbt_common.tests import test_caching_enabled from dbt_common.utils.jinja import ( diff --git a/dbt_common/contracts/config/metadata.py b/dbt_common/contracts/config/metadata.py index 83f3457e..563e6079 100644 --- a/dbt_common/contracts/config/metadata.py +++ b/dbt_common/contracts/config/metadata.py @@ -9,7 +9,7 @@ class Metadata(Enum): @classmethod - def from_field(cls: Type[M], fld: Field) -> M: + def from_field(cls: Type[M], fld: Field[Any]) -> M: default = cls.default_field() key = cls.metadata_key() @@ -28,7 +28,7 @@ def metadata_key(cls) -> str: raise NotImplementedError("Not implemented") -def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M: +def _get_meta_value(cls: Type[M], fld: Field[Any], key: str, default: Any) -> M: # a metadata field might exist. If it does, it might have a matching key. # If it has both, make sure the value is valid and return it. If it # doesn't, return the default. @@ -65,5 +65,5 @@ def metadata_key(cls) -> str: return "show_hide" @classmethod - def should_show(cls, fld: Field) -> bool: + def should_show(cls, fld: Field[Any]) -> bool: return cls.from_field(fld) == cls.Show diff --git a/dbt_common/dataclass_schema.py b/dbt_common/dataclass_schema.py index 4e003b13..8721fe4b 100644 --- a/dbt_common/dataclass_schema.py +++ b/dbt_common/dataclass_schema.py @@ -1,4 +1,4 @@ -from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple +from typing import Any, ClassVar, Dict, get_type_hints, List, Optional, Tuple, Union import re import jsonschema from dataclasses import fields, Field @@ -6,7 +6,6 @@ from datetime import datetime from dateutil.parser import parse -# type: ignore from mashumaro.config import ( TO_DICT_ADD_OMIT_NONE_FLAG, ADD_SERIALIZATION_CONTEXT, @@ -33,8 +32,8 @@ def serialize(self, value: datetime) -> str: out += "Z" return out - def deserialize(self, value) -> datetime: - return value if isinstance(value, datetime) else parse(cast(str, value)) + def deserialize(self, value: Union[datetime, str]) -> datetime: + return value if isinstance(value, datetime) else parse(value) class dbtMashConfig(MashBaseConfig): @@ -63,7 +62,7 @@ class dbtClassMixin(DataClassMessagePackMixin): against the schema """ - _mapped_fields: ClassVar[Optional[Dict[Any, List[Tuple[Field, str]]]]] = None + _mapped_fields: ClassVar[Optional[Dict[Any, List[Tuple[Field[Any], str]]]]] = None # Config class used by Mashumaro class Config(dbtMashConfig): diff --git a/tests/unit/test_connection_retries.py b/tests/unit/test_connection_retries.py index 44fc72f5..12352352 100644 --- a/tests/unit/test_connection_retries.py +++ b/tests/unit/test_connection_retries.py @@ -5,12 +5,12 @@ from dbt_common.utils.connection import connection_exception_retry -def no_retry_fn(): +def no_retry_fn() -> str: return "success" class TestNoRetries: - def test_no_retry(self): + def test_no_retry(self) -> None: fn_to_retry = functools.partial(no_retry_fn) result = connection_exception_retry(fn_to_retry, 3) diff --git a/tests/unit/test_diff.py b/tests/unit/test_diff.py index 54f735e3..26d9d490 100644 --- a/tests/unit/test_diff.py +++ b/tests/unit/test_diff.py @@ -1,12 +1,14 @@ import json -from typing import Any, Dict +from typing import Any, Dict, List import pytest from dbt_common.record import Diff +Case = List[Dict[str, Any]] + @pytest.fixture -def current_query(): +def current_query() -> Case: return [ { "params": { @@ -21,7 +23,7 @@ def current_query(): @pytest.fixture -def query_modified_order(): +def query_modified_order() -> Case: return [ { "params": { @@ -36,7 +38,7 @@ def query_modified_order(): @pytest.fixture -def query_modified_value(): +def query_modified_value() -> Case: return [ { "params": { @@ -51,7 +53,7 @@ def query_modified_value(): @pytest.fixture -def current_simple(): +def current_simple() -> Case: return [ { "params": { @@ -65,7 +67,7 @@ def current_simple(): @pytest.fixture -def current_simple_modified(): +def current_simple_modified() -> Case: return [ { "params": { @@ -79,7 +81,7 @@ def current_simple_modified(): @pytest.fixture -def env_record(): +def env_record() -> Case: return [ { "params": {}, @@ -94,7 +96,7 @@ def env_record(): @pytest.fixture -def modified_env_record(): +def modified_env_record() -> Case: return [ { "params": {}, @@ -108,30 +110,30 @@ def modified_env_record(): ] -def test_diff_query_records_no_diff(current_query, query_modified_order): +def test_diff_query_records_no_diff(current_query: Case, query_modified_order: Case) -> None: # Setup: Create an instance of Diff diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) result = diff_instance.diff_query_records(current_query, query_modified_order) # the order changed but the diff should be empty - expected_result = {} + expected_result: Dict[str, Any] = {} assert result == expected_result # Replace expected_result with what you actually expect -def test_diff_query_records_with_diff(current_query, query_modified_value): +def test_diff_query_records_with_diff(current_query: Case, query_modified_value: Case) -> None: diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) result = diff_instance.diff_query_records(current_query, query_modified_value) # the values changed this time - expected_result = { + expected_result: Dict[str, Any] = { "values_changed": {"root[0]['result']['table'][1]['b']": {"new_value": 7, "old_value": 10}} } assert result == expected_result -def test_diff_env_records(env_record, modified_env_record): +def test_diff_env_records(env_record: Case, modified_env_record: Case) -> None: diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) @@ -147,17 +149,17 @@ def test_diff_env_records(env_record, modified_env_record): assert result == expected_result -def test_diff_default_no_diff(current_simple): +def test_diff_default_no_diff(current_simple: Case) -> None: diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) # use the same list to ensure no diff result = diff_instance.diff_default(current_simple, current_simple) - expected_result = {} + expected_result: Dict[str, Any] = {} assert result == expected_result -def test_diff_default_with_diff(current_simple, current_simple_modified): +def test_diff_default_with_diff(current_simple: Case, current_simple_modified: Case) -> None: diff_instance = Diff( current_recording_path="path/to/current", previous_recording_path="path/to/previous" ) @@ -170,7 +172,7 @@ def test_diff_default_with_diff(current_simple, current_simple_modified): # Mock out reading the files so we don't have to class MockFile: - def __init__(self, json_data): + def __init__(self, json_data) -> None: self.json_data = json_data def __enter__(self): diff --git a/tests/unit/test_functions.py b/tests/unit/test_functions.py index 372b2bda..9a8a9c22 100644 --- a/tests/unit/test_functions.py +++ b/tests/unit/test_functions.py @@ -38,7 +38,7 @@ def valid_error_names() -> Set[str]: class TestWarnOrError: - def test_fires_error(self, valid_error_names: Set[str]): + def test_fires_error(self, valid_error_names: Set[str]) -> None: functions.WARN_ERROR_OPTIONS = WarnErrorOptions( include="*", valid_error_names=valid_error_names ) @@ -49,8 +49,8 @@ def test_fires_warning( self, valid_error_names: Set[str], event_catcher: EventCatcher, - set_event_manager_with_catcher, - ): + set_event_manager_with_catcher: None, + ) -> None: functions.WARN_ERROR_OPTIONS = WarnErrorOptions( include="*", exclude=list(valid_error_names), valid_error_names=valid_error_names ) @@ -62,8 +62,8 @@ def test_silenced( self, valid_error_names: Set[str], event_catcher: EventCatcher, - set_event_manager_with_catcher, - ): + set_event_manager_with_catcher: None, + ) -> None: functions.WARN_ERROR_OPTIONS = WarnErrorOptions( include="*", silence=list(valid_error_names), valid_error_names=valid_error_names ) diff --git a/tests/unit/test_jinja.py b/tests/unit/test_jinja.py index e906a0ac..cf44eee7 100644 --- a/tests/unit/test_jinja.py +++ b/tests/unit/test_jinja.py @@ -227,7 +227,7 @@ def test_incomplete_block_failure(self) -> None: with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock"}) - def test_wrong_end_failure(self): + def test_wrong_end_failure(self) -> None: body = "{% myblock foo %} {% endotherblock %}" with self.assertRaises(CompilationError): extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"}) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 93c57046..bb5563e2 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,5 @@ import unittest +from typing import Any, Tuple, Union import dbt_common.exceptions import dbt_common.utils.dict @@ -68,7 +69,7 @@ def setUp(self) -> None: } @staticmethod - def intify_all(value, _): + def intify_all(value, _) -> int: try: return int(value) except (TypeError, ValueError): @@ -98,7 +99,7 @@ def test__simple_cases(self) -> None: self.assertEqual(actual, expected) @staticmethod - def special_keypath(value, keypath): + def special_keypath(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any: if tuple(keypath) == ("foo", "baz", 1): return "hello" else: