diff --git a/aura/tests/test_signaling.py b/aura/tests/test_signaling.py index ca8c792..d216243 100644 --- a/aura/tests/test_signaling.py +++ b/aura/tests/test_signaling.py @@ -4,7 +4,6 @@ import json from aura.webrtc import SignalingServer import sys -print(sys.path) import time @@ -62,7 +61,6 @@ async def test_client_connection(signaling_server, unused_tcp_port): uri = f'ws://localhost:{unused_tcp_port}/signaling' async with websockets.connect(uri) as websocket: - # Wait briefly for connection to be registered await asyncio.sleep(0.1) assert signaling_server.get_client_count() == 1 @@ -72,10 +70,8 @@ async def test_signaling_message_exchange(signaling_server, unused_tcp_port): uri = f'ws://localhost:{unused_tcp_port}/signaling' async with websockets.connect(uri) as client1, websockets.connect(uri) as client2: - # Wait for connections to be established await asyncio.sleep(0.1) - # Test SDP offer exchange offer_message = { "type": "offer", "sdp": "test_sdp_offer" @@ -94,16 +90,82 @@ async def test_broadcast_message(signaling_server, unused_tcp_port): uri = f'ws://localhost:{unused_tcp_port}/signaling' async with websockets.connect(uri) as client1, websockets.connect(uri) as client2: - # Wait for connections to be established await asyncio.sleep(0.1) test_message = "broadcast test message" signaling_server.broadcast_message(test_message) - # Both clients should receive the message msg1 = await client1.recv() msg2 = await client2.recv() assert msg1 == test_message assert msg2 == test_message +@pytest.mark.asyncio +async def test_disconnect_client(signaling_server, unused_tcp_port): + """Test disconnecting a client""" + uri = f'ws://localhost:{unused_tcp_port}/signaling' + + async with websockets.connect(uri) as client: + await asyncio.sleep(0.1) + + connected_clients = signaling_server.get_connected_clients() + assert len(connected_clients) == 1 + client_id = connected_clients[0] + + assert signaling_server.disconnect_client(client_id) == True + await asyncio.sleep(0.1) + + assert signaling_server.get_client_count() == 0 + assert signaling_server.disconnect_client(client_id) == False + +@pytest.mark.asyncio +async def test_send_to_client(signaling_server, unused_tcp_port): + """Test sending messages to specific clients""" + uri = f'ws://localhost:{unused_tcp_port}/signaling' + + async with websockets.connect(uri) as client1, websockets.connect(uri) as client2: + await asyncio.sleep(0.1) + + clients = signaling_server.get_connected_clients() + client_id = clients[0] + + test_message = "test message" + assert signaling_server.send_to_client(client_id, test_message) == True + + assert signaling_server.send_to_client("invalid_id", test_message) == False + + received_msg = await client1.recv() + assert received_msg == test_message + +@pytest.mark.asyncio +async def test_server_status(signaling_server, unused_tcp_port): + """Test server status information""" + uri = f'ws://localhost:{unused_tcp_port}/signaling' + + status = signaling_server.get_server_status() + assert isinstance(status, dict) + assert "ip" in status + assert "port" in status + assert "connected_clients" in status + assert status["port"] == str(unused_tcp_port) + + async with websockets.connect(uri) as client: + await asyncio.sleep(0.1) + updated_status = signaling_server.get_server_status() + assert updated_status["connected_clients"] == "1" + +@pytest.mark.asyncio +async def test_client_capacity(signaling_server, unused_tcp_port): + """Test server capacity handling""" + uri = f'ws://localhost:{unused_tcp_port}/signaling' + + assert signaling_server.is_at_capacity() == False + + async with websockets.connect(uri) as client1: + await asyncio.sleep(0.1) + assert signaling_server.is_at_capacity() == False + + async with websockets.connect(uri) as client2: + await asyncio.sleep(0.1) + assert signaling_server.is_at_capacity() == True diff --git a/aura/tests/test_streamer.py b/aura/tests/test_streamer.py new file mode 100644 index 0000000..4b068ba --- /dev/null +++ b/aura/tests/test_streamer.py @@ -0,0 +1,97 @@ +import os +import socket +import shutil +import pytest +import asyncio +import websockets +from contextlib import closing +from aura.webrtc import VideoStreamer + +@pytest.fixture(autouse=True) +def setup_video_dir(): + video_dir = "/tmp/video" + os.makedirs(video_dir, exist_ok=True) + yield + shutil.rmtree(video_dir, ignore_errors=True) + +@pytest.fixture +def free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(('', 0)) + return s.getsockname()[1] + +@pytest.fixture +def streamer(free_port): + streamer = VideoStreamer( + ws_ip="127.0.0.1", + ws_port=free_port, + ivf_dir="/tmp/video" + ) + yield streamer + try: + streamer.close_connection() + except: + pass + +@pytest.fixture +async def mock_websocket_server(free_port): + async def handler(websocket): + async for message in websocket: + await websocket.send(message) + + server = await websockets.serve( + handler, + "127.0.0.1", + free_port, + reuse_address=True, + reuse_port=True + ) + yield server + server.close() + await server.wait_closed() + +@pytest.fixture +def sample_ivf_file(tmp_path): + ivf_path = tmp_path / "test.ivf" + with open(ivf_path, "wb") as f: + f.write(b"DKIF\x00\x00\x00\x00") + f.write(b"\x00" * 24) # Dummy header data + return ivf_path + +@pytest.mark.asyncio +async def test_start_streaming(streamer, mock_websocket_server): + try: + async with asyncio.timeout(5): + streamer.start_streaming() + + # Wait for connection establishment + for _ in range(50): + if streamer.get_connection_state() != "new": + break + await asyncio.sleep(0.1) + + assert streamer.get_connection_state() in ("connecting", "connected") + assert streamer.get_signaling_state() == "stable" + finally: + # Ensure cleanup + await streamer.close_connection() + +def test_take_screenshot_no_connection(streamer): + with pytest.raises(RuntimeError, match="No active peer connection"): + streamer.take_screenshot() + +def test_get_stats_no_connection(streamer): + with pytest.raises(RuntimeError, match="No active peer connection"): + streamer.get_stats() + +def test_close_connection_no_connection(streamer): + with pytest.raises(RuntimeError, match="No active peer connection"): + streamer.close_connection() + +def test_video_directory_monitoring(streamer, sample_ivf_file): + streamer.start_streaming() + shutil.copy(sample_ivf_file, "/tmp/video/test.ivf") + import time + time.sleep(1) + os.remove("/tmp/video/test.ivf") diff --git a/aura/webrtc/signaling.py b/aura/webrtc/signaling.py index c66a8ed..bcd56c9 100644 --- a/aura/webrtc/signaling.py +++ b/aura/webrtc/signaling.py @@ -1,6 +1,7 @@ import socket -from aura import SignalingServer +from aura import SignalingServer, VideoStreamer import sys +import socket import signal import time from camera import ProcessingPipeline, FaceNotFoundException @@ -9,6 +10,7 @@ import cv2 from datetime import datetime + def get_free_port(): """Get an unused TCP port.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -72,5 +74,6 @@ def main(): print("\nShutting down signaling server...") if __name__ == "__main__": - main() + help(SignalingServer) + help(VideoStreamer) \ No newline at end of file diff --git a/src/streamer.rs b/src/streamer.rs index a7ac7b6..05a5ac7 100644 --- a/src/streamer.rs +++ b/src/streamer.rs @@ -151,7 +151,6 @@ impl VideoStreamer { #[pyo3(text_signature = "(self) -> None")] fn close_connection(&self) -> PyResult<()> { let peer_connection = self.peer_connection.clone(); - let rt = tokio::runtime::Runtime::new().unwrap(); rt.block_on(async move { if let Some(pc) = peer_connection.lock().await.as_ref() { @@ -161,6 +160,7 @@ impl VideoStreamer { e ))); } + *peer_connection.lock().await = None; Ok(()) } else { Err(PyErr::new::(