Skip to content

Commit

Permalink
Merge pull request #19 from wesky93/viridianforge/protobuf-4.22-upgrades
Browse files Browse the repository at this point in the history
Viridianforge/protobuf 4.23 upgrades
  • Loading branch information
ViridianForge authored May 17, 2023
2 parents df98b66 + 9ddc297 commit c18e499
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 39 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ jobs:
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics
# - name: Test with pytest
# run: |
# pytest --cov-report=xml --cov=src/grpc_requests
# - name: Upload coverage to Codecov
# uses: codecov/codecov-action@v3
- name: Test with pytest
run: |
pytest --cov-report=xml --cov=src/grpc_requests
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
- name: Build and publish
env:
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.7', '3.8', '3.9', '3.10' ]
python-version: [ '3.7', '3.8', '3.9', '3.10', '3.11' ]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ grpcio>=1.49.1
grpcio-reflection>=1.49.1
google-api-core>=2.9.0
cryptography>=39.0.1
protobuf<4.22
protobuf<=4.23.0
21 changes: 11 additions & 10 deletions src/grpc_requests/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple, TypeVar, Union

import grpc
from google.protobuf import descriptor_pb2, descriptor_pool as _descriptor_pool, symbol_database as _symbol_database
from google.protobuf import descriptor_pb2, descriptor_pool as _descriptor_pool, message_factory
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor
from google.protobuf.descriptor_pb2 import ServiceDescriptorProto
from google.protobuf.json_format import MessageToDict, ParseDict
Expand Down Expand Up @@ -54,7 +54,6 @@ def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_optio
compression=None, credentials: Optional[CredentialsInfo] = None, **kwargs):
self.endpoint = endpoint
self._desc_pool = descriptor_pool or _descriptor_pool.Default()
self._symbol_db = symbol_db or _symbol_database.Default()
self.compression = compression
self.channel_options = channel_options
if ssl:
Expand Down Expand Up @@ -86,7 +85,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
try:
self._channel._close()
except Exception as e: # pylint: disable=bare-except
logger.warning('can not closed channel', exc_info=e)
logger.warning('can not close channel', exc_info=e)
return False

def __del__(self):
Expand All @@ -97,8 +96,8 @@ def __del__(self):
logger.warning('can not delete channel', exc_info=e)


def parse_request_data(reqeust_data, input_type):
_data = reqeust_data or {}
def parse_request_data(request_data, input_type):
_data = request_data or {}
request = ParseDict(_data, input_type()) if isinstance(_data, dict) else _data
return request

Expand Down Expand Up @@ -202,8 +201,10 @@ def _register_methods(self, service_descriptor: ServiceDescriptor) -> Dict[str,
method_name = method_proto.name
method_desc: MethodDescriptor = service_descriptor.methods_by_name[method_name]

input_type = self._symbol_db.GetPrototype(method_desc.input_type)
output_type = self._symbol_db.GetPrototype(method_desc.output_type)
msg_factory = message_factory.MessageFactory(method_proto)

input_type = msg_factory.GetPrototype(method_desc.input_type)
output_type = msg_factory.GetPrototype(method_desc.output_type)
method_type = MethodTypeMatch[(method_proto.client_streaming, method_proto.server_streaming)]

method_register_func = getattr(self.channel, method_type.value)
Expand Down Expand Up @@ -306,7 +307,7 @@ def service(self, name):
if name in self.service_names:
return ServiceClient(client=self, service_name=name)
else:
raise ValueError(f"{name} doesn't support. Available service {self.service_names}")
raise ValueError(f"{name} is not a supported service. Available services are {self.service_names}")


class ReflectionClient(BaseGrpcClient):
Expand Down Expand Up @@ -362,7 +363,7 @@ def _register_file_descriptor(self, file_descriptor):
try:
self._desc_pool.Add(file_descriptor)
except TypeError:
logger.warning(f"{file_descriptor.name} already present in pool. Skipping.")
logger.debug(f"{file_descriptor.name} already present in pool. Skipping.")
logger.debug(f"end {file_descriptor.name} register")

def register_service(self, service_name):
Expand All @@ -377,7 +378,7 @@ def register_service(self, service_name):
file_descriptor = self._get_file_descriptor_by_symbol(service_name)
self._register_file_descriptor(file_descriptor)
except Exception as e:
logger.warning(f"registered {service_name} failed, may be already registered", exc_info=e)
logger.debug(f"registered {service_name} failed, may be already registered", exc_info=e)
logger.debug(f"end {service_name} register")
else:
logger.debug(f"{service_name} is already register")
Expand Down
18 changes: 18 additions & 0 deletions src/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import multiprocessing
import pytest
import time

from .test_servers.helloworld_server import HelloWorldServer


def helloworld_server_starter():
server = HelloWorldServer('50051')
server.serve()

@pytest.fixture(scope="session", autouse=True)
def helloworld_server():
helloworld_server_process = multiprocessing.Process(target=helloworld_server_starter)
helloworld_server_process.start()
time.sleep(1)
yield
helloworld_server_process.terminate()
52 changes: 33 additions & 19 deletions src/tests/reflection_client_test.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
import multiprocessing
import logging
import pytest
import time

from ..grpc_requests.client import Client
from .test_servers.helloworld_server import HelloWorldServer
from google.protobuf.json_format import ParseError

"""
Test cases for reflection based client
"""

def helloworld_server_starter():
server = HelloWorldServer('50051')
server.serve()

@pytest.fixture(scope="module")
def helloworld_server():
helloworld_server_process = multiprocessing.Process(target=helloworld_server_starter)
helloworld_server_process.start()
time.sleep(1)
yield
helloworld_server_process.terminate()

logger = logging.getLogger('name')

@pytest.fixture(scope="module")
def helloworld_reflection_client():
Expand All @@ -31,12 +19,29 @@ def helloworld_reflection_client():
pytest.fail("Could not connect to local HelloWorld server")


def test_unary_unary(helloworld_server, helloworld_reflection_client):
def test_unary_unary(helloworld_reflection_client):
response = helloworld_reflection_client.request('helloworld.Greeter', 'SayHello', {"name": "sinsky"})
assert type(response) == dict
assert response == {"message": "Hello, sinsky!"}

def test_unary_stream(helloworld_server, helloworld_reflection_client):
def test_empty_body_request(helloworld_reflection_client):
response = helloworld_reflection_client.request('helloworld.Greeter', 'SayHello', {})
logger.warning(f"Response: {response}")
assert type(response) == dict

def test_nonexistent_service(helloworld_reflection_client):
with pytest.raises(ValueError):
helloworld_reflection_client.request('helloworld.Speaker', 'SingHello', {})

def test_nonexistent_method(helloworld_reflection_client):
with pytest.raises(ValueError):
helloworld_reflection_client.request('helloworld.Greeter', 'SayGoodbye', {})

def test_unsupported_argument(helloworld_reflection_client):
with pytest.raises(ParseError):
helloworld_reflection_client.request('helloworld.Greeter', 'SayHello', {"foo": "bar"})

def test_unary_stream(helloworld_reflection_client):
name_list = ["sinsky", "viridianforge", "jack", "harry"]
responses = helloworld_reflection_client.request(
'helloworld.Greeter',
Expand All @@ -47,7 +52,7 @@ def test_unary_stream(helloworld_server, helloworld_reflection_client):
for response, name in zip(responses, name_list):
assert response == {"message": f"Hello, {name}!"}

def test_stream_unary(helloworld_server, helloworld_reflection_client):
def test_stream_unary(helloworld_reflection_client):
name_list = ["sinsky", "viridianforge", "jack", "harry"]
response = helloworld_reflection_client.request(
'helloworld.Greeter',
Expand All @@ -57,7 +62,7 @@ def test_stream_unary(helloworld_server, helloworld_reflection_client):
assert type(response) == dict
assert response == {'message': f'Hello, {" ".join(name_list)}!'}

def test_stream_stream(helloworld_server, helloworld_reflection_client):
def test_stream_stream(helloworld_reflection_client):
name_list = ["sinsky", "viridianforge", "jack", "harry"]
responses = helloworld_reflection_client.request(
'helloworld.Greeter',
Expand All @@ -67,3 +72,12 @@ def test_stream_stream(helloworld_server, helloworld_reflection_client):
assert all(type(response) == dict for response in responses)
for response, name in zip(responses, name_list):
assert response == {"message": f"Hello, {name}!"}

def test_reflection_service_client(helloworld_reflection_client):
svc_client = helloworld_reflection_client.service('helloworld.Greeter')
method_names = svc_client.method_names
assert method_names == ('SayHello', 'SayHelloGroup', 'HelloEveryone', 'SayHelloOneByOne')

def test_reflection_service_client_invalid_service(helloworld_reflection_client):
with pytest.raises(ValueError):
helloworld_reflection_client.service('helloWorld.Singer')
22 changes: 22 additions & 0 deletions src/tests/service_client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import logging
import pytest

from ..grpc_requests.client import Client, ServiceClient

"""
Test cases for ServiceClient
"""

logger = logging.getLogger('name')

@pytest.fixture(scope="module")
def helloworld_service_client():
try:
client = ServiceClient(Client('localhost:50051'), "helloworld.Greeter")
yield client
except: # noqa: E722
pytest.fail("Could not connect to local HelloWorld server")

def test_method_names(helloworld_service_client):
method_names = helloworld_service_client.method_names
assert method_names == ('SayHello', 'SayHelloGroup', 'HelloEveryone', 'SayHelloOneByOne')
75 changes: 75 additions & 0 deletions src/tests/stub_client_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import logging
import pytest

from ..grpc_requests.client import StubClient
from .test_protos.helloworld_pb2 import _GREETER
from google.protobuf.json_format import ParseError

"""
Test cases for reflection based client
"""

logger = logging.getLogger('name')

@pytest.fixture(scope="module")
def helloworld_stub_client():
try:
client = StubClient('localhost:50051', [_GREETER])
yield client
except: # noqa: E722
pytest.fail("Could not connect to local HelloWorld server")


def test_unary_unary(helloworld_stub_client):
response = helloworld_stub_client.unary_unary('helloworld.Greeter', 'SayHello', {"name": "sinsky"})
assert type(response) == dict
assert response == {"message": "Hello, sinsky!"}

def test_empty_body_request(helloworld_stub_client):
response = helloworld_stub_client.unary_unary('helloworld.Greeter', 'SayHello', {})
logger.warning(f"Response: {response}")
assert type(response) == dict

def test_nonexistent_service(helloworld_stub_client):
with pytest.raises(ValueError):
helloworld_stub_client.unary_unary('helloworld.Speaker', 'SingHello', {})

def test_nonexistent_method(helloworld_stub_client):
with pytest.raises(ValueError):
helloworld_stub_client.unary_unary('helloworld.Greeter', 'SayGoodbye', {})

def test_unsupported_argument(helloworld_stub_client):
with pytest.raises(ParseError):
helloworld_stub_client.unary_unary('helloworld.Greeter', 'SayHello', {"foo": "bar"})

def test_unary_stream(helloworld_stub_client):
name_list = ["sinsky", "viridianforge", "jack", "harry"]
responses = helloworld_stub_client.unary_stream(
'helloworld.Greeter',
'SayHelloGroup',
{"name": "".join(name_list)}
)
assert all(type(response) == dict for response in responses)
for response, name in zip(responses, name_list):
assert response == {"message": f"Hello, {name}!"}

def test_stream_unary(helloworld_stub_client):
name_list = ["sinsky", "viridianforge", "jack", "harry"]
response = helloworld_stub_client.stream_unary(
'helloworld.Greeter',
'HelloEveryone',
[{"name": name} for name in name_list]
)
assert type(response) == dict
assert response == {'message': f'Hello, {" ".join(name_list)}!'}

def test_stream_stream(helloworld_stub_client):
name_list = ["sinsky", "viridianforge", "jack", "harry"]
responses = helloworld_stub_client.stream_stream(
'helloworld.Greeter',
'SayHelloOneByOne',
[{"name": name} for name in name_list]
)
assert all(type(response) == dict for response in responses)
for response, name in zip(responses, name_list):
assert response == {"message": f"Hello, {name}!"}
4 changes: 2 additions & 2 deletions src/tests/test_servers/helloworld_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def __init__(self, port: str):
self.server.add_insecure_port(f'[::]:{port}')

def serve(self):
logging.warning('Start the server?')
logging.debug('Server starting...')
self.server.start()
logging.warning('Server running')
logging.debug('Server running...')
self.server.wait_for_termination()

def shutdown(self):
Expand Down
1 change: 0 additions & 1 deletion tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@

flake8 . --count --statistics

# Tests currently being re-worked - re-enable when finished
pytest --cov-report=xml --cov=src/grpc_requests

0 comments on commit c18e499

Please sign in to comment.