diff --git a/vocode/streaming/action/worker.py b/vocode/streaming/action/worker.py index b2dbd380c4..2b3e486b65 100644 --- a/vocode/streaming/action/worker.py +++ b/vocode/streaming/action/worker.py @@ -19,7 +19,7 @@ from vocode.streaming.utils.state_manager import AbstractConversationStateManager -class ActionsWorker(InterruptibleWorker): +class ActionsWorker(InterruptibleWorker[InterruptibleEvent[ActionInput]]): consumer: AbstractWorker[InterruptibleEvent[ActionResultAgentInput]] def __init__( @@ -48,16 +48,6 @@ async def process(self, item: InterruptibleEvent[ActionInput]): conversation_id=action_input.conversation_id, action_input=action_input, action_output=action_output, - vonage_uuid=( - action_input.vonage_uuid - if isinstance(action_input, VonagePhoneConversationActionInput) - else None - ), - twilio_sid=( - action_input.twilio_sid - if isinstance(action_input, TwilioPhoneConversationActionInput) - else None - ), is_quiet=action.quiet, ), is_interruptible=False, diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index e38de8e99b..ee09f0d642 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -44,6 +44,10 @@ ) from vocode.streaming.utils import unrepeating_randomizer from vocode.streaming.utils.speed_manager import SpeedManager +from vocode.streaming.utils.state_manager import ( + TwilioPhoneConversationStateManager, + VonagePhoneConversationStateManager, +) from vocode.utils.sentry_utils import CustomSentrySpans, sentry_create_span if TYPE_CHECKING: @@ -70,8 +74,6 @@ class AgentInputType(str, Enum): class AgentInput(TypedModel, type=AgentInputType.BASE.value): # type: ignore conversation_id: str - vonage_uuid: Optional[str] - twilio_sid: Optional[str] agent_response_tracker: Optional[asyncio.Event] = None class Config: @@ -515,24 +517,25 @@ def create_action_input( ) -> ActionInput: action_input: ActionInput if isinstance(action, VonagePhoneConversationAction): - assert ( - agent_input.vonage_uuid is not None - ), "Cannot use VonagePhoneConversationActionFactory unless the attached conversation is a VonagePhoneConversation" + assert isinstance( + self.conversation_state_manager, + VonagePhoneConversationStateManager, + ), "Cannot use a VonagePhoneConversationAction unless the attached conversation is a VonagePhoneConversation" action_input = action.create_phone_conversation_action_input( - agent_input.conversation_id, - params, - agent_input.vonage_uuid, - user_message_tracker, + conversation_id=agent_input.conversation_id, + params=params, + vonage_uuid=self.conversation_state_manager.get_vonage_uuid(), + user_message_tracker=user_message_tracker, ) elif isinstance(action, TwilioPhoneConversationAction): - assert ( - agent_input.twilio_sid is not None - ), "Cannot use TwilioPhoneConversationActionFactory unless the attached conversation is a TwilioPhoneConversation" + assert isinstance( + self.conversation_state_manager, TwilioPhoneConversationStateManager + ), "Cannot use a TwilioPhoneConversationAction unless the attached conversation is a TwilioPhoneConversation" action_input = action.create_phone_conversation_action_input( - agent_input.conversation_id, - params, - agent_input.twilio_sid, - user_message_tracker, + conversation_id=agent_input.conversation_id, + params=params, + twilio_sid=self.conversation_state_manager.get_twilio_sid(), + user_message_tracker=user_message_tracker, ) else: action_input = action.create_action_input( diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index 62fa590732..67291009e2 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -346,8 +346,6 @@ async def process(self, transcription: Transcription): TranscriptionAgentInput( transcription=transcription, conversation_id=self.conversation.id, - vonage_uuid=getattr(self.conversation, "vonage_uuid", None), - twilio_sid=getattr(self.conversation, "twilio_sid", None), agent_response_tracker=agent_response_tracker, ), ) diff --git a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py index 9dac564fcb..50859b9ee1 100644 --- a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Generic, Literal, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, Literal, Optional, TypeVar, Union from fastapi import WebSocket from loguru import logger @@ -18,6 +18,7 @@ from vocode.streaming.transcriber.abstract_factory import AbstractTranscriberFactory from vocode.streaming.utils import create_conversation_id from vocode.streaming.utils.events_manager import EventsManager +from vocode.streaming.utils.state_manager import PhoneConversationStateManager TelephonyOutputDeviceType = TypeVar( "TelephonyOutputDeviceType", bound=Union[TwilioOutputDevice, VonageOutputDevice] @@ -28,7 +29,9 @@ TelephonyProvider = Literal["twilio", "vonage"] -class AbstractPhoneConversation(Generic[TelephonyOutputDeviceType]): +class AbstractPhoneConversation( + AudioPipeline[TelephonyOutputDeviceType, PhoneConversationStateManager] +): telephony_provider: TelephonyProvider def __init__( @@ -59,5 +62,7 @@ async def attach_ws_and_start(self, ws: WebSocket): pass async def terminate(self): - self.pipeline.events_manager.publish_event(PhoneCallEndedEvent(conversation_id=self.id)) + self.pipeline.events_manager.publish_event( + PhoneCallEndedEvent(conversation_id=self.pipeline.id) + ) await self.pipeline.terminate() diff --git a/vocode/streaming/utils/state_manager.py b/vocode/streaming/utils/state_manager.py index 516965199f..bd69952152 100644 --- a/vocode/streaming/utils/state_manager.py +++ b/vocode/streaming/utils/state_manager.py @@ -148,6 +148,9 @@ def create_vonage_client(self): maybe_vonage_config=self._vonage_phone_conversation.vonage_config, ) + def get_vonage_uuid(self): + return self._vonage_phone_conversation.vonage_uuid + class TwilioPhoneConversationStateManager(PhoneConversationStateManager): def __init__(self, conversation: "TwilioPhoneConversation"): @@ -162,3 +165,6 @@ def create_twilio_client(self): base_url=self._twilio_phone_conversation.base_url, maybe_twilio_config=self.get_twilio_config(), ) + + def get_twilio_sid(self): + return self._twilio_phone_conversation.twilio_sid