diff --git a/src/grpc_requests/aio.py b/src/grpc_requests/aio.py index 00afe9d..53edca0 100644 --- a/src/grpc_requests/aio.py +++ b/src/grpc_requests/aio.py @@ -18,10 +18,16 @@ import grpc from google.protobuf import ( descriptor_pb2, + message_factory, +) +from google.protobuf import ( descriptor_pool as _descriptor_pool, +) +from google.protobuf import ( symbol_database as _symbol_database, - message_factory, -) # noqa: E501 +) + +# noqa: E501 from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from google.protobuf.descriptor_pb2 import ServiceDescriptorProto from google.protobuf.json_format import MessageToDict, ParseDict @@ -34,11 +40,13 @@ if sys.version_info >= (3, 8): import importlib.metadata + from typing import Protocol def get_metadata(package_name: str): return importlib.metadata.version(package_name) else: import pkg_resources + from typing_extensions import Protocol def get_metadata(package_name: str): return pkg_resources.get_distribution(package_name).version @@ -146,27 +154,67 @@ def __del__(self): pass -def parse_request_data(reqeust_data, input_type): - _data = reqeust_data or {} - if isinstance(_data, dict): - request = ParseDict(_data, input_type()) - else: - request = _data - return request +class MessageParsersProtocol(Protocol): + def parse_request_data(self, request_data, input_type): ... + + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): ... + + async def parse_response(self, response): ... + + async def parse_stream_responses(self, responses: AsyncIterable): ... + + +class MessageParsers(MessageParsersProtocol): + def parse_request_data(self, request_data, input_type): + _data = request_data or {} + if isinstance(_data, dict): + request = ParseDict(_data, input_type()) + else: + request = _data + return request + + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): + for request_data in stream_requests_data: + yield self.parse_request_data(request_data or {}, input_type) + + async def parse_response(self, response): + return MessageToDict(response, preserving_proto_field_name=True) + + async def parse_stream_responses(self, responses: AsyncIterable): + async for resp in responses: + yield await self.parse_response(resp) + +class CustomArgumentParsers(MessageParsersProtocol): + _message_to_dict_kwargs: Dict[str, Any] + _parse_dict_kwargs: Dict[str, Any] -def parse_stream_requests(stream_requests_data: Iterable, input_type): - for request_data in stream_requests_data: - yield parse_request_data(request_data or {}, input_type) + def __init__( + self, + message_to_dict_kwargs: Dict[str, Any] = dict(), + parse_dict_kwargs: Dict[str, Any] = dict(), + ): + self._message_to_dict_kwargs = message_to_dict_kwargs or {} + self._parse_dict_kwargs = parse_dict_kwargs or {} + def parse_request_data(self, request_data, input_type): + _data = request_data or {} + if isinstance(_data, dict): + request = ParseDict(_data, input_type(), **self._parse_dict_kwargs) + else: + request = _data + return request -async def parse_response(response): - return MessageToDict(response, preserving_proto_field_name=True) + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): + for request_data in stream_requests_data: + yield self.parse_request_data(request_data or {}, input_type) + async def parse_response(self, response): + return MessageToDict(response, **self._message_to_dict_kwargs) -async def parse_stream_responses(responses: AsyncIterable): - async for resp in responses: - yield await parse_response(resp) + async def parse_stream_responses(self, responses: AsyncIterable): + async for resp in responses: + yield await self.parse_response(resp) class MethodType(Enum): @@ -179,18 +227,10 @@ class MethodType(Enum): def is_unary_request(self): return "unary_" in self.value - @property - def request_parser(self): - return parse_request_data if self.is_unary_request else parse_stream_requests - @property def is_unary_response(self): return "_unary" in self.value - @property - def response_parser(self): - return parse_response if self.is_unary_response else parse_stream_responses - class MethodMetaData(NamedTuple): input_type: Any @@ -198,6 +238,21 @@ class MethodMetaData(NamedTuple): method_type: MethodType handler: Any descriptor: MethodDescriptor + parsers: MessageParsersProtocol + + @property + def request_parser(self): + if self.method_type.is_unary_request: + return self.parsers.parse_request_data + else: + return self.parsers.parse_stream_requests + + @property + def response_parser(self): + if self.method_type.is_unary_response: + return self.parsers.parse_response + else: + return self.parsers.parse_stream_responses IS_REQUEST_STREAM = TypeVar("IS_REQUEST_STREAM") @@ -220,6 +275,7 @@ def __init__( ssl=False, compression=None, skip_check_method_available=False, + message_parsers: MessageParsersProtocol = MessageParsers(), **kwargs, ): super().__init__( @@ -233,6 +289,7 @@ def __init__( self._service_names: list = None self.has_server_registered = False self._skip_check_method_available = skip_check_method_available + self._message_parsers = message_parsers self._services_module_name = {} self._service_methods_meta: Dict[str, Dict[str, MethodMetaData]] = {} @@ -309,6 +366,7 @@ def _register_methods( output_type=output_type, handler=handler, descriptor=method_desc, + parsers=self._message_parsers, ) return metadata @@ -348,19 +406,17 @@ async def _request(self, service, method, request, raw_output=False, **kwargs): # does not check request is available method_meta = self.get_method_meta(service, method) - _request = method_meta.method_type.request_parser( - request, method_meta.input_type - ) + _request = method_meta.request_parser(request, method_meta.input_type) if method_meta.method_type.is_unary_response: result = await method_meta.handler(_request, **kwargs) if raw_output: return result else: - return await method_meta.method_type.response_parser(result) + return await method_meta.response_parser(result) else: result = method_meta.handler(_request, **kwargs) - return method_meta.method_type.response_parser(result) + return method_meta.response_parser(result) async def request(self, service, method, request=None, raw_output=False, **kwargs): await self.check_method_available(service, method) @@ -427,6 +483,7 @@ def __init__( descriptor_pool=None, ssl=False, compression=None, + message_parsers: MessageParsersProtocol = MessageParsers(), **kwargs, ): super().__init__( @@ -435,6 +492,7 @@ def __init__( descriptor_pool, ssl=ssl, compression=compression, + message_parsers=message_parsers, **kwargs, ) self.reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self.channel) diff --git a/src/grpc_requests/client.py b/src/grpc_requests/client.py index 9e31c09..f3743c4 100644 --- a/src/grpc_requests/client.py +++ b/src/grpc_requests/client.py @@ -16,9 +16,8 @@ ) import grpc -from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor_pb2, message_factory from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import message_factory from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor from google.protobuf.descriptor_pb2 import ServiceDescriptorProto from google.protobuf.json_format import MessageToDict, ParseDict @@ -28,13 +27,16 @@ if sys.version_info >= (3, 8): import importlib.metadata - from typing import TypedDict # pylint: disable=no-name-in-module + from typing import ( + Protocol, + TypedDict, # pylint: disable=no-name-in-module + ) def get_metadata(package_name: str): return importlib.metadata.version(package_name) else: import pkg_resources - from typing_extensions import TypedDict + from typing_extensions import Protocol, TypedDict def get_metadata(package_name: str): return pkg_resources.get_distribution(package_name).version @@ -147,24 +149,67 @@ def __del__(self): logger.warning("can not delete channel", exc_info=e) -def parse_request_data(request_data, input_type): - _data = request_data or {} - request = ParseDict(_data, input_type()) if isinstance(_data, dict) else _data - return request +class MessageParsersProtocol(Protocol): + def parse_request_data(self, request_data, input_type): ... + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): ... -def parse_stream_requests(stream_requests_data: Iterable, input_type): - for request_data in stream_requests_data: - yield parse_request_data(request_data or {}, input_type) + def parse_response(self, response): ... + def parse_stream_responses(self, responses: Iterable): ... -def parse_response(response): - return MessageToDict(response, preserving_proto_field_name=True) +class MessageParsers(MessageParsersProtocol): + def parse_request_data(self, request_data, input_type): + _data = request_data or {} + if isinstance(_data, dict): + request = ParseDict(_data, input_type()) + else: + request = _data + return request + + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): + for request_data in stream_requests_data: + yield self.parse_request_data(request_data or {}, input_type) + + def parse_response(self, response): + return MessageToDict(response, preserving_proto_field_name=True) + + def parse_stream_responses(self, responses: Iterable): + for resp in responses: + yield self.parse_response(resp) + + +class CustomArgumentParsers(MessageParsersProtocol): + _message_to_dict_kwargs: Dict[str, Any] + _parse_dict_kwargs: Dict[str, Any] + + def __init__( + self, + message_to_dict_kwargs: Dict[str, Any] = dict(), + parse_dict_kwargs: Dict[str, Any] = dict(), + ): + self._message_to_dict_kwargs = message_to_dict_kwargs or {} + self._parse_dict_kwargs = parse_dict_kwargs or {} + + def parse_request_data(self, request_data, input_type): + _data = request_data or {} + if isinstance(_data, dict): + request = ParseDict(_data, input_type(), **self._parse_dict_kwargs) + else: + request = _data + return request + + def parse_stream_requests(self, stream_requests_data: Iterable, input_type): + for request_data in stream_requests_data: + yield self.parse_request_data(request_data or {}, input_type) -def parse_stream_responses(responses: Iterable): - for resp in responses: - yield parse_response(resp) + def parse_response(self, response): + return MessageToDict(response, **self._message_to_dict_kwargs) + + def parse_stream_responses(self, responses: Iterable): + for resp in responses: + yield self.parse_response(resp) class MethodType(Enum): @@ -177,18 +222,10 @@ class MethodType(Enum): def is_unary_request(self): return "unary_" in self.value - @property - def request_parser(self): - return parse_request_data if self.is_unary_request else parse_stream_requests - @property def is_unary_response(self): return "_unary" in self.value - @property - def response_parser(self): - return parse_response if self.is_unary_response else parse_stream_responses - class MethodMetaData(NamedTuple): input_type: Any @@ -196,6 +233,21 @@ class MethodMetaData(NamedTuple): method_type: MethodType handler: Any descriptor: MethodDescriptor + parsers: MessageParsersProtocol + + @property + def request_parser(self): + if self.method_type.is_unary_request: + return self.parsers.parse_request_data + else: + return self.parsers.parse_stream_requests + + @property + def response_parser(self): + if self.method_type.is_unary_response: + return self.parsers.parse_response + else: + return self.parsers.parse_stream_responses IS_REQUEST_STREAM = TypeVar("IS_REQUEST_STREAM") @@ -219,6 +271,7 @@ def __init__( ssl=False, compression=None, skip_check_method_available=False, + message_parsers: MessageParsersProtocol = MessageParsers(), **kwargs, ): super().__init__( @@ -233,6 +286,7 @@ def __init__( self._lazy = lazy self.has_server_registered = False self._skip_check_method_available = skip_check_method_available + self._message_parsers = message_parsers self._services_module_name = {} self._service_methods_meta: Dict[str, Dict[str, MethodMetaData]] = {} @@ -304,6 +358,7 @@ def _register_methods( output_type=output_type, handler=handler, descriptor=method_desc, + parsers=self._message_parsers, ) return metadata @@ -350,15 +405,13 @@ def _request(self, service, method, request, raw_output=False, **kwargs): # does not check request is available method_meta = self.get_method_meta(service, method) - _request = method_meta.method_type.request_parser( - request, method_meta.input_type - ) + _request = method_meta.request_parser(request, method_meta.input_type) result = method_meta.handler(_request, **kwargs) if raw_output: return result else: - return method_meta.method_type.response_parser(result) + return method_meta.response_parser(result) def request(self, service, method, request=None, raw_output=False, **kwargs): self.check_method_available(service, method) diff --git a/src/tests/async_reflection_client_test.py b/src/tests/async_reflection_client_test.py index 366310f..b16902b 100644 --- a/src/tests/async_reflection_client_test.py +++ b/src/tests/async_reflection_client_test.py @@ -1,17 +1,16 @@ import logging -import pytest -from grpc_requests.aio import AsyncClient, MethodType -from google.protobuf.json_format import ParseError import grpc.aio - +import pytest +from google.protobuf import descriptor_pb2, descriptor_pool +from google.protobuf.json_format import ParseError +from grpc_requests.aio import AsyncClient, CustomArgumentParsers, MethodType from tests.common import AsyncMetadataClientInterceptor from tests.test_servers.dependencies import ( dependencies_pb2, dependency1_pb2, dependency2_pb2, ) -from google.protobuf import descriptor_pool, descriptor_pb2 """ Test cases for async reflection based client @@ -269,3 +268,107 @@ async def test_register_file_descriptors_incomplete_dependencies(): file_descriptors.append(proto) with pytest.raises(grpc.aio._call.AioRpcError): await client.register_file_descriptors(file_descriptors) + + +@pytest.mark.asyncio +async def test_unary_unary_defaults(): + client = AsyncClient( + "localhost:50054", + descriptor_pool=descriptor_pool.DescriptorPool(), + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + greeter_service = await client.service("helloworld.Greeter") + response = await greeter_service.SayHello({"name": "sinsky"}) + assert isinstance(response, dict) + assert response == {"message": ""} + + +@pytest.mark.asyncio +async def test_stream_unary_defaults(): + client = AsyncClient( + "localhost:50054", + descriptor_pool=descriptor_pool.DescriptorPool(), + ) + greeter_service = await client.service("helloworld.Greeter") + name_list = ["sinsky", "viridianforge", "jack", "harry"] + response = await greeter_service.HelloEveryone( + [{"name": name} for name in name_list] + ) + assert isinstance(response, dict) + assert response == {} + + +@pytest.mark.asyncio +async def test_stream_unary_empty(): + client = AsyncClient( + "localhost:50054", + descriptor_pool=descriptor_pool.DescriptorPool(), + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + greeter_service = await client.service("helloworld.Greeter") + name_list = ["sinsky", "viridianforge", "jack", "harry"] + response = await greeter_service.HelloEveryone( + [{"name": name} for name in name_list] + ) + assert isinstance(response, dict) + assert response == {"message": ""} + + +@pytest.mark.asyncio +async def test_stream_stream_empty(): + client = AsyncClient( + "localhost:50054", + descriptor_pool=descriptor_pool.DescriptorPool(), + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + greeter_service = await client.service("helloworld.Greeter") + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = [ + x + async for x in await greeter_service.SayHelloOneByOne( + [{"name": name} for name in name_list] + ) + ] + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {"message": ""} + + +@pytest.mark.asyncio +async def test_unary_stream_empty(): + client = AsyncClient( + "localhost:50051", + descriptor_pool=descriptor_pool.DescriptorPool(), + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + greeter_service = await client.service("helloworld.Greeter") + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = [ + x + async for x in await greeter_service.SayHelloGroup( + [{"name": name} for name in name_list] + ) + ] + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {"message": ""} diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 09e9b9f..b71f841 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -1,12 +1,13 @@ import multiprocessing -import pytest import time -from test_servers.helloworld.helloworld_server import HelloWorldServer +import pytest from test_servers.client_tester.client_tester_server import ClientTesterServer from test_servers.dependencies.dependencies_server import ( HelloWorldServer as DependencyServer, ) +from test_servers.helloworld.helloworld_server import HelloWorldServer +from tests.test_servers.helloworld.helloworld_server import EmptyGreeter def helloworld_server_starter(): @@ -24,6 +25,11 @@ def dependency_server_starter(): server.serve() +def helloworld_empty_server_starter(): + server = HelloWorldServer("50054", servicer=EmptyGreeter()) + server.serve() + + @pytest.fixture(scope="session", autouse=True) def helloworld_server(): helloworld_server_process = multiprocessing.Process( @@ -55,3 +61,14 @@ def dependency_server(): time.sleep(1) yield dependency_server_process.terminate() + + +@pytest.fixture(scope="session", autouse=True) +def helloworld_empty_server(): + helloworld_empty_server_process = multiprocessing.Process( + target=helloworld_empty_server_starter + ) + helloworld_empty_server_process.start() + time.sleep(1) + yield + helloworld_empty_server_process.terminate() diff --git a/src/tests/reflection_client_test.py b/src/tests/reflection_client_test.py index 5ffb4f1..33edcf7 100644 --- a/src/tests/reflection_client_test.py +++ b/src/tests/reflection_client_test.py @@ -1,18 +1,17 @@ import logging -import pytest -from grpc_requests.client import Client, MethodType -from google.protobuf.json_format import ParseError -from google.protobuf.descriptor import MethodDescriptor import grpc - +import pytest +from google.protobuf import descriptor_pb2, descriptor_pool +from google.protobuf.descriptor import MethodDescriptor +from google.protobuf.json_format import ParseError +from grpc_requests.client import Client, CustomArgumentParsers, MethodType from tests.common import MetadataClientInterceptor from tests.test_servers.dependencies import ( dependencies_pb2, dependency1_pb2, dependency2_pb2, ) -from google.protobuf import descriptor_pool, descriptor_pb2 """ Test cases for reflection based client @@ -49,6 +48,34 @@ def client_tester_reflection_client(): pytest.fail("Could not connect to local Test server") +@pytest.fixture(scope="module") +def helloworld_empty_reflection_client(): + try: + # Don't use get_by_endpoint here so we don't cache parsers + client = Client("localhost:50054") + yield client + except: # noqa: E722 + pytest.fail("Could not connect to local Empty HelloWorld server") + + +@pytest.fixture(scope="module") +def helloworld_empty_reflection_client_custom_parsers(): + try: + # Don't use get_by_endpoint here so we don't cache parsers + client = Client( + "localhost:50054", + message_parsers=CustomArgumentParsers( + message_to_dict_kwargs={ + "preserving_proto_field_name": True, + "including_default_value_fields": True, + } + ), + ) + yield client + except: # noqa: E722 + pytest.fail("Could not connect to local Empty HelloWorld server") + + def test_metadata_usage(helloworld_reflection_client): response = helloworld_reflection_client.request( "helloworld.Greeter", @@ -304,3 +331,77 @@ def test_register_file_descriptors_incomplete_dependencies(): file_descriptors.append(proto) with pytest.raises(grpc.RpcError): client.register_file_descriptors(file_descriptors) + + +def test_unary_unary_empty_default(helloworld_empty_reflection_client): + response = helloworld_empty_reflection_client.request( + "helloworld.Greeter", "SayHello", {"name": "sinsky"} + ) + assert isinstance(response, dict) + assert response == {} + + +def test_unary_unary_empty_custom(helloworld_empty_reflection_client_custom_parsers): + response = helloworld_empty_reflection_client_custom_parsers.request( + "helloworld.Greeter", "SayHello", {"name": "unary_unary_custom"} + ) + assert isinstance(response, dict) + assert response == {"message": ""} + + +def test_unary_stream_empty_default(helloworld_empty_reflection_client): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = helloworld_empty_reflection_client.request( + "helloworld.Greeter", "SayHelloGroup", {"name": "".join(name_list)} + ) + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {} + + +def test_unary_stream_empty_custom(helloworld_empty_reflection_client_custom_parsers): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = helloworld_empty_reflection_client_custom_parsers.request( + "helloworld.Greeter", "SayHelloGroup", {"name": "".join(name_list)} + ) + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {"message": ""} + + +def test_stream_unary_empty_default(helloworld_empty_reflection_client): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + response = helloworld_empty_reflection_client.request( + "helloworld.Greeter", "HelloEveryone", [{"name": name} for name in name_list] + ) + assert isinstance(response, dict) + assert response == {} + + +def test_stream_unary_empty_custom(helloworld_empty_reflection_client_custom_parsers): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + response = helloworld_empty_reflection_client_custom_parsers.request( + "helloworld.Greeter", "HelloEveryone", [{"name": name} for name in name_list] + ) + assert isinstance(response, dict) + assert response == {"message": ""} + + +def test_stream_stream_empty_default(helloworld_empty_reflection_client): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = helloworld_empty_reflection_client.request( + "helloworld.Greeter", "SayHelloOneByOne", [{"name": name} for name in name_list] + ) + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {} + + +def test_stream_stream_empty_custom(helloworld_empty_reflection_client_custom_parsers): + name_list = ["sinsky", "viridianforge", "jack", "harry"] + responses = helloworld_empty_reflection_client_custom_parsers.request( + "helloworld.Greeter", "SayHelloOneByOne", [{"name": name} for name in name_list] + ) + assert all(isinstance(response, dict) for response in responses) + for response, name in zip(responses, name_list): + assert response == {"message": ""} diff --git a/src/tests/test_servers/helloworld/helloworld_server.py b/src/tests/test_servers/helloworld/helloworld_server.py index 22e0056..0c55495 100644 --- a/src/tests/test_servers/helloworld/helloworld_server.py +++ b/src/tests/test_servers/helloworld/helloworld_server.py @@ -1,13 +1,14 @@ +import logging from concurrent import futures -from grpc_reflection.v1alpha import reflection import grpc -import logging +from grpc_reflection.v1alpha import reflection + +from .helloworld_pb2 import DESCRIPTOR, HelloReply from .helloworld_pb2_grpc import GreeterServicer, add_GreeterServicer_to_server -from .helloworld_pb2 import HelloReply, DESCRIPTOR -class Greeter(GreeterServicer): +class Greeter(GreeterServicer): def SayHello(self, request, context): """ Unary-Unary @@ -16,9 +17,13 @@ def SayHello(self, request, context): if context.invocation_metadata(): for key, value in context.invocation_metadata(): if key == "password" and value == "12345": - return HelloReply(message=f"Hello, {request.name}, password accepted!") + return HelloReply( + message=f"Hello, {request.name}, password accepted!" + ) if key == "interceptor" and value == "true": - return HelloReply(message=f"Hello, {request.name}, interceptor accepted!") + return HelloReply( + message=f"Hello, {request.name}, interceptor accepted!" + ) return HelloReply(message=f"Hello, {request.name}!") def SayHelloGroup(self, request, context): @@ -51,24 +56,60 @@ def SayHelloOneByOne(self, request_iterator, context): yield HelloReply(message=f"Hello {request.name}") -class HelloWorldServer(): +class EmptyGreeter(GreeterServicer): + def SayHello(self, request, context): + """ + Unary-Unary + Sends a HelloReply based on a HelloRequest. + """ + return HelloReply() + + def SayHelloGroup(self, request, context): + """ + Unary-Stream + Streams a series of HelloReplies based on the names in a HelloRequest. + """ + names = request.name + for name in names.split(): + yield HelloReply() + + def HelloEveryone(self, request_iterator, context): + """ + Stream-Unary + Sends a HelloReply based on the name recieved from a stream of + HelloRequests. + """ + names = [] + for request in request_iterator: + names.append(request.name) + return HelloReply() + + def SayHelloOneByOne(self, request_iterator, context): + """ + Stream-Stream + Streams HelloReplies in response to a stream of HelloRequests. + """ + for request in request_iterator: + yield HelloReply() + +class HelloWorldServer: server = None - def __init__(self, port: str): + def __init__(self, port: str, servicer: Greeter = Greeter()): self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - add_GreeterServicer_to_server(Greeter(), self.server) + add_GreeterServicer_to_server(servicer, self.server) SERVICE_NAMES = ( - DESCRIPTOR.services_by_name['Greeter'].full_name, + DESCRIPTOR.services_by_name["Greeter"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, self.server) - self.server.add_insecure_port(f'[::]:{port}') + self.server.add_insecure_port(f"[::]:{port}") def serve(self): - logging.debug('Server starting...') + logging.debug("Server starting...") self.server.start() - logging.debug('Server running...') + logging.debug("Server running...") self.server.wait_for_termination() def shutdown(self):