From 4ec15bec94985b68024530cc371b910b9a00b27e Mon Sep 17 00:00:00 2001 From: leohoare Date: Sun, 12 Jan 2025 19:42:53 +1100 Subject: [PATCH] refactor, switch to single client with common code and fallback Signed-off-by: leohoare --- openfeature/client.py | 435 ++++++++++++++++++++++++++++--- openfeature/provider/__init__.py | 77 +++++- tests/test_client.py | 34 ++- 3 files changed, 506 insertions(+), 40 deletions(-) diff --git a/openfeature/client.py b/openfeature/client.py index 9e4518ec..448772dd 100644 --- a/openfeature/client.py +++ b/openfeature/client.py @@ -20,7 +20,7 @@ FlagType, Reason, ) -from openfeature.hook import Hook, HookContext +from openfeature.hook import Hook, HookContext, HookHints from openfeature.hook._hook_support import ( after_all_hooks, after_hooks, @@ -55,6 +55,28 @@ FlagResolutionDetails[typing.Union[dict, list]], ], ] +GetDetailCallableAsync = typing.Union[ + typing.Callable[ + [str, bool, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[bool]], + ], + typing.Callable[ + [str, int, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[int]], + ], + typing.Callable[ + [str, float, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[float]], + ], + typing.Callable[ + [str, str, typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[str]], + ], + typing.Callable[ + [str, typing.Union[dict, list], typing.Optional[EvaluationContext]], + typing.Awaitable[FlagResolutionDetails[typing.Union[dict, list]]], + ], +] TypeMap = typing.Dict[ FlagType, typing.Union[ @@ -113,6 +135,21 @@ def get_boolean_value( flag_evaluation_options, ).value + async def get_boolean_value_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> bool: + details = await self.get_boolean_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_boolean_details( self, flag_key: str, @@ -128,6 +165,21 @@ def get_boolean_details( flag_evaluation_options, ) + async def get_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[bool]: + return await self.evaluate_flag_details_async( + FlagType.BOOLEAN, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_string_value( self, flag_key: str, @@ -142,6 +194,21 @@ def get_string_value( flag_evaluation_options, ).value + async def get_string_value_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> str: + details = await self.get_string_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_string_details( self, flag_key: str, @@ -157,6 +224,21 @@ def get_string_details( flag_evaluation_options, ) + async def get_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[str]: + return await self.evaluate_flag_details_async( + FlagType.STRING, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_integer_value( self, flag_key: str, @@ -171,6 +253,21 @@ def get_integer_value( flag_evaluation_options, ).value + async def get_integer_value_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> int: + details = await self.get_integer_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_integer_details( self, flag_key: str, @@ -186,6 +283,21 @@ def get_integer_details( flag_evaluation_options, ) + async def get_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[int]: + return await self.evaluate_flag_details_async( + FlagType.INTEGER, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_float_value( self, flag_key: str, @@ -200,6 +312,21 @@ def get_float_value( flag_evaluation_options, ).value + async def get_float_value_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> float: + details = await self.get_float_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_float_details( self, flag_key: str, @@ -215,6 +342,21 @@ def get_float_details( flag_evaluation_options, ) + async def get_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[float]: + return await self.evaluate_flag_details_async( + FlagType.FLOAT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def get_object_value( self, flag_key: str, @@ -229,6 +371,21 @@ def get_object_value( flag_evaluation_options, ).value + async def get_object_value_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> typing.Union[dict, list]: + details = await self.get_object_details_async( + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + return details.value + def get_object_details( self, flag_key: str, @@ -244,26 +401,35 @@ def get_object_details( flag_evaluation_options, ) - def evaluate_flag_details( # noqa: PLR0915 + async def get_object_details_async( self, - flag_type: FlagType, flag_key: str, - default_value: typing.Any, + default_value: typing.Union[dict, list], evaluation_context: typing.Optional[EvaluationContext] = None, flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, - ) -> FlagEvaluationDetails[typing.Any]: - """ - Evaluate the flag requested by the user from the clients provider. - - :param flag_type: the type of the flag being returned - :param flag_key: the string key of the selected flag - :param default_value: backup value returned if no result found by the provider - :param evaluation_context: Information for the purposes of flag evaluation - :param flag_evaluation_options: Additional flag evaluation information - :return: a FlagEvaluationDetails object with the fully evaluated flag from a - provider - """ + ) -> FlagEvaluationDetails[typing.Union[dict, list]]: + return await self.evaluate_flag_details_async( + FlagType.OBJECT, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + def _establish_hooks_and_provider( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext], + flag_evaluation_options: typing.Optional[FlagEvaluationOptions], + ) -> typing.Tuple[ + FeatureProvider, + HookContext, + HookHints, + typing.List[Hook], + typing.List[Hook], + ]: if evaluation_context is None: evaluation_context = EvaluationContext() @@ -295,7 +461,17 @@ def evaluate_flag_details( # noqa: PLR0915 reversed_merged_hooks = merged_hooks[:] reversed_merged_hooks.reverse() + return provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks + + def _assert_provider_status( + self, + flag_type: FlagType, + hook_context: HookContext, + reversed_merged_hooks: typing.List[Hook], + hook_hints: HookHints, + ) -> typing.Union[None, ErrorCode]: status = self.get_provider_status() + if status == ProviderStatus.NOT_READY: error_hooks( flag_type, @@ -304,42 +480,191 @@ def evaluate_flag_details( # noqa: PLR0915 reversed_merged_hooks, hook_hints, ) + return ErrorCode.PROVIDER_NOT_READY + if status == ProviderStatus.FATAL: + error_hooks( + flag_type, + hook_context, + ProviderFatalError(), + reversed_merged_hooks, + hook_hints, + ) + return ErrorCode.PROVIDER_FATAL + return None + + def _before_hooks_and_merge_context( + self, + flag_type: FlagType, + hook_context: HookContext, + merged_hooks: typing.List[Hook], + hook_hints: HookHints, + evaluation_context: typing.Optional[EvaluationContext], + ) -> EvaluationContext: + # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md + # Any resulting evaluation context from a before hook will overwrite + # duplicate fields defined globally, on the client, or in the invocation. + # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context + invocation_context = before_hooks( + flag_type, hook_context, merged_hooks, hook_hints + ) + if evaluation_context: + invocation_context = invocation_context.merge(ctx2=evaluation_context) + + # Requirement 3.2.2 merge: API.context->client.context->invocation.context + merged_context = ( + api.get_evaluation_context().merge(self.context).merge(invocation_context) + ) + return merged_context + + async def evaluate_flag_details_async( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a typing.Awaitable[FlagEvaluationDetails] object with the fully evaluated flag from a + provider + """ + provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = ( + self._establish_hooks_and_provider( + flag_type, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + ) + error_code = self._assert_provider_status( + flag_type, + hook_context, + reversed_merged_hooks, + hook_hints, + ) + if error_code: return FlagEvaluationDetails( flag_key=flag_key, value=default_value, reason=Reason.ERROR, - error_code=ErrorCode.PROVIDER_NOT_READY, + error_code=error_code, ) - if status == ProviderStatus.FATAL: - error_hooks( + + try: + merged_context = self._before_hooks_and_merge_context( flag_type, hook_context, - ProviderFatalError(), + merged_hooks, + hook_hints, + evaluation_context, + ) + + flag_evaluation = await self._create_provider_evaluation_async( + provider, + flag_type, + flag_key, + default_value, + merged_context, + ) + + after_hooks( + flag_type, + hook_context, + flag_evaluation, reversed_merged_hooks, hook_hints, ) + + return flag_evaluation + + except OpenFeatureError as err: + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + return FlagEvaluationDetails( flag_key=flag_key, value=default_value, reason=Reason.ERROR, - error_code=ErrorCode.PROVIDER_FATAL, + error_code=err.error_code, + error_message=err.error_message, + ) + # Catch any type of exception here since the user can provide any exception + # in the error hooks + except Exception as err: # pragma: no cover + logger.exception( + "Unable to correctly evaluate flag with key: '%s'", flag_key ) - try: - # https://github.com/open-feature/spec/blob/main/specification/sections/03-evaluation-context.md - # Any resulting evaluation context from a before hook will overwrite - # duplicate fields defined globally, on the client, or in the invocation. - # Requirement 3.2.2, 4.3.4: API.context->client.context->invocation.context - invocation_context = before_hooks( - flag_type, hook_context, merged_hooks, hook_hints + error_hooks(flag_type, hook_context, err, reversed_merged_hooks, hook_hints) + + error_message = getattr(err, "error_message", str(err)) + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=ErrorCode.GENERAL, + error_message=error_message, + ) + + finally: + after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) + + def evaluate_flag_details( + self, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + flag_evaluation_options: typing.Optional[FlagEvaluationOptions] = None, + ) -> FlagEvaluationDetails[typing.Any]: + """ + Evaluate the flag requested by the user from the clients provider. + + :param flag_type: the type of the flag being returned + :param flag_key: the string key of the selected flag + :param default_value: backup value returned if no result found by the provider + :param evaluation_context: Information for the purposes of flag evaluation + :param flag_evaluation_options: Additional flag evaluation information + :return: a FlagEvaluationDetails object with the fully evaluated flag from a + provider + """ + provider, hook_context, hook_hints, merged_hooks, reversed_merged_hooks = ( + self._establish_hooks_and_provider( + flag_type, + flag_key, + default_value, + evaluation_context, + flag_evaluation_options, + ) + ) + error_code = self._assert_provider_status( + flag_type, + hook_context, + reversed_merged_hooks, + hook_hints, + ) + if error_code: + return FlagEvaluationDetails( + flag_key=flag_key, + value=default_value, + reason=Reason.ERROR, + error_code=error_code, ) - invocation_context = invocation_context.merge(ctx2=evaluation_context) - # Requirement 3.2.2 merge: API.context->client.context->invocation.context - merged_context = ( - api.get_evaluation_context() - .merge(self.context) - .merge(invocation_context) + try: + merged_context = self._before_hooks_and_merge_context( + flag_type, + hook_context, + merged_hooks, + hook_hints, + evaluation_context, ) flag_evaluation = self._create_provider_evaluation( @@ -391,6 +716,48 @@ def evaluate_flag_details( # noqa: PLR0915 finally: after_all_hooks(flag_type, hook_context, reversed_merged_hooks, hook_hints) + async def _create_provider_evaluation_async( + self, + provider: FeatureProvider, + flag_type: FlagType, + flag_key: str, + default_value: typing.Any, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagEvaluationDetails[typing.Any]: + args = ( + flag_key, + default_value, + evaluation_context, + ) + get_details_callables_async: typing.Mapping[ + FlagType, GetDetailCallableAsync + ] = { + FlagType.BOOLEAN: provider.resolve_boolean_details_async, + FlagType.INTEGER: provider.resolve_integer_details_async, + FlagType.FLOAT: provider.resolve_float_details_async, + FlagType.OBJECT: provider.resolve_object_details_async, + FlagType.STRING: provider.resolve_string_details_async, + } + get_details_callable = get_details_callables_async.get(flag_type) + if not get_details_callable: + raise GeneralError(error_message="Unknown flag type") + + resolution = await get_details_callable(*args) + resolution.raise_for_error() + + # we need to check the get_args to be compatible with union types. + _typecheck_flag_value(resolution.value, flag_type) + + return FlagEvaluationDetails( + flag_key=flag_key, + value=resolution.value, + variant=resolution.variant, + flag_metadata=resolution.flag_metadata or {}, + reason=resolution.reason, + error_code=resolution.error_code, + error_message=resolution.error_message, + ) + def _create_provider_evaluation( self, provider: FeatureProvider, diff --git a/openfeature/provider/__init__.py b/openfeature/provider/__init__.py index 8927551e..6a782635 100644 --- a/openfeature/provider/__init__.py +++ b/openfeature/provider/__init__.py @@ -11,7 +11,7 @@ from .metadata import Metadata -__all__ = ["AbstractProvider", "ProviderStatus", "FeatureProvider", "Metadata"] +__all__ = ["AbstractProvider", "FeatureProvider", "Metadata", "ProviderStatus"] class ProviderStatus(Enum): @@ -47,6 +47,13 @@ def resolve_boolean_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[bool]: ... + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: ... + def resolve_string_details( self, flag_key: str, @@ -54,6 +61,13 @@ def resolve_string_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[str]: ... + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: ... + def resolve_integer_details( self, flag_key: str, @@ -61,6 +75,13 @@ def resolve_integer_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[int]: ... + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: ... + def resolve_float_details( self, flag_key: str, @@ -68,6 +89,13 @@ def resolve_float_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[float]: ... + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: ... + def resolve_object_details( self, flag_key: str, @@ -75,6 +103,13 @@ def resolve_object_details( evaluation_context: typing.Optional[EvaluationContext] = None, ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: ... + class AbstractProvider(FeatureProvider): def attach( @@ -111,6 +146,14 @@ def resolve_boolean_details( ) -> FlagResolutionDetails[bool]: pass + async def resolve_boolean_details_async( + self, + flag_key: str, + default_value: bool, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[bool]: + return self.resolve_boolean_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_string_details( self, @@ -120,6 +163,14 @@ def resolve_string_details( ) -> FlagResolutionDetails[str]: pass + async def resolve_string_details_async( + self, + flag_key: str, + default_value: str, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[str]: + return self.resolve_string_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_integer_details( self, @@ -129,6 +180,14 @@ def resolve_integer_details( ) -> FlagResolutionDetails[int]: pass + async def resolve_integer_details_async( + self, + flag_key: str, + default_value: int, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[int]: + return self.resolve_integer_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_float_details( self, @@ -138,6 +197,14 @@ def resolve_float_details( ) -> FlagResolutionDetails[float]: pass + async def resolve_float_details_async( + self, + flag_key: str, + default_value: float, + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[float]: + return self.resolve_float_details(flag_key, default_value, evaluation_context) + @abstractmethod def resolve_object_details( self, @@ -147,6 +214,14 @@ def resolve_object_details( ) -> FlagResolutionDetails[typing.Union[dict, list]]: pass + async def resolve_object_details_async( + self, + flag_key: str, + default_value: typing.Union[dict, list], + evaluation_context: typing.Optional[EvaluationContext] = None, + ) -> FlagResolutionDetails[typing.Union[dict, list]]: + return self.resolve_object_details(flag_key, default_value, evaluation_context) + def emit_provider_ready(self, details: ProviderEventDetails) -> None: self.emit(ProviderEvent.PROVIDER_READY, details) diff --git a/tests/test_client.py b/tests/test_client.py index b51c460c..8a56c2e6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import asyncio import time import uuid from concurrent.futures import ThreadPoolExecutor @@ -57,9 +58,13 @@ def test_should_get_flag_value_based_on_method_type( "flag_type, default_value, get_method", ( (bool, True, "get_boolean_details"), + (bool, True, "get_boolean_details_async"), (str, "String", "get_string_details"), + (str, "String", "get_string_details_async"), (int, 100, "get_integer_details"), + (int, 100, "get_integer_details_async"), (float, 10.23, "get_float_details"), + (float, 10.23, "get_float_details_async"), ( dict, { @@ -69,28 +74,47 @@ def test_should_get_flag_value_based_on_method_type( }, "get_object_details", ), + ( + dict, + { + "String": "string", + "Number": 2, + "Boolean": True, + }, + "get_object_details_async", + ), ( list, ["string1", "string2"], "get_object_details", ), + ( + list, + ["string1", "string2"], + "get_object_details_async", + ), ), ) -def test_should_get_flag_detail_based_on_method_type( +@pytest.mark.asyncio +async def test_should_get_flag_detail_based_on_method_type( flag_type, default_value, get_method, no_op_provider_client ): # Given # When - flag = getattr(no_op_provider_client, get_method)( - flag_key="Key", default_value=default_value - ) + method = getattr(no_op_provider_client, get_method) + if asyncio.iscoroutinefunction(method): + flag = await method(flag_key="Key", default_value=default_value) + else: + flag = method(flag_key="Key", default_value=default_value) # Then assert flag is not None assert flag.value == default_value assert isinstance(flag.value, flag_type) -def test_should_raise_exception_when_invalid_flag_type_provided(no_op_provider_client): +def test_should_raise_exception_when_invalid_flag_type_provided( + no_op_provider_client, +): # Given # When flag = no_op_provider_client.evaluate_flag_details(