From de03ec33853eadb2b75c7df32b41576db011713e Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Sat, 29 Jun 2024 17:41:01 -0700 Subject: [PATCH] deprecate state manager --- playground/streaming/agent/chat.py | 8 +- tests/fakedata/conversation.py | 109 ++++++++++- tests/streaming/action/test_dtmf.py | 42 +++-- .../streaming/action/test_end_conversation.py | 78 ++++---- .../streaming/action/test_external_actions.py | 77 ++++---- tests/streaming/action/test_record_email.py | 30 +--- tests/streaming/action/test_transfer_call.py | 95 ++++++---- tests/streaming/agent/test_base_agent.py | 6 - vocode/streaming/action/base_action.py | 8 +- vocode/streaming/action/default_factory.py | 18 +- vocode/streaming/action/dtmf.py | 10 +- vocode/streaming/action/end_conversation.py | 8 +- .../action/execute_external_action.py | 8 +- vocode/streaming/action/phone_call_action.py | 26 +-- .../action/streaming_conversation_action.py | 17 ++ vocode/streaming/action/transfer_call.py | 31 ++-- vocode/streaming/action/worker.py | 28 ++- vocode/streaming/agent/anthropic_agent.py | 2 +- vocode/streaming/agent/base_agent.py | 80 +++------ vocode/streaming/agent/chat_gpt_agent.py | 2 +- vocode/streaming/agent/groq_agent.py | 2 +- vocode/streaming/agent/langchain_agent.py | 2 +- .../agent/websocket_user_implemented_agent.py | 2 +- vocode/streaming/models/actions.py | 7 + vocode/streaming/models/pipeline.py | 1 + .../pipeline/abstract_pipeline_factory.py | 10 +- vocode/streaming/pipeline/audio_pipeline.py | 4 +- vocode/streaming/pipeline/worker.py | 8 +- vocode/streaming/streaming_conversation.py | 87 +++++---- .../abstract_phone_conversation.py | 13 +- .../conversation/twilio_phone_conversation.py | 46 ++++- .../conversation/vonage_phone_conversation.py | 47 ++++- .../telephony/server/router/calls.py | 22 +-- vocode/streaming/utils/state_manager.py | 170 ------------------ 34 files changed, 538 insertions(+), 566 deletions(-) create mode 100644 vocode/streaming/action/streaming_conversation_action.py delete mode 100644 vocode/streaming/utils/state_manager.py diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index e2f17857aa..1c5a76cf0c 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -21,7 +21,7 @@ from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.transcript import Transcript from vocode.streaming.pipeline.worker import InterruptibleAgentResponseEvent, QueueConsumer -from vocode.streaming.utils.state_manager import AbstractConversationStateManager +from vocode.streaming.streaming_conversation import StreamingConversation load_dotenv() @@ -75,7 +75,7 @@ def create_action(self, action_config: ActionConfig) -> BaseAction: raise Exception("Invalid action type") -class DummyConversationManager(AbstractConversationStateManager): +class DummyStreamingConversation(StreamingConversation): """For use with Agents operating in a non-call context.""" def __init__( @@ -192,7 +192,7 @@ async def sender(): ) actions_worker.consumer = agent agent.actions_consumer = actions_worker - actions_worker.attach_conversation_state_manager(agent.conversation_state_manager) + actions_worker.pipeline = agent.streaming_conversation actions_worker.start() await asyncio.gather(receiver(), sender()) @@ -226,7 +226,7 @@ async def agent_main(): ), action_factory=ShoutActionFactory(), ) - agent.attach_conversation_state_manager(DummyConversationManager()) + agent.streaming_conversation = DummyStreamingConversation() agent.attach_transcript(transcript) agent.start() diff --git a/tests/fakedata/conversation.py b/tests/fakedata/conversation.py index 69e407d33a..31563fd0f7 100644 --- a/tests/fakedata/conversation.py +++ b/tests/fakedata/conversation.py @@ -9,13 +9,24 @@ from vocode.streaming.models.audio import AudioEncoding from vocode.streaming.models.message import BaseMessage from vocode.streaming.models.synthesizer import PlayHtSynthesizerConfig, SynthesizerConfig +from vocode.streaming.models.telephony import PhoneCallDirection, TwilioConfig, VonageConfig from vocode.streaming.models.transcriber import DeepgramTranscriberConfig, TranscriberConfig from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.output_device.audio_chunk import ChunkState -from vocode.streaming.streaming_conversation import StreamingConversation +from vocode.streaming.streaming_conversation import ( + StreamingConversation, + StreamingConversationFactory, +) from vocode.streaming.synthesizer.base_synthesizer import BaseSynthesizer +from vocode.streaming.telephony.config_manager.base_config_manager import BaseConfigManager from vocode.streaming.telephony.constants import DEFAULT_CHUNK_SIZE, DEFAULT_SAMPLING_RATE -from vocode.streaming.transcriber.base_transcriber import BaseTranscriber +from vocode.streaming.telephony.conversation.twilio_phone_conversation import ( + TwilioPhoneConversation, +) +from vocode.streaming.telephony.conversation.vonage_phone_conversation import ( + VonagePhoneConversation, +) +from vocode.streaming.transcriber.base_transcriber import AbstractTranscriber, BaseTranscriber from vocode.streaming.transcriber.deepgram_transcriber import DeepgramEndpointingConfig from vocode.streaming.utils.events_manager import EventsManager @@ -95,18 +106,47 @@ def create_fake_transcriber(mocker: MockerFixture, transcriber_config: Transcrib return transcriber +def create_fake_transcriber_factory( + mocker: MockerFixture, transcriber: Optional[AbstractTranscriber] = None +): + factory = mocker.MagicMock() + factory.create_transcriber = mocker.MagicMock( + return_value=transcriber + or create_fake_transcriber(mocker, DEFAULT_DEEPGRAM_TRANSCRIBER_CONFIG) + ) + return factory + + def create_fake_agent(mocker: MockerFixture, agent_config: AgentConfig): agent = mocker.MagicMock() agent.get_agent_config = mocker.MagicMock(return_value=agent_config) return agent +def create_fake_agent_factory(mocker: MockerFixture, agent: Optional[BaseAgent] = None): + factory = mocker.MagicMock() + factory.create_agent = mocker.MagicMock( + return_value=agent or create_fake_agent(mocker, DEFAULT_CHAT_GPT_AGENT_CONFIG) + ) + return factory + + def create_fake_synthesizer(mocker: MockerFixture, synthesizer_config: SynthesizerConfig): synthesizer = mocker.MagicMock() synthesizer.get_synthesizer_config = mocker.MagicMock(return_value=synthesizer_config) return synthesizer +def create_fake_synthesizer_factory( + mocker: MockerFixture, synthesizer: Optional[BaseSynthesizer] = None +): + factory = mocker.MagicMock() + factory.create_synthesizer = mocker.MagicMock( + return_value=synthesizer or create_fake_synthesizer(mocker, DEFAULT_SYNTHESIZER_CONFIG) + ) + return factory + + def create_fake_streaming_conversation( mocker: MockerFixture, transcriber: Optional[BaseTranscriber[TranscriberConfig]] = None, @@ -132,3 +172,68 @@ def create_fake_streaming_conversation( conversation_id=conversation_id, events_manager=events_manager, ) + + +def create_fake_streaming_conversation_factory( + mocker: MockerFixture, + transcriber: Optional[BaseTranscriber[TranscriberConfig]] = None, + agent: Optional[BaseAgent] = None, + synthesizer: Optional[BaseSynthesizer] = None, +): + return StreamingConversationFactory( + transcriber_factory=create_fake_transcriber_factory(mocker, transcriber), + agent_factory=create_fake_agent_factory(mocker, agent), + synthesizer_factory=create_fake_synthesizer_factory(mocker, synthesizer), + ) + + +def create_fake_twilio_phone_conversation_with_streaming_conversation_pipeline( + mocker: MockerFixture, + streaming_conversation_factory: StreamingConversationFactory, + direction: PhoneCallDirection = "outbound", + from_phone: str = "+1234567890", + to_phone: str = "+0987654321", + base_url: str = "http://test.com", + twilio_sid: str = "test_sid", + twilio_config: Optional[TwilioConfig] = None, + config_manager: Optional[BaseConfigManager] = None, + events_manager: Optional[EventsManager] = None, +): + return TwilioPhoneConversation( + direction=direction, + from_phone=from_phone, + to_phone=to_phone, + base_url=base_url, + config_manager=config_manager, + pipeline_factory=streaming_conversation_factory, + pipeline_config=mocker.MagicMock(), + twilio_sid=twilio_sid, + twilio_config=twilio_config, + events_manager=events_manager, + ) + + +def create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline( + mocker: MockerFixture, + streaming_conversation_factory: StreamingConversationFactory, + direction: PhoneCallDirection = "outbound", + from_phone: str = "+1234567890", + to_phone: str = "+0987654321", + base_url: str = "http://test.com", + vonage_uuid: str = "test_uuid", + vonage_config: Optional[VonageConfig] = None, + config_manager: Optional[BaseConfigManager] = None, + events_manager: Optional[EventsManager] = None, +): + return VonagePhoneConversation( + direction=direction, + from_phone=from_phone, + to_phone=to_phone, + base_url=base_url, + config_manager=config_manager, + pipeline_factory=streaming_conversation_factory, + pipeline_config=mocker.MagicMock(), + vonage_uuid=vonage_uuid, + vonage_config=vonage_config or mocker.MagicMock(), + events_manager=events_manager, + ) diff --git a/tests/streaming/action/test_dtmf.py b/tests/streaming/action/test_dtmf.py index 24050c0525..a0834cf32e 100644 --- a/tests/streaming/action/test_dtmf.py +++ b/tests/streaming/action/test_dtmf.py @@ -1,6 +1,11 @@ import pytest from aioresponses import aioresponses +from tests.fakedata.conversation import ( + create_fake_agent, + create_fake_streaming_conversation_factory, + create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline, +) from tests.fakedata.id import generate_uuid from vocode.streaming.action.dtmf import ( DTMFParameters, @@ -8,13 +13,10 @@ TwilioDTMF, VonageDTMF, ) -from vocode.streaming.models.actions import ( - TwilioPhoneConversationActionInput, - VonagePhoneConversationActionInput, -) +from vocode.streaming.models.actions import ActionInput +from vocode.streaming.models.agent import ChatGPTAgentConfig from vocode.streaming.models.telephony import VonageConfig from vocode.streaming.utils import create_conversation_id -from vocode.streaming.utils.state_manager import VonagePhoneConversationStateManager @pytest.mark.asyncio @@ -23,22 +25,35 @@ async def test_vonage_dtmf_press_digits(mocker, mock_env): vonage_uuid = generate_uuid() digits = "1234" - vonage_phone_conversation_mock = mocker.MagicMock() vonage_config = VonageConfig( api_key="api_key", api_secret="api_secret", application_id="application_id", private_key="-----BEGIN PRIVATE KEY-----\nasdf\n-----END PRIVATE KEY-----", ) - vonage_phone_conversation_mock.vonage_config = vonage_config + vonage_phone_conversation_mock = ( + create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline( + mocker, + streaming_conversation_factory=create_fake_streaming_conversation_factory( + mocker, + agent=create_fake_agent( + mocker, + agent_config=ChatGPTAgentConfig( + prompt_preamble="", + actions=[action.action_config], + ), + ), + ), + vonage_config=vonage_config, + vonage_uuid=vonage_uuid, + ) + ) mocker.patch("vonage.Client._create_jwt_auth_string", return_value=b"asdf") - action.attach_conversation_state_manager( - VonagePhoneConversationStateManager(vonage_phone_conversation_mock) - ) + vonage_phone_conversation_mock.pipeline.actions_worker.attach_state(action) assert ( - action.conversation_state_manager.create_vonage_client().get_telephony_config() + vonage_phone_conversation_mock.create_vonage_client().get_telephony_config() == vonage_config ) @@ -48,11 +63,10 @@ async def test_vonage_dtmf_press_digits(mocker, mock_env): status=200, ) action_output = await action.run( - action_input=VonagePhoneConversationActionInput( + action_input=ActionInput( action_config=DTMFVocodeActionConfig(), conversation_id=create_conversation_id(), params=DTMFParameters(buttons=digits), - vonage_uuid=str(vonage_uuid), ) ) @@ -66,7 +80,7 @@ async def test_twilio_dtmf_press_digits(mocker, mock_env): twilio_sid = "twilio_sid" action_output = await action.run( - action_input=TwilioPhoneConversationActionInput( + action_input=ActionInput( action_config=DTMFVocodeActionConfig(), conversation_id=create_conversation_id(), params=DTMFParameters(buttons=digits), diff --git a/tests/streaming/action/test_end_conversation.py b/tests/streaming/action/test_end_conversation.py index 312c72285b..0e6e0eaeae 100644 --- a/tests/streaming/action/test_end_conversation.py +++ b/tests/streaming/action/test_end_conversation.py @@ -7,6 +7,7 @@ from pydantic.v1 import BaseModel from pytest_mock import MockerFixture +from tests.fakedata.conversation import create_fake_agent, create_fake_streaming_conversation from tests.fakedata.id import generate_uuid from vocode.streaming.action.end_conversation import ( EndConversation, @@ -14,10 +15,13 @@ EndConversationVocodeActionConfig, ) from vocode.streaming.models.actions import ( + ActionInput, TwilioPhoneConversationActionInput, VonagePhoneConversationActionInput, ) +from vocode.streaming.models.agent import ChatGPTAgentConfig from vocode.streaming.models.transcript import Transcript +from vocode.streaming.streaming_conversation import StreamingConversation from vocode.streaming.utils import create_conversation_id @@ -26,8 +30,6 @@ class Config: arbitrary_types_allowed = True action: EndConversation - vonage_uuid: UUID - twilio_sid: str conversation_id: str @@ -36,20 +38,10 @@ def end_conversation_action_test_case(mocker: MockerFixture) -> EndConversationA action = EndConversation(action_config=EndConversationVocodeActionConfig()) return EndConversationActionTestCase( action=action, - vonage_uuid=generate_uuid(), - twilio_sid="twilio_sid", conversation_id=create_conversation_id(), ) -@pytest.fixture -def conversation_state_manager_mock(mocker: MockerFixture) -> MagicMock: - mock = mocker.MagicMock() - mock.terminate_conversation = mocker.AsyncMock() - mock.transcript = Transcript() - return mock - - @pytest.fixture def user_message_tracker() -> asyncio.Event: tracker = asyncio.Event() @@ -57,76 +49,68 @@ def user_message_tracker() -> asyncio.Event: return tracker +@pytest.fixture +def mock_streaming_conversation_with_end_conversation_action( + mocker, end_conversation_action_test_case: EndConversationActionTestCase +): + mock_streaming_conversation = create_fake_streaming_conversation( + mocker, + agent=create_fake_agent( + mocker, + agent_config=ChatGPTAgentConfig( + prompt_preamble="", actions=[end_conversation_action_test_case.action.action_config] + ), + ), + ) + mock_streaming_conversation.actions_worker.attach_state( + end_conversation_action_test_case.action + ) + mock_streaming_conversation.active = True + return mock_streaming_conversation + + @pytest.mark.asyncio -@pytest.mark.parametrize( - "action_input_class, identifier", - [ - (VonagePhoneConversationActionInput, "vonage_uuid"), - (TwilioPhoneConversationActionInput, "twilio_sid"), - ], -) async def test_end_conversation_success( mocker: MockerFixture, mock_env: Generator, + mock_streaming_conversation_with_end_conversation_action: StreamingConversation, end_conversation_action_test_case: EndConversationActionTestCase, - conversation_state_manager_mock: MagicMock, user_message_tracker: asyncio.Event, - action_input_class: Type[BaseModel], - identifier: str, ): - end_conversation_action_test_case.action.attach_conversation_state_manager( - conversation_state_manager_mock - ) - identifier_value = getattr(end_conversation_action_test_case, identifier) - action_input = action_input_class( + action_input = ActionInput( action_config=EndConversationVocodeActionConfig(), conversation_id=end_conversation_action_test_case.conversation_id, params=EndConversationParameters(), - **{identifier: str(identifier_value)}, user_message_tracker=user_message_tracker, ) response = await end_conversation_action_test_case.action.run(action_input=action_input) assert response.response.success - assert conversation_state_manager_mock.terminate_conversation.call_count == 1 + assert not mock_streaming_conversation_with_end_conversation_action.is_active() @pytest.mark.asyncio -@pytest.mark.parametrize( - "action_input_class, identifier", - [ - (VonagePhoneConversationActionInput, "vonage_uuid"), - (TwilioPhoneConversationActionInput, "twilio_sid"), - ], -) async def test_end_conversation_fails_if_interrupted( mocker: MockerFixture, mock_env: Generator, + mock_streaming_conversation_with_end_conversation_action: StreamingConversation, end_conversation_action_test_case: EndConversationActionTestCase, - conversation_state_manager_mock: MagicMock, user_message_tracker: asyncio.Event, - action_input_class: Type[BaseModel], - identifier: str, ): - conversation_state_manager_mock.transcript.add_bot_message( + mock_streaming_conversation_with_end_conversation_action.transcript.add_bot_message( "Unfinished", conversation_id=end_conversation_action_test_case.conversation_id ) - end_conversation_action_test_case.action.attach_conversation_state_manager( - conversation_state_manager_mock - ) - identifier_value = getattr(end_conversation_action_test_case, identifier) - action_input = action_input_class( + action_input = ActionInput( action_config=EndConversationVocodeActionConfig(), conversation_id=end_conversation_action_test_case.conversation_id, params=EndConversationParameters(), - **{identifier: str(identifier_value)}, user_message_tracker=user_message_tracker, ) response = await end_conversation_action_test_case.action.run(action_input=action_input) assert not response.response.success - assert conversation_state_manager_mock.terminate_conversation.call_count == 0 + assert mock_streaming_conversation_with_end_conversation_action.is_active() diff --git a/tests/streaming/action/test_external_actions.py b/tests/streaming/action/test_external_actions.py index 315f670d9f..c35b3a09e8 100644 --- a/tests/streaming/action/test_external_actions.py +++ b/tests/streaming/action/test_external_actions.py @@ -3,7 +3,15 @@ import os import pytest - +from pytest_mock import MockerFixture + +from tests.fakedata.conversation import ( + create_fake_agent, + create_fake_streaming_conversation, + create_fake_streaming_conversation_factory, + create_fake_twilio_phone_conversation_with_streaming_conversation_pipeline, + create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline, +) from tests.fakedata.id import generate_uuid from vocode.streaming.action.execute_external_action import ( ExecuteExternalAction, @@ -11,15 +19,21 @@ ExecuteExternalActionVocodeActionConfig, ) from vocode.streaming.action.external_actions_requester import ExternalActionResponse +from vocode.streaming.agent.base_agent import BaseAgent from vocode.streaming.models.actions import ( + ActionInput, TwilioPhoneConversationActionInput, VonagePhoneConversationActionInput, ) -from vocode.streaming.utils import create_conversation_id -from vocode.streaming.utils.state_manager import ( - TwilioPhoneConversationStateManager, - VonagePhoneConversationStateManager, +from vocode.streaming.models.agent import ChatGPTAgentConfig +from vocode.streaming.streaming_conversation import StreamingConversation +from vocode.streaming.telephony.conversation.twilio_phone_conversation import ( + TwilioPhoneConversation, ) +from vocode.streaming.telephony.conversation.vonage_phone_conversation import ( + VonagePhoneConversation, +) +from vocode.streaming.utils import create_conversation_id ACTION_INPUT_SCHEMA: dict = { "type": "object", @@ -62,59 +76,32 @@ def execute_action_setup(mocker, action_config) -> ExecuteExternalAction: @pytest.fixture -def mock_twilio_conversation_state_manager(mocker) -> TwilioPhoneConversationStateManager: - """Fixture to mock TwilioPhoneConversationStateManager.""" - manager = mocker.MagicMock(spec=TwilioPhoneConversationStateManager) - manager.mute_agent = mocker.MagicMock() - # Add any other necessary mock setup here - return manager +def mock_agent_with_execute_external_action(mocker: MockerFixture, action_config) -> BaseAgent: + return create_fake_agent( + mocker, + agent_config=ChatGPTAgentConfig(prompt_preamble="", actions=[action_config]), + ) @pytest.fixture -def mock_vonage_conversation_state_manager(mocker) -> VonagePhoneConversationStateManager: - """Fixture to mock VonagePhoneConversationStateManager.""" - manager = mocker.MagicMock(spec=VonagePhoneConversationStateManager) - manager.mute_agent = mocker.MagicMock() - # Add any other necessary mock setup here - return manager +def mock_streaming_conversation( + mocker: MockerFixture, mock_agent_with_execute_external_action: BaseAgent +) -> StreamingConversation: + return create_fake_streaming_conversation(mocker, agent=mock_agent_with_execute_external_action) @pytest.mark.asyncio -async def test_vonage_execute_external_action_success( +async def test_execute_external_action_success( mocker, - mock_vonage_conversation_state_manager: VonagePhoneConversationStateManager, + mock_streaming_conversation: StreamingConversation, execute_action_setup: ExecuteExternalAction, ): - execute_action_setup.attach_conversation_state_manager(mock_vonage_conversation_state_manager) - vonage_uuid = generate_uuid() - - response = await execute_action_setup.run( - action_input=VonagePhoneConversationActionInput( - action_config=execute_action_setup.action_config, - conversation_id=create_conversation_id(), - params=ExecuteExternalActionParameters(payload={}), - vonage_uuid=str(vonage_uuid), - ), - ) - - assert response.response.success - assert response.response.result == {"test": "test"} - - -@pytest.mark.asyncio -async def test_twilio_execute_external_action_success( - mocker, - mock_twilio_conversation_state_manager: TwilioPhoneConversationStateManager, - execute_action_setup: ExecuteExternalAction, -): - execute_action_setup.attach_conversation_state_manager(mock_twilio_conversation_state_manager) - + mock_streaming_conversation.actions_worker.attach_state(execute_action_setup) response = await execute_action_setup.run( - action_input=TwilioPhoneConversationActionInput( + action_input=ActionInput( action_config=execute_action_setup.action_config, conversation_id=create_conversation_id(), params=ExecuteExternalActionParameters(payload={}), - twilio_sid="twilio_sid", ), ) diff --git a/tests/streaming/action/test_record_email.py b/tests/streaming/action/test_record_email.py index 7864dd37ff..1586725e17 100644 --- a/tests/streaming/action/test_record_email.py +++ b/tests/streaming/action/test_record_email.py @@ -6,10 +6,7 @@ RecordEmailParameters, RecordEmailVocodeActionConfig, ) -from vocode.streaming.models.actions import ( - TwilioPhoneConversationActionInput, - VonagePhoneConversationActionInput, -) +from vocode.streaming.models.actions import ActionInput from vocode.streaming.utils import create_conversation_id # id is just a description of the parameterized test case's input @@ -40,38 +37,17 @@ def record_email_action() -> RecordEmail: @pytest.mark.asyncio @pytest.mark.parametrize("email_input,expected_success", EMAIL_TEST_CASES) -async def test_vonage_email_validation( - record_email_action: RecordEmail, email_input: str, expected_success: bool -): - vonage_uuid = generate_uuid() - res = await record_email_action.run( - action_input=VonagePhoneConversationActionInput( - action_config=RecordEmailVocodeActionConfig(), - conversation_id=create_conversation_id(), - params=RecordEmailParameters( - raw_value="", - formatted_value=email_input, - ), - vonage_uuid=str(vonage_uuid), - ), - ) - assert res.response.success == expected_success - - -@pytest.mark.asyncio -@pytest.mark.parametrize("email_input,expected_success", EMAIL_TEST_CASES) -async def test_twilio_email_validation( +async def test_email_validation( record_email_action: RecordEmail, email_input: str, expected_success: bool ): res = await record_email_action.run( - action_input=TwilioPhoneConversationActionInput( + action_input=ActionInput( action_config=RecordEmailVocodeActionConfig(), conversation_id=create_conversation_id(), params=RecordEmailParameters( raw_value="", formatted_value=email_input, ), - twilio_sid="twilio_sid", ), ) assert res.response.success == expected_success diff --git a/tests/streaming/action/test_transfer_call.py b/tests/streaming/action/test_transfer_call.py index f4111d877e..776eab10c7 100644 --- a/tests/streaming/action/test_transfer_call.py +++ b/tests/streaming/action/test_transfer_call.py @@ -7,6 +7,13 @@ from aioresponses import aioresponses from pytest_mock import MockerFixture +from tests.fakedata.conversation import ( + create_fake_agent, + create_fake_streaming_conversation, + create_fake_streaming_conversation_factory, + create_fake_twilio_phone_conversation_with_streaming_conversation_pipeline, + create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline, +) from tests.fakedata.id import generate_uuid from vocode.streaming.action.transfer_call import ( TransferCallEmptyParameters, @@ -14,18 +21,23 @@ TwilioTransferCall, VonageTransferCall, ) +from vocode.streaming.agent.base_agent import BaseAgent from vocode.streaming.models.actions import ( TwilioPhoneConversationActionInput, VonagePhoneConversationActionInput, ) +from vocode.streaming.models.agent import ChatGPTAgentConfig from vocode.streaming.models.events import Sender from vocode.streaming.models.telephony import TwilioConfig, VonageConfig from vocode.streaming.models.transcript import Message, Transcript -from vocode.streaming.utils import create_conversation_id -from vocode.streaming.utils.state_manager import ( - TwilioPhoneConversationStateManager, - VonagePhoneConversationStateManager, +from vocode.streaming.streaming_conversation import StreamingConversation +from vocode.streaming.telephony.conversation.twilio_phone_conversation import ( + TwilioPhoneConversation, +) +from vocode.streaming.telephony.conversation.vonage_phone_conversation import ( + VonagePhoneConversation, ) +from vocode.streaming.utils import create_conversation_id TRANSFER_PHONE_NUMBER = "12345678920" @@ -49,38 +61,51 @@ def mock_vonage_config(): @pytest.fixture -def mock_twilio_phone_conversation(mock_twilio_config) -> MagicMock: - twilio_phone_conversation = MagicMock() - twilio_phone_conversation.twilio_config = mock_twilio_config - return twilio_phone_conversation +def mock_agent_with_transfer_call_action(mocker: MockerFixture) -> BaseAgent: + return create_fake_agent( + mocker, + agent_config=ChatGPTAgentConfig( + prompt_preamble="", + actions=[TransferCallVocodeActionConfig(phone_number=TRANSFER_PHONE_NUMBER)], + ), + ) @pytest.fixture -def mock_vonage_phone_conversation(mock_vonage_config) -> MagicMock: - vonage_phone_conversation = MagicMock() - vonage_phone_conversation.vonage_config = mock_vonage_config - return vonage_phone_conversation +def mock_streaming_conversation_factory( + mocker: MockerFixture, mock_agent_with_transfer_call_action: BaseAgent +) -> StreamingConversation: + return create_fake_streaming_conversation_factory( + mocker, agent=mock_agent_with_transfer_call_action + ) @pytest.fixture -def mock_twilio_conversation_state_manager( - mocker: Any, mock_twilio_phone_conversation: MagicMock -) -> TwilioPhoneConversationStateManager: - return TwilioPhoneConversationStateManager(mock_twilio_phone_conversation) +def mock_twilio_phone_conversation( + mocker: MockerFixture, mock_twilio_config, mock_streaming_conversation_factory +) -> TwilioPhoneConversation: + return create_fake_twilio_phone_conversation_with_streaming_conversation_pipeline( + mocker, + streaming_conversation_factory=mock_streaming_conversation_factory, + twilio_config=mock_twilio_config, + ) @pytest.fixture -def mock_vonage_conversation_state_manager( - mocker: Any, mock_vonage_phone_conversation: MagicMock -) -> VonagePhoneConversationStateManager: - return VonagePhoneConversationStateManager(mock_vonage_phone_conversation) +def mock_vonage_phone_conversation( + mocker: MockerFixture, mock_vonage_config, mock_streaming_conversation_factory +) -> VonagePhoneConversation: + return create_fake_vonage_phone_conversation_with_streaming_conversation_pipeline( + mocker, + streaming_conversation_factory=mock_streaming_conversation_factory, + vonage_config=mock_vonage_config, + ) @pytest.mark.asyncio async def test_twilio_transfer_call_succeeds( mocker: Any, - mock_twilio_conversation_state_manager: TwilioPhoneConversationStateManager, - mock_twilio_phone_conversation: MagicMock, + mock_twilio_phone_conversation: TwilioPhoneConversation, mock_twilio_config: TwilioConfig, ): action = TwilioTransferCall( @@ -88,7 +113,8 @@ async def test_twilio_transfer_call_succeeds( ) user_message_tracker = asyncio.Event() user_message_tracker.set() - action.attach_conversation_state_manager(mock_twilio_conversation_state_manager) + + mock_twilio_phone_conversation.pipeline.actions_worker.attach_state(action) conversation_id = create_conversation_id() twilio_sid = "twilio_sid" @@ -100,8 +126,6 @@ async def test_twilio_transfer_call_succeeds( user_message_tracker=user_message_tracker, ) - mock_twilio_phone_conversation.transcript = Transcript(event_logs=[]) - with aioresponses() as m: m.post( "https://api.twilio.com/2010-04-01/Accounts/{twilio_account_sid}/Calls/{twilio_call_sid}.json".format( @@ -126,15 +150,16 @@ async def test_twilio_transfer_call_succeeds( @pytest.mark.asyncio async def test_twilio_transfer_call_fails_if_interrupted( mocker: Any, - mock_twilio_conversation_state_manager: TwilioPhoneConversationStateManager, - mock_twilio_phone_conversation: MagicMock, + mock_twilio_phone_conversation: TwilioPhoneConversation, ) -> None: action = TwilioTransferCall( action_config=TransferCallVocodeActionConfig(phone_number=TRANSFER_PHONE_NUMBER), ) user_message_tracker = asyncio.Event() user_message_tracker.set() - action.attach_conversation_state_manager(mock_twilio_conversation_state_manager) + + mock_twilio_phone_conversation.pipeline.actions_worker.attach_state(action) + conversation_id = create_conversation_id() inner_transfer_call_mock = mocker.patch( @@ -142,7 +167,7 @@ async def test_twilio_transfer_call_fails_if_interrupted( autospec=True, ) - mock_twilio_phone_conversation.transcript = Transcript( + mock_twilio_phone_conversation.pipeline.transcript = Transcript( event_logs=[ Message( sender=Sender.BOT, @@ -170,7 +195,7 @@ async def test_twilio_transfer_call_fails_if_interrupted( async def test_vonage_transfer_call_inbound( mocker: MockerFixture, mock_env, - mock_vonage_conversation_state_manager: VonagePhoneConversationStateManager, + mock_vonage_phone_conversation: VonagePhoneConversation, ) -> None: transfer_phone_number = "12345678920" action = VonageTransferCall( @@ -181,13 +206,13 @@ async def test_vonage_transfer_call_inbound( vonage_uuid = generate_uuid() - mock_vonage_conversation_state_manager._vonage_phone_conversation.direction = "inbound" - mock_vonage_conversation_state_manager._vonage_phone_conversation.to_phone = "1234567894" - mock_vonage_conversation_state_manager._vonage_phone_conversation.from_phone = "1234567895" + mock_vonage_phone_conversation.direction = "inbound" + mock_vonage_phone_conversation.to_phone = "1234567894" + mock_vonage_phone_conversation.from_phone = "1234567895" conversation_id = create_conversation_id() - action.attach_conversation_state_manager(mock_vonage_conversation_state_manager) + mock_vonage_phone_conversation.pipeline.actions_worker.attach_state(action) user_message_tracker = asyncio.Event() user_message_tracker.set() @@ -218,5 +243,5 @@ async def test_vonage_transfer_call_inbound( assert ncco[0]["endpoint"][0]["number"] == transfer_phone_number assert ( ncco[0]["from"] - == mock_vonage_conversation_state_manager._vonage_phone_conversation.to_phone # if inbound, the agent number is the to_phone + == mock_vonage_phone_conversation.to_phone # if inbound, the agent number is the to_phone ) diff --git a/tests/streaming/agent/test_base_agent.py b/tests/streaming/agent/test_base_agent.py index a95e74f022..9dfea4931d 100644 --- a/tests/streaming/agent/test_base_agent.py +++ b/tests/streaming/agent/test_base_agent.py @@ -23,7 +23,6 @@ InterruptibleEvent, QueueConsumer, ) -from vocode.streaming.utils.state_manager import ConversationStateManager @pytest.fixture(autouse=True) @@ -41,17 +40,12 @@ def _create_agent( agent_config: ChatGPTAgentConfig, transcript: Optional[Transcript] = None, action_factory: Optional[AbstractActionFactory] = None, - conversation_state_manager: Optional[ConversationStateManager] = None, ) -> ChatGPTAgent: agent = ChatGPTAgent(agent_config, action_factory=action_factory) if transcript: agent.attach_transcript(transcript) else: agent.attach_transcript(Transcript()) - if conversation_state_manager: - agent.attach_conversation_state_manager(conversation_state_manager) - else: - agent.attach_conversation_state_manager(mocker.MagicMock()) return agent diff --git a/vocode/streaming/action/base_action.py b/vocode/streaming/action/base_action.py index e1e07f028b..a0e742cbf8 100644 --- a/vocode/streaming/action/base_action.py +++ b/vocode/streaming/action/base_action.py @@ -14,13 +14,14 @@ ) if TYPE_CHECKING: - from vocode.streaming.utils.state_manager import AbstractConversationStateManager + from vocode.streaming.pipeline.audio_pipeline import AudioPipeline ActionConfigType = TypeVar("ActionConfigType", bound=ActionConfig) class BaseAction(Generic[ActionConfigType, ParametersType, ResponseType]): # type: ignore description: str = "" + pipeline: "AudioPipeline" def __init__( self, @@ -35,11 +36,6 @@ def __init__( self.quiet = quiet self.is_interruptible = is_interruptible - def attach_conversation_state_manager( - self, conversation_state_manager: "AbstractConversationStateManager" - ): - self.conversation_state_manager = conversation_state_manager - async def run(self, action_input: ActionInput[ParametersType]) -> ActionOutput[ResponseType]: raise NotImplementedError diff --git a/vocode/streaming/action/default_factory.py b/vocode/streaming/action/default_factory.py index 227ed56c42..6111b10e8b 100644 --- a/vocode/streaming/action/default_factory.py +++ b/vocode/streaming/action/default_factory.py @@ -33,15 +33,23 @@ class DefaultActionFactory(AbstractActionFactory): - def __init__(self, actions: Sequence[ActionConfig] | dict = {}): - - self.action_configs_dict = {action.type: action for action in actions} - self.actions = CONVERSATION_ACTIONS + def __init__(self): + self.actions = CONVERSATION_ACTIONS # TODO (DOW-119): StreamingConversationActionFactory def create_action(self, action_config: ActionConfig): - if action_config.type not in self.action_configs_dict: + if action_config.type not in self.actions: raise Exception("Action type not supported by Agent config.") action_class = self.actions[action_config.type] return action_class(action_config) + + +class DefaultTwilioPhoneConversationActionFactory(DefaultActionFactory): + def __init__(self): + self.actions = {**TWILIO_ACTIONS, **CONVERSATION_ACTIONS} + + +class DefaultVonagePhoneConversationActionFactory(DefaultActionFactory): + def __init__(self): + self.actions = {**VONAGE_ACTIONS, **CONVERSATION_ACTIONS} diff --git a/vocode/streaming/action/dtmf.py b/vocode/streaming/action/dtmf.py index 392b8eef6c..d41302ed82 100644 --- a/vocode/streaming/action/dtmf.py +++ b/vocode/streaming/action/dtmf.py @@ -9,10 +9,6 @@ ) from vocode.streaming.models.actions import ActionConfig as VocodeActionConfig from vocode.streaming.models.actions import ActionInput, ActionOutput -from vocode.streaming.utils.state_manager import ( - TwilioPhoneConversationStateManager, - VonagePhoneConversationStateManager, -) class DTMFParameters(BaseModel): @@ -43,16 +39,15 @@ class VonageDTMF( description: str = FUNCTION_DESCRIPTION parameters_type: Type[DTMFParameters] = DTMFParameters response_type: Type[DTMFResponse] = DTMFResponse - conversation_state_manager: VonagePhoneConversationStateManager def __init__(self, action_config: DTMFVocodeActionConfig): super().__init__(action_config, quiet=True) async def run(self, action_input: ActionInput[DTMFParameters]) -> ActionOutput[DTMFResponse]: buttons = action_input.params.buttons - vonage_client = self.conversation_state_manager.create_vonage_client() + vonage_client = self.vonage_phone_conversation.create_vonage_client() await vonage_client.send_dtmf( - vonage_uuid=self.get_vonage_uuid(action_input), digits=buttons + vonage_uuid=self.vonage_phone_conversation.vonage_uuid, digits=buttons ) return ActionOutput( @@ -67,7 +62,6 @@ class TwilioDTMF( description: str = FUNCTION_DESCRIPTION parameters_type: Type[DTMFParameters] = DTMFParameters response_type: Type[DTMFResponse] = DTMFResponse - conversation_state_manager: TwilioPhoneConversationStateManager def __init__(self, action_config: DTMFVocodeActionConfig): super().__init__( diff --git a/vocode/streaming/action/end_conversation.py b/vocode/streaming/action/end_conversation.py index 7175ad8461..548df660d0 100644 --- a/vocode/streaming/action/end_conversation.py +++ b/vocode/streaming/action/end_conversation.py @@ -3,7 +3,7 @@ from loguru import logger from pydantic.v1 import BaseModel -from vocode.streaming.action.base_action import BaseAction +from vocode.streaming.action.streaming_conversation_action import StreamingConversationAction from vocode.streaming.models.actions import ActionConfig as VocodeActionConfig from vocode.streaming.models.actions import ActionInput, ActionOutput @@ -41,7 +41,7 @@ def action_result_to_string(self, input: ActionInput, output: ActionOutput) -> s class EndConversation( - BaseAction[ + StreamingConversationAction[ EndConversationVocodeActionConfig, EndConversationParameters, EndConversationResponse, @@ -73,14 +73,14 @@ async def run( if action_input.user_message_tracker is not None: await action_input.user_message_tracker.wait() - if self.conversation_state_manager.transcript.was_last_message_interrupted(): + if self.pipeline.transcript.was_last_message_interrupted(): logger.info("Last bot message was interrupted") return ActionOutput( action_type=action_input.action_config.type, response=EndConversationResponse(success=False), ) - await self.conversation_state_manager.terminate_conversation() + self.pipeline.mark_terminated() await self._end_of_run_hook() return ActionOutput( diff --git a/vocode/streaming/action/execute_external_action.py b/vocode/streaming/action/execute_external_action.py index c689b206d6..35b67ddbf2 100644 --- a/vocode/streaming/action/execute_external_action.py +++ b/vocode/streaming/action/execute_external_action.py @@ -3,11 +3,11 @@ from pydantic.v1 import BaseModel -from vocode.streaming.action.base_action import BaseAction from vocode.streaming.action.external_actions_requester import ( ExternalActionResponse, ExternalActionsRequester, ) +from vocode.streaming.action.streaming_conversation_action import StreamingConversationAction from vocode.streaming.models.actions import ActionConfig as VocodeActionConfig from vocode.streaming.models.actions import ActionInput, ActionOutput, ExternalActionProcessingMode from vocode.streaming.models.message import BaseMessage @@ -36,7 +36,7 @@ class ExecuteExternalActionResponse(BaseModel): class ExecuteExternalAction( - BaseAction[ + StreamingConversationAction[ ExecuteExternalActionVocodeActionConfig, ExecuteExternalActionParameters, ExecuteExternalActionResponse, @@ -92,9 +92,9 @@ async def run( if self.should_respond and action_input.user_message_tracker is not None: await action_input.user_message_tracker.wait() - self.conversation_state_manager.mute_agent() + self.pipeline.agent.is_muted = True response = await self.send_external_action_request(action_input) - self.conversation_state_manager.unmute_agent() + self.pipeline.agent.is_muted = False # TODO (EA): pass specific context based on error return ActionOutput( diff --git a/vocode/streaming/action/phone_call_action.py b/vocode/streaming/action/phone_call_action.py index f780413fc3..49db053fea 100644 --- a/vocode/streaming/action/phone_call_action.py +++ b/vocode/streaming/action/phone_call_action.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from vocode.streaming.action.base_action import ActionConfigType, BaseAction from vocode.streaming.models.actions import ( @@ -9,13 +9,19 @@ TwilioPhoneConversationActionInput, VonagePhoneConversationActionInput, ) -from vocode.streaming.utils.state_manager import ( - TwilioPhoneConversationStateManager, - VonagePhoneConversationStateManager, -) + +if TYPE_CHECKING: + from vocode.streaming.telephony.conversation.twilio_phone_conversation import ( + TwilioPhoneConversation, + ) + from vocode.streaming.telephony.conversation.vonage_phone_conversation import ( + VonagePhoneConversation, + ) class VonagePhoneConversationAction(BaseAction[ActionConfigType, ParametersType, ResponseType]): + vonage_phone_conversation: "VonagePhoneConversation" + def create_phone_conversation_action_input( self, conversation_id: str, @@ -37,12 +43,10 @@ def get_vonage_uuid(self, action_input: ActionInput[ParametersType]) -> str: assert isinstance(action_input, VonagePhoneConversationActionInput) return action_input.vonage_uuid - def attach_conversation_state_manager(self, conversation_state_manager: Any): - assert isinstance(conversation_state_manager, VonagePhoneConversationStateManager) - self.conversation_state_manager = conversation_state_manager - class TwilioPhoneConversationAction(BaseAction[ActionConfigType, ParametersType, ResponseType]): + twilio_phone_conversation: "TwilioPhoneConversation" + def create_phone_conversation_action_input( self, conversation_id: str, @@ -63,7 +67,3 @@ def create_phone_conversation_action_input( def get_twilio_sid(self, action_input: ActionInput[ParametersType]) -> str: assert isinstance(action_input, TwilioPhoneConversationActionInput) return action_input.twilio_sid - - def attach_conversation_state_manager(self, conversation_state_manager: Any): - assert isinstance(conversation_state_manager, TwilioPhoneConversationStateManager) - self.conversation_state_manager = conversation_state_manager diff --git a/vocode/streaming/action/streaming_conversation_action.py b/vocode/streaming/action/streaming_conversation_action.py new file mode 100644 index 0000000000..2e885e159c --- /dev/null +++ b/vocode/streaming/action/streaming_conversation_action.py @@ -0,0 +1,17 @@ +from typing import TYPE_CHECKING + +from vocode.streaming.action.base_action import ActionConfigType, BaseAction +from vocode.streaming.models.actions import ( + ActionInput, + ParametersType, + ResponseType, + TwilioPhoneConversationActionInput, + VonagePhoneConversationActionInput, +) + +if TYPE_CHECKING: + from vocode.streaming.streaming_conversation import StreamingConversation + + +class StreamingConversationAction(BaseAction[ActionConfigType, ParametersType, ResponseType]): + pipeline: "StreamingConversation" diff --git a/vocode/streaming/action/transfer_call.py b/vocode/streaming/action/transfer_call.py index 397dd7a4ef..77ec6533a0 100644 --- a/vocode/streaming/action/transfer_call.py +++ b/vocode/streaming/action/transfer_call.py @@ -7,14 +7,11 @@ TwilioPhoneConversationAction, VonagePhoneConversationAction, ) +from vocode.streaming.action.streaming_conversation_action import StreamingConversationAction from vocode.streaming.models.actions import ActionConfig as VocodeActionConfig from vocode.streaming.models.actions import ActionInput, ActionOutput from vocode.streaming.utils.async_requester import AsyncRequestor from vocode.streaming.utils.phone_numbers import sanitize_phone_number -from vocode.streaming.utils.state_manager import ( - TwilioPhoneConversationStateManager, - VonagePhoneConversationStateManager, -) class TransferCallEmptyParameters(BaseModel): @@ -66,13 +63,15 @@ def action_result_to_string(self, input: ActionInput, output: ActionOutput) -> s class TwilioTransferCall( + StreamingConversationAction[ + TransferCallVocodeActionConfig, TransferCallParameters, TransferCallResponse + ], TwilioPhoneConversationAction[ TransferCallVocodeActionConfig, TransferCallParameters, TransferCallResponse - ] + ], ): description: str = FUNCTION_DESCRIPTION response_type: Type[TransferCallResponse] = TransferCallResponse - conversation_state_manager: TwilioPhoneConversationStateManager @property def parameters_type(self) -> Type[TransferCallParameters]: @@ -93,7 +92,7 @@ def __init__( ) async def transfer_call(self, twilio_call_sid: str, to_phone: str): - twilio_client = self.conversation_state_manager.create_twilio_client() + twilio_client = self.twilio_phone_conversation.create_twilio_client() url = "https://api.twilio.com/2010-04-01/Accounts/{twilio_account_sid}/Calls/{twilio_call_sid}.json".format( twilio_account_sid=twilio_client.get_telephony_config().account_sid, @@ -125,7 +124,7 @@ async def run( logger.info("Finished waiting for user message tracker, now attempting to transfer call") - if self.conversation_state_manager.transcript.was_last_message_interrupted(): + if self.pipeline.transcript.was_last_message_interrupted(): logger.info("Last bot message was interrupted, not transferring call") return ActionOutput( action_type=action_input.action_config.type, @@ -141,13 +140,15 @@ async def run( class VonageTransferCall( + StreamingConversationAction[ + TransferCallVocodeActionConfig, TransferCallParameters, TransferCallResponse + ], VonagePhoneConversationAction[ TransferCallVocodeActionConfig, TransferCallParameters, TransferCallResponse - ] + ], ): description: str = FUNCTION_DESCRIPTION response_type: Type[TransferCallResponse] = TransferCallResponse - conversation_state_manager: VonagePhoneConversationStateManager @property def parameters_type(self) -> Type[TransferCallParameters]: @@ -169,17 +170,17 @@ async def run( ) -> ActionOutput[TransferCallResponse]: if action_input.user_message_tracker is not None: await action_input.user_message_tracker.wait() - self.conversation_state_manager.mute_agent() + self.pipeline.agent.is_muted = True phone_number = self.action_config.get_phone_number(action_input) sanitized_phone_number = sanitize_phone_number(phone_number) - if self.conversation_state_manager.get_direction() == "outbound": - agent_phone_number = self.conversation_state_manager.get_from_phone() + if self.vonage_phone_conversation.direction == "outbound": + agent_phone_number = self.vonage_phone_conversation.from_phone else: - agent_phone_number = self.conversation_state_manager.get_to_phone() + agent_phone_number = self.vonage_phone_conversation.to_phone - await self.conversation_state_manager.create_vonage_client().update_call( + await self.vonage_phone_conversation.create_vonage_client().update_call( vonage_uuid=self.get_vonage_uuid(action_input), new_ncco=[ { diff --git a/vocode/streaming/action/worker.py b/vocode/streaming/action/worker.py index 2b3e486b65..ac668e8da3 100644 --- a/vocode/streaming/action/worker.py +++ b/vocode/streaming/action/worker.py @@ -1,26 +1,24 @@ from __future__ import annotations -import asyncio +from typing import TYPE_CHECKING from vocode.streaming.action.abstract_factory import AbstractActionFactory -from vocode.streaming.action.default_factory import DefaultActionFactory -from vocode.streaming.agent.base_agent import ActionResultAgentInput, AgentInput -from vocode.streaming.models.actions import ( - ActionInput, - TwilioPhoneConversationActionInput, - VonagePhoneConversationActionInput, -) +from vocode.streaming.action.base_action import BaseAction +from vocode.streaming.models.actions import ActionInput, ActionResponse from vocode.streaming.pipeline.worker import ( AbstractWorker, InterruptibleEvent, InterruptibleEventFactory, InterruptibleWorker, ) -from vocode.streaming.utils.state_manager import AbstractConversationStateManager + +if TYPE_CHECKING: + from vocode.streaming.pipeline.audio_pipeline import AudioPipeline class ActionsWorker(InterruptibleWorker[InterruptibleEvent[ActionInput]]): - consumer: AbstractWorker[InterruptibleEvent[ActionResultAgentInput]] + consumer: AbstractWorker[InterruptibleEvent[ActionResponse]] + pipeline: "AudioPipeline" def __init__( self, @@ -32,19 +30,17 @@ def __init__( ) self.action_factory = action_factory - def attach_conversation_state_manager( - self, conversation_state_manager: AbstractConversationStateManager - ): - self.conversation_state_manager = conversation_state_manager + def attach_state(self, action: BaseAction): + action.pipeline = self.pipeline async def process(self, item: InterruptibleEvent[ActionInput]): action_input = item.payload action = self.action_factory.create_action(action_input.action_config) - action.attach_conversation_state_manager(self.conversation_state_manager) + self.attach_state(action) action_output = await action.run(action_input) self.consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_event( - ActionResultAgentInput( + ActionResponse( conversation_id=action_input.conversation_id, action_input=action_input, action_output=action_output, diff --git a/vocode/streaming/agent/anthropic_agent.py b/vocode/streaming/agent/anthropic_agent.py index 820fa6539a..d410bb82bf 100644 --- a/vocode/streaming/agent/anthropic_agent.py +++ b/vocode/streaming/agent/anthropic_agent.py @@ -91,7 +91,7 @@ async def generate_response( response_generator = collate_response_async using_input_streaming_synthesizer = ( - self.conversation_state_manager.using_input_streaming_synthesizer() + self.streaming_conversation.using_input_streaming_synthesizer() ) if using_input_streaming_synthesizer: response_generator = stream_response_async diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index ee09f0d642..217c526ef7 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -26,6 +26,7 @@ ActionConfig, ActionInput, ActionOutput, + ActionResponse, EndOfTurn, FunctionCall, ) @@ -44,14 +45,10 @@ ) from vocode.streaming.utils import unrepeating_randomizer from vocode.streaming.utils.speed_manager import SpeedManager -from vocode.streaming.utils.state_manager import ( - TwilioPhoneConversationStateManager, - VonagePhoneConversationStateManager, -) from vocode.utils.sentry_utils import CustomSentrySpans, sentry_create_span if TYPE_CHECKING: - from vocode.streaming.utils.state_manager import AbstractConversationStateManager + from vocode.streaming.streaming_conversation import StreamingConversation AGENT_TRACE_NAME = "agent" POST_QUESTION_BACKCHANNELS = [ @@ -134,6 +131,7 @@ class StreamedResponse(GeneratedResponse): AgentConfigType = TypeVar("AgentConfigType", bound=AgentConfig) +# TODO: consolidate BaseAgent and AbstractAgent class AbstractAgent(Generic[AgentConfigType]): def __init__(self, agent_config: AgentConfigType): self.agent_config = agent_config @@ -156,9 +154,13 @@ def get_cut_off_response(self) -> str: return random.choice(on_cut_off_messages).text -class BaseAgent(AbstractAgent[AgentConfigType], InterruptibleWorker): +class BaseAgent( + AbstractAgent[AgentConfigType], + InterruptibleWorker[InterruptibleEvent[AgentInput | ActionResponse]], +): agent_responses_consumer: AbstractWorker[InterruptibleAgentResponseEvent[AgentResponse]] actions_consumer: Optional[AbstractWorker[InterruptibleEvent[ActionInput]]] + streaming_conversation: "StreamingConversation" def __init__( self, @@ -166,7 +168,6 @@ def __init__( action_factory: AbstractActionFactory = DefaultActionFactory(), interruptible_event_factory: InterruptibleEventFactory = InterruptibleEventFactory(), ): - self.input_queue: asyncio.Queue[InterruptibleEvent[AgentInput]] = asyncio.Queue() AbstractAgent.__init__(self, agent_config=agent_config) InterruptibleWorker.__init__( self, @@ -188,12 +189,6 @@ def get_functions(self): def attach_transcript(self, transcript: Transcript): self.transcript = transcript - def attach_conversation_state_manager( - self, - conversation_state_manager: AbstractConversationStateManager, - ): - self.conversation_state_manager = conversation_state_manager - def attach_speed_manager(self, speed_manager: SpeedManager): self.speed_manager = speed_manager @@ -206,11 +201,6 @@ def _get_speed_adjusted_silence_seconds(self, seconds: float) -> float: def set_interruptible_event_factory(self, factory: InterruptibleEventFactory): self.interruptible_event_factory = factory - def get_input_queue( - self, - ) -> asyncio.Queue[InterruptibleEvent[AgentInput]]: - return self.input_queue - def is_first_response(self): assert self.transcript is not None @@ -249,7 +239,7 @@ async def _maybe_prepend_interrupt_responses( async def handle_generate_response( self, transcription: Transcription, - agent_input: AgentInput, + agent_input: AgentInput | ActionResponse, ) -> bool: conversation_id = agent_input.conversation_id responses = self._maybe_prepend_interrupt_responses( @@ -293,7 +283,9 @@ async def handle_generate_response( function_call = generated_response.message continue - agent_response_tracker = agent_input.agent_response_tracker or asyncio.Event() + agent_response_tracker = ( + agent_input.agent_response_tracker if isinstance(agent_input, AgentInput) else None + ) or asyncio.Event() self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( AgentResponseMessage( @@ -325,8 +317,8 @@ async def handle_generate_response( # if the client (the implemented agent) doesn't create an EndOfTurn, then we need to create one if not end_of_turn_agent_response_tracker: end_of_turn_agent_response_tracker = ( - agent_input.agent_response_tracker or asyncio.Event() - ) + agent_input.agent_response_tracker if isinstance(agent_input, AgentInput) else None + ) or asyncio.Event() self.agent_responses_consumer.consume_nonblocking( self.interruptible_event_factory.create_interruptible_agent_response_event( AgentResponseMessage( @@ -389,7 +381,7 @@ async def handle_respond(self, transcription: Transcription, conversation_id: st logger.debug("No response generated") return False - async def process(self, item: InterruptibleEvent[AgentInput]): + async def process(self, item: InterruptibleEvent[AgentInput | ActionResponse]): assert self.transcript is not None try: agent_input = item.payload @@ -399,7 +391,7 @@ async def process(self, item: InterruptibleEvent[AgentInput]): text=transcription.message, conversation_id=agent_input.conversation_id, ) - elif isinstance(agent_input, ActionResultAgentInput): + elif isinstance(agent_input, ActionResponse): self.transcript.add_action_finish_log( action_input=agent_input.action_input, action_output=agent_input.action_output, @@ -479,7 +471,9 @@ def _get_action_config(self, function_name: str) -> Optional[ActionConfig]: return action_config return None - async def call_function(self, function_call: FunctionCall, agent_input: AgentInput): + async def call_function( + self, function_call: FunctionCall, agent_input: AgentInput | ActionResponse + ): action_config = self._get_action_config(function_call.name) if action_config is None: logger.error(f"Function {function_call.name} not found in agent config, skipping") @@ -511,39 +505,15 @@ async def call_function(self, function_call: FunctionCall, agent_input: AgentInp def create_action_input( self, action: BaseAction, - agent_input: AgentInput, + agent_input: AgentInput | ActionResponse, params: Dict, user_message_tracker: Optional[asyncio.Event] = None, ) -> ActionInput: - action_input: ActionInput - if isinstance(action, VonagePhoneConversationAction): - assert isinstance( - self.conversation_state_manager, - VonagePhoneConversationStateManager, - ), "Cannot use a VonagePhoneConversationAction unless the attached conversation is a VonagePhoneConversation" - action_input = action.create_phone_conversation_action_input( - conversation_id=agent_input.conversation_id, - params=params, - vonage_uuid=self.conversation_state_manager.get_vonage_uuid(), - user_message_tracker=user_message_tracker, - ) - elif isinstance(action, TwilioPhoneConversationAction): - assert isinstance( - self.conversation_state_manager, TwilioPhoneConversationStateManager - ), "Cannot use a TwilioPhoneConversationAction unless the attached conversation is a TwilioPhoneConversation" - action_input = action.create_phone_conversation_action_input( - conversation_id=agent_input.conversation_id, - params=params, - twilio_sid=self.conversation_state_manager.get_twilio_sid(), - user_message_tracker=user_message_tracker, - ) - else: - action_input = action.create_action_input( - agent_input.conversation_id, - params, - user_message_tracker, - ) - return action_input + return action.create_action_input( + agent_input.conversation_id, + params, + user_message_tracker, + ) def enqueue_action_input( self, diff --git a/vocode/streaming/agent/chat_gpt_agent.py b/vocode/streaming/agent/chat_gpt_agent.py index 047aa6c90e..69eff4664a 100644 --- a/vocode/streaming/agent/chat_gpt_agent.py +++ b/vocode/streaming/agent/chat_gpt_agent.py @@ -272,7 +272,7 @@ async def generate_response( response_generator = collate_response_async using_input_streaming_synthesizer = ( - self.conversation_state_manager.using_input_streaming_synthesizer() + self.streaming_conversation.using_input_streaming_synthesizer() ) if using_input_streaming_synthesizer: response_generator = stream_response_async diff --git a/vocode/streaming/agent/groq_agent.py b/vocode/streaming/agent/groq_agent.py index bd4a170396..7e69877f06 100644 --- a/vocode/streaming/agent/groq_agent.py +++ b/vocode/streaming/agent/groq_agent.py @@ -205,7 +205,7 @@ async def generate_response( response_generator = collate_response_async using_input_streaming_synthesizer = ( - self.conversation_state_manager.using_input_streaming_synthesizer() + self.streaming_conversation.using_input_streaming_synthesizer() ) if using_input_streaming_synthesizer: response_generator = stream_response_async diff --git a/vocode/streaming/agent/langchain_agent.py b/vocode/streaming/agent/langchain_agent.py index 9e41ea03e8..f0f581ad8b 100644 --- a/vocode/streaming/agent/langchain_agent.py +++ b/vocode/streaming/agent/langchain_agent.py @@ -113,7 +113,7 @@ async def generate_response( response_generator = collate_response_async using_input_streaming_synthesizer = ( - self.conversation_state_manager.using_input_streaming_synthesizer() + self.streaming_conversation.using_input_streaming_synthesizer() ) if using_input_streaming_synthesizer: response_generator = stream_response_async diff --git a/vocode/streaming/agent/websocket_user_implemented_agent.py b/vocode/streaming/agent/websocket_user_implemented_agent.py index ae9096b25c..186d173eae 100644 --- a/vocode/streaming/agent/websocket_user_implemented_agent.py +++ b/vocode/streaming/agent/websocket_user_implemented_agent.py @@ -26,7 +26,7 @@ class WebSocketUserImplementedAgent(BaseAgent[WebSocketUserImplementedAgentConfig]): - input_queue: asyncio.Queue[InterruptibleEvent[AgentInput]] + input_queue: asyncio.Queue[InterruptibleEvent[AgentInput | AgentResponse]] def __init__( self, diff --git a/vocode/streaming/models/actions.py b/vocode/streaming/models/actions.py index 47b7e454cb..39ffe058f2 100644 --- a/vocode/streaming/models/actions.py +++ b/vocode/streaming/models/actions.py @@ -126,4 +126,11 @@ class ActionOutput(BaseModel, Generic[ResponseType]): response: ResponseType +class ActionResponse(BaseModel, Generic[ParametersType, ResponseType]): + conversation_id: str + action_input: ActionInput + action_output: ActionOutput + is_quiet: bool = False + + ExternalActionProcessingMode = Literal["muted"] diff --git a/vocode/streaming/models/pipeline.py b/vocode/streaming/models/pipeline.py index c8445eda13..ffab9c4173 100644 --- a/vocode/streaming/models/pipeline.py +++ b/vocode/streaming/models/pipeline.py @@ -19,3 +19,4 @@ class StreamingConversationConfig(TypedModel, type=PipelineType.STREAMING_CONVER transcriber_config: TranscriberConfig agent_config: AgentConfig synthesizer_config: SynthesizerConfig + speed_coefficient: float = 1.0 diff --git a/vocode/streaming/pipeline/abstract_pipeline_factory.py b/vocode/streaming/pipeline/abstract_pipeline_factory.py index 88b81ef849..bd92587aa1 100644 --- a/vocode/streaming/pipeline/abstract_pipeline_factory.py +++ b/vocode/streaming/pipeline/abstract_pipeline_factory.py @@ -1,22 +1,24 @@ from abc import ABC, abstractmethod from typing import Generic, Optional, TypeVar +from vocode.streaming.action.worker import ActionsWorker from vocode.streaming.models.model import BaseModel from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice -from vocode.streaming.pipeline.audio_pipeline import AudioPipeline +from vocode.streaming.pipeline.audio_pipeline import AudioPipeline, OutputDeviceType from vocode.streaming.utils.events_manager import EventsManager PipelineConfigType = TypeVar("PipelineConfigType", bound=BaseModel) -class AbstractPipelineFactory(Generic[PipelineConfigType], ABC): +class AbstractPipelineFactory(Generic[PipelineConfigType, OutputDeviceType], ABC): @abstractmethod def create_pipeline( self, config: PipelineConfigType, - output_device: AbstractOutputDevice, + output_device: OutputDeviceType, id: Optional[str] = None, events_manager: Optional[EventsManager] = None, - ) -> AudioPipeline: + actions_worker: Optional[ActionsWorker] = None, + ) -> AudioPipeline[OutputDeviceType]: raise NotImplementedError diff --git a/vocode/streaming/pipeline/audio_pipeline.py b/vocode/streaming/pipeline/audio_pipeline.py index 5b3e548fbe..803b60bb8b 100644 --- a/vocode/streaming/pipeline/audio_pipeline.py +++ b/vocode/streaming/pipeline/audio_pipeline.py @@ -1,6 +1,7 @@ from abc import abstractmethod -from typing import Generic, TypeVar +from typing import Generic, Optional, TypeVar +from vocode.streaming.action.worker import ActionsWorker from vocode.streaming.output_device.abstract_output_device import AbstractOutputDevice from vocode.streaming.pipeline.worker import AbstractWorker from vocode.streaming.utils.events_manager import EventsManager @@ -11,6 +12,7 @@ class AudioPipeline(AbstractWorker[bytes], Generic[OutputDeviceType]): output_device: OutputDeviceType events_manager: EventsManager + actions_worker: Optional[ActionsWorker] id: str def receive_audio(self, chunk: bytes): diff --git a/vocode/streaming/pipeline/worker.py b/vocode/streaming/pipeline/worker.py index 06b069ab18..c707250ae1 100644 --- a/vocode/streaming/pipeline/worker.py +++ b/vocode/streaming/pipeline/worker.py @@ -171,16 +171,16 @@ def interrupt(self) -> bool: class InterruptibleEventFactory: def create_interruptible_event( - self, payload: Any, is_interruptible: bool = True - ) -> InterruptibleEvent: + self, payload: Payload, is_interruptible: bool = True + ) -> InterruptibleEvent[Payload]: return InterruptibleEvent(payload, is_interruptible=is_interruptible) def create_interruptible_agent_response_event( self, - payload: Any, + payload: Payload, is_interruptible: bool = True, agent_response_tracker: Optional[asyncio.Event] = None, - ) -> InterruptibleAgentResponseEvent: + ) -> InterruptibleAgentResponseEvent[Payload]: return InterruptibleAgentResponseEvent( payload, is_interruptible=is_interruptible, diff --git a/vocode/streaming/streaming_conversation.py b/vocode/streaming/streaming_conversation.py index 67291009e2..8fcaf95832 100644 --- a/vocode/streaming/streaming_conversation.py +++ b/vocode/streaming/streaming_conversation.py @@ -25,6 +25,9 @@ from sentry_sdk.tracing import Span from vocode import conversation_id as ctx_conversation_id +from vocode.streaming.action.abstract_factory import AbstractActionFactory +from vocode.streaming.action.base_action import BaseAction +from vocode.streaming.action.streaming_conversation_action import StreamingConversationAction from vocode.streaming.action.worker import ActionsWorker from vocode.streaming.agent.abstract_factory import AbstractAgentFactory from vocode.streaming.agent.base_agent import ( @@ -85,7 +88,6 @@ from vocode.streaming.utils.create_task import asyncio_create_task_with_done_error_log from vocode.streaming.utils.events_manager import EventsManager from vocode.streaming.utils.speed_manager import SpeedManager -from vocode.streaming.utils.state_manager import ConversationStateManager from vocode.utils.sentry_utils import ( CustomSentrySpans, complete_span_by_op, @@ -119,35 +121,6 @@ LOW_INTERRUPT_SENSITIVITY_BACKCHANNEL_UTTERANCE_LENGTH_THRESHOLD = 3 -class StreamingConversationFactory(AbstractPipelineFactory[StreamingConversationConfig]): - - def __init__( - self, - transcriber_factory: AbstractTranscriberFactory = DefaultTranscriberFactory(), - agent_factory: AbstractAgentFactory = DefaultAgentFactory(), - synthesizer_factory: AbstractSynthesizerFactory = DefaultSynthesizerFactory(), - ): - self.transcriber_factory = transcriber_factory - self.agent_factory = agent_factory - self.synthesizer_factory = synthesizer_factory - - def create_pipeline( - self, - config: StreamingConversationConfig, - output_device: OutputDeviceType, - id: Optional[str] = None, - events_manager: Optional[EventsManager] = None, - ): - return StreamingConversation( - output_device=output_device, - transcriber=self.transcriber_factory.create_transcriber(config.transcriber_config), - agent=self.agent_factory.create_agent(config.agent_config), - synthesizer=self.synthesizer_factory.create_synthesizer(config.synthesizer_config), - conversation_id=id, - events_manager=events_manager, - ) - - class StreamingConversation(AudioPipeline[OutputDeviceType]): class QueueingInterruptibleEventFactory(InterruptibleEventFactory): def __init__(self, conversation: "StreamingConversation"): @@ -183,7 +156,7 @@ class TranscriptionsWorker(AsyncQueueWorker[Transcription]): """Processes all transcriptions: sends an interrupt if needed and sends final transcriptions to the output queue""" - consumer: AbstractWorker[InterruptibleEvent[Transcription]] + consumer: AbstractWorker[InterruptibleEvent[TranscriptionAgentInput]] def __init__( self, @@ -626,6 +599,7 @@ def __init__( synthesizer: BaseSynthesizer, speed_coefficient: float = 1.0, conversation_id: Optional[str] = None, + actions_worker: Optional[ActionsWorker] = None, events_manager: Optional[EventsManager] = None, ): self.id = conversation_id or create_conversation_id() @@ -644,7 +618,6 @@ def __init__( Tuple[Union[BaseMessage, EndOfTurn], Optional[SynthesisResult]] ] ] = asyncio.Queue() - self.state_manager = self.create_state_manager() # Transcriptions Worker self.transcriptions_worker = self.TranscriptionsWorker( @@ -656,7 +629,7 @@ def __init__( # Agent self.transcriptions_worker.consumer = self.agent self.agent.set_interruptible_event_factory(self.interruptible_event_factory) - self.agent.attach_conversation_state_manager(self.state_manager) + self.agent.streaming_conversation = self # Agent Responses Worker self.agent_responses_worker = self.AgentResponsesWorker( @@ -668,11 +641,11 @@ def __init__( # Actions Worker self.actions_worker = None if self.agent.get_agent_config().actions: - self.actions_worker = ActionsWorker( + self.actions_worker = actions_worker or ActionsWorker( action_factory=self.agent.action_factory, interruptible_event_factory=self.interruptible_event_factory, ) - self.actions_worker.attach_conversation_state_manager(self.state_manager) + self.actions_worker.pipeline = self self.actions_worker.consumer = self.agent self.agent.actions_consumer = self.actions_worker @@ -719,9 +692,6 @@ def __init__( self.interrupt_lock = asyncio.Lock() - def create_state_manager(self) -> ConversationStateManager: - return ConversationStateManager(conversation=self) - async def start(self, mark_ready: Optional[Callable[[], Awaitable[None]]] = None): self.transcriber.start() self.transcriptions_worker.start() @@ -1026,6 +996,12 @@ def _on_interrupt(): synthesis_result.synthesis_total_span.finish() return message_sent, cut_off + def using_input_streaming_synthesizer(self): + return isinstance( + self.synthesizer, + InputStreamingSynthesizer, + ) + def mark_terminated(self, bot_disconnect: bool = False): self.is_terminated.set() @@ -1078,3 +1054,38 @@ def is_active(self): async def wait_for_termination(self): await self.is_terminated.wait() + + +class StreamingConversationFactory( + AbstractPipelineFactory[StreamingConversationConfig, OutputDeviceType] +): + + def __init__( + self, + transcriber_factory: AbstractTranscriberFactory = DefaultTranscriberFactory(), + agent_factory: AbstractAgentFactory = DefaultAgentFactory(), + synthesizer_factory: AbstractSynthesizerFactory = DefaultSynthesizerFactory(), + speed_coefficient: float = 1.0, + ): + self.transcriber_factory = transcriber_factory + self.agent_factory = agent_factory + self.synthesizer_factory = synthesizer_factory + + def create_pipeline( + self, + config: StreamingConversationConfig, + output_device: OutputDeviceType, + id: Optional[str] = None, + events_manager: Optional[EventsManager] = None, + actions_worker: Optional[ActionsWorker] = None, + ): + return StreamingConversation( + output_device=output_device, + transcriber=self.transcriber_factory.create_transcriber(config.transcriber_config), + agent=self.agent_factory.create_agent(config.agent_config), + synthesizer=self.synthesizer_factory.create_synthesizer(config.synthesizer_config), + conversation_id=id, + events_manager=events_manager, + actions_worker=actions_worker, + speed_coefficient=config.speed_coefficient, + ) diff --git a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py index 50859b9ee1..5304a4fb39 100644 --- a/vocode/streaming/telephony/conversation/abstract_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/abstract_phone_conversation.py @@ -4,21 +4,12 @@ from fastapi import WebSocket from loguru import logger -from vocode.streaming.agent.abstract_factory import AbstractAgentFactory -from vocode.streaming.models.agent import AgentConfig from vocode.streaming.models.events import PhoneCallEndedEvent -from vocode.streaming.models.synthesizer import SynthesizerConfig from vocode.streaming.models.telephony import PhoneCallDirection -from vocode.streaming.models.transcriber import TranscriberConfig from vocode.streaming.output_device.twilio_output_device import TwilioOutputDevice from vocode.streaming.output_device.vonage_output_device import VonageOutputDevice from vocode.streaming.pipeline.audio_pipeline import AudioPipeline -from vocode.streaming.synthesizer.abstract_factory import AbstractSynthesizerFactory from vocode.streaming.telephony.config_manager.base_config_manager import BaseConfigManager -from vocode.streaming.transcriber.abstract_factory import AbstractTranscriberFactory -from vocode.streaming.utils import create_conversation_id -from vocode.streaming.utils.events_manager import EventsManager -from vocode.streaming.utils.state_manager import PhoneConversationStateManager TelephonyOutputDeviceType = TypeVar( "TelephonyOutputDeviceType", bound=Union[TwilioOutputDevice, VonageOutputDevice] @@ -29,9 +20,7 @@ TelephonyProvider = Literal["twilio", "vonage"] -class AbstractPhoneConversation( - AudioPipeline[TelephonyOutputDeviceType, PhoneConversationStateManager] -): +class AbstractPhoneConversation(Generic[TelephonyOutputDeviceType]): telephony_provider: TelephonyProvider def __init__( diff --git a/vocode/streaming/telephony/conversation/twilio_phone_conversation.py b/vocode/streaming/telephony/conversation/twilio_phone_conversation.py index bc3c67f0e6..a8124d5684 100644 --- a/vocode/streaming/telephony/conversation/twilio_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/twilio_phone_conversation.py @@ -7,25 +7,40 @@ from fastapi import WebSocket from loguru import logger +from vocode.streaming.action.abstract_factory import AbstractActionFactory +from vocode.streaming.action.base_action import BaseAction +from vocode.streaming.action.default_factory import DefaultTwilioPhoneConversationActionFactory +from vocode.streaming.action.phone_call_action import TwilioPhoneConversationAction +from vocode.streaming.action.worker import ActionsWorker from vocode.streaming.models.events import PhoneCallConnectedEvent +from vocode.streaming.models.model import BaseModel from vocode.streaming.models.telephony import PhoneCallDirection, TwilioConfig from vocode.streaming.output_device.twilio_output_device import ( ChunkFinishedMarkMessage, TwilioOutputDevice, ) -from vocode.streaming.pipeline.audio_pipeline import AudioPipeline +from vocode.streaming.pipeline.abstract_pipeline_factory import AbstractPipelineFactory from vocode.streaming.telephony.client.twilio_client import TwilioClient from vocode.streaming.telephony.config_manager.base_config_manager import BaseConfigManager from vocode.streaming.telephony.conversation.abstract_phone_conversation import ( AbstractPhoneConversation, ) -from vocode.streaming.utils.state_manager import TwilioPhoneConversationStateManager +from vocode.streaming.utils.events_manager import EventsManager class TwilioPhoneConversationWebsocketAction(Enum): CLOSE_WEBSOCKET = 1 +class TwilioPhoneConversationActionsWorker(ActionsWorker): + twilio_phone_conversation: "TwilioPhoneConversation" + + def attach_state(self, action: BaseAction): + super().attach_state(action) + if isinstance(action, TwilioPhoneConversationAction): + action.twilio_phone_conversation = self.twilio_phone_conversation + + class TwilioPhoneConversation(AbstractPhoneConversation[TwilioOutputDevice]): telephony_provider = "twilio" @@ -36,12 +51,28 @@ def __init__( to_phone: str, base_url: str, config_manager: BaseConfigManager, - pipeline: AudioPipeline[TwilioOutputDevice], + pipeline_factory: AbstractPipelineFactory[BaseModel, TwilioOutputDevice], + pipeline_config: BaseModel, twilio_sid: str, twilio_config: Optional[TwilioConfig] = None, + id: Optional[str] = None, + action_factory: Optional[AbstractActionFactory] = None, + events_manager: Optional[EventsManager] = None, record_call: bool = False, noise_suppression: bool = False, # is currently a no-op ): + actions_worker = TwilioPhoneConversationActionsWorker( + action_factory=action_factory or DefaultTwilioPhoneConversationActionFactory() + ) + pipeline = pipeline_factory.create_pipeline( + config=pipeline_config, + output_device=TwilioOutputDevice(), + id=id, + events_manager=events_manager, + actions_worker=actions_worker, + ) + actions_worker.twilio_phone_conversation = self + super().__init__( direction=direction, from_phone=from_phone, @@ -61,9 +92,6 @@ def __init__( self.twilio_sid = twilio_sid self.record_call = record_call - def create_state_manager(self) -> TwilioPhoneConversationStateManager: - return TwilioPhoneConversationStateManager(self) - async def attach_ws_and_start(self, ws: WebSocket): super().attach_ws(ws) @@ -114,3 +142,9 @@ async def _handle_ws_message(self, message) -> Optional[TwilioPhoneConversationW logger.debug("Stopping...") return TwilioPhoneConversationWebsocketAction.CLOSE_WEBSOCKET return None + + def create_twilio_client(self): + return TwilioClient( + base_url=self.base_url, + maybe_twilio_config=self.twilio_config, + ) diff --git a/vocode/streaming/telephony/conversation/vonage_phone_conversation.py b/vocode/streaming/telephony/conversation/vonage_phone_conversation.py index dafdf3e3b0..faa3cac0b5 100644 --- a/vocode/streaming/telephony/conversation/vonage_phone_conversation.py +++ b/vocode/streaming/telephony/conversation/vonage_phone_conversation.py @@ -1,23 +1,39 @@ import os +from typing import Optional import numpy as np from fastapi import WebSocket, WebSocketDisconnect from loguru import logger +from vocode.streaming.action.abstract_factory import AbstractActionFactory +from vocode.streaming.action.base_action import BaseAction +from vocode.streaming.action.default_factory import DefaultVonagePhoneConversationActionFactory +from vocode.streaming.action.phone_call_action import VonagePhoneConversationAction +from vocode.streaming.action.worker import ActionsWorker from vocode.streaming.models.events import PhoneCallConnectedEvent +from vocode.streaming.models.model import BaseModel from vocode.streaming.models.telephony import PhoneCallDirection, VonageConfig from vocode.streaming.output_device.vonage_output_device import VonageOutputDevice -from vocode.streaming.pipeline.audio_pipeline import AudioPipeline +from vocode.streaming.pipeline.abstract_pipeline_factory import AbstractPipelineFactory from vocode.streaming.telephony.client.vonage_client import VonageClient from vocode.streaming.telephony.config_manager.base_config_manager import BaseConfigManager from vocode.streaming.telephony.conversation.abstract_phone_conversation import ( AbstractPhoneConversation, ) -from vocode.streaming.utils.state_manager import VonagePhoneConversationStateManager +from vocode.streaming.utils.events_manager import EventsManager KOALA_CHUNK_SIZE = 512 # 16 bit samples, size 256 +class VonagePhoneConversationActionsWorker(ActionsWorker): + vonage_phone_conversation: "VonagePhoneConversation" + + def attach_state(self, action: BaseAction): + super().attach_state(action) + if isinstance(action, VonagePhoneConversationAction): + action.vonage_phone_conversation = self.vonage_phone_conversation + + class VonagePhoneConversation(AbstractPhoneConversation[VonageOutputDevice]): telephony_provider = "vonage" @@ -28,11 +44,27 @@ def __init__( to_phone: str, base_url: str, config_manager: BaseConfigManager, - pipeline: AudioPipeline[VonageOutputDevice], + pipeline_factory: AbstractPipelineFactory[BaseModel, VonageOutputDevice], + pipeline_config: BaseModel, vonage_uuid: str, vonage_config: VonageConfig, + id: Optional[str] = None, + action_factory: Optional[AbstractActionFactory] = None, + events_manager: Optional[EventsManager] = None, noise_suppression: bool = False, ): + actions_worker = VonagePhoneConversationActionsWorker( + action_factory=action_factory or DefaultVonagePhoneConversationActionFactory() + ) + pipeline = pipeline_factory.create_pipeline( + config=pipeline_config, + output_device=VonageOutputDevice(), + id=id, + events_manager=events_manager, + actions_worker=actions_worker, + ) + actions_worker.vonage_phone_conversation = self + super().__init__( direction=direction, from_phone=from_phone, @@ -57,9 +89,6 @@ def __init__( access_key=os.environ["KOALA_ACCESS_KEY"], ) - def create_state_manager(self) -> VonagePhoneConversationStateManager: - return VonagePhoneConversationStateManager(self) - async def attach_ws_and_start(self, ws: WebSocket): # start message await ws.receive() @@ -110,3 +139,9 @@ def receive_audio(self, chunk: bytes): self.buffer = self.buffer[KOALA_CHUNK_SIZE:] else: self.pipeline.receive_audio(chunk) + + def create_vonage_client(self): + return VonageClient( + base_url=self.base_url, + maybe_vonage_config=self.vonage_config, + ) diff --git a/vocode/streaming/telephony/server/router/calls.py b/vocode/streaming/telephony/server/router/calls.py index f6ee221a7f..fa2148ff01 100644 --- a/vocode/streaming/telephony/server/router/calls.py +++ b/vocode/streaming/telephony/server/router/calls.py @@ -65,12 +65,10 @@ def _from_call_config( twilio_config=call_config.twilio_config, twilio_sid=call_config.twilio_sid, direction=call_config.direction, - pipeline=pipeline_factory.create_pipeline( - config=call_config.pipeline_config, - output_device=TwilioOutputDevice(), - id=conversation_id, - events_manager=events_manager, - ), + pipeline_factory=pipeline_factory, + pipeline_config=call_config.pipeline_config, + id=conversation_id, + events_manager=events_manager, ) elif isinstance(call_config, VonageCallConfig): return VonagePhoneConversation( @@ -81,14 +79,10 @@ def _from_call_config( vonage_config=call_config.vonage_config, vonage_uuid=call_config.vonage_uuid, direction=call_config.direction, - pipeline=pipeline_factory.create_pipeline( - config=call_config.pipeline_config, - output_device=VonageOutputDevice( - output_to_speaker=call_config.output_to_speaker - ), - id=conversation_id, - events_manager=events_manager, - ), + pipeline_factory=pipeline_factory, + pipeline_config=call_config.pipeline_config, + id=conversation_id, + events_manager=events_manager, ) else: raise ValueError(f"Unknown call config type {call_config.type}") diff --git a/vocode/streaming/utils/state_manager.py b/vocode/streaming/utils/state_manager.py deleted file mode 100644 index bd69952152..0000000000 --- a/vocode/streaming/utils/state_manager.py +++ /dev/null @@ -1,170 +0,0 @@ -from typing import TYPE_CHECKING, Optional - -from vocode.streaming.models.transcriber import EndpointingConfig -from vocode.streaming.synthesizer.input_streaming_synthesizer import InputStreamingSynthesizer -from vocode.streaming.telephony.client.twilio_client import TwilioClient -from vocode.streaming.telephony.client.vonage_client import VonageClient -from vocode.streaming.utils.redis_conversation_message_queue import RedisConversationMessageQueue - -if TYPE_CHECKING: - from vocode.streaming.streaming_conversation import StreamingConversation - from vocode.streaming.telephony.conversation.abstract_phone_conversation import ( - AbstractPhoneConversation, - ) - from vocode.streaming.telephony.conversation.twilio_phone_conversation import ( - TwilioPhoneConversation, - ) - from vocode.streaming.telephony.conversation.vonage_phone_conversation import ( - VonagePhoneConversation, - ) - - -# TODO: make this a proper ABC -class AbstractConversationStateManager: - @property - def logger(self): - raise NotImplementedError - - @property - def transcript(self): - raise NotImplementedError - - def get_transcriber_endpointing_config(self) -> Optional[EndpointingConfig]: - raise NotImplementedError - - def set_transcriber_endpointing_config(self, endpointing_config: EndpointingConfig): - raise NotImplementedError - - def disable_synthesis(self): - raise NotImplementedError - - def enable_synthesis(self): - raise NotImplementedError - - def mute_agent(self): - raise NotImplementedError - - def unmute_agent(self): - raise NotImplementedError - - def using_input_streaming_synthesizer(self): - raise NotImplementedError - - async def terminate_conversation(self): - raise NotImplementedError - - def get_conversation_id(self): - raise NotImplementedError - - -class AbstractPhoneConversationStateManager(AbstractConversationStateManager): - def get_config_manager(self): - raise NotImplementedError - - def get_to_phone(self): - raise NotImplementedError - - def get_from_phone(self): - raise NotImplementedError - - -class ConversationStateManager(AbstractConversationStateManager): - def __init__(self, conversation: "StreamingConversation"): - self._conversation = conversation - if not hasattr(self, "redis_message_queue"): - self.redis_message_queue = RedisConversationMessageQueue() - - @property - def transcript(self): - return self._conversation.transcript - - def get_transcriber_endpointing_config(self) -> Optional[EndpointingConfig]: - return self._conversation.transcriber.get_transcriber_config().endpointing_config - - def set_transcriber_endpointing_config(self, endpointing_config: EndpointingConfig): - assert self.get_transcriber_endpointing_config() is not None - self._conversation.transcriber.get_transcriber_config().endpointing_config = ( - endpointing_config - ) - - def disable_synthesis(self): - self._conversation.synthesis_enabled = False - - def enable_synthesis(self): - self._conversation.synthesis_enabled = True - - def mute_agent(self): - self._conversation.agent.is_muted = True - - def unmute_agent(self): - self._conversation.agent.is_muted = False - - def using_input_streaming_synthesizer(self): - return isinstance( - self._conversation.synthesizer, - InputStreamingSynthesizer, - ) - - async def terminate_conversation(self): - self._conversation.mark_terminated() - - def set_call_check_for_idle_paused(self, value: bool): - if not self._conversation: - return - self._conversation.set_check_for_idle_paused(value) - - def get_conversation_id(self): - return self._conversation.id - - -class PhoneConversationStateManager( - AbstractPhoneConversationStateManager, ConversationStateManager -): - def __init__(self, conversation: "AbstractPhoneConversation"): - ConversationStateManager.__init__(self, conversation.pipeline) - self._phone_conversation = conversation - - def get_config_manager(self): - return self._phone_conversation.config_manager - - def get_to_phone(self): - return self._phone_conversation.to_phone - - def get_from_phone(self): - return self._phone_conversation.from_phone - - def get_direction(self): - return self._phone_conversation.direction - - -class VonagePhoneConversationStateManager(PhoneConversationStateManager): - def __init__(self, conversation: "VonagePhoneConversation"): - super().__init__(conversation=conversation) - self._vonage_phone_conversation = conversation - - def create_vonage_client(self): - return VonageClient( - base_url=self._vonage_phone_conversation.base_url, - maybe_vonage_config=self._vonage_phone_conversation.vonage_config, - ) - - def get_vonage_uuid(self): - return self._vonage_phone_conversation.vonage_uuid - - -class TwilioPhoneConversationStateManager(PhoneConversationStateManager): - def __init__(self, conversation: "TwilioPhoneConversation"): - super().__init__(conversation=conversation) - self._twilio_phone_conversation = conversation - - def get_twilio_config(self): - return self._twilio_phone_conversation.twilio_config - - def create_twilio_client(self): - return TwilioClient( - base_url=self._twilio_phone_conversation.base_url, - maybe_twilio_config=self.get_twilio_config(), - ) - - def get_twilio_sid(self): - return self._twilio_phone_conversation.twilio_sid