Skip to content

Commit

Permalink
Merge pull request #76 from artificialinc/aidan/custom-parsers
Browse files Browse the repository at this point in the history
Add support for custom message parsing.
  • Loading branch information
ViridianForge authored Apr 16, 2024
2 parents 58bbf11 + 0e10fd6 commit 0731b5d
Show file tree
Hide file tree
Showing 6 changed files with 457 additions and 84 deletions.
118 changes: 88 additions & 30 deletions src/grpc_requests/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@
import grpc
from google.protobuf import (
descriptor_pb2,
message_factory,
)
from google.protobuf import (
descriptor_pool as _descriptor_pool,
)
from google.protobuf import (
symbol_database as _symbol_database,
message_factory,
) # noqa: E501
)

# noqa: E501
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor
from google.protobuf.descriptor_pb2 import ServiceDescriptorProto
from google.protobuf.json_format import MessageToDict, ParseDict
Expand All @@ -34,11 +40,13 @@

if sys.version_info >= (3, 8):
import importlib.metadata
from typing import Protocol

def get_metadata(package_name: str):
return importlib.metadata.version(package_name)
else:
import pkg_resources
from typing_extensions import Protocol

def get_metadata(package_name: str):
return pkg_resources.get_distribution(package_name).version
Expand Down Expand Up @@ -146,27 +154,67 @@ def __del__(self):
pass


def parse_request_data(reqeust_data, input_type):
_data = reqeust_data or {}
if isinstance(_data, dict):
request = ParseDict(_data, input_type())
else:
request = _data
return request
class MessageParsersProtocol(Protocol):
def parse_request_data(self, request_data, input_type): ...

def parse_stream_requests(self, stream_requests_data: Iterable, input_type): ...

async def parse_response(self, response): ...

async def parse_stream_responses(self, responses: AsyncIterable): ...


class MessageParsers(MessageParsersProtocol):
def parse_request_data(self, request_data, input_type):
_data = request_data or {}
if isinstance(_data, dict):
request = ParseDict(_data, input_type())
else:
request = _data
return request

def parse_stream_requests(self, stream_requests_data: Iterable, input_type):
for request_data in stream_requests_data:
yield self.parse_request_data(request_data or {}, input_type)

async def parse_response(self, response):
return MessageToDict(response, preserving_proto_field_name=True)

async def parse_stream_responses(self, responses: AsyncIterable):
async for resp in responses:
yield await self.parse_response(resp)


class CustomArgumentParsers(MessageParsersProtocol):
_message_to_dict_kwargs: Dict[str, Any]
_parse_dict_kwargs: Dict[str, Any]

def parse_stream_requests(stream_requests_data: Iterable, input_type):
for request_data in stream_requests_data:
yield parse_request_data(request_data or {}, input_type)
def __init__(
self,
message_to_dict_kwargs: Dict[str, Any] = dict(),
parse_dict_kwargs: Dict[str, Any] = dict(),
):
self._message_to_dict_kwargs = message_to_dict_kwargs or {}
self._parse_dict_kwargs = parse_dict_kwargs or {}

def parse_request_data(self, request_data, input_type):
_data = request_data or {}
if isinstance(_data, dict):
request = ParseDict(_data, input_type(), **self._parse_dict_kwargs)
else:
request = _data
return request

async def parse_response(response):
return MessageToDict(response, preserving_proto_field_name=True)
def parse_stream_requests(self, stream_requests_data: Iterable, input_type):
for request_data in stream_requests_data:
yield self.parse_request_data(request_data or {}, input_type)

async def parse_response(self, response):
return MessageToDict(response, **self._message_to_dict_kwargs)

async def parse_stream_responses(responses: AsyncIterable):
async for resp in responses:
yield await parse_response(resp)
async def parse_stream_responses(self, responses: AsyncIterable):
async for resp in responses:
yield await self.parse_response(resp)


class MethodType(Enum):
Expand All @@ -179,25 +227,32 @@ class MethodType(Enum):
def is_unary_request(self):
return "unary_" in self.value

@property
def request_parser(self):
return parse_request_data if self.is_unary_request else parse_stream_requests

@property
def is_unary_response(self):
return "_unary" in self.value

@property
def response_parser(self):
return parse_response if self.is_unary_response else parse_stream_responses


class MethodMetaData(NamedTuple):
input_type: Any
output_type: Any
method_type: MethodType
handler: Any
descriptor: MethodDescriptor
parsers: MessageParsersProtocol

@property
def request_parser(self):
if self.method_type.is_unary_request:
return self.parsers.parse_request_data
else:
return self.parsers.parse_stream_requests

@property
def response_parser(self):
if self.method_type.is_unary_response:
return self.parsers.parse_response
else:
return self.parsers.parse_stream_responses


IS_REQUEST_STREAM = TypeVar("IS_REQUEST_STREAM")
Expand All @@ -220,6 +275,7 @@ def __init__(
ssl=False,
compression=None,
skip_check_method_available=False,
message_parsers: MessageParsersProtocol = MessageParsers(),
**kwargs,
):
super().__init__(
Expand All @@ -233,6 +289,7 @@ def __init__(
self._service_names: list = None
self.has_server_registered = False
self._skip_check_method_available = skip_check_method_available
self._message_parsers = message_parsers
self._services_module_name = {}
self._service_methods_meta: Dict[str, Dict[str, MethodMetaData]] = {}

Expand Down Expand Up @@ -309,6 +366,7 @@ def _register_methods(
output_type=output_type,
handler=handler,
descriptor=method_desc,
parsers=self._message_parsers,
)
return metadata

Expand Down Expand Up @@ -348,19 +406,17 @@ async def _request(self, service, method, request, raw_output=False, **kwargs):
# does not check request is available
method_meta = self.get_method_meta(service, method)

_request = method_meta.method_type.request_parser(
request, method_meta.input_type
)
_request = method_meta.request_parser(request, method_meta.input_type)
if method_meta.method_type.is_unary_response:
result = await method_meta.handler(_request, **kwargs)

if raw_output:
return result
else:
return await method_meta.method_type.response_parser(result)
return await method_meta.response_parser(result)
else:
result = method_meta.handler(_request, **kwargs)
return method_meta.method_type.response_parser(result)
return method_meta.response_parser(result)

async def request(self, service, method, request=None, raw_output=False, **kwargs):
await self.check_method_available(service, method)
Expand Down Expand Up @@ -427,6 +483,7 @@ def __init__(
descriptor_pool=None,
ssl=False,
compression=None,
message_parsers: MessageParsersProtocol = MessageParsers(),
**kwargs,
):
super().__init__(
Expand All @@ -435,6 +492,7 @@ def __init__(
descriptor_pool,
ssl=ssl,
compression=compression,
message_parsers=message_parsers,
**kwargs,
)
self.reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self.channel)
Expand Down
Loading

0 comments on commit 0731b5d

Please sign in to comment.