Skip to content

Commit

Permalink
Add more tests to signaling and streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
ghubnerr committed Dec 24, 2024
1 parent 288c787 commit 2dc8259
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 9 deletions.
74 changes: 68 additions & 6 deletions aura/tests/test_signaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
from aura.webrtc import SignalingServer
import sys
print(sys.path)

import time

Expand Down Expand Up @@ -62,7 +61,6 @@ async def test_client_connection(signaling_server, unused_tcp_port):

uri = f'ws://localhost:{unused_tcp_port}/signaling'
async with websockets.connect(uri) as websocket:
# Wait briefly for connection to be registered
await asyncio.sleep(0.1)
assert signaling_server.get_client_count() == 1

Expand All @@ -72,10 +70,8 @@ async def test_signaling_message_exchange(signaling_server, unused_tcp_port):
uri = f'ws://localhost:{unused_tcp_port}/signaling'

async with websockets.connect(uri) as client1, websockets.connect(uri) as client2:
# Wait for connections to be established
await asyncio.sleep(0.1)

# Test SDP offer exchange
offer_message = {
"type": "offer",
"sdp": "test_sdp_offer"
Expand All @@ -94,16 +90,82 @@ async def test_broadcast_message(signaling_server, unused_tcp_port):
uri = f'ws://localhost:{unused_tcp_port}/signaling'

async with websockets.connect(uri) as client1, websockets.connect(uri) as client2:
# Wait for connections to be established
await asyncio.sleep(0.1)

test_message = "broadcast test message"
signaling_server.broadcast_message(test_message)

# Both clients should receive the message
msg1 = await client1.recv()
msg2 = await client2.recv()

assert msg1 == test_message
assert msg2 == test_message

@pytest.mark.asyncio
async def test_disconnect_client(signaling_server, unused_tcp_port):
"""Test disconnecting a client"""
uri = f'ws://localhost:{unused_tcp_port}/signaling'

async with websockets.connect(uri) as client:
await asyncio.sleep(0.1)

connected_clients = signaling_server.get_connected_clients()
assert len(connected_clients) == 1
client_id = connected_clients[0]

assert signaling_server.disconnect_client(client_id) == True
await asyncio.sleep(0.1)

assert signaling_server.get_client_count() == 0
assert signaling_server.disconnect_client(client_id) == False

@pytest.mark.asyncio
async def test_send_to_client(signaling_server, unused_tcp_port):
"""Test sending messages to specific clients"""
uri = f'ws://localhost:{unused_tcp_port}/signaling'

async with websockets.connect(uri) as client1, websockets.connect(uri) as client2:
await asyncio.sleep(0.1)

clients = signaling_server.get_connected_clients()
client_id = clients[0]

test_message = "test message"
assert signaling_server.send_to_client(client_id, test_message) == True

assert signaling_server.send_to_client("invalid_id", test_message) == False

received_msg = await client1.recv()
assert received_msg == test_message

@pytest.mark.asyncio
async def test_server_status(signaling_server, unused_tcp_port):
"""Test server status information"""
uri = f'ws://localhost:{unused_tcp_port}/signaling'

status = signaling_server.get_server_status()
assert isinstance(status, dict)
assert "ip" in status
assert "port" in status
assert "connected_clients" in status
assert status["port"] == str(unused_tcp_port)

async with websockets.connect(uri) as client:
await asyncio.sleep(0.1)
updated_status = signaling_server.get_server_status()
assert updated_status["connected_clients"] == "1"

@pytest.mark.asyncio
async def test_client_capacity(signaling_server, unused_tcp_port):
"""Test server capacity handling"""
uri = f'ws://localhost:{unused_tcp_port}/signaling'

assert signaling_server.is_at_capacity() == False

async with websockets.connect(uri) as client1:
await asyncio.sleep(0.1)
assert signaling_server.is_at_capacity() == False

async with websockets.connect(uri) as client2:
await asyncio.sleep(0.1)
assert signaling_server.is_at_capacity() == True
97 changes: 97 additions & 0 deletions aura/tests/test_streamer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import socket
import shutil
import pytest
import asyncio
import websockets
from contextlib import closing
from aura.webrtc import VideoStreamer

@pytest.fixture(autouse=True)
def setup_video_dir():
video_dir = "/tmp/video"
os.makedirs(video_dir, exist_ok=True)
yield
shutil.rmtree(video_dir, ignore_errors=True)

@pytest.fixture
def free_port():
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind(('', 0))
return s.getsockname()[1]

@pytest.fixture
def streamer(free_port):
streamer = VideoStreamer(
ws_ip="127.0.0.1",
ws_port=free_port,
ivf_dir="/tmp/video"
)
yield streamer
try:
streamer.close_connection()
except:
pass

@pytest.fixture
async def mock_websocket_server(free_port):
async def handler(websocket):
async for message in websocket:
await websocket.send(message)

server = await websockets.serve(
handler,
"127.0.0.1",
free_port,
reuse_address=True,
reuse_port=True
)
yield server
server.close()
await server.wait_closed()

@pytest.fixture
def sample_ivf_file(tmp_path):
ivf_path = tmp_path / "test.ivf"
with open(ivf_path, "wb") as f:
f.write(b"DKIF\x00\x00\x00\x00")
f.write(b"\x00" * 24) # Dummy header data
return ivf_path

@pytest.mark.asyncio
async def test_start_streaming(streamer, mock_websocket_server):
try:
async with asyncio.timeout(5):
streamer.start_streaming()

# Wait for connection establishment
for _ in range(50):
if streamer.get_connection_state() != "new":
break
await asyncio.sleep(0.1)

assert streamer.get_connection_state() in ("connecting", "connected")
assert streamer.get_signaling_state() == "stable"
finally:
# Ensure cleanup
await streamer.close_connection()

def test_take_screenshot_no_connection(streamer):
with pytest.raises(RuntimeError, match="No active peer connection"):
streamer.take_screenshot()

def test_get_stats_no_connection(streamer):
with pytest.raises(RuntimeError, match="No active peer connection"):
streamer.get_stats()

def test_close_connection_no_connection(streamer):
with pytest.raises(RuntimeError, match="No active peer connection"):
streamer.close_connection()

def test_video_directory_monitoring(streamer, sample_ivf_file):
streamer.start_streaming()
shutil.copy(sample_ivf_file, "/tmp/video/test.ivf")
import time
time.sleep(1)
os.remove("/tmp/video/test.ivf")
7 changes: 5 additions & 2 deletions aura/webrtc/signaling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import socket
from aura import SignalingServer
from aura import SignalingServer, VideoStreamer
import sys
import socket
import signal
import time
from camera import ProcessingPipeline, FaceNotFoundException
Expand All @@ -9,6 +10,7 @@
import cv2
from datetime import datetime


def get_free_port():
"""Get an unused TCP port."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
Expand Down Expand Up @@ -72,5 +74,6 @@ def main():
print("\nShutting down signaling server...")

if __name__ == "__main__":
main()
help(SignalingServer)
help(VideoStreamer)

2 changes: 1 addition & 1 deletion src/streamer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ impl VideoStreamer {
#[pyo3(text_signature = "(self) -> None")]
fn close_connection(&self) -> PyResult<()> {
let peer_connection = self.peer_connection.clone();

let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async move {
if let Some(pc) = peer_connection.lock().await.as_ref() {
Expand All @@ -161,6 +160,7 @@ impl VideoStreamer {
e
)));
}
*peer_connection.lock().await = None;
Ok(())
} else {
Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
Expand Down

0 comments on commit 2dc8259

Please sign in to comment.