-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor tunnels module for streaming audio
- Loading branch information
Showing
5 changed files
with
241 additions
and
214 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import base64 | ||
import logging | ||
import wave | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
class AudioStream: | ||
"""Handles audio file processing and validation.""" | ||
|
||
def __init__(self, source: Path, chunk_duration: float, sample_rate: int): | ||
self.source = source | ||
self.chunk_duration = chunk_duration | ||
self.sample_rate = sample_rate | ||
|
||
def validate_config(self): | ||
"""Validate audio configuration parameters.""" | ||
if not self.source.exists(): | ||
raise ValueError(f"Audio file not found: {self.source}") | ||
if self.chunk_duration <= 0: | ||
raise ValueError(f"Invalid chunk duration: {self.chunk_duration}") | ||
if self.sample_rate <= 0: | ||
raise ValueError(f"Invalid sample rate: {self.sample_rate}") | ||
|
||
# Validate audio file format | ||
try: | ||
with wave.open(str(self.source), 'rb') as wf: | ||
if wf.getframerate() != self.sample_rate: | ||
raise ValueError( | ||
f"Expected sample rate {self.sample_rate}, " | ||
f"got {wf.getframerate()}" | ||
) | ||
if wf.getsampwidth() != 2: | ||
raise ValueError("Only 16-bit PCM WAV files are supported") | ||
if wf.getnchannels() != 1: | ||
raise ValueError("Only mono audio is supported") | ||
except wave.Error as e: | ||
raise ValueError(f"Invalid WAV file: {e}") | ||
|
||
def encode_chunk(self, chunk: np.ndarray) -> str: | ||
"""Encode audio chunk to base64 string.""" | ||
float32_data = chunk.astype(np.float32) / 32768.0 # Normalize to [-1.0, 1.0] | ||
return base64.b64encode(float32_data.tobytes()).decode("utf-8") | ||
|
||
def get_total_chunks(self) -> int: | ||
"""Calculate total number of chunks in the audio file.""" | ||
chunk_frames = int(self.chunk_duration * self.sample_rate) | ||
with wave.open(str(self.source), 'rb') as wf: | ||
frames = wf.getnframes() | ||
return (frames + chunk_frames - 1) // chunk_frames # Round up division | ||
|
||
def read_chunks(self): | ||
"""Generator that yields audio chunks.""" | ||
chunk_frames = int(self.chunk_duration * self.sample_rate) | ||
|
||
with wave.open(str(self.source), 'rb') as wf: | ||
while True: | ||
frames = wf.readframes(chunk_frames) | ||
if not frames: | ||
break | ||
|
||
chunk = np.frombuffer(frames, dtype=np.int16) | ||
yield chunk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import argparse | ||
import logging | ||
from pathlib import Path | ||
|
||
from tunnels.client import AudioClient | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def main(): | ||
"""Main entry point.""" | ||
parser = argparse.ArgumentParser( | ||
description="Stream audio to a WebSocket server as base64 text." | ||
) | ||
parser.add_argument( | ||
"source", | ||
type=str, | ||
help="Path to the audio file (e.g., test_audio.wav)." | ||
) | ||
parser.add_argument( | ||
"server_url", | ||
type=str, | ||
help="WebSocket server URL (e.g., ws://localhost:8765)." | ||
) | ||
parser.add_argument( | ||
"--chunk-duration", | ||
type=float, | ||
default=0.5, | ||
help="Duration of each chunk in seconds. Default: 0.5" | ||
) | ||
parser.add_argument( | ||
"--sample-rate", | ||
type=int, | ||
default=16000, | ||
help="Expected sample rate of the audio file. Default: 16000" | ||
) | ||
parser.add_argument( | ||
"--debug", | ||
action="store_true", | ||
help="Enable debug logging" | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
# Configure logging | ||
logging.basicConfig( | ||
level=logging.DEBUG if args.debug else logging.INFO, | ||
format='%(asctime)s - %(levelname)s - %(message)s' | ||
) | ||
|
||
try: | ||
client = AudioClient( | ||
source=Path(args.source), | ||
server_url=args.server_url, | ||
chunk_duration=args.chunk_duration, | ||
sample_rate=args.sample_rate | ||
) | ||
client.run() | ||
except KeyboardInterrupt: | ||
pass | ||
except Exception as e: | ||
logger.error(f"Error: {e}") | ||
raise | ||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.