Skip to content

Commit

Permalink
refactor: update backend interface
Browse files Browse the repository at this point in the history
The interface was a bit inconsistent and loose. These refactors just
tighten things up a bit.

Signed-off-by: Daniel Bluhm <[email protected]>
  • Loading branch information
dbluhm committed May 7, 2024
1 parent 6464f51 commit 1ac1996
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 50 deletions.
3 changes: 1 addition & 2 deletions docker-compose-local.yaml → demo/docker-compose-local.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ version: '3'

services:
websocket-gateway:
build: .
build: ..
ports:
- "8765:8765"
volumes:
- ./server:/code
- ./wait-for-tunnel.sh:/wait-for-tunnel.sh:ro,z
entrypoint: /wait-for-tunnel.sh
command: >
Expand Down
4 changes: 2 additions & 2 deletions docker-compose.yaml → demo/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ version: '3'

services:
websocket-gateway:
build: .
build: ..
ports:
- "8765:8765"
volumes:
- ./socketdock:/usr/src/app/socketdock:z
- ../socketdock:/usr/src/app/socketdock:z
command: >
--bindip 0.0.0.0
--backend http
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion wait-for-tunnel.sh → demo/wait-for-tunnel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ done
WS_ENDPOINT=$(curl --silent "${TUNNEL_ENDPOINT}/start" | python -c "import sys, json; print(json.load(sys.stdin)['url'])" | sed -rn 's#https?://([^/]+).*#\1#p')
echo "fetched hostname and port [$WS_ENDPOINT]"

exec "$@" --externalhostandport ${WS_ENDPOINT}
exec "$@" --externalhostandport ${WS_ENDPOINT}
7 changes: 4 additions & 3 deletions socketdock/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import argparse
from sanic import Sanic

from .api import api, backend_var, endpoint_var
from .api import api, backend_var


def config() -> argparse.Namespace:
Expand Down Expand Up @@ -38,12 +38,13 @@ def main():
elif args.backend == "http":
from .httpbackend import HTTPBackend

backend = HTTPBackend(args.connect_uri, args.message_uri, args.disconnect_uri)
backend = HTTPBackend(
args.endpoint, args.connect_uri, args.message_uri, args.disconnect_uri
)
else:
raise ValueError("Invalid backend type")

backend_var.set(backend)
endpoint_var.set(args.endpoint)

logging.basicConfig(level=args.log_level)

Expand Down
22 changes: 5 additions & 17 deletions socketdock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .backend import Backend

backend_var: ContextVar[Backend] = ContextVar("backend")
endpoint_var: ContextVar[str] = ContextVar("endpoint")

api = Blueprint("api", url_prefix="/")

Expand Down Expand Up @@ -78,9 +77,6 @@ async def socket_handler(request: Request, websocket: Websocket):
global lifetime_connections
backend = backend_var.get()
socket_id = None
endpoint = endpoint_var.get()
send = f"{endpoint}/socket/{socket_id}/send"
disconnect = f"{endpoint_var.get()}/socket/{socket_id}/disconnect"
try:
# register user
LOGGER.info("new client connected")
Expand All @@ -92,23 +88,15 @@ async def socket_handler(request: Request, websocket: Websocket):
LOGGER.info("Request headers: %s", dict(request.headers.items()))

await backend.socket_connected(
{
"connection_id": socket_id,
"headers": dict(request.headers.items()),
"send": send,
"disconnect": disconnect,
},
connection_id=socket_id,
headers=dict(request.headers.items()),
)

async for message in websocket:
if message:
await backend.inbound_socket_message(
{
"connection_id": socket_id,
"send": send,
"disconnect": disconnect,
},
message,
connection_id=socket_id,
message=message,
)
else:
LOGGER.warning("empty message received")
Expand All @@ -118,4 +106,4 @@ async def socket_handler(request: Request, websocket: Websocket):
if socket_id:
del active_connections[socket_id]
LOGGER.info("Removed connection: %s", socket_id)
await backend.socket_disconnected({"connection_id": socket_id})
await backend.socket_disconnected(socket_id)
17 changes: 10 additions & 7 deletions socketdock/backend.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
"""Backend interface for SocketDock."""

from abc import ABC, abstractmethod
from typing import Union
from typing import Dict, Union


class Backend(ABC):
"""Backend interface for SocketDock."""

@abstractmethod
async def socket_connected(self, callback_uris: dict):
async def socket_connected(
self,
connection_id: str,
headers: Dict[str, str],
):
"""Handle new socket connections, with calback provided."""
raise NotImplementedError()

@abstractmethod
async def inbound_socket_message(
self, callback_uris: dict, message: Union[str, bytes]
self,
connection_id: str,
message: Union[str, bytes],
):
"""Handle inbound socket message, with calback provided."""
raise NotImplementedError()

@abstractmethod
async def socket_disconnected(self, bundle: dict):
async def socket_disconnected(self, connection_id: str):
"""Handle socket disconnected."""
raise NotImplementedError()
57 changes: 48 additions & 9 deletions socketdock/httpbackend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""HTTP backend for SocketDock."""

import logging
from typing import Union
from typing import Dict, Union

import aiohttp

Expand All @@ -14,16 +14,46 @@
class HTTPBackend(Backend):
"""HTTP backend for SocketDock."""

def __init__(self, connect_uri: str, message_uri: str, disconnect_uri: str):
def __init__(
self,
socket_base_uri: str,
connect_uri: str,
message_uri: str,
disconnect_uri: str,
):
"""Initialize HTTP backend."""
self._connect_uri = connect_uri
self._message_uri = message_uri
self._disconnect_uri = disconnect_uri
self.socket_base_uri = socket_base_uri

def send_callback(self, connection_id: str) -> str:
"""Return the callback URI for sending a message to a connected socket."""
return f"{self.socket_base_uri}/{connection_id}/send"

def disconnect_callback(self, connection_id: str) -> str:
"""Return the callback URI for disconnecting a connected socket."""
return f"{self.socket_base_uri}/{connection_id}/disconnect"

async def socket_connected(self, callback_uris: dict):
def callback_uris(self, connection_id: str) -> Dict[str, str]:
"""Return labelled callback URIs."""
return {
"send": self.send_callback(connection_id),
"disconnect": self.disconnect_callback(connection_id),
}

async def socket_connected(
self,
connection_id: str,
headers: Dict[str, str],
):
"""Handle inbound socket message, with calback provided."""
http_body = {
"meta": callback_uris,
"meta": {
**self.callback_uris(connection_id),
"headers": headers,
"connection_id": connection_id,
},
}

if self._connect_uri:
Expand All @@ -37,11 +67,16 @@ async def socket_connected(self, callback_uris: dict):
LOGGER.debug("Response: %s", response)

async def inbound_socket_message(
self, callback_uris: dict, message: Union[str, bytes]
self,
connection_id: str,
message: Union[str, bytes],
):
"""Handle inbound socket message, with calback provided."""
http_body = {
"meta": callback_uris,
"meta": {
**self.callback_uris(connection_id),
"connection_id": connection_id,
},
"message": message.decode("utf-8") if isinstance(message, bytes) else message,
}

Expand All @@ -54,11 +89,15 @@ async def inbound_socket_message(
else:
LOGGER.debug("Response: %s", response)

async def socket_disconnected(self, bundle: dict):
async def socket_disconnected(self, connection_id: str):
"""Handle socket disconnected."""
async with aiohttp.ClientSession() as session:
LOGGER.info("Notifying of disconnect: %s %s", self._disconnect_uri, bundle)
async with session.post(self._disconnect_uri, json=bundle) as resp:
LOGGER.info(
"Notifying of disconnect: %s %s", self._disconnect_uri, connection_id
)
async with session.post(
self._disconnect_uri, json={"connection_id": connection_id}
) as resp:
response = await resp.text()
if resp.status != 200:
LOGGER.error("Error posting to disconnect uri: %s", response)
Expand Down
24 changes: 15 additions & 9 deletions socketdock/testbackend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test backend for SocketDock."""

from typing import Union
from typing import Dict, Union
import aiohttp

from .backend import Backend
Expand All @@ -9,27 +9,33 @@
class TestBackend(Backend):
"""Test backend for SocketDock."""

async def socket_connected(self, callback_uris: dict):
def __init__(self, base_uri: str):
"""Initialize backend."""
self.base_uri = base_uri

async def socket_connected(
self,
connection_id: str,
headers: Dict[str, str],
):
"""Socket connected.
This test backend doesn't care, but can be useful to clean up state.
"""

async def inbound_socket_message(
self, callback_uris: dict, message: Union[str, bytes]
self,
connection_id: str,
message: Union[str, bytes],
):
"""Receive socket message."""
# send three backend messages in response
# TODO: send response message via callback URI for sending a message
send_uri = callback_uris["send"]
send_uri = f"{self.base_uri}/{connection_id}/send"
async with aiohttp.ClientSession() as session:
async with session.post(send_uri, data="Hello yourself") as resp:
response = await resp.text()
print(response)

# response = requests.post(send_uri, data="Hello yourself!")

async def socket_disconnected(self, bundle: dict):
async def socket_disconnected(self, connection_id: str):
"""Socket disconnected.
This test backend doesn't care, but can be useful to clean up state.
Expand Down

0 comments on commit 1ac1996

Please sign in to comment.