From 268eff10a1e38118a2734745b9db14f7419a08a5 Mon Sep 17 00:00:00 2001 From: John Welsh Date: Tue, 15 Oct 2024 15:42:48 -0700 Subject: [PATCH] rename node --- examples/ros2/README.md | 18 ++++++++++++++++++ .../ros2/{asr_node.py => whisper_trt_node.py} | 10 +++++----- ...asr_pipeline.py => whisper_trt_pipeline.py} | 4 ++-- 3 files changed, 25 insertions(+), 7 deletions(-) create mode 100644 examples/ros2/README.md rename examples/ros2/{asr_node.py => whisper_trt_node.py} (94%) rename examples/ros2/{asr_pipeline.py => whisper_trt_pipeline.py} (99%) diff --git a/examples/ros2/README.md b/examples/ros2/README.md new file mode 100644 index 0000000..78f4be0 --- /dev/null +++ b/examples/ros2/README.md @@ -0,0 +1,18 @@ +# WhisperTRT - ROS2 Node + +This example includes a ROS2 node for interfacing with WhisperTRT. + +It includes the full pipeline, including connecting to a microphone, and outputs recognized speech +segments on the ``/speech`` topic. + +| Name | Description | Default | +|------|-------------|---------| +| model | The Whisper model to use. | "small.en" | +| backend | The Whisper backend to use. | "whisper_trt" | +| cache_dir | Directory to cache the built models. | None | +| vad_window | Number of audio chunks to use in max-filter window for voice activity detection. | 5 | +| mic_device_index | The microphone device index. | None | +| mic_sample_rate | The microphone sample rate. | 16000 | +| mic_channels | The microphone number of channels. | 6 | +| mic_bitwidth | The microphone bitwidth. | 2 | +| speech_topic | The topic to publish speech segments to. | "/speech" | \ No newline at end of file diff --git a/examples/ros2/asr_node.py b/examples/ros2/whisper_trt_node.py similarity index 94% rename from examples/ros2/asr_node.py rename to examples/ros2/whisper_trt_node.py index 8ab8a7d..b6f7477 100644 --- a/examples/ros2/asr_node.py +++ b/examples/ros2/whisper_trt_node.py @@ -24,12 +24,12 @@ from rclpy.node import Node from std_msgs.msg import String from rcl_interfaces.msg import ParameterDescriptor -from asr_pipeline import ASRPipeline +from whisper_trt_pipeline import WhisperTRTPipeline -class AsrNode(Node): +class WhisperTRTNode(Node): def __init__(self): - super().__init__('AsrNode') + super().__init__('WhisperTRTNode') self.declare_parameter("model", "small.en") self.declare_parameter("backend", "whisper_trt") @@ -66,7 +66,7 @@ def handle_asr(text): self.speech_publisher.publish(msg) logger.info("published " + text) - self.pipeline = ASRPipeline( + self.pipeline = WhisperTRTPipeline( model=self.get_parameter("model").value, vad_window=self.get_parameter("vad_window").value, backend=self.get_parameter("backend").value, @@ -87,7 +87,7 @@ def start_asr_pipeline(self): def main(args=None): rclpy.init(args=args) - node = AsrNode() + node = WhisperTRTNode() node.start_asr_pipeline() diff --git a/examples/ros2/asr_pipeline.py b/examples/ros2/whisper_trt_pipeline.py similarity index 99% rename from examples/ros2/asr_pipeline.py rename to examples/ros2/whisper_trt_pipeline.py index bafa048..71eae2b 100644 --- a/examples/ros2/asr_pipeline.py +++ b/examples/ros2/whisper_trt_pipeline.py @@ -301,7 +301,7 @@ def transcribe(self, audio): self.asr_callback(text) -class ASRPipeline: +class WhisperTRTPipeline: def __init__(self, model: str = "small.en", @@ -389,7 +389,7 @@ def handle_vad_end(): def handle_asr(text): print("asr done: " + text) - pipeline = ASRPipeline( + pipeline = WhisperTRTPipeline( model=args.model, backend=args.backend, cache_dir=args.cache_dir,