Skip to content

Commit

Permalink
rename node
Browse files Browse the repository at this point in the history
  • Loading branch information
jaybdub committed Oct 15, 2024
1 parent 56cebc0 commit 268eff1
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
18 changes: 18 additions & 0 deletions examples/ros2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# WhisperTRT - ROS2 Node

This example includes a ROS2 node for interfacing with WhisperTRT.

It includes the full pipeline, including connecting to a microphone, and outputs recognized speech
segments on the ``/speech`` topic.

| Name | Description | Default |
|------|-------------|---------|
| model | The Whisper model to use. | "small.en" |
| backend | The Whisper backend to use. | "whisper_trt" |
| cache_dir | Directory to cache the built models. | None |
| vad_window | Number of audio chunks to use in max-filter window for voice activity detection. | 5 |
| mic_device_index | The microphone device index. | None |
| mic_sample_rate | The microphone sample rate. | 16000 |
| mic_channels | The microphone number of channels. | 6 |
| mic_bitwidth | The microphone bitwidth. | 2 |
| speech_topic | The topic to publish speech segments to. | "/speech" |
10 changes: 5 additions & 5 deletions examples/ros2/asr_node.py → examples/ros2/whisper_trt_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from rclpy.node import Node
from std_msgs.msg import String
from rcl_interfaces.msg import ParameterDescriptor
from asr_pipeline import ASRPipeline
from whisper_trt_pipeline import WhisperTRTPipeline


class AsrNode(Node):
class WhisperTRTNode(Node):
def __init__(self):
super().__init__('AsrNode')
super().__init__('WhisperTRTNode')

self.declare_parameter("model", "small.en")
self.declare_parameter("backend", "whisper_trt")
Expand Down Expand Up @@ -66,7 +66,7 @@ def handle_asr(text):
self.speech_publisher.publish(msg)
logger.info("published " + text)

self.pipeline = ASRPipeline(
self.pipeline = WhisperTRTPipeline(
model=self.get_parameter("model").value,
vad_window=self.get_parameter("vad_window").value,
backend=self.get_parameter("backend").value,
Expand All @@ -87,7 +87,7 @@ def start_asr_pipeline(self):

def main(args=None):
rclpy.init(args=args)
node = AsrNode()
node = WhisperTRTNode()

node.start_asr_pipeline()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def transcribe(self, audio):
self.asr_callback(text)


class ASRPipeline:
class WhisperTRTPipeline:

def __init__(self,
model: str = "small.en",
Expand Down Expand Up @@ -389,7 +389,7 @@ def handle_vad_end():
def handle_asr(text):
print("asr done: " + text)

pipeline = ASRPipeline(
pipeline = WhisperTRTPipeline(
model=args.model,
backend=args.backend,
cache_dir=args.cache_dir,
Expand Down

0 comments on commit 268eff1

Please sign in to comment.