diff --git a/CHANGELOG.md b/CHANGELOG.md index ff9ad46..8cba293 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.1.17](https://github.com/wesky93/grpc_requests/releases/tag/v0.1.17) - 2024-04-22 + +## Added + +- Support for custom message parsing in both async and sync clients + +## Removed + +- Removed singular FileDescriptor getter methods and Method specific field descriptor + methods as laid out previously. + ## [0.1.16](https://github.com/wesky93/grpc_requests/releases/tag/v0.1.16) - 2024-03-03 ## Added diff --git a/src/examples/README.md b/src/examples/README.md index 1ed2d45..964f36e 100644 --- a/src/examples/README.md +++ b/src/examples/README.md @@ -123,6 +123,26 @@ result = await greeter.HelloEveryone(requests_data) results = [x async for x in await greeter.SayHelloOneByOne(requests_data)] ``` +## Setting a Client's message_to_dict behavior + +By utilizing `CustomArgumentParsers`, behavioral arguments can be passed to +message_to_dict at time of Client instantiation. This is available for both +synchronous and asynchronous clients. + +```python +client = Client( + "localhost:50051", + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) +``` + +[Review the json_format documentation for what kwargs are available to message_to_dict.](https://googleapis.dev/python/protobuf/latest/google/protobuf/json_format.html) + ## Retrieving Information about a Server All forms of clients expose methods to allow a user to query a server about its diff --git a/src/grpc_requests/__init__.py b/src/grpc_requests/__init__.py index 1254b6f..642ea0e 100644 --- a/src/grpc_requests/__init__.py +++ b/src/grpc_requests/__init__.py @@ -7,4 +7,4 @@ ) from .client import Client, ReflectionClient, StubClient, get_by_endpoint -__version__ = "0.1.16" +__version__ = "0.1.17" diff --git a/src/grpc_requests/aio.py b/src/grpc_requests/aio.py index 00afe9d..36d84fd 100644 --- a/src/grpc_requests/aio.py +++ b/src/grpc_requests/aio.py @@ -1,6 +1,5 @@ import logging import sys -import warnings from enum import Enum from functools import partial from typing import ( @@ -18,10 +17,16 @@ import grpc from google.protobuf import ( descriptor_pb2, + message_factory, +) +from google.protobuf import ( descriptor_pool as _descriptor_pool, +) +from google.protobuf import ( symbol_database as _symbol_database, - message_factory, -) # noqa: E501 +) + +# noqa: E501 from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from google.protobuf.descriptor_pb2 import ServiceDescriptorProto from google.protobuf.json_format import MessageToDict, ParseDict @@ -34,11 +39,13 @@ if sys.version_info >= (3, 8): import importlib.metadata + from typing import Protocol def get_metadata(package_name: str): return importlib.metadata.version(package_name) else: import pkg_resources + from typing_extensions import Protocol def get_metadata(package_name: str): return pkg_resources.get_distribution(package_name).version @@ -146,27 +153,67 @@ def __del__(self): pass -def parse_request_data(reqeust_data, input_type): - _data = reqeust_data or {} - if isinstance(_data, dict): - request = ParseDict(_data, input_type()) - else: - request = _data - return request +class MessageParsersProtocol(Protocol): + def parse_request_data(self, request_data, input_type): ... + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): ... -def parse_stream_requests(stream_requests_data: Iterable, input_type): - for request_data in stream_requests_data: - yield parse_request_data(request_data or {}, input_type) + async def parse_response(self, response): ... + async def parse_stream_responses(self, responses: AsyncIterable): ... + + +class MessageParsers(MessageParsersProtocol): + def parse_request_data(self, request_data, input_type): + _data = request_data or {} + if isinstance(_data, dict): + request = ParseDict(_data, input_type()) + else: + request = _data + return request + + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): + for request_data in stream_requests_data: + yield self.parse_request_data(request_data or {}, input_type) + + async def parse_response(self, response): + return MessageToDict(response, preserving_proto_field_name=True) + + async def parse_stream_responses(self, responses: AsyncIterable): + async for resp in responses: + yield await self.parse_response(resp) + + +class CustomArgumentParsers(MessageParsersProtocol): + _message_to_dict_kwargs: Dict[str, Any] + _parse_dict_kwargs: Dict[str, Any] + + def __init__( + self, + message_to_dict_kwargs: Dict[str, Any] = dict(), + parse_dict_kwargs: Dict[str, Any] = dict(), + ): + self._message_to_dict_kwargs = message_to_dict_kwargs or {} + self._parse_dict_kwargs = parse_dict_kwargs or {} + + def parse_request_data(self, request_data, input_type): + _data = request_data or {} + if isinstance(_data, dict): + request = ParseDict(_data, input_type(), **self._parse_dict_kwargs) + else: + request = _data + return request -async def parse_response(response): - return MessageToDict(response, preserving_proto_field_name=True) + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): + for request_data in stream_requests_data: + yield self.parse_request_data(request_data or {}, input_type) + async def parse_response(self, response): + return MessageToDict(response, **self._message_to_dict_kwargs) -async def parse_stream_responses(responses: AsyncIterable): - async for resp in responses: - yield await parse_response(resp) + async def parse_stream_responses(self, responses: AsyncIterable): + async for resp in responses: + yield await self.parse_response(resp) class MethodType(Enum): @@ -179,18 +226,10 @@ class MethodType(Enum): def is_unary_request(self): return "unary_" in self.value - @property - def request_parser(self): - return parse_request_data if self.is_unary_request else parse_stream_requests - @property def is_unary_response(self): return "_unary" in self.value - @property - def response_parser(self): - return parse_response if self.is_unary_response else parse_stream_responses - class MethodMetaData(NamedTuple): input_type: Any @@ -198,6 +237,21 @@ class MethodMetaData(NamedTuple): method_type: MethodType handler: Any descriptor: MethodDescriptor + parsers: MessageParsersProtocol + + @property + def request_parser(self): + if self.method_type.is_unary_request: + return self.parsers.parse_request_data + else: + return self.parsers.parse_stream_requests + + @property + def response_parser(self): + if self.method_type.is_unary_response: + return self.parsers.parse_response + else: + return self.parsers.parse_stream_responses IS_REQUEST_STREAM = TypeVar("IS_REQUEST_STREAM") @@ -220,6 +274,7 @@ def __init__( ssl=False, compression=None, skip_check_method_available=False, + message_parsers: MessageParsersProtocol = MessageParsers(), **kwargs, ): super().__init__( @@ -233,6 +288,7 @@ def __init__( self._service_names: list = None self.has_server_registered = False self._skip_check_method_available = skip_check_method_available + self._message_parsers = message_parsers self._services_module_name = {} self._service_methods_meta: Dict[str, Dict[str, MethodMetaData]] = {} @@ -309,6 +365,7 @@ def _register_methods( output_type=output_type, handler=handler, descriptor=method_desc, + parsers=self._message_parsers, ) return metadata @@ -348,19 +405,17 @@ async def _request(self, service, method, request, raw_output=False, **kwargs): # does not check request is available method_meta = self.get_method_meta(service, method) - _request = method_meta.method_type.request_parser( - request, method_meta.input_type - ) + _request = method_meta.request_parser(request, method_meta.input_type) if method_meta.method_type.is_unary_response: result = await method_meta.handler(_request, **kwargs) if raw_output: return result else: - return await method_meta.method_type.response_parser(result) + return await method_meta.response_parser(result) else: result = method_meta.handler(_request, **kwargs) - return method_meta.method_type.response_parser(result) + return method_meta.response_parser(result) async def request(self, service, method, request=None, raw_output=False, **kwargs): await self.check_method_available(service, method) @@ -427,6 +482,7 @@ def __init__( descriptor_pool=None, ssl=False, compression=None, + message_parsers: MessageParsersProtocol = MessageParsers(), **kwargs, ): super().__init__( @@ -435,6 +491,7 @@ def __init__( descriptor_pool, ssl=ssl, compression=compression, + message_parsers=message_parsers, **kwargs, ) self.reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self.channel) @@ -453,26 +510,6 @@ async def _get_service_names(self): services = tuple([s.name for s in resp.list_services_response.service]) return services - async def get_file_descriptor_by_name(self, name): - warnings.warn( - "This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_name() instead.", - DeprecationWarning, - ) - request = reflection_pb2.ServerReflectionRequest(file_by_filename=name) - result = await self._reflection_single_request(request) - proto = result.file_descriptor_response.file_descriptor_proto[0] - return descriptor_pb2.FileDescriptorProto.FromString(proto) - - async def get_file_descriptor_by_symbol(self, symbol): - warnings.warn( - "This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_symbol() instead.", - DeprecationWarning, - ) - request = reflection_pb2.ServerReflectionRequest(file_containing_symbol=symbol) - result = await self._reflection_single_request(request) - proto = result.file_descriptor_response.file_descriptor_proto[0] - return descriptor_pb2.FileDescriptorProto.FromString(proto) - async def get_file_descriptors_by_name(self, name): request = reflection_pb2.ServerReflectionRequest(file_by_filename=name) result = await self._reflection_single_request(request) diff --git a/src/grpc_requests/client.py b/src/grpc_requests/client.py index 9e31c09..c954a25 100644 --- a/src/grpc_requests/client.py +++ b/src/grpc_requests/client.py @@ -1,6 +1,5 @@ import logging import sys -import warnings from enum import Enum from functools import partial from typing import ( @@ -16,25 +15,27 @@ ) import grpc -from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor_pb2, message_factory from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message_factory from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from google.protobuf.descriptor_pb2 import ServiceDescriptorProto from google.protobuf.json_format import MessageToDict, ParseDict from grpc_reflection.v1alpha import reflection_pb2, reflection_pb2_grpc -from .utils import describe_descriptor, describe_request, load_data +from .utils import describe_descriptor, load_data if sys.version_info >= (3, 8): import importlib.metadata - from typing import TypedDict # pylint: disable=no-name-in-module + from typing import ( + Protocol, + TypedDict, # pylint: disable=no-name-in-module + ) def get_metadata(package_name: str): return importlib.metadata.version(package_name) else: import pkg_resources - from typing_extensions import TypedDict + from typing_extensions import Protocol, TypedDict def get_metadata(package_name: str): return pkg_resources.get_distribution(package_name).version @@ -147,24 +148,67 @@ def __del__(self): logger.warning("can not delete channel", exc_info=e) -def parse_request_data(request_data, input_type): - _data = request_data or {} - request = ParseDict(_data, input_type()) if isinstance(_data, dict) else _data - return request +class MessageParsersProtocol(Protocol): + def parse_request_data(self, request_data, input_type): ... + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): ... -def parse_stream_requests(stream_requests_data: Iterable, input_type): - for request_data in stream_requests_data: - yield parse_request_data(request_data or {}, input_type) + def parse_response(self, response): ... + def parse_stream_responses(self, responses: Iterable): ... -def parse_response(response): - return MessageToDict(response, preserving_proto_field_name=True) + +class MessageParsers(MessageParsersProtocol): + def parse_request_data(self, request_data, input_type): + _data = request_data or {} + if isinstance(_data, dict): + request = ParseDict(_data, input_type()) + else: + request = _data + return request + + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): + for request_data in stream_requests_data: + yield self.parse_request_data(request_data or {}, input_type) + + def parse_response(self, response): + return MessageToDict(response, preserving_proto_field_name=True) + + def parse_stream_responses(self, responses: Iterable): + for resp in responses: + yield self.parse_response(resp) -def parse_stream_responses(responses: Iterable): - for resp in responses: - yield parse_response(resp) +class CustomArgumentParsers(MessageParsersProtocol): + _message_to_dict_kwargs: Dict[str, Any] + _parse_dict_kwargs: Dict[str, Any] + + def __init__( + self, + message_to_dict_kwargs: Dict[str, Any] = dict(), + parse_dict_kwargs: Dict[str, Any] = dict(), + ): + self._message_to_dict_kwargs = message_to_dict_kwargs or {} + self._parse_dict_kwargs = parse_dict_kwargs or {} + + def parse_request_data(self, request_data, input_type): + _data = request_data or {} + if isinstance(_data, dict): + request = ParseDict(_data, input_type(), **self._parse_dict_kwargs) + else: + request = _data + return request + + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): + for request_data in stream_requests_data: + yield self.parse_request_data(request_data or {}, input_type) + + def parse_response(self, response): + return MessageToDict(response, **self._message_to_dict_kwargs) + + def parse_stream_responses(self, responses: Iterable): + for resp in responses: + yield self.parse_response(resp) class MethodType(Enum): @@ -177,18 +221,10 @@ class MethodType(Enum): def is_unary_request(self): return "unary_" in self.value - @property - def request_parser(self): - return parse_request_data if self.is_unary_request else parse_stream_requests - @property def is_unary_response(self): return "_unary" in self.value - @property - def response_parser(self): - return parse_response if self.is_unary_response else parse_stream_responses - class MethodMetaData(NamedTuple): input_type: Any @@ -196,6 +232,21 @@ class MethodMetaData(NamedTuple): method_type: MethodType handler: Any descriptor: MethodDescriptor + parsers: MessageParsersProtocol + + @property + def request_parser(self): + if self.method_type.is_unary_request: + return self.parsers.parse_request_data + else: + return self.parsers.parse_stream_requests + + @property + def response_parser(self): + if self.method_type.is_unary_response: + return self.parsers.parse_response + else: + return self.parsers.parse_stream_responses IS_REQUEST_STREAM = TypeVar("IS_REQUEST_STREAM") @@ -219,6 +270,7 @@ def __init__( ssl=False, compression=None, skip_check_method_available=False, + message_parsers: MessageParsersProtocol = MessageParsers(), **kwargs, ): super().__init__( @@ -233,6 +285,7 @@ def __init__( self._lazy = lazy self.has_server_registered = False self._skip_check_method_available = skip_check_method_available + self._message_parsers = message_parsers self._services_module_name = {} self._service_methods_meta: Dict[str, Dict[str, MethodMetaData]] = {} @@ -304,6 +357,7 @@ def _register_methods( output_type=output_type, handler=handler, descriptor=method_desc, + parsers=self._message_parsers, ) return metadata @@ -350,15 +404,13 @@ def _request(self, service, method, request, raw_output=False, **kwargs): # does not check request is available method_meta = self.get_method_meta(service, method) - _request = method_meta.method_type.request_parser( - request, method_meta.input_type - ) + _request = method_meta.request_parser(request, method_meta.input_type) result = method_meta.handler(_request, **kwargs) if raw_output: return result else: - return method_meta.method_type.response_parser(result) + return method_meta.response_parser(result) def request(self, service, method, request=None, raw_output=False, **kwargs): self.check_method_available(service, method) @@ -390,13 +442,6 @@ def get_service_descriptor(self, service): """ return self._desc_pool.FindServiceByName(service) - def describe_method_request(self, service, method): - warnings.warn( - "This function is deprecated, and will be removed in the 0.1.17 release. Use describe_descriptor() instead.", - DeprecationWarning, - ) - return describe_request(self.get_method_descriptor(service, method)) - def describe_request(self, service, method): return describe_descriptor( self.get_method_descriptor(service, method).input_type @@ -472,26 +517,6 @@ def _get_service_names(self): services = tuple([s.name for s in resp.list_services_response.service]) return services - def get_file_descriptor_by_name(self, name): - warnings.warn( - "This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_name() instead.", - DeprecationWarning, - ) - request = reflection_pb2.ServerReflectionRequest(file_by_filename=name) - result = self._reflection_single_request(request) - proto = result.file_descriptor_response.file_descriptor_proto[0] - return descriptor_pb2.FileDescriptorProto.FromString(proto) - - def get_file_descriptor_by_symbol(self, symbol): - warnings.warn( - "This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_symbol() instead.", - DeprecationWarning, - ) - request = reflection_pb2.ServerReflectionRequest(file_containing_symbol=symbol) - result = self._reflection_single_request(request) - proto = result.file_descriptor_response.file_descriptor_proto[0] - return descriptor_pb2.FileDescriptorProto.FromString(proto) - def get_file_descriptors_by_name(self, name): request = reflection_pb2.ServerReflectionRequest(file_by_filename=name) result = self._reflection_single_request(request) diff --git a/src/grpc_requests/utils.py b/src/grpc_requests/utils.py index b84c0c3..8a6e80c 100644 --- a/src/grpc_requests/utils.py +++ b/src/grpc_requests/utils.py @@ -2,11 +2,9 @@ from google.protobuf.descriptor import ( Descriptor, EnumDescriptor, - MethodDescriptor, OneofDescriptor, ) -import warnings # String descriptions of protobuf field types FIELD_TYPES = [ @@ -37,23 +35,6 @@ def load_data(_path): return data -def describe_request(method_descriptor: MethodDescriptor) -> dict: - """ - Provide a dictionary that describes the fields of a Method request - with a string description of their types. - :param method_descriptor: MethodDescriptor - :return: dict - a mapping of field names to their types - """ - warnings.warn( - "This function is deprecated, and will be removed in the 0.1.17 release. Use describe_descriptor() instead.", - DeprecationWarning, - ) - description = {} - for field in method_descriptor.input_type.fields: - description[field.name] = FIELD_TYPES[field.type - 1] - return description - - def describe_descriptor(descriptor: Descriptor, indent: int = 0) -> str: """ Prints a human readable description of a protobuf descriptor. diff --git a/src/tests/async_reflection_client_test.py b/src/tests/async_reflection_client_test.py index 366310f..cbe6e48 100644 --- a/src/tests/async_reflection_client_test.py +++ b/src/tests/async_reflection_client_test.py @@ -1,17 +1,16 @@ import logging -import pytest -from grpc_requests.aio import AsyncClient, MethodType -from google.protobuf.json_format import ParseError import grpc.aio - +import pytest +from google.protobuf import descriptor_pb2, descriptor_pool +from google.protobuf.json_format import ParseError +from grpc_requests.aio import AsyncClient, CustomArgumentParsers, MethodType from tests.common import AsyncMetadataClientInterceptor from tests.test_servers.dependencies import ( dependencies_pb2, dependency1_pb2, dependency2_pb2, ) -from google.protobuf import descriptor_pool, descriptor_pb2 """ Test cases for async reflection based client @@ -172,24 +171,6 @@ async def test_get_service_descriptor(): assert service_descriptor.name == "Greeter" -@pytest.mark.asyncio -async def test_get_file_descriptor_by_name(): - client = AsyncClient("localhost:50051") - file_descriptor = await client.get_file_descriptor_by_name("helloworld.proto") - assert file_descriptor.name == "helloworld.proto" - assert file_descriptor.package == "helloworld" - assert file_descriptor.syntax == "proto3" - - -@pytest.mark.asyncio -async def test_get_file_descriptor_by_symbol(): - client = AsyncClient("localhost:50051") - file_descriptor = await client.get_file_descriptor_by_symbol("helloworld.Greeter") - assert file_descriptor.name == "helloworld.proto" - assert file_descriptor.package == "helloworld" - assert file_descriptor.syntax == "proto3" - - @pytest.mark.asyncio async def test_get_file_descriptors_by_name(): client = AsyncClient( @@ -269,3 +250,107 @@ async def test_register_file_descriptors_incomplete_dependencies(): file_descriptors.append(proto) with pytest.raises(grpc.aio._call.AioRpcError): await client.register_file_descriptors(file_descriptors) + + +@pytest.mark.asyncio +async def test_unary_unary_defaults(): + client = AsyncClient( + "localhost:50054", + descriptor_pool=descriptor_pool.DescriptorPool(), + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + greeter_service = await client.service("helloworld.Greeter") + response = await greeter_service.SayHello({"name": "sinsky"}) + assert isinstance(response, dict) + assert response == {"message": ""} + + +@pytest.mark.asyncio +async def test_stream_unary_defaults(): + client = AsyncClient( + "localhost:50054", + descriptor_pool=descriptor_pool.DescriptorPool(), + ) + greeter_service = await client.service("helloworld.Greeter") + name_list = ["sinsky", "viridianforge", "jack", "harry"] + response = await greeter_service.HelloEveryone( + [{"name": name} for name in name_list] + ) + assert isinstance(response, dict) + assert response == {} + + +@pytest.mark.asyncio +async def test_stream_unary_empty(): + client = AsyncClient( + "localhost:50054", + descriptor_pool=descriptor_pool.DescriptorPool(), + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + greeter_service = await client.service("helloworld.Greeter") + name_list = ["sinsky", "viridianforge", "jack", "harry"] + response = await greeter_service.HelloEveryone( + [{"name": name} for name in name_list] + ) + assert isinstance(response, dict) + assert response == {"message": ""} + + +@pytest.mark.asyncio +async def test_stream_stream_empty(): + client = AsyncClient( + "localhost:50054", + descriptor_pool=descriptor_pool.DescriptorPool(), + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + greeter_service = await client.service("helloworld.Greeter") + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = [ + x + async for x in await greeter_service.SayHelloOneByOne( + [{"name": name} for name in name_list] + ) + ] + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {"message": ""} + + +@pytest.mark.asyncio +async def test_unary_stream_empty(): + client = AsyncClient( + "localhost:50051", + descriptor_pool=descriptor_pool.DescriptorPool(), + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + greeter_service = await client.service("helloworld.Greeter") + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = [ + x + async for x in await greeter_service.SayHelloGroup( + [{"name": name} for name in name_list] + ) + ] + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {"message": ""} diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 09e9b9f..b71f841 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -1,12 +1,13 @@ import multiprocessing -import pytest import time -from test_servers.helloworld.helloworld_server import HelloWorldServer +import pytest from test_servers.client_tester.client_tester_server import ClientTesterServer from test_servers.dependencies.dependencies_server import ( HelloWorldServer as DependencyServer, ) +from test_servers.helloworld.helloworld_server import HelloWorldServer +from tests.test_servers.helloworld.helloworld_server import EmptyGreeter def helloworld_server_starter(): @@ -24,6 +25,11 @@ def dependency_server_starter(): server.serve() +def helloworld_empty_server_starter(): + server = HelloWorldServer("50054", servicer=EmptyGreeter()) + server.serve() + + @pytest.fixture(scope="session", autouse=True) def helloworld_server(): helloworld_server_process = multiprocessing.Process( @@ -55,3 +61,14 @@ def dependency_server(): time.sleep(1) yield dependency_server_process.terminate() + + +@pytest.fixture(scope="session", autouse=True) +def helloworld_empty_server(): + helloworld_empty_server_process = multiprocessing.Process( + target=helloworld_empty_server_starter + ) + helloworld_empty_server_process.start() + time.sleep(1) + yield + helloworld_empty_server_process.terminate() diff --git a/src/tests/reflection_client_test.py b/src/tests/reflection_client_test.py index 5ffb4f1..16f6d02 100644 --- a/src/tests/reflection_client_test.py +++ b/src/tests/reflection_client_test.py @@ -1,18 +1,17 @@ import logging -import pytest -from grpc_requests.client import Client, MethodType -from google.protobuf.json_format import ParseError -from google.protobuf.descriptor import MethodDescriptor import grpc - +import pytest +from google.protobuf import descriptor_pb2, descriptor_pool +from google.protobuf.descriptor import MethodDescriptor +from google.protobuf.json_format import ParseError +from grpc_requests.client import Client, CustomArgumentParsers, MethodType from tests.common import MetadataClientInterceptor from tests.test_servers.dependencies import ( dependencies_pb2, dependency1_pb2, dependency2_pb2, ) -from google.protobuf import descriptor_pool, descriptor_pb2 """ Test cases for reflection based client @@ -49,6 +48,34 @@ def client_tester_reflection_client(): pytest.fail("Could not connect to local Test server") +@pytest.fixture(scope="module") +def helloworld_empty_reflection_client(): + try: + # Don't use get_by_endpoint here so we don't cache parsers + client = Client("localhost:50054") + yield client + except: # noqa: E722 + pytest.fail("Could not connect to local Empty HelloWorld server") + + +@pytest.fixture(scope="module") +def helloworld_empty_reflection_client_custom_parsers(): + try: + # Don't use get_by_endpoint here so we don't cache parsers + client = Client( + "localhost:50054", + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + yield client + except: # noqa: E722 + pytest.fail("Could not connect to local Empty HelloWorld server") + + def test_metadata_usage(helloworld_reflection_client): response = helloworld_reflection_client.request( "helloworld.Greeter", @@ -84,23 +111,6 @@ def test_unary_unary(helloworld_reflection_client): assert response == {"message": "Hello, sinsky!"} -def test_describe_method_request(client_tester_reflection_client): - request_description = client_tester_reflection_client.describe_method_request( - "client_tester.ClientTester", "TestUnaryUnary" - ) - expected_request_description = { - "factor": "INT32", - "readings": "FLOAT", - "uuid": "UINT64", - "sample_flag": "BOOL", - "request_name": "STRING", - "extra_data": "BYTES", - } - assert ( - request_description == expected_request_description - ), f"Expected: {expected_request_description}, Actual: {request_description}" - - def test_describe_request(client_tester_reflection_client): request_description = client_tester_reflection_client.describe_request( "client_tester.ClientTester", "TestUnaryUnary" @@ -212,24 +222,6 @@ def test_get_service_descriptor(helloworld_reflection_client): assert service_descriptor.name == "Greeter" -def test_get_file_descriptor_by_name(helloworld_reflection_client): - file_descriptor = helloworld_reflection_client.get_file_descriptor_by_name( - "helloworld.proto" - ) - assert file_descriptor.name == "helloworld.proto" - assert file_descriptor.package == "helloworld" - assert file_descriptor.syntax == "proto3" - - -def test_get_file_descriptor_by_symbol(helloworld_reflection_client): - file_descriptor = helloworld_reflection_client.get_file_descriptor_by_symbol( - "helloworld.Greeter" - ) - assert file_descriptor.name == "helloworld.proto" - assert file_descriptor.package == "helloworld" - assert file_descriptor.syntax == "proto3" - - def test_get_file_descriptors_by_name(): client = Client("localhost:50053", descriptor_pool=descriptor_pool.DescriptorPool()) file_descriptor = client.get_file_descriptors_by_name("dependencies.proto") @@ -304,3 +296,77 @@ def test_register_file_descriptors_incomplete_dependencies(): file_descriptors.append(proto) with pytest.raises(grpc.RpcError): client.register_file_descriptors(file_descriptors) + + +def test_unary_unary_empty_default(helloworld_empty_reflection_client): + response = helloworld_empty_reflection_client.request( + "helloworld.Greeter", "SayHello", {"name": "sinsky"} + ) + assert isinstance(response, dict) + assert response == {} + + +def test_unary_unary_empty_custom(helloworld_empty_reflection_client_custom_parsers): + response = helloworld_empty_reflection_client_custom_parsers.request( + "helloworld.Greeter", "SayHello", {"name": "unary_unary_custom"} + ) + assert isinstance(response, dict) + assert response == {"message": ""} + + +def test_unary_stream_empty_default(helloworld_empty_reflection_client): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = helloworld_empty_reflection_client.request( + "helloworld.Greeter", "SayHelloGroup", {"name": "".join(name_list)} + ) + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {} + + +def test_unary_stream_empty_custom(helloworld_empty_reflection_client_custom_parsers): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = helloworld_empty_reflection_client_custom_parsers.request( + "helloworld.Greeter", "SayHelloGroup", {"name": "".join(name_list)} + ) + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {"message": ""} + + +def test_stream_unary_empty_default(helloworld_empty_reflection_client): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + response = helloworld_empty_reflection_client.request( + "helloworld.Greeter", "HelloEveryone", [{"name": name} for name in name_list] + ) + assert isinstance(response, dict) + assert response == {} + + +def test_stream_unary_empty_custom(helloworld_empty_reflection_client_custom_parsers): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + response = helloworld_empty_reflection_client_custom_parsers.request( + "helloworld.Greeter", "HelloEveryone", [{"name": name} for name in name_list] + ) + assert isinstance(response, dict) + assert response == {"message": ""} + + +def test_stream_stream_empty_default(helloworld_empty_reflection_client): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = helloworld_empty_reflection_client.request( + "helloworld.Greeter", "SayHelloOneByOne", [{"name": name} for name in name_list] + ) + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {} + + +def test_stream_stream_empty_custom(helloworld_empty_reflection_client_custom_parsers): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = helloworld_empty_reflection_client_custom_parsers.request( + "helloworld.Greeter", "SayHelloOneByOne", [{"name": name} for name in name_list] + ) + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {"message": ""} diff --git a/src/tests/test_servers/helloworld/helloworld_server.py b/src/tests/test_servers/helloworld/helloworld_server.py index 22e0056..0c55495 100644 --- a/src/tests/test_servers/helloworld/helloworld_server.py +++ b/src/tests/test_servers/helloworld/helloworld_server.py @@ -1,13 +1,14 @@ +import logging from concurrent import futures -from grpc_reflection.v1alpha import reflection import grpc -import logging +from grpc_reflection.v1alpha import reflection + +from .helloworld_pb2 import DESCRIPTOR, HelloReply from .helloworld_pb2_grpc import GreeterServicer, add_GreeterServicer_to_server -from .helloworld_pb2 import HelloReply, DESCRIPTOR -class Greeter(GreeterServicer): +class Greeter(GreeterServicer): def SayHello(self, request, context): """ Unary-Unary @@ -16,9 +17,13 @@ def SayHello(self, request, context): if context.invocation_metadata(): for key, value in context.invocation_metadata(): if key == "password" and value == "12345": - return HelloReply(message=f"Hello, {request.name}, password accepted!") + return HelloReply( + message=f"Hello, {request.name}, password accepted!" + ) if key == "interceptor" and value == "true": - return HelloReply(message=f"Hello, {request.name}, interceptor accepted!") + return HelloReply( + message=f"Hello, {request.name}, interceptor accepted!" + ) return HelloReply(message=f"Hello, {request.name}!") def SayHelloGroup(self, request, context): @@ -51,24 +56,60 @@ def SayHelloOneByOne(self, request_iterator, context): yield HelloReply(message=f"Hello {request.name}") -class HelloWorldServer(): +class EmptyGreeter(GreeterServicer): + def SayHello(self, request, context): + """ + Unary-Unary + Sends a HelloReply based on a HelloRequest. + """ + return HelloReply() + + def SayHelloGroup(self, request, context): + """ + Unary-Stream + Streams a series of HelloReplies based on the names in a HelloRequest. + """ + names = request.name + for name in names.split(): + yield HelloReply() + + def HelloEveryone(self, request_iterator, context): + """ + Stream-Unary + Sends a HelloReply based on the name recieved from a stream of + HelloRequests. + """ + names = [] + for request in request_iterator: + names.append(request.name) + return HelloReply() + + def SayHelloOneByOne(self, request_iterator, context): + """ + Stream-Stream + Streams HelloReplies in response to a stream of HelloRequests. + """ + for request in request_iterator: + yield HelloReply() + +class HelloWorldServer: server = None - def __init__(self, port: str): + def __init__(self, port: str, servicer: Greeter = Greeter()): self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - add_GreeterServicer_to_server(Greeter(), self.server) + add_GreeterServicer_to_server(servicer, self.server) SERVICE_NAMES = ( - DESCRIPTOR.services_by_name['Greeter'].full_name, + DESCRIPTOR.services_by_name["Greeter"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, self.server) - self.server.add_insecure_port(f'[::]:{port}') + self.server.add_insecure_port(f"[::]:{port}") def serve(self): - logging.debug('Server starting...') + logging.debug("Server starting...") self.server.start() - logging.debug('Server running...') + logging.debug("Server running...") self.server.wait_for_termination() def shutdown(self):