From 3253b0611e05975c48a37f619fdb30c7e39809dc Mon Sep 17 00:00:00 2001 From: Mariano Gappa Date: Thu, 15 Aug 2024 11:35:20 +0100 Subject: [PATCH] feat: Enable destination plugins to do reads (for backend use case). --- cloudquery/sdk/internal/memdb/memdb.py | 5 ++ .../sdk/internal/servers/plugin_v3/plugin.py | 7 ++- cloudquery/sdk/message/__init__.py | 1 + cloudquery/sdk/message/read.py | 5 ++ cloudquery/sdk/plugin/plugin.py | 3 ++ tests/serve/plugin.py | 49 +++++++++++++++++++ 6 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 cloudquery/sdk/message/read.py diff --git a/cloudquery/sdk/internal/memdb/memdb.py b/cloudquery/sdk/internal/memdb/memdb.py index aa1c9e8..dd7bd5d 100644 --- a/cloudquery/sdk/internal/memdb/memdb.py +++ b/cloudquery/sdk/internal/memdb/memdb.py @@ -5,6 +5,7 @@ from cloudquery.sdk import schema from typing import List, Generator, Dict import pyarrow as pa +from cloudquery.sdk.schema.table import Table from cloudquery.sdk.types import JSONType from dataclasses import dataclass, field @@ -109,5 +110,9 @@ def write(self, writer: Generator[message.WriteMessage, None, None]) -> None: else: raise NotImplementedError(f"Unknown message type {type(msg)}") + def read(self, table: Table) -> Generator[message.ReadMessage, None, None]: + for table, record in self._db.items(): + yield message.ReadMessage(record) + def close(self) -> None: self._db = {} diff --git a/cloudquery/sdk/internal/servers/plugin_v3/plugin.py b/cloudquery/sdk/internal/servers/plugin_v3/plugin.py index ffbc581..c63dd8b 100644 --- a/cloudquery/sdk/internal/servers/plugin_v3/plugin.py +++ b/cloudquery/sdk/internal/servers/plugin_v3/plugin.py @@ -81,8 +81,11 @@ def Sync(self, request, context): # unknown sync message type raise NotImplementedError() - def Read(self, request, context): - raise NotImplementedError() + def Read(self, request: plugin_pb2.Read.Request, context) -> Generator[plugin_pb2.Read.Response, None, None]: + for msg in self._plugin.read(request): + buf = arrow.record_to_bytes(msg.record) + yield plugin_pb2.Read.Response(record=buf) + def Write( self, request_iterator: Generator[plugin_pb2.Write.Request, None, None], context diff --git a/cloudquery/sdk/message/__init__.py b/cloudquery/sdk/message/__init__.py index 5ddb77a..bbcb84f 100644 --- a/cloudquery/sdk/message/__init__.py +++ b/cloudquery/sdk/message/__init__.py @@ -5,3 +5,4 @@ WriteMigrateTableMessage, WriteDeleteStale, ) +from .read import ReadMessage diff --git a/cloudquery/sdk/message/read.py b/cloudquery/sdk/message/read.py new file mode 100644 index 0000000..8c0c9e7 --- /dev/null +++ b/cloudquery/sdk/message/read.py @@ -0,0 +1,5 @@ +import pyarrow as pa + +class ReadMessage: + def __init__(self, record: pa.RecordBatch): + self.record = record diff --git a/cloudquery/sdk/plugin/plugin.py b/cloudquery/sdk/plugin/plugin.py index 5b3d5f5..d03dbef 100644 --- a/cloudquery/sdk/plugin/plugin.py +++ b/cloudquery/sdk/plugin/plugin.py @@ -93,5 +93,8 @@ def sync(self, options: SyncOptions) -> Generator[message.SyncMessage, None, Non def write(self, writer: Generator[message.WriteMessage, None, None]) -> None: raise NotImplementedError() + def read(self, table: Table) -> Generator[message.ReadMessage, None, None]: + raise NotImplementedError() + def close(self) -> None: raise NotImplementedError() diff --git a/tests/serve/plugin.py b/tests/serve/plugin.py index f7336bc..72fa3fc 100644 --- a/tests/serve/plugin.py +++ b/tests/serve/plugin.py @@ -73,6 +73,55 @@ def writer_iterator(): cmd.stop() pool.shutdown() +def test_plugin_read(): + p = MemDB() + sample_record_1 = pa.RecordBatch.from_arrays( + [ + pa.array([1, 2, 3]), + pa.array(["a", "b", "c"]), + ], + schema=test_table.to_arrow_schema(), + ) + sample_record_2 = pa.RecordBatch.from_arrays( + [ + pa.array([2, 3, 4]), + pa.array(["b", "c", "d"]), + ], + schema=test_table.to_arrow_schema(), + ) + p._db["test_1"] = sample_record_1 + p._db["test_2"] = sample_record_2 + + cmd = serve.PluginCommand(p) + port = random.randint(5000, 50000) + pool = futures.ThreadPoolExecutor(max_workers=1) + pool.submit(cmd.run, ["serve", "--address", f"[::]:{port}"]) + time.sleep(1) + try: + with grpc.insecure_channel(f"localhost:{port}") as channel: + stub = plugin_pb2_grpc.PluginStub(channel) + response = stub.GetName(plugin_pb2.GetName.Request()) + assert response.name == "memdb" + + response = stub.GetVersion(plugin_pb2.GetVersion.Request()) + assert response.version == "development" + + response = stub.Init(plugin_pb2.Init.Request(spec=b"")) + assert response is not None + + request = plugin_pb2.Read.Request(table=arrow.schema_to_bytes(test_table.to_arrow_schema())) + reader_iterator = stub.Read(request) + + output_records = [] + for msg in reader_iterator: + output_records.append(arrow.new_record_from_bytes(msg.record)) + + assert len(output_records) == 2 + assert output_records[0].equals(sample_record_1) + assert output_records[1].equals(sample_record_2) + finally: + cmd.stop() + pool.shutdown() def test_plugin_package(): p = MemDB()