Skip to content

Commit

Permalink
feat: Enable destination plugins to do reads (for backend use case).
Browse files Browse the repository at this point in the history
  • Loading branch information
marianogappa committed Aug 15, 2024
1 parent b3f898a commit 3253b06
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 2 deletions.
5 changes: 5 additions & 0 deletions cloudquery/sdk/internal/memdb/memdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cloudquery.sdk import schema
from typing import List, Generator, Dict
import pyarrow as pa
from cloudquery.sdk.schema.table import Table
from cloudquery.sdk.types import JSONType
from dataclasses import dataclass, field

Expand Down Expand Up @@ -109,5 +110,9 @@ def write(self, writer: Generator[message.WriteMessage, None, None]) -> None:
else:
raise NotImplementedError(f"Unknown message type {type(msg)}")

def read(self, table: Table) -> Generator[message.ReadMessage, None, None]:
for table, record in self._db.items():
yield message.ReadMessage(record)

def close(self) -> None:
self._db = {}
7 changes: 5 additions & 2 deletions cloudquery/sdk/internal/servers/plugin_v3/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,11 @@ def Sync(self, request, context):
# unknown sync message type
raise NotImplementedError()

def Read(self, request, context):
raise NotImplementedError()
def Read(self, request: plugin_pb2.Read.Request, context) -> Generator[plugin_pb2.Read.Response, None, None]:
for msg in self._plugin.read(request):
buf = arrow.record_to_bytes(msg.record)
yield plugin_pb2.Read.Response(record=buf)


def Write(
self, request_iterator: Generator[plugin_pb2.Write.Request, None, None], context
Expand Down
1 change: 1 addition & 0 deletions cloudquery/sdk/message/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
WriteMigrateTableMessage,
WriteDeleteStale,
)
from .read import ReadMessage
5 changes: 5 additions & 0 deletions cloudquery/sdk/message/read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import pyarrow as pa

class ReadMessage:
def __init__(self, record: pa.RecordBatch):
self.record = record
3 changes: 3 additions & 0 deletions cloudquery/sdk/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,8 @@ def sync(self, options: SyncOptions) -> Generator[message.SyncMessage, None, Non
def write(self, writer: Generator[message.WriteMessage, None, None]) -> None:
raise NotImplementedError()

def read(self, table: Table) -> Generator[message.ReadMessage, None, None]:
raise NotImplementedError()

def close(self) -> None:
raise NotImplementedError()
49 changes: 49 additions & 0 deletions tests/serve/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,55 @@ def writer_iterator():
cmd.stop()
pool.shutdown()

def test_plugin_read():
p = MemDB()
sample_record_1 = pa.RecordBatch.from_arrays(
[
pa.array([1, 2, 3]),
pa.array(["a", "b", "c"]),
],
schema=test_table.to_arrow_schema(),
)
sample_record_2 = pa.RecordBatch.from_arrays(
[
pa.array([2, 3, 4]),
pa.array(["b", "c", "d"]),
],
schema=test_table.to_arrow_schema(),
)
p._db["test_1"] = sample_record_1
p._db["test_2"] = sample_record_2

cmd = serve.PluginCommand(p)
port = random.randint(5000, 50000)
pool = futures.ThreadPoolExecutor(max_workers=1)
pool.submit(cmd.run, ["serve", "--address", f"[::]:{port}"])
time.sleep(1)
try:
with grpc.insecure_channel(f"localhost:{port}") as channel:
stub = plugin_pb2_grpc.PluginStub(channel)
response = stub.GetName(plugin_pb2.GetName.Request())
assert response.name == "memdb"

response = stub.GetVersion(plugin_pb2.GetVersion.Request())
assert response.version == "development"

response = stub.Init(plugin_pb2.Init.Request(spec=b""))
assert response is not None

request = plugin_pb2.Read.Request(table=arrow.schema_to_bytes(test_table.to_arrow_schema()))
reader_iterator = stub.Read(request)

output_records = []
for msg in reader_iterator:
output_records.append(arrow.new_record_from_bytes(msg.record))

assert len(output_records) == 2
assert output_records[0].equals(sample_record_1)
assert output_records[1].equals(sample_record_2)
finally:
cmd.stop()
pool.shutdown()

def test_plugin_package():
p = MemDB()
Expand Down

0 comments on commit 3253b06

Please sign in to comment.