-
-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add WebSocket server with multi-client support #263
base: develop
Are you sure you want to change the base?
Conversation
* Add testing configuration and diarization tests * Add aggregation tests * Add end-to-end test for a sample wav file and several latencies * Fix rounding error in min latency unit test * Improve CI workflows and add pytest. Fix matplotlib colormap error * Install missing dependencies in CI * Add onnxruntime as a test dependency * Update expected timestamp tolerance to up to 50ms
* Updated readme embed-extraction pipeline * Updated readme embed-extraction pipeline * Update README.md * Apply suggestions from code review * Update README.md * Update README.md --------- Co-authored-by: Juan Coria <[email protected]>
662ad3a
to
fb9fecf
Compare
Couple things I still want to work on:
|
I have also added a cleanup step in the server, when a client disconnects. This was mostly to ensure explicit memory management since client streams are not sharing resources, at the moment - but this should also address #255 |
Moved to LazyModel for resource management, based off this comment. Client-specific Pipeline instances now share resources that are initialised in a common PipelineConfig instance. Still unsure about how this would scale with client connections - would appreciate any thoughts on this! |
bbf2df2
to
d4380c4
Compare
d4380c4
to
7ba2f55
Compare
2cbcdc3
to
bba43ae
Compare
Oh it would definitely still be a subclass of |
Okay this does make sense to me too - not exposing such websocket-specific functionality. Made the changes! |
… in the server itself
604d140
to
12f7ba9
Compare
Added error handling for the following edge cases - in
These edge-cases can occur due to race conditions in the client lifecycle (connect/disconnect/cleanup) or network issues that lead to client state mismatches between the server and client. Added warnings to catch these async timing issues, and documented the edge-case conditions in the respective method's docstring. |
Modified Apart from this, complete documentation in the README is pending. Will get to that next. |
e6607a2
to
74bd40b
Compare
@@ -202,6 +202,7 @@ def embedding_loader(): | |||
segmentation = SegmentationModel(segmentation_loader) | |||
embedding = EmbeddingModel(embedding_loader) | |||
config = SpeakerDiarizationConfig( | |||
# Set the segmentation model used in the paper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't correct. To remove
@@ -332,20 +333,57 @@ diart.client microphone --host <server-address> --port 7007 | |||
|
|||
See `-h` for more options. | |||
|
|||
### From the Dockerfile |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
### From the Dockerfile | |
### From a Docker container |
|
||
You can also run the server in a Docker container. First, build the image: | ||
```shell | ||
docker build -t diart -f Dockerfile . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-f Dockerfile
is not needed, as it will pick up the file with that name in the specified directory
|
||
Run the server with default configuration: | ||
```shell | ||
docker run -p 7007:7007 --gpus all -e HF_TOKEN=<token> diart |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably add a note somewhere saying that for GPU usage they need to install nvidia-container-toolkit
.
Also, is there a way to pick up the HF token from the huggingface-cli
config? That way we avoid passing it directly and keeping it in the terminal history. This is possible when running outside docker, and we shouldn't make it mandatory, as it's an important security feature.
docker run -p 7007:7007 --gpus all -e HF_TOKEN=<token> diart | ||
``` | ||
|
||
Run with custom configuration: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Run with custom configuration: | |
Example with a custom configuration: |
Raises | ||
------ | ||
Warning | ||
If client not found in self._clients. Common cases: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as previous comment
try: | ||
# Clean up pipeline state using built-in reset method | ||
client_state = self._clients[client_id] | ||
client_state.inference.pipeline.reset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure a reset is required because the pipeline will be removed from memory anyway
# Ensure client is removed even if cleanup fails | ||
self._clients.pop(client_id, None) | ||
|
||
def close_all(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be called shutdown()
because it shutdowns the server after closing all clients
while retry_count < max_retries: | ||
try: | ||
self.server.run_forever() | ||
break # If server exits normally, break the retry loop | ||
except OSError as e: | ||
logger.warning(f"WebSocket server connection error: {e}") | ||
retry_count += 1 | ||
if retry_count < max_retries: | ||
delay = base_delay * (2 ** (retry_count - 1)) # Exponential backoff | ||
logger.info( | ||
f"Retrying in {delay} seconds... " | ||
f"(attempt {retry_count}/{max_retries})" | ||
) | ||
time.sleep(delay) | ||
else: | ||
logger.error( | ||
f"WebSocket server failed to start after {max_retries} attempts. " | ||
f"Last error: {e}" | ||
) | ||
except Exception as e: | ||
logger.error(f"Fatal server error: {e}") | ||
break | ||
finally: | ||
self.close_all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that I think about it, it's probably not required to retry starting the server, right? I mean if starting the server doesn't work, it's probably a configuration error that should be fixed by the developer, for example if the port is already in use. What do you think? What use case did you have in mind for retrying?
|
||
return ClientState(audio_source=audio_source, inference=inference) | ||
|
||
def _on_connect(self, client: Dict[Text, Any], server: WebsocketServer) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should allow a max number of clients to connect? My reasoning is the following: if we have to copy StreamingInference
instances (including models) for every new client, the server will most likely crash at some point (especially if sharing GPU). However, given system resources, we can probably estimate how many clients fit in the machine, or if the new client fits in the remaining available resources.
If this is too complicated, we can simply add a parameter inside __init__()
for the maximum number of simultaneous clients. Something like client_pool_size: int = 4
.
@janaab11 do you have any updates on this PR? How can we unblock it? |
Overview
Implements a WebSocket server that can handle audio streams from multiple client connections
Changes
StreamingInferenceHandler
for managing connectionsTesting
Please let me know if any changes or improvements are needed!
Fixes #252