diff --git a/examples/asyncio/ping_client.py b/examples/asyncio/ping_client.py new file mode 100644 index 0000000..1e3ceb8 --- /dev/null +++ b/examples/asyncio/ping_client.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +import thriftpy +from thriftpy.contrib.async import make_client + +import asyncio + + +pp_thrift = thriftpy.load("pingpong.thrift", module_name="pp_thrift") + +@asyncio.coroutine +def main(): + c = yield from make_client(pp_thrift.PingService) + + pong = yield from c.ping() + print(pong) + + c.close() + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + loop.run_until_complete(main()) + loop.close() diff --git a/examples/asyncio/ping_server.py b/examples/asyncio/ping_server.py new file mode 100644 index 0000000..2954935 --- /dev/null +++ b/examples/asyncio/ping_server.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- + +import thriftpy +import asyncio +from thriftpy.contrib.async import make_server + + +pp_thrift = thriftpy.load("pingpong.thrift", module_name="pp_thrift") + + +class Dispatcher(object): + @asyncio.coroutine + def ping(self): + print("ping pong!") + return 'pong' + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + server = loop.run_until_complete( + make_server(pp_thrift.PingService, Dispatcher())) + + try: + loop.run_forever() + except KeyboardInterrupt: + pass + + server.close() + loop.run_until_complete(server.wait_closed()) + loop.close() diff --git a/examples/asyncio/pingpong.thrift b/examples/asyncio/pingpong.thrift new file mode 100644 index 0000000..0bb9c85 --- /dev/null +++ b/examples/asyncio/pingpong.thrift @@ -0,0 +1,7 @@ +# ping service demo +service PingService { + /* + * Sexy c style comment + */ + string ping(), +} diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..19b837a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,5 @@ +import sys + +collect_ignore = ["setup.py"] +if sys.version_info < (3, 5): + collect_ignore.append("test_asyncio.py") diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py new file mode 100644 index 0000000..b4b9297 --- /dev/null +++ b/tests/test_asyncio.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import + +from os import path + +import thriftpy +from thriftpy.contrib.async import make_client, make_server +from thriftpy.rpc import make_client as make_sync_client +from thriftpy.transport import TFramedTransportFactory + +import pytest +import asyncio +import threading + +addressbook = thriftpy.load(path.join(path.dirname(__file__), + "addressbook.thrift")) + + +class Dispatcher(object): + def __init__(self): + self.registry = {} + + @asyncio.coroutine + def add(self, person): + """ + bool add(1: Person person); + """ + if person.name in self.registry: + return False + self.registry[person.name] = person + return True + + @asyncio.coroutine + def get(self, name): + """ + Person get(1: string name) throws (1: PersonNotExistsError not_exists); + """ + if name not in self.registry: + raise addressbook.PersonNotExistsError( + 'Person "{0}" does not exist!'.format(name)) + return self.registry[name] + + @asyncio.coroutine + def remove(self, name): + """ + bool remove(1: string name) throws (1: PersonNotExistsError not_exists) + """ + # delay action for later + yield from asyncio.sleep(.1) + if name not in self.registry: + raise addressbook.PersonNotExistsError( + 'Person "{0}" does not exist!'.format(name)) + del self.registry[name] + return True + + +class Server(threading.Thread): + def __init__(self): + self.loop = loop = asyncio.new_event_loop() + self.server = loop.run_until_complete(make_server( + service=addressbook.AddressBookService, + handler=Dispatcher(), + loop=loop + )) + super().__init__() + + def run(self): + loop = self.loop + server = self.server + asyncio.set_event_loop(loop) + + loop.run_forever() + + server.close() + loop.run_until_complete(server.wait_closed()) + + loop.close() + + def stop(self): + self.loop.call_soon_threadsafe(self.loop.stop) + self.join() + + +@pytest.fixture +def server(): + server = Server() + server.start() + yield server + server.stop() + + +class TestAsyncClient: + @pytest.fixture + async def client(self, request, server): + client = await make_client(addressbook.AddressBookService) + request.addfinalizer(client.close) + return client + + @pytest.mark.asyncio + async def test_result(self, client): + dennis = addressbook.Person(name='Dennis Ritchie') + success = await client.add(dennis) + assert success + success = await client.add(dennis) + assert not success + person = await client.get(dennis.name) + assert person.name == dennis.name + + @pytest.mark.asyncio + async def test_exception(self, client): + with pytest.raises(addressbook.PersonNotExistsError): + await client.get('Brian Kernighan') + + +class TestSyncClient: + @pytest.fixture + async def client(self, request, server): + client = make_sync_client(addressbook.AddressBookService, + trans_factory=TFramedTransportFactory()) + request.addfinalizer(client.close) + return client + + def test_result(self, client): + dennis = addressbook.Person(name='Dennis Ritchie') + success = client.add(dennis) + assert success + success = client.add(dennis) + assert not success + person = client.get(dennis.name) + assert person.name == dennis.name + + def test_exception(self, client): + with pytest.raises(addressbook.PersonNotExistsError): + client.get('Brian Kernighan') diff --git a/thriftpy/contrib/async.py b/thriftpy/contrib/async.py new file mode 100644 index 0000000..f0ab6ec --- /dev/null +++ b/thriftpy/contrib/async.py @@ -0,0 +1,190 @@ +from thriftpy.thrift import TType, TMessageType, TApplicationException, TProcessor, TClient, args2kwargs +from thriftpy.transport import TMemoryBuffer +from thriftpy.protocol import TBinaryProtocolFactory + +import asyncio +import struct + +import logging +LOG = logging.getLogger(__name__) + + +class TAsyncTransport(TMemoryBuffer): + def __init__(self, trans): + super().__init__() + self._trans = trans + self._io_lock = asyncio.Lock() + + def flush(self): + buf = self.getvalue() + self._trans.write(struct.pack("!i", len(buf)) + buf) + self.setvalue(b'') + + @asyncio.coroutine + def read_frame(self): + # do not yield the event loop on a single reader + # between reading the frame_size and the buffer + with (yield from self._io_lock): + buff = yield from self._trans.readexactly(4) + sz, = struct.unpack('!i', buff) + + frame = yield from self._trans.readexactly(sz) + self.setvalue(frame) + + @asyncio.coroutine + def drain(self): + # drain cannot be called concurrently + with (yield from self._io_lock): + yield from self._trans.drain() + + +class TAsyncReader(TAsyncTransport): + def close(self): + self._trans.feed_eof() + super().close() + + +class TAsyncWriter(TAsyncTransport): + def close(self): + self._trans.write_eof() + super().close() + + +class TAsyncProcessor(TProcessor): + def __init__(self, service, handler): + self._service = service + self._handler = handler + + @asyncio.coroutine + def process(self, iprot, oprot): + # the standard thrift protocol packs a single request per frame + # note that chunked requests are not supported, and would require + # additional sequence information + yield from iprot.trans.read_frame() + api, seqid, result, call = self.process_in(iprot) + + if isinstance(result, TApplicationException): + self.send_exception(oprot, api, result, seqid) + yield from oprot.trans.drain() + + try: + result.success = yield from call() + except Exception as e: + # raise if api don't have throws + self.handle_exception(e, result) + + if not result.oneway: + self.send_result(oprot, api, result, seqid) + yield from oprot.trans.drain() + + +class TAsyncServer(object): + def __init__(self, processor, + iprot_factory=None, + oprot_factory=None, + timeout=None): + self.processor = processor + self.iprot_factory = iprot_factory or TBinaryProtocolFactory() + self.oprot_factory = oprot_factory or self.iprot_factory + self.timeout = timeout + + @asyncio.coroutine + def __call__(self, reader, writer): + itrans = TAsyncReader(reader) + iproto = self.iprot_factory.get_protocol(itrans) + + otrans = TAsyncWriter(writer) + oproto = self.oprot_factory.get_protocol(otrans) + + while not reader.at_eof(): + try: + fut = self.processor.process(iproto, oproto) + yield from asyncio.wait_for(fut, self.timeout) + except ConnectionError: + LOG.debug('client has closed the connection') + writer.close() + except asyncio.TimeoutError: + LOG.debug('timeout when processing the client request') + writer.close() + except asyncio.IncompleteReadError: + LOG.debug('client has closed the connection') + writer.close() + except Exception: + # app exception + LOG.exception('unhandled app exception') + writer.close() + writer.close() + + +class TAsyncClient(TClient): + def __init__(self, *args, timeout=None, **kwargs): + super().__init__(*args, **kwargs) + self.timeout = timeout + + @asyncio.coroutine + def _req(self, _api, *args, **kwargs): + fut = self._req_impl(_api, *args, **kwargs) + result = yield from asyncio.wait_for(fut, self.timeout) + return result + + @asyncio.coroutine + def _req_impl(self, _api, *args, **kwargs): + args_cls = getattr(self._service, _api + "_args") + _kw = args2kwargs(args_cls.thrift_spec, *args) + + kwargs.update(_kw) + result_cls = getattr(self._service, _api + "_result") + + self._send(_api, **kwargs) + yield from self._oprot.trans.drain() + + # wait result only if non-oneway + if not getattr(result_cls, "oneway"): + yield from self._iprot.trans.read_frame() + return self._recv(_api) + + def close(self): + self._iprot.trans.close() + self._oprot.trans.close() + + +@asyncio.coroutine +def make_server( + service, + handler, + host = 'localhost', + port = 9090, + proto_factory = TBinaryProtocolFactory(), + loop = None, + timeout = None + ): + """ + create a thrift server running on an asyncio event-loop. + """ + processor = TAsyncProcessor(service, handler) + if loop is None: + loop = asyncio.get_event_loop() + server = yield from asyncio.start_server( + TAsyncServer(processor, proto_factory, timeout=timeout), host, port, loop=loop) + return server + + +@asyncio.coroutine +def make_client(service, + host = 'localhost', + port = 9090, + proto_factory = TBinaryProtocolFactory(), + timeout = None, + loop = None): + if loop is None: + loop = asyncio.get_event_loop() + + reader, writer = yield from asyncio.open_connection( + host, port, loop=loop) + + itrans = TAsyncReader(reader) + iproto = proto_factory.get_protocol(itrans) + + otrans = TAsyncWriter(writer) + oproto = proto_factory.get_protocol(otrans) + return TAsyncClient(service, iproto, oproto) diff --git a/tox.ini b/tox.ini index c488270..ac5c120 100644 --- a/tox.ini +++ b/tox.ini @@ -14,6 +14,7 @@ deps = toro cython py26: ordereddict + py,py35: pytest_asyncio [testenv:flake8] deps = flake8