-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #296 from derkmed/intent-detection-dev
Add Intent Detection Node
- Loading branch information
Showing
4 changed files
with
362 additions
and
0 deletions.
There are no files selected for viewing
164 changes: 164 additions & 0 deletions
164
ros/angel_system_nodes/angel_system_nodes/base_intent_detector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import queue | ||
import rclpy | ||
from rclpy.node import Node | ||
from termcolor import colored | ||
import threading | ||
|
||
from angel_msgs.msg import InterpretedAudioUserIntent, Utterance | ||
from angel_utils import declare_and_get_parameters | ||
|
||
NEXT_STEP_KEYPHRASES = ["skip", "next", "next step"] | ||
PREV_STEP_KEYPHRASES = ["previous", "previous step", "last step", "go back"] | ||
QUESTION_KEYPHRASES = ["question"] | ||
OVERRIDE_KEYPHRASES = ["angel", "angel system"] | ||
|
||
# TODO(derekahmed): Please figure out how to keep this sync-ed with | ||
# config/angel_system_cmds/user_intent_to_sys_cmd_v1.yaml. | ||
# Please refer to labels defined in | ||
# https://docs.google.com/document/d/1uuvSL5de3LVM9c0tKpRKYazDxckffRHf7IAcabSw9UA . | ||
INTENT_LABELS = ["next_step", "prev_step", "inquiry", "other"] | ||
|
||
UTTERANCES_TOPIC = "utterances_topic" | ||
PARAM_EXPECT_USER_INTENT_TOPIC = "expect_user_intent_topic" | ||
PARAM_INTERP_USER_INTENT_TOPIC = "interp_user_intent_topic" | ||
|
||
|
||
class BaseIntentDetector(Node): | ||
def __init__(self): | ||
super().__init__(self.__class__.__name__) | ||
self.log = self.get_logger() | ||
|
||
# Handle parameterization. | ||
param_values = declare_and_get_parameters( | ||
self, | ||
[ | ||
(UTTERANCES_TOPIC,), | ||
(PARAM_EXPECT_USER_INTENT_TOPIC,), | ||
(PARAM_INTERP_USER_INTENT_TOPIC,), | ||
], | ||
) | ||
self._utterances_topic = param_values[UTTERANCES_TOPIC] | ||
self._expect_uintent_topic = param_values[PARAM_EXPECT_USER_INTENT_TOPIC] | ||
self._interp_uintent_topic = param_values[PARAM_INTERP_USER_INTENT_TOPIC] | ||
|
||
# Handle subscription/publication topics. | ||
self.subscription = self.create_subscription( | ||
Utterance, self._utterances_topic, self.utterance_callback, 1 | ||
) | ||
self._expected_publisher = self.create_publisher( | ||
InterpretedAudioUserIntent, self._expect_uintent_topic, 1 | ||
) | ||
self._interp_publisher = self.create_publisher( | ||
InterpretedAudioUserIntent, self._interp_uintent_topic, 1 | ||
) | ||
|
||
self.utterance_message_queue = queue.Queue() | ||
self.handler_thread = threading.Thread( | ||
target=self.process_utterance_message_queue | ||
) | ||
self.handler_thread.start() | ||
|
||
def utterance_callback(self, msg): | ||
""" | ||
This is the main ROS node listener callback loop that will process all messages received | ||
via subscribed topics. | ||
""" | ||
self.log.debug(f'Received message:\n\n"{msg.value}"') | ||
self.utterance_message_queue.put(msg) | ||
|
||
def process_utterance_message_queue(self): | ||
""" | ||
Constant loop to process received messages. | ||
""" | ||
while True: | ||
msg = self.utterance_message_queue.get() | ||
self.log.debug(f'Processing message:\n\n"{msg.value}"') | ||
intent, score = self.detect_intents(msg) | ||
if not intent: | ||
continue | ||
self.publish_msg(msg.value, intent, score) | ||
|
||
def detect_intents(self, msg): | ||
""" | ||
Keyphrase search for intent detection. This implementation does simple | ||
string matching to assign a detected label. When multiple intents are | ||
detected, the message is classified as the first intent or as an | ||
'inquiry' if 'inquiry' is one of the classifications. | ||
""" | ||
|
||
def _tiebreak_intents(intents, confidences): | ||
classification = intents[0] | ||
score = confidences[0] | ||
if len(intents) > 1: | ||
for i, intent in enumerate(intents): | ||
if intent == INTENT_LABELS[2]: | ||
classification, score = intent, confidences[i] | ||
self.log.info( | ||
f'Detected multiple intents: {intents}. Selected "{classification}".' | ||
) | ||
return classification, score | ||
|
||
lower_utterance = msg.value.lower() | ||
intents = [] | ||
confidences = [] | ||
if self._contains_phrase(lower_utterance, NEXT_STEP_KEYPHRASES): | ||
intents.append(INTENT_LABELS[0]) | ||
confidences.append(0.5) | ||
if self._contains_phrase(lower_utterance, PREV_STEP_KEYPHRASES): | ||
intents.append(INTENT_LABELS[1]) | ||
confidences.append(0.5) | ||
if self._contains_phrase(lower_utterance, QUESTION_KEYPHRASES): | ||
intents.append(INTENT_LABELS[2]) | ||
confidences.append(0.5) | ||
if not intents: | ||
colored_utterance = colored(msg.value, "light_blue") | ||
self.log.info(f'No intents detected for:\n>>> "{colored_utterance}":') | ||
return None, -1.0 | ||
|
||
classification, confidence = _tiebreak_intents(intents, confidences) | ||
classification = colored(classification, "light_green") | ||
return classification, confidence | ||
|
||
def publish_msg(self, utterance, intent, score): | ||
""" | ||
Handles message publishing for an utterance with a detected intent. | ||
""" | ||
intent_msg = InterpretedAudioUserIntent() | ||
intent_msg.header.frame_id = "Intent Detection" | ||
intent_msg.header.stamp = self.get_clock().now().to_msg() | ||
intent_msg.utterance_text = utterance | ||
intent_msg.user_intent = intent | ||
intent_msg.confidence = score | ||
published_topic = None | ||
if self._contains_phrase(utterance.lower(), OVERRIDE_KEYPHRASES): | ||
intent_msg.confidence = 1.0 | ||
self._expected_publisher.publish(intent_msg) | ||
published_topic = PARAM_EXPECT_USER_INTENT_TOPIC | ||
else: | ||
self._interp_publisher.publish(intent_msg) | ||
published_topic = PARAM_INTERP_USER_INTENT_TOPIC | ||
|
||
colored_utterance = colored(utterance, "light_blue") | ||
colored_intent = colored(intent_msg.user_intent, "light_green") | ||
self.log.info( | ||
f'Publishing {{"{colored_intent}": {score}}} to {published_topic} ' | ||
+ f'for:\n>>> "{colored_utterance}"' | ||
) | ||
|
||
def _contains_phrase(self, utterance, phrases): | ||
for phrase in phrases: | ||
if phrase in utterance: | ||
return True | ||
return False | ||
|
||
|
||
def main(): | ||
rclpy.init() | ||
intent_detector = BaseIntentDetector() | ||
rclpy.spin(intent_detector) | ||
intent_detector.destroy_node() | ||
rclpy.shutdown() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
100 changes: 100 additions & 0 deletions
100
ros/angel_system_nodes/angel_system_nodes/gpt_intent_detector.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from langchain import PromptTemplate, FewShotPromptTemplate | ||
from langchain.chains import LLMChain | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.llms import OpenAI | ||
import openai | ||
import os | ||
import rclpy | ||
|
||
from angel_system_nodes.base_intent_detector import BaseIntentDetector, INTENT_LABELS | ||
|
||
openai.organization = os.getenv("OPENAI_ORG_ID") | ||
openai.api_key = os.getenv("OPENAI_API_KEY") | ||
|
||
# The following are few shot examples when prompting GPT. | ||
FEW_SHOT_EXAMPLES = [ | ||
{"utterance": "Go back to the previous step!", "label": "prev_step."}, | ||
{"utterance": "Next step, please.", "label": "next_step"}, | ||
{"utterance": "How should I wrap this tourniquet?", "label": "inquiry"}, | ||
{"utterance": "The sky is blue", "label": "other"}, | ||
] | ||
|
||
|
||
class GptIntentDetector(BaseIntentDetector): | ||
def __init__(self): | ||
super().__init__() | ||
self.log = self.get_logger() | ||
|
||
# This node additionally includes fields for interacting with OpenAI | ||
# via LangChain. | ||
if not os.getenv("OPENAI_API_KEY"): | ||
self.log.info("OPENAI_API_KEY environment variable is unset!") | ||
else: | ||
self.openai_api_key = os.getenv("OPENAI_API_KEY") | ||
if not os.getenv("OPENAI_ORG_ID"): | ||
self.log.info("OPENAI_ORG_ID environment variable is unset!") | ||
else: | ||
self.openai_org_id = os.getenv("OPENAI_ORG_ID") | ||
if not bool(self.openai_api_key and self.openai_org_id): | ||
raise ValueError("Please configure OpenAI API Keys.") | ||
self.chain = self._configure_langchain() | ||
|
||
def _configure_langchain(self): | ||
def _labels_list_parenthetical_str(labels): | ||
concat_labels = ", ".join(labels) | ||
return f"({concat_labels})" | ||
|
||
def _labels_list_str(labels): | ||
return ", ".join(labels[:-1]) + f" or {labels[-1]}" | ||
|
||
all_intents_parenthetical = _labels_list_parenthetical_str(INTENT_LABELS) | ||
all_intents = _labels_list_str(INTENT_LABELS) | ||
|
||
# Define the few shot template. | ||
template = ( | ||
f"Utterance: {{utterance}}\nIntent {all_intents_parenthetical}: {{label}}" | ||
) | ||
example_prompt = PromptTemplate( | ||
input_variables=["utterance", "label"], template=template | ||
) | ||
prompt_instructions = f"Classify each utterance as {all_intents}.\n" | ||
inference_sample = ( | ||
f"Utterance: {{utterance}}\nIntent {all_intents_parenthetical}:" | ||
) | ||
few_shot_prompt = FewShotPromptTemplate( | ||
examples=FEW_SHOT_EXAMPLES, | ||
example_prompt=example_prompt, | ||
prefix=prompt_instructions, | ||
suffix=inference_sample, | ||
input_variables=["utterance"], | ||
example_separator="\n", | ||
) | ||
|
||
# Please refer to https://github.com/hwchase17/langchain/blob/master/langchain/llms/openai.py | ||
openai_llm = ChatOpenAI( | ||
model_name="gpt-3.5-turbo", | ||
openai_api_key=self.openai_api_key, | ||
temperature=0.0, | ||
# Only 2 tokens needed for classification (tokens are delimited by use of '_', i.e. | ||
# 'next_step' counts as 2 tokens). | ||
max_tokens=2, | ||
) | ||
return LLMChain(llm=openai_llm, prompt=few_shot_prompt) | ||
|
||
def detect_intents(self, msg): | ||
""" | ||
Detects the user intent via langchain execution of GPT. | ||
""" | ||
return self.chain.run(utterance=msg), 0.5 | ||
|
||
|
||
def main(): | ||
rclpy.init() | ||
intent_detector = GptIntentDetector() | ||
rclpy.spin(intent_detector) | ||
intent_detector.destroy_node() | ||
rclpy.shutdown() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# | ||
# Used to evaluate Intent Detection with vocal processing for a specified ROS bag of data | ||
# This configuration should be run by itself (e.g. not in combination with | ||
# another tmuxinator launch). | ||
# | ||
# NOTE: In order to query GPT, you will need to execute | ||
# ``` | ||
# export OPENAI_API_KEY="YOUR API KEY" | ||
# export OPENAI_ORG_ID="YOUR ORG ID" | ||
# ``` | ||
# | ||
|
||
name: Intent Detection with VAD | ||
root: <%= ENV["ANGEL_WORKSPACE_DIR"] %> | ||
|
||
# Optional tmux socket | ||
# socket_name: foo | ||
|
||
# Note that the pre and post options have been deprecated and will be replaced by | ||
# project hooks. | ||
|
||
# Project hooks | ||
|
||
# Runs on project start, always | ||
# on_project_start: command | ||
on_project_start: | | ||
export ROS_NAMESPACE=${ROS_NAMESPACE:-/debug} | ||
export CONFIG_DIR=${ANGEL_WORKSPACE_DIR}/src/angel_system_nodes/configs | ||
export NODE_RESOURCES_DIR=${ANGEL_WORKSPACE_DIR}/src/angel_system_nodes/resource | ||
# Run on project start, the first time | ||
# on_project_first_start: command | ||
|
||
# Run on project start, after the first time | ||
# on_project_restart: command | ||
|
||
# Run on project exit ( detaching from tmux session ) | ||
# on_project_exit: command | ||
|
||
# Run on project stop | ||
# on_project_stop: command | ||
|
||
# Runs in each window and pane before window/pane specific commands. Useful for setting up interpreter versions. | ||
# pre_window: rbenv shell 2.0.0-p247 | ||
|
||
# Pass command line options to tmux. Useful for specifying a different tmux.conf. | ||
# tmux_options: -f ~/.tmux.mac.conf | ||
tmux_options: -f <%= ENV["ANGEL_WORKSPACE_DIR"] %>/tmux/tmux.conf | ||
|
||
# Change the command to call tmux. This can be used by derivatives/wrappers like byobu. | ||
# tmux_command: byobu | ||
|
||
# Specifies (by name or index) which window will be selected on project startup. If not set, the first window is used. | ||
# startup_window: editor | ||
|
||
# Specifies (by index) which pane of the specified window will be selected on project startup. If not set, the first pane is used. | ||
# startup_pane: 1 | ||
|
||
# Controls whether the tmux session should be attached to automatically. Defaults to true. | ||
# attach: false | ||
|
||
windows: | ||
# - ros_bag_play: ros2 bag play <<PATH_TO_BAG_FILE>> | ||
- ros_bag_play: sleep 5; ros2 bag play /angel_workspace/ros_bags/rosbag2_2023_03_01-17_28_00_0.db3 | ||
- vocal: | ||
layout: even-vertical | ||
panes: | ||
- vad: ros2 run angel_system_nodes voice_activity_detector --ros-args | ||
-r __ns:=${ROS_NAMESPACE} | ||
-p input_audio_topic:=HeadsetAudioData | ||
-p output_voice_activity_topic:=DetectedVoiceData | ||
-p vad_server_url:=http://communication.cs.columbia.edu:55667/vad | ||
-p vad_cadence:=3 | ||
-p vad_margin:=0.20 | ||
-p max_accumulation_length:=10 | ||
-p debug_mode:=True | ||
- asr: ros2 run angel_system_nodes asr --ros-args | ||
-r __ns:=${ROS_NAMESPACE} | ||
-p audio_topic:=DetectedVoiceData | ||
-p utterances_topic:=utterances_topic | ||
-p asr_server_url:=http://communication.cs.columbia.edu:55667/asr | ||
-p asr_req_segment_duration:=1 | ||
-p is_sentence_tokenize:=False | ||
-p debug_mode:=True | ||
- intent: | ||
layout: even-vertical | ||
panes: | ||
- base_intent_detection: ros2 run angel_system_nodes base_intent_detector --ros-args | ||
-r __ns:=${ROS_NAMESPACE} | ||
-p utterances_topic:=utterances_topic | ||
-p expect_user_intent_topic:=expect_user_intent_topic | ||
-p interp_user_intent_topic:=interp_user_intent_topic | ||
- gpt_intent_detection: ros2 run angel_system_nodes gpt_intent_detector --ros-args | ||
-r __ns:=${ROS_NAMESPACE} | ||
-p utterances_topic:=utterances_topic | ||
-p expect_user_intent_topic:=expect_user_intent_topic | ||
-p interp_user_intent_topic:=interp_user_intent_topic |