Skip to content

Commit

Permalink
Even more type annotations.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterallenwebb committed Aug 5, 2024
1 parent c9cc99e commit 1d3504c
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 42 deletions.
2 changes: 1 addition & 1 deletion dbt_common/clients/agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def as_matrix(table):
return [r.values() for r in table.rows.values()]


def from_csv(abspath, text_columns, delimiter=","):
def from_csv(abspath, text_columns, delimiter=",") -> agate.Table:
type_tester = build_type_tester(text_columns=text_columns)
with open(abspath, encoding="utf-8") as fp:
if fp.read(1) != BOM:
Expand Down
12 changes: 6 additions & 6 deletions dbt_common/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Union, Set, Type
from typing_extensions import Protocol

import jinja2 # type: ignore
import jinja2.ext # type: ignore
import jinja2.nativetypes # type: ignore
import jinja2.nodes # type: ignore
import jinja2.parser # type: ignore
import jinja2.sandbox # type: ignore
import jinja2
import jinja2.ext
import jinja2.nativetypes
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox

from dbt_common.tests import test_caching_enabled
from dbt_common.utils.jinja import (
Expand Down
6 changes: 3 additions & 3 deletions dbt_common/contracts/config/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class Metadata(Enum):
@classmethod
def from_field(cls: Type[M], fld: Field) -> M:
def from_field(cls: Type[M], fld: Field[Any]) -> M:
default = cls.default_field()
key = cls.metadata_key()

Expand All @@ -28,7 +28,7 @@ def metadata_key(cls) -> str:
raise NotImplementedError("Not implemented")


def _get_meta_value(cls: Type[M], fld: Field, key: str, default: Any) -> M:
def _get_meta_value(cls: Type[M], fld: Field[Any], key: str, default: Any) -> M:
# a metadata field might exist. If it does, it might have a matching key.
# If it has both, make sure the value is valid and return it. If it
# doesn't, return the default.
Expand Down Expand Up @@ -65,5 +65,5 @@ def metadata_key(cls) -> str:
return "show_hide"

@classmethod
def should_show(cls, fld: Field) -> bool:
def should_show(cls, fld: Field[Any]) -> bool:
return cls.from_field(fld) == cls.Show
9 changes: 4 additions & 5 deletions dbt_common/dataclass_schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Any, cast, ClassVar, Dict, get_type_hints, List, Optional, Tuple
from typing import Any, ClassVar, Dict, get_type_hints, List, Optional, Tuple, Union
import re
import jsonschema
from dataclasses import fields, Field
from enum import Enum
from datetime import datetime
from dateutil.parser import parse

# type: ignore
from mashumaro.config import (
TO_DICT_ADD_OMIT_NONE_FLAG,
ADD_SERIALIZATION_CONTEXT,
Expand All @@ -33,8 +32,8 @@ def serialize(self, value: datetime) -> str:
out += "Z"
return out

def deserialize(self, value) -> datetime:
return value if isinstance(value, datetime) else parse(cast(str, value))
def deserialize(self, value: Union[datetime, str]) -> datetime:
return value if isinstance(value, datetime) else parse(value)


class dbtMashConfig(MashBaseConfig):
Expand Down Expand Up @@ -63,7 +62,7 @@ class dbtClassMixin(DataClassMessagePackMixin):
against the schema
"""

_mapped_fields: ClassVar[Optional[Dict[Any, List[Tuple[Field, str]]]]] = None
_mapped_fields: ClassVar[Optional[Dict[Any, List[Tuple[Field[Any], str]]]]] = None

# Config class used by Mashumaro
class Config(dbtMashConfig):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_connection_retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from dbt_common.utils.connection import connection_exception_retry


def no_retry_fn():
def no_retry_fn() -> str:
return "success"


class TestNoRetries:
def test_no_retry(self):
def test_no_retry(self) -> None:
fn_to_retry = functools.partial(no_retry_fn)
result = connection_exception_retry(fn_to_retry, 3)

Expand Down
36 changes: 19 additions & 17 deletions tests/unit/test_diff.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import json
from typing import Any, Dict
from typing import Any, Dict, List

import pytest
from dbt_common.record import Diff

Case = List[Dict[str, Any]]


@pytest.fixture
def current_query():
def current_query() -> Case:
return [
{
"params": {
Expand All @@ -21,7 +23,7 @@ def current_query():


@pytest.fixture
def query_modified_order():
def query_modified_order() -> Case:
return [
{
"params": {
Expand All @@ -36,7 +38,7 @@ def query_modified_order():


@pytest.fixture
def query_modified_value():
def query_modified_value() -> Case:
return [
{
"params": {
Expand All @@ -51,7 +53,7 @@ def query_modified_value():


@pytest.fixture
def current_simple():
def current_simple() -> Case:
return [
{
"params": {
Expand All @@ -65,7 +67,7 @@ def current_simple():


@pytest.fixture
def current_simple_modified():
def current_simple_modified() -> Case:
return [
{
"params": {
Expand All @@ -79,7 +81,7 @@ def current_simple_modified():


@pytest.fixture
def env_record():
def env_record() -> Case:
return [
{
"params": {},
Expand All @@ -94,7 +96,7 @@ def env_record():


@pytest.fixture
def modified_env_record():
def modified_env_record() -> Case:
return [
{
"params": {},
Expand All @@ -108,30 +110,30 @@ def modified_env_record():
]


def test_diff_query_records_no_diff(current_query, query_modified_order):
def test_diff_query_records_no_diff(current_query: Case, query_modified_order: Case) -> None:
# Setup: Create an instance of Diff
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
result = diff_instance.diff_query_records(current_query, query_modified_order)
# the order changed but the diff should be empty
expected_result = {}
expected_result: Dict[str, Any] = {}
assert result == expected_result # Replace expected_result with what you actually expect


def test_diff_query_records_with_diff(current_query, query_modified_value):
def test_diff_query_records_with_diff(current_query: Case, query_modified_value: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
result = diff_instance.diff_query_records(current_query, query_modified_value)
# the values changed this time
expected_result = {
expected_result: Dict[str, Any] = {
"values_changed": {"root[0]['result']['table'][1]['b']": {"new_value": 7, "old_value": 10}}
}
assert result == expected_result


def test_diff_env_records(env_record, modified_env_record):
def test_diff_env_records(env_record: Case, modified_env_record: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
Expand All @@ -147,17 +149,17 @@ def test_diff_env_records(env_record, modified_env_record):
assert result == expected_result


def test_diff_default_no_diff(current_simple):
def test_diff_default_no_diff(current_simple: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
# use the same list to ensure no diff
result = diff_instance.diff_default(current_simple, current_simple)
expected_result = {}
expected_result: Dict[str, Any] = {}
assert result == expected_result


def test_diff_default_with_diff(current_simple, current_simple_modified):
def test_diff_default_with_diff(current_simple: Case, current_simple_modified: Case) -> None:
diff_instance = Diff(
current_recording_path="path/to/current", previous_recording_path="path/to/previous"
)
Expand All @@ -170,7 +172,7 @@ def test_diff_default_with_diff(current_simple, current_simple_modified):

# Mock out reading the files so we don't have to
class MockFile:
def __init__(self, json_data):
def __init__(self, json_data) -> None:
self.json_data = json_data

def __enter__(self):
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def valid_error_names() -> Set[str]:


class TestWarnOrError:
def test_fires_error(self, valid_error_names: Set[str]):
def test_fires_error(self, valid_error_names: Set[str]) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", valid_error_names=valid_error_names
)
Expand All @@ -49,8 +49,8 @@ def test_fires_warning(
self,
valid_error_names: Set[str],
event_catcher: EventCatcher,
set_event_manager_with_catcher,
):
set_event_manager_with_catcher: None,
) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", exclude=list(valid_error_names), valid_error_names=valid_error_names
)
Expand All @@ -62,8 +62,8 @@ def test_silenced(
self,
valid_error_names: Set[str],
event_catcher: EventCatcher,
set_event_manager_with_catcher,
):
set_event_manager_with_catcher: None,
) -> None:
functions.WARN_ERROR_OPTIONS = WarnErrorOptions(
include="*", silence=list(valid_error_names), valid_error_names=valid_error_names
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_incomplete_block_failure(self) -> None:
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={"myblock"})

def test_wrong_end_failure(self):
def test_wrong_end_failure(self) -> None:
body = "{% myblock foo %} {% endotherblock %}"
with self.assertRaises(CompilationError):
extract_toplevel_blocks(body, allowed_blocks={"myblock", "otherblock"})
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from typing import Any, Tuple, Union

import dbt_common.exceptions
import dbt_common.utils.dict
Expand Down Expand Up @@ -68,7 +69,7 @@ def setUp(self) -> None:
}

@staticmethod
def intify_all(value, _):
def intify_all(value, _) -> int:
try:
return int(value)
except (TypeError, ValueError):
Expand Down Expand Up @@ -98,7 +99,7 @@ def test__simple_cases(self) -> None:
self.assertEqual(actual, expected)

@staticmethod
def special_keypath(value, keypath):
def special_keypath(value: Any, keypath: Tuple[Union[str, int], ...]) -> Any:
if tuple(keypath) == ("foo", "baz", 1):
return "hello"
else:
Expand Down

0 comments on commit 1d3504c

Please sign in to comment.