Skip to content

Commit

Permalink
feat: add client to receiver authentication, optional
Browse files Browse the repository at this point in the history
  • Loading branch information
edwinkys committed Nov 7, 2024
1 parent 2a788d7 commit 81f19c4
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
PHANTASM_SECRET=xxx
27 changes: 21 additions & 6 deletions clients/python/phantasmpy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,27 @@ class Phantasm:
Args:
- host: Hostname of the receiver service.
- port: Port where the receiver listens for requests.
- secret: Key used to authenticate the client with the receiver.
"""

def __init__(self, host: str = "localhost", port: int = 2505):
def __init__(
self,
host: str = "localhost",
port: int = 2505,
secret: str = "",
):
channel = grpc.insecure_channel(f"{host}:{port}")
self.connection = ReceiverStub(channel)
self.metadata = [("authorization", secret)]

def heartbeat(self) -> HeartbeatResponse:
"""Check if the client can connect to the receiver service."""

response = self.connection.Heartbeat(request=Empty())
response = self.connection.Heartbeat(
request=Empty(),
metadata=self.metadata,
)

return HeartbeatResponse(version=response.version)

def get_approval(
Expand Down Expand Up @@ -52,11 +63,15 @@ def get_approval(
except Exception as e:
raise ValueError(f"Invalid parameters: {e}")

response = self.connection.GetApproval(request=request)
approved = response.approved
parameters = response.parameters or ""
response = self.connection.GetApproval(
request=request,
metadata=self.metadata,
)

return GetApprovalResponse(approved=approved, parameters=parameters)
return GetApprovalResponse(
approved=response.approved,
parameters=response.parameters or "",
)


def emulate_get_approval():
Expand Down
26 changes: 25 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ use clap::{arg, ArgMatches, Command};
use futures::{SinkExt, StreamExt};
use protos::receiver_server::ReceiverServer;
use services::Phantasm;
use std::env;
use std::net::{IpAddr, Ipv4Addr};
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tokio_tungstenite::accept_async;
use tonic::transport::Server;
use tonic::{Request, Status};
use types::{Connection, ConnectionID};

const START_COMMAND: &str = "start";
Expand Down Expand Up @@ -75,16 +77,38 @@ async fn start_handler(args: &ArgMatches) {
}

async fn start_receiver_server(service: Arc<Phantasm>, port: u16) {
let service = ReceiverServer::with_interceptor(service, auth_interceptor);
let addr = format!("[::]:{port}").parse().unwrap();
tracing::info!("Receiver server is ready on port {port}");

Server::builder()
.add_service(ReceiverServer::new(service))
.add_service(service)
.serve(addr)
.await
.expect("Failed to start the receiver server");
}

fn auth_interceptor(request: Request<()>) -> Result<Request<()>, Status> {
let secret = env::var("PHANTASM_SECRET").unwrap_or_default();
if secret.is_empty() {
return Ok(request);
}

let unauthorized = {
let message = "Invalid or missing authorization key";
Status::unauthenticated(message)
};

if let Some(auth) = request.metadata().get("authorization") {
let auth = auth.to_str().map_err(|_| unauthorized.clone())?;
if auth == secret.as_str() {
return Ok(request);
}
}

Err(unauthorized)
}

async fn start_coordinator_server(service: Arc<Phantasm>, port: u16) {
let addr = format!("[::]:{port}");
let listener = TcpListener::bind(addr).await.unwrap();
Expand Down

0 comments on commit 81f19c4

Please sign in to comment.