From 7ab6312e5de4e41c37c302aac8288d0f5ab82c93 Mon Sep 17 00:00:00 2001 From: Lucas Valente Date: Mon, 21 Oct 2024 14:41:57 +0200 Subject: [PATCH] feat: support for custom grain (#55) * feat: add functionality to deprecate attributes This commit adds a new `BaseModel.DEPRECATED` key that can be added to the dataclasses `metadata` dict to indicate that this field is deprecated. * feat: new `DeprecatedMixin` for whole classes This commit introduces a new `DeprecatedMixin` class that can be added to any model class to mark that class as deprecated. It will throw a warning if the user tries to instantiate the class. * feat: mark `TimeGranularity` as deprecated Since we introduced custom grains, the old `queryable_granularities` is deprecated in favor of the new `queryable_time_granilarities`, which is just a simple list of strings. This commit marks the `TimeGranularity` class and all granularity fields that return it as deprecated. * fix: change `OrderByGroupBy` to use `str` as grain This commit changes the `OrderByGroupBy` class to use `str` instead of the deprecated `TimeGranularity` enum as its input grain. We can do this without a deprecation because we haven't released the SDK since the order by refactor, so we can just change it. * docs: changelog entry * fix: preload models in `__init__` We need to call `BaseModel._register_subclasses` otherwise models will fail to use `camelCase` and raise deprecation warnings. That is done in `dbtsl.models.__init__`. If the user never explicitly imports that, this won't get called, and they might get an error. This fixes that by adding an explicit call to it on the library init. * fix: catch deprecation warnings in GQL client We are raising deprecation warnings from the GQL client when we instantiate the models. To avoid the warning spam, we filter those warnings out. They should only be display if the user uses any deprecated class, not us. --- .../Deprecations-20241017-163158.yaml | 3 + .../unreleased/Features-20241017-163057.yaml | 3 + .../Under the Hood-20241017-163037.yaml | 3 + .changie.yaml | 4 +- dbtsl/__init__.py | 2 + dbtsl/api/adbc/protocol.py | 2 +- dbtsl/api/graphql/client/base.py | 11 ++-- dbtsl/api/shared/query_params.py | 4 +- dbtsl/asyncio.py | 2 + dbtsl/models/__init__.py | 2 +- dbtsl/models/base.py | 58 +++++++++++++++++-- dbtsl/models/dimension.py | 13 ++++- dbtsl/models/metric.py | 13 ++++- dbtsl/models/saved_query.py | 10 +++- dbtsl/models/time.py | 14 ++++- tests/api/adbc/test_protocol.py | 5 +- tests/test_models.py | 49 +++++++++++++++- 17 files changed, 173 insertions(+), 25 deletions(-) create mode 100644 .changes/unreleased/Deprecations-20241017-163158.yaml create mode 100644 .changes/unreleased/Features-20241017-163057.yaml create mode 100644 .changes/unreleased/Under the Hood-20241017-163037.yaml diff --git a/.changes/unreleased/Deprecations-20241017-163158.yaml b/.changes/unreleased/Deprecations-20241017-163158.yaml new file mode 100644 index 0000000..6f6b7bd --- /dev/null +++ b/.changes/unreleased/Deprecations-20241017-163158.yaml @@ -0,0 +1,3 @@ +kind: Deprecations +body: Deprecate `TimeGranularity` enum and all other fields that used it +time: 2024-10-17T16:31:58.091095+02:00 diff --git a/.changes/unreleased/Features-20241017-163057.yaml b/.changes/unreleased/Features-20241017-163057.yaml new file mode 100644 index 0000000..ed2feb7 --- /dev/null +++ b/.changes/unreleased/Features-20241017-163057.yaml @@ -0,0 +1,3 @@ +kind: Features +body: Add support for custom time granularity +time: 2024-10-17T16:30:57.023867+02:00 diff --git a/.changes/unreleased/Under the Hood-20241017-163037.yaml b/.changes/unreleased/Under the Hood-20241017-163037.yaml new file mode 100644 index 0000000..8fb4c5b --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241017-163037.yaml @@ -0,0 +1,3 @@ +kind: Under the Hood +body: Add new mechanisms to deprecate fields and classes +time: 2024-10-17T16:30:37.793294+02:00 diff --git a/.changie.yaml b/.changie.yaml index 8f33ee3..671e88d 100644 --- a/.changie.yaml +++ b/.changie.yaml @@ -8,7 +8,9 @@ kindFormat: '### {{.Kind}}' changeFormat: '* {{.Body}}' kinds: - label: Breaking Changes - auto: major + auto: minor + - label: Deprecations + auto: minor - label: Features auto: minor - label: Fixes diff --git a/dbtsl/__init__.py b/dbtsl/__init__.py index 6ea742f..0ae6d21 100644 --- a/dbtsl/__init__.py +++ b/dbtsl/__init__.py @@ -1,3 +1,4 @@ +# pyright: reportUnusedImport=false try: from dbtsl.client.sync import SyncSemanticLayerClient @@ -13,6 +14,7 @@ def err_factory(*args, **kwargs) -> None: # noqa: D103 SemanticLayerClient = err_factory +import dbtsl.models # noqa: F401 from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric __all__ = ["SemanticLayerClient", "OrderByMetric", "OrderByGroupBy"] diff --git a/dbtsl/api/adbc/protocol.py b/dbtsl/api/adbc/protocol.py index c49d176..b1a58b9 100644 --- a/dbtsl/api/adbc/protocol.py +++ b/dbtsl/api/adbc/protocol.py @@ -32,7 +32,7 @@ def _serialize_val(cls, val: Any) -> str: if isinstance(val, OrderByGroupBy): d = f'Dimension("{val.name}")' if val.grain: - grain_str = val.grain.name.lower() + grain_str = val.grain.lower() d += f'.grain("{grain_str}")' if val.descending: d += ".descending(True)" diff --git a/dbtsl/api/graphql/client/base.py b/dbtsl/api/graphql/client/base.py index 2c8ea19..b904f1a 100644 --- a/dbtsl/api/graphql/client/base.py +++ b/dbtsl/api/graphql/client/base.py @@ -1,4 +1,5 @@ import functools +import warnings from abc import abstractmethod from typing import Any, Dict, Generic, Optional, Protocol, TypeVar, Union @@ -102,10 +103,12 @@ def __getattr__(self, attr: str) -> Any: if op is None: raise AttributeError() - return functools.partial( - self._run, - op=op, - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return functools.partial( + self._run, + op=op, + ) TClient = TypeVar("TClient", bound=BaseGraphQLClient, covariant=True) diff --git a/dbtsl/api/shared/query_params.py b/dbtsl/api/shared/query_params.py index 083065b..538de28 100644 --- a/dbtsl/api/shared/query_params.py +++ b/dbtsl/api/shared/query_params.py @@ -1,8 +1,6 @@ from dataclasses import dataclass from typing import List, Optional, TypedDict, Union -from dbtsl.models.time import TimeGranularity - @dataclass(frozen=True) class OrderByMetric: @@ -20,7 +18,7 @@ class OrderByGroupBy: """ name: str - grain: Optional[TimeGranularity] + grain: Optional[str] descending: bool = False diff --git a/dbtsl/asyncio.py b/dbtsl/asyncio.py index 5692da2..cc1a4ce 100644 --- a/dbtsl/asyncio.py +++ b/dbtsl/asyncio.py @@ -1,3 +1,4 @@ +# pyright: reportUnusedImport=false try: from dbtsl.client.asyncio import AsyncSemanticLayerClient except ImportError: @@ -11,6 +12,7 @@ def err_factory(*args, **kwargs) -> None: # noqa: D103 AsyncSemanticLayerClient = err_factory +import dbtsl.models # noqa: F401 __all__ = [ "AsyncSemanticLayerClient", diff --git a/dbtsl/models/__init__.py b/dbtsl/models/__init__.py index 7884bb2..45394bc 100644 --- a/dbtsl/models/__init__.py +++ b/dbtsl/models/__init__.py @@ -25,7 +25,7 @@ # Only importing this so it registers aliases _ = QueryResult -BaseModel._apply_aliases() +BaseModel._register_subclasses() __all__ = [ "AggregationType", diff --git a/dbtsl/models/base.py b/dbtsl/models/base.py index ac044a3..930d57c 100644 --- a/dbtsl/models/base.py +++ b/dbtsl/models/base.py @@ -1,9 +1,10 @@ import inspect +import warnings from dataclasses import dataclass, fields, is_dataclass from dataclasses import field as dc_field from functools import cache from types import MappingProxyType -from typing import Any, List, Set, Type, Union +from typing import Any, ClassVar, Dict, List, Set, Type, Union from typing import get_args as get_type_args from typing import get_origin as get_type_origin @@ -25,19 +26,68 @@ class BaseModel(DataClassDictMixin): Adds some functionality like automatically creating camelCase aliases. """ + DEPRECATED: ClassVar[str] = "dbtsl_deprecated" + + # Mapping of "subclass.field" to "deprecation reason" + _deprecated_fields: ClassVar[Dict[str, str]] = dict() + + @staticmethod + def _get_deprecation_key(class_name: str, field_name: str) -> str: + return f"{class_name}.{field_name}" + + @classmethod + def _warn_if_deprecated(cls, field_name: str) -> None: + key = BaseModel._get_deprecation_key(cls.__name__, field_name) + reason = BaseModel._deprecated_fields.get(key) + if reason is not None: + warnings.warn(reason, DeprecationWarning) + class Config(BaseConfig): # noqa: D106 lazy_compilation = True @classmethod - def _apply_aliases(cls) -> None: - """Apply camelCase aliases to all subclasses.""" + def _register_subclasses(cls) -> None: + """Process fields of all subclasses. + + This will: + - Apply camelCase aliases + - Pre-populate the _deprecated_fields dict with the deprecated fields + """ for subclass in cls.__subclasses__(): assert is_dataclass(subclass), "Subclass of BaseModel must be dataclass" for field in fields(subclass): camel_name = snake_case_to_camel_case(field.name) if field.name != camel_name: - field.metadata = MappingProxyType(field_options(alias=camel_name)) + opts = field_options(alias=camel_name) + if field.metadata is not None: + opts = {**opts, **field.metadata} + field.metadata = MappingProxyType(opts) + + if cls.DEPRECATED in field.metadata: + reason = field.metadata[cls.DEPRECATED] + key = BaseModel._get_deprecation_key(subclass.__name__, field.name) + cls._deprecated_fields[key] = reason + + def __getattribute__(self, name: str) -> Any: # noqa: D105 + v = object.__getattribute__(self, name) + if not name.startswith("__") and not callable(v): + self._warn_if_deprecated(name) + + return v + + +class DeprecatedMixin: + """Add this to any deprecated model.""" + + @classmethod + def _deprecation_message(cls) -> str: + """The deprecation message that will get displayed.""" + return f"{cls.__name__} is deprecated" + + def __init__(self, *args, **kwargs) -> None: # noqa: D107 + warnings.warn(self._deprecation_message(), DeprecationWarning) + super(DeprecatedMixin, self).__init__() @dataclass(frozen=True, eq=True) diff --git a/dbtsl/models/dimension.py b/dbtsl/models/dimension.py index 05e44fa..c19e1f7 100644 --- a/dbtsl/models/dimension.py +++ b/dbtsl/models/dimension.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import List, Optional @@ -13,6 +13,12 @@ class DimensionType(str, Enum): TIME = "TIME" +QUERYABLE_GRANULARITIES_DEPRECATION = ( + "Since the introduction of custom time granularities, `Dimension.queryable_granularities` is deprecated. " + "Use `queryable_time_granularities` instead." +) + + @dataclass(frozen=True) class Dimension(BaseModel, GraphQLFragmentMixin): """A metric dimension.""" @@ -24,4 +30,7 @@ class Dimension(BaseModel, GraphQLFragmentMixin): label: Optional[str] is_partition: bool expr: Optional[str] - queryable_granularities: List[TimeGranularity] + queryable_granularities: List[TimeGranularity] = field( + metadata={BaseModel.DEPRECATED: QUERYABLE_GRANULARITIES_DEPRECATION} + ) + queryable_time_granularities: List[str] diff --git a/dbtsl/models/metric.py b/dbtsl/models/metric.py index 034d4df..cb2a790 100644 --- a/dbtsl/models/metric.py +++ b/dbtsl/models/metric.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from typing import List, Optional @@ -19,6 +19,12 @@ class MetricType(str, Enum): CONVERSION = "CONVERSION" +QUERYABLE_GRANULARITIES_DEPRECATION = ( + "Since the introduction of custom time granularities, `Metric.queryable_granularities` is deprecated. " + "Use `queryable_time_granularities` instead." +) + + @dataclass(frozen=True) class Metric(BaseModel, GraphQLFragmentMixin): """A metric.""" @@ -29,6 +35,9 @@ class Metric(BaseModel, GraphQLFragmentMixin): dimensions: List[Dimension] measures: List[Measure] entities: List[Entity] - queryable_granularities: List[TimeGranularity] + queryable_granularities: List[TimeGranularity] = field( + metadata={BaseModel.DEPRECATED: QUERYABLE_GRANULARITIES_DEPRECATION} + ) + queryable_time_granularities: List[str] label: str requires_metric_time: bool diff --git a/dbtsl/models/saved_query.py b/dbtsl/models/saved_query.py index ce0068d..478a44e 100644 --- a/dbtsl/models/saved_query.py +++ b/dbtsl/models/saved_query.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from dataclasses import field as dc_field from enum import Enum from typing import List, Optional @@ -37,12 +38,19 @@ class SavedQueryMetricParam(BaseModel, GraphQLFragmentMixin): name: str +GRAIN_DEPRECATION = ( + "Since the introduction of custom time granularities, `SavedQueryGroupByParam.grain` is deprecated. " + "Use `time_granularity` instead." +) + + @dataclass(frozen=True) class SavedQueryGroupByParam(BaseModel, GraphQLFragmentMixin): """The groupBy param of a saved query.""" name: str - grain: Optional[TimeGranularity] + grain: Optional[TimeGranularity] = dc_field(metadata={BaseModel.DEPRECATED: GRAIN_DEPRECATION}) + time_granularity: Optional[str] date_part: Optional[DatePart] diff --git a/dbtsl/models/time.py b/dbtsl/models/time.py index 81631de..d5f77e8 100644 --- a/dbtsl/models/time.py +++ b/dbtsl/models/time.py @@ -1,9 +1,21 @@ from enum import Enum +from typing_extensions import override -class TimeGranularity(str, Enum): +from dbtsl.models.base import DeprecatedMixin + + +class TimeGranularity(str, DeprecatedMixin, Enum): """A time granularity.""" + @override + @classmethod + def _deprecation_message(cls) -> str: + return ( + "Since the introduction of custom time granularity, the `TimeGranularity` enum is deprecated. " + "Please just use strings to represent time grains." + ) + NANOSECOND = "NANOSECOND" MICROSECOND = "MICROSECOND" MILLISECOND = "MILLISECOND" diff --git a/tests/api/adbc/test_protocol.py b/tests/api/adbc/test_protocol.py index fec3e74..b3a40e0 100644 --- a/tests/api/adbc/test_protocol.py +++ b/tests/api/adbc/test_protocol.py @@ -1,6 +1,5 @@ from dbtsl.api.adbc.protocol import ADBCProtocol from dbtsl.api.shared.query_params import OrderByGroupBy, OrderByMetric -from dbtsl.models.time import TimeGranularity def test_serialize_val_basic_values() -> None: @@ -23,11 +22,11 @@ def test_serialize_val_OrderByGroupBy() -> None: == 'Dimension("m").descending(True)' ) assert ( - ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.DAY, descending=False)) + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain="day", descending=False)) == 'Dimension("m").grain("day")' ) assert ( - ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain=TimeGranularity.WEEK, descending=True)) + ADBCProtocol._serialize_val(OrderByGroupBy(name="m", grain="week", descending=True)) == 'Dimension("m").grain("week").descending(True)' ) diff --git a/tests/test_models.py b/tests/test_models.py index 50f1957..562b957 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,8 +1,11 @@ +import warnings from dataclasses import dataclass +from dataclasses import field as dc_field from typing import List import pytest from mashumaro.codecs.basic import decode +from typing_extensions import override from dbtsl.api.graphql.util import normalize_query from dbtsl.api.shared.query_params import ( @@ -14,7 +17,7 @@ validate_order_by, validate_query_parameters, ) -from dbtsl.models.base import BaseModel, GraphQLFragmentMixin +from dbtsl.models.base import BaseModel, DeprecatedMixin, GraphQLFragmentMixin from dbtsl.models.base import snake_case_to_camel_case as stc @@ -31,7 +34,7 @@ def test_base_model_auto_alias() -> None: class SubModel(BaseModel): hello_world: str - BaseModel._apply_aliases() + BaseModel._register_subclasses() data = { "helloWorld": "asdf", @@ -89,6 +92,48 @@ class B(BaseModel, GraphQLFragmentMixin): assert b_fragments[1] == a_fragment +def test_DeprecatedMixin() -> None: + msg = "i am deprecated :(" + + class MyDeprecatedClass(DeprecatedMixin): + @override + @classmethod + def _deprecation_message(cls) -> str: + return msg + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + _ = MyDeprecatedClass() + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert msg == str(w[0].message) + + +def test_attr_deprecation_warning() -> None: + msg = "i am deprecated :(" + + @dataclass(frozen=True) + class MyClassWithDeprecatedField(BaseModel): + its_fine: bool = True + oh_no: bool = dc_field(default=False, metadata={BaseModel.DEPRECATED: msg}) + + BaseModel._register_subclasses() + + m = MyClassWithDeprecatedField() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + _ = m.its_fine + assert len(w) == 0 + + _ = m.oh_no + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert msg == str(w[0].message) + + def test_validate_order_by_params_passthrough_OrderByMetric() -> None: i = OrderByMetric(name="asdf", descending=True) r = validate_order_by([], [], i)