Skip to content

Commit

Permalink
fix: fix realtime session closure and streaming of files
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 186ff5b53adb50fcb353a9cec87d68a33885fcce
  • Loading branch information
s0h3yl committed Oct 20, 2023
1 parent b775605 commit 1b239ff
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 12 deletions.
12 changes: 9 additions & 3 deletions assemblyai/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,18 @@ def stream_file(

with open(filepath, "rb") as f:
while True:
data = f.read(int(sample_rate * 0.30) * 2)
enough_data = ((len(data) / (16 / 8)) / sample_rate) * 1_000
# send in 300ms segments
data = f.read(int(sample_rate * 0.300) * 2)

if not data or enough_data < 300.0:
if not data:
yield b"\x00" * int(sample_rate * 1 * 2)
break

enough_data = (len(data) / 2) / sample_rate

if enough_data < 0.300:
data = data + b"\x00" * int(sample_rate * (1 - enough_data) * 2)

yield data

time.sleep(0.15)
Expand Down
21 changes: 12 additions & 9 deletions assemblyai/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,19 +1045,13 @@ def close(self, terminate: bool = False) -> None:
"""
Closes the connection to the real-time service gracefully.
"""

with self._write_queue.mutex:
self._write_queue.queue.clear()

if terminate and not self._stop_event.is_set():
self._websocket.send(json.dumps({"terminate_session": True}))
self._websocket.close()

self._stop_event.set()
self._write_queue.put({"terminate_session": True})

try:
self._read_thread.join()
self._write_thread.join()
self._websocket.close()
except Exception:
pass

Expand Down Expand Up @@ -1105,7 +1099,12 @@ def _write(self) -> None:
continue

try:
self._websocket.send(self._encode_data(data))
if isinstance(data, dict):
self._websocket.send(json.dumps(data))
elif isinstance(data, bytes):
self._websocket.send(self._encode_data(data))
else:
raise ValueError("unsupported message type")
except websockets.exceptions.ConnectionClosed as exc:
return self._handle_error(exc)

Expand Down Expand Up @@ -1143,6 +1142,10 @@ def _handle_message(
and self._on_open
):
self._on_open(types.RealtimeSessionOpened(**message))
elif (
message["message_type"] == types.RealtimeMessageTypes.session_terminated
):
self._stop_event.set()
elif "error" in message:
self._on_error(types.RealtimeError(message["error"]))

Expand Down
1 change: 1 addition & 0 deletions assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,7 @@ class RealtimeMessageTypes(str, Enum):
partial_transcript = "PartialTranscript"
final_transcript = "FinalTranscript"
session_begins = "SessionBegins"
session_terminated = "SessionTerminated"


class RealtimeSessionOpened(BaseModel):
Expand Down
74 changes: 74 additions & 0 deletions tests/unit/test_extras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from unittest.mock import mock_open, patch

import assemblyai as aai


def test_stream_file_empty_file():
"""
Test streaming of an empty file.
"""

data = b""
sample_rate = 44100

m = mock_open(read_data=data)

with patch("builtins.open", m), patch("time.sleep", return_value=None):
chunks = list(aai.extras.stream_file("fake_path", sample_rate))

# Always expect one chunk due to padding
expected_chunk_length = int(sample_rate * 1 * 2)
assert len(chunks) == 1
assert len(chunks[0]) == expected_chunk_length
assert chunks[0] == b"\x00" * expected_chunk_length


def test_stream_file_small_file():
"""
Tests streaming a file smaller than 300ms.
"""

data = b"\x00" * int(0.2 * 44100) * 2
sample_rate = 44100

m = mock_open(read_data=data)

with patch("builtins.open", m), patch("time.sleep", return_value=None):
chunks = list(aai.extras.stream_file("fake_path", sample_rate))

# Expecting two chunks because of padding at the end
assert len(chunks) == 2


def test_stream_file_large_file():
"""
Test streaming a file larger than 300ms.
"""

data = b"\x00" * int(0.6 * 44100) * 2
sample_rate = 44100

m = mock_open(read_data=data)

with patch("builtins.open", m), patch("time.sleep", return_value=None):
chunks = list(aai.extras.stream_file("fake_path", sample_rate))

# Expecting three chunks because of padding at the end
assert len(chunks) == 3


def test_stream_file_exact_file():
"""
Test streaming a file exactly 300ms long.
"""

data = b"\x00" * int(0.3 * 44100) * 2
sample_rate = 44100

m = mock_open(read_data=data)

with patch("builtins.open", m), patch("time.sleep", return_value=None):
chunks = list(aai.extras.stream_file("fake_path", sample_rate))

# Expecting two chunks because of padding at the end
assert len(chunks) == 2

0 comments on commit 1b239ff

Please sign in to comment.