From 0a1acba6577bea8626cda22af1a7c9caecb83d73 Mon Sep 17 00:00:00 2001 From: Ethan Sutin Date: Fri, 1 Mar 2024 10:04:16 -0800 Subject: [PATCH 1/4] Add person and voice sample model --- ...6aff0a993d7_add_person_and_voicesamples.py | 55 +++++++++++++++++++ owl/models/schemas.py | 17 +++++- 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 alembic/versions/b6aff0a993d7_add_person_and_voicesamples.py diff --git a/alembic/versions/b6aff0a993d7_add_person_and_voicesamples.py b/alembic/versions/b6aff0a993d7_add_person_and_voicesamples.py new file mode 100644 index 00000000..3b1e5f29 --- /dev/null +++ b/alembic/versions/b6aff0a993d7_add_person_and_voicesamples.py @@ -0,0 +1,55 @@ +"""Add person and voicesamples + +Revision ID: b6aff0a993d7 +Revises: 33bddba74d25 +Create Date: 2024-03-01 08:56:55.205553 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel + + +# revision identifiers, used by Alembic. +revision: str = 'b6aff0a993d7' +down_revision: Union[str, None] = '33bddba74d25' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Use batch operations to support SQLite ALTER TABLE for adding constraints + with op.batch_alter_table('utterance', schema=None) as batch_op: + batch_op.add_column(sa.Column('person_id', sa.Integer(), nullable=True)) + batch_op.create_foreign_key('fk_utterance_person', 'person', ['person_id'], ['id']) + + op.create_table('person', + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('first_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('last_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('voicesample', + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('filepath', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('speaker_embeddings', sa.JSON(), nullable=True), + sa.Column('person_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['person_id'], ['person.id'], name='fk_voicesample_person'), + sa.PrimaryKeyConstraint('id') + ) + +def downgrade() -> None: + # Use batch operations for dropping column with SQLite + with op.batch_alter_table('utterance', schema=None) as batch_op: + batch_op.drop_constraint('fk_utterance_person', type_='foreignkey') + batch_op.drop_column('person_id') + + # Commands for dropping tables remain unchanged + op.drop_table('voicesample') + op.drop_table('person') \ No newline at end of file diff --git a/owl/models/schemas.py b/owl/models/schemas.py index eee40c42..b9ad7e8a 100644 --- a/owl/models/schemas.py +++ b/owl/models/schemas.py @@ -1,5 +1,5 @@ from typing import List, Optional -from sqlmodel import SQLModel, Field, Relationship +from sqlmodel import SQLModel, Field, Relationship, Column, JSON from datetime import datetime, timezone from pydantic import BaseModel from enum import Enum @@ -36,6 +36,8 @@ class Utterance(CreatedAtMixin, table=True): transcription: "Transcription" = Relationship(back_populates="utterances") words: List[Word] = Relationship(back_populates="utterance", sa_relationship_kwargs={"cascade": "all, delete-orphan"}) + person_id: Optional[int] = Field(default=None, foreign_key="person.id") + person: Optional["Person"] = Relationship(back_populates="utterances") class Transcription(CreatedAtMixin, table=True): id: Optional[int] = Field(default=None, primary_key=True) @@ -106,6 +108,19 @@ class CaptureSegment(CreatedAtMixin, table=True): conversation: Optional[Conversation] = Relationship(back_populates="capture_segment_file") +class Person(CreatedAtMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + first_name: str + last_name: str + voice_samples: List["VoiceSample"] = Relationship(back_populates="person") + utterances: List[Utterance] = Relationship(back_populates="person") + +class VoiceSample(CreatedAtMixin, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + filepath: str = Field(...) + speaker_embeddings: dict = Field(default={}, sa_column=Column(JSON)) + person_id: Optional[int] = Field(default=None, foreign_key="person.id") + person: Optional["Person"] = Relationship(back_populates="voice_samples") # API Response Models # https://sqlmodel.tiangolo.com/tutorial/fastapi/relationships/#dont-include-all-the-data From 8444ed7becdbfe96a6c2f29b0a924cbabef37a17 Mon Sep 17 00:00:00 2001 From: Ethan Sutin Date: Fri, 1 Mar 2024 10:05:19 -0800 Subject: [PATCH 2/4] Enroll speaker from CLI --- owl/core/cli.py | 37 +++++++++++++++++++++++++++++++++++++ owl/core/config.py | 7 ++++++- owl/database/crud.py | 14 +++++++++++++- owl/sample_config.yaml | 4 ++++ 4 files changed, 60 insertions(+), 2 deletions(-) diff --git a/owl/core/cli.py b/owl/core/cli.py index 12822ce2..4db51e36 100644 --- a/owl/core/cli.py +++ b/owl/core/cli.py @@ -13,6 +13,9 @@ import subprocess from alembic import command from alembic.config import Config +from ..database.database import Database +from ..database.crud import create_person, create_voice_sample +from ..models.schemas import Person, VoiceSample import click from rich.console import Console @@ -202,6 +205,40 @@ def create_migration(config: Configuration, message: str): console.log(f"[bold green]Migration script generated with message: '{message}'") +#################################################################################################### +# Persons +#################################################################################################### + +@cli.command() +@add_options(_config_options) +@click.option('--first-name', required=True, help='First name of the person') +@click.option('--last-name', required=True, help='Last name of the person') +@click.option('--voice-sample-path', required=True, help='Path to the voice sample file') +def enroll_speaker(config: Configuration, first_name: str, last_name: str, voice_sample_path: str): + """Enroll a new person with a voice sample.""" + console = Console() + console.log("[bold green]Enrolling speaker...") + + database = Database(config.database) + with next(database.get_db()) as db: + person = create_person(db, Person(first_name=first_name, last_name=last_name)) + sample_directory = config.speaker_identification.voice_sample_directory + sample_directory = os.path.join(sample_directory, str(person.id)) + os.makedirs(sample_directory, exist_ok=True) + + filename = os.path.basename(voice_sample_path) + extension = os.path.splitext(filename)[1] + + sample_file_path = os.path.join(sample_directory, f"{uuid.uuid1().hex}.{extension[1:]}") + + with next(database.get_db()) as db: + voice_sample = create_voice_sample(db, VoiceSample(person_id=person.id, filepath=sample_file_path)) + with open(voice_sample_path, "rb") as f: + with open(sample_file_path, "wb") as f2: + f2.write(f.read()) + + console.log(f"[bold green]Enrolled new person: '{person.id} ({voice_sample.id})'") + #################################################################################################### # Server #################################################################################################### diff --git a/owl/core/config.py b/owl/core/config.py index 03b9e4d6..9bc94571 100644 --- a/owl/core/config.py +++ b/owl/core/config.py @@ -49,6 +49,10 @@ class StreamingTranscriptionConfiguration(BaseModel): class AsyncTranscriptionConfiguration(BaseModel): provider: str +class SpeakerIdentificationConfiguration(BaseModel): + provider: str + voice_sample_directory: Optional[str] = None + class DatabaseConfiguration(BaseModel): url: str @@ -104,4 +108,5 @@ def load_config_yaml(cls, config_filepath: str) -> 'Configuration': conversation_endpointing: ConversationEndpointingConfiguration notification: NotificationConfiguration udp: UDPConfiguration - bing: BingConfiguration | None = None \ No newline at end of file + bing: BingConfiguration | None = None + speaker_identification: SpeakerIdentificationConfiguration | None = None \ No newline at end of file diff --git a/owl/database/crud.py b/owl/database/crud.py index 60055e52..55e5f088 100644 --- a/owl/database/crud.py +++ b/owl/database/crud.py @@ -1,5 +1,5 @@ from sqlmodel import SQLModel, Session, select -from ..models.schemas import Transcription, Conversation, Utterance, Location, CaptureSegment, Capture, ConversationState +from ..models.schemas import Transcription, Conversation, Utterance, Location, CaptureSegment, Capture, ConversationState, Person, VoiceSample from typing import List, Optional from sqlalchemy.orm import joinedload, selectinload from sqlalchemy import desc, func, or_ @@ -8,6 +8,18 @@ logger = logging.getLogger(__name__) +def create_person(db: Session, person: Person) -> Person: + db.add(person) + db.commit() + db.refresh(person) + return person + +def create_voice_sample(db: Session, voice_sample: VoiceSample) -> VoiceSample: + db.add(voice_sample) + db.commit() + db.refresh(voice_sample) + return voice_sample + def create_utterance(db: Session, utterance: Utterance) -> Utterance: db.add(utterance) db.commit() diff --git a/owl/sample_config.yaml b/owl/sample_config.yaml index 1915375a..eb2093d9 100644 --- a/owl/sample_config.yaml +++ b/owl/sample_config.yaml @@ -86,6 +86,10 @@ udp: host: '0.0.0.0' port: 8001 +speaker_identification: + provider: speech_brain + voice_sample_directory: voice_samples + # To enable web search # bing: # subscription_key: your_bing_subscription_service_key \ No newline at end of file From ff584f60abcef8bd89c40f304a1edde2737d8a69 Mon Sep 17 00:00:00 2001 From: Ethan Sutin Date: Fri, 1 Mar 2024 10:05:33 -0800 Subject: [PATCH 3/4] Speaker identification service --- owl/services/stt/speaker_identification/__init__.py | 0 .../abstract_speaker_identification_service.py | 8 ++++++++ 2 files changed, 8 insertions(+) create mode 100644 owl/services/stt/speaker_identification/__init__.py create mode 100644 owl/services/stt/speaker_identification/abstract_speaker_identification_service.py diff --git a/owl/services/stt/speaker_identification/__init__.py b/owl/services/stt/speaker_identification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/owl/services/stt/speaker_identification/abstract_speaker_identification_service.py b/owl/services/stt/speaker_identification/abstract_speaker_identification_service.py new file mode 100644 index 00000000..713e3ae0 --- /dev/null +++ b/owl/services/stt/speaker_identification/abstract_speaker_identification_service.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod +from ....models.schemas import Transcript + +class AbstractSpeakerIdentificationService(ABC): + + @abstractmethod + async def identifiy_speakers(self, transcript: Transcript, persons) -> Transcript: + pass \ No newline at end of file From 7467cb8b8cac15565d4307593d5c871b8553e8a5 Mon Sep 17 00:00:00 2001 From: Ethan Sutin Date: Fri, 1 Mar 2024 13:41:23 -0800 Subject: [PATCH 4/4] Stub identification service --- clients/ios/Owl/Models/Transcript.swift | 15 +++++++++++++- .../Owl/Views/ConversationDetailView.swift | 2 +- owl/core/config.py | 6 +++++- owl/database/crud.py | 3 +++ owl/models/schemas.py | 11 ++++++++++ owl/sample_config.yaml | 3 +++ owl/server/main.py | 4 +++- .../conversation/conversation_service.py | 14 +++++++++++-- ...abstract_speaker_identification_service.py | 5 +++-- .../speaker_identification_service_factory.py | 20 +++++++++++++++++++ ...ch_brain_speaker_identification_service.py | 17 ++++++++++++++++ 11 files changed, 92 insertions(+), 8 deletions(-) create mode 100644 owl/services/stt/speaker_identification/speaker_identification_service_factory.py create mode 100644 owl/services/stt/speaker_identification/speech_brain_speaker_identification_service.py diff --git a/clients/ios/Owl/Models/Transcript.swift b/clients/ios/Owl/Models/Transcript.swift index 7b21aef8..b45320f9 100644 --- a/clients/ios/Owl/Models/Transcript.swift +++ b/clients/ios/Owl/Models/Transcript.swift @@ -21,7 +21,6 @@ struct CaptureFile: Codable { } } - struct CaptureFileSegment: Codable { var id: Int var filePath: String @@ -34,6 +33,7 @@ struct CaptureFileSegment: Codable { case sourceCapture = "source_capture" } } + struct Transcription: Codable { var id: Int var model: String @@ -53,6 +53,7 @@ struct Utterance: Codable { var end: Double? var text: String? var speaker: String? + var person: Person? // var words: [Word] } @@ -70,3 +71,15 @@ struct Word: Codable { case utteranceId = "utterance_id" } } + +struct Person: Codable { + var id: Int + var firstName: String + var lastName: String? + + enum CodingKeys: String, CodingKey { + case id + case firstName = "first_name" + case lastName = "last_name" + } +} diff --git a/clients/ios/Owl/Views/ConversationDetailView.swift b/clients/ios/Owl/Views/ConversationDetailView.swift index 0cfb4744..784f8296 100644 --- a/clients/ios/Owl/Views/ConversationDetailView.swift +++ b/clients/ios/Owl/Views/ConversationDetailView.swift @@ -57,7 +57,7 @@ struct ConversationDetailView: View { if !transcription.utterances.isEmpty { ForEach(transcription.utterances, id: \.id) { utterance in HStack { - Text("\(utterance.speaker ?? "Unknown"):") + Text("\(utterance.person?.firstName ?? utterance.speaker ?? "Unknown"):") .fontWeight(.semibold) .foregroundColor(.blue) diff --git a/owl/core/config.py b/owl/core/config.py index 9bc94571..29874c69 100644 --- a/owl/core/config.py +++ b/owl/core/config.py @@ -53,6 +53,9 @@ class SpeakerIdentificationConfiguration(BaseModel): provider: str voice_sample_directory: Optional[str] = None +class SpeechBrainConfiguration(BaseModel): + threshold: float # todo actual config + class DatabaseConfiguration(BaseModel): url: str @@ -109,4 +112,5 @@ def load_config_yaml(cls, config_filepath: str) -> 'Configuration': notification: NotificationConfiguration udp: UDPConfiguration bing: BingConfiguration | None = None - speaker_identification: SpeakerIdentificationConfiguration | None = None \ No newline at end of file + speaker_identification: SpeakerIdentificationConfiguration | None = None + speech_brain: SpeechBrainConfiguration | None = None \ No newline at end of file diff --git a/owl/database/crud.py b/owl/database/crud.py index 55e5f088..67ce876b 100644 --- a/owl/database/crud.py +++ b/owl/database/crud.py @@ -20,6 +20,9 @@ def create_voice_sample(db: Session, voice_sample: VoiceSample) -> VoiceSample: db.refresh(voice_sample) return voice_sample +def get_persons(db: Session) -> List[Person]: + return db.query(Person).all() + def create_utterance(db: Session, utterance: Utterance) -> Utterance: db.add(utterance) db.commit() diff --git a/owl/models/schemas.py b/owl/models/schemas.py index b9ad7e8a..61477a2b 100644 --- a/owl/models/schemas.py +++ b/owl/models/schemas.py @@ -136,6 +136,15 @@ class WordRead(BaseModel): class Config: from_attributes=True +class PersonRead(BaseModel): + id: Optional[int] + first_name: str + last_name: Optional[str] + + class Config: + from_attributes=True + + class UtteranceRead(BaseModel): id: Optional[int] start: Optional[float] @@ -143,6 +152,8 @@ class UtteranceRead(BaseModel): spoken_at: Optional[datetime] text: Optional[str] speaker: Optional[str] + person: Optional[PersonRead] = None + class Config: from_attributes=True json_encoders = { diff --git a/owl/sample_config.yaml b/owl/sample_config.yaml index eb2093d9..d626d523 100644 --- a/owl/sample_config.yaml +++ b/owl/sample_config.yaml @@ -89,6 +89,9 @@ udp: speaker_identification: provider: speech_brain voice_sample_directory: voice_samples + +speech_brain: + threshold: 0.2 # To enable web search # bing: diff --git a/owl/server/main.py b/owl/server/main.py index 58a72393..97473842 100644 --- a/owl/server/main.py +++ b/owl/server/main.py @@ -21,6 +21,7 @@ from ..services import LLMService, CaptureService, ConversationService, NotificationService, BingSearchService from ..database.database import Database from ..services.stt.asynchronous.async_transcription_service_factory import AsyncTranscriptionServiceFactory +from ..services.stt.speaker_identification.speaker_identification_service_factory import SpeakerIdentificationServiceFactory from .task import Task import logging import asyncio @@ -83,7 +84,8 @@ def create_server_app(config: Configuration) -> FastAPI: notification_service = NotificationService(config.notification) capture_service = CaptureService(config=config, database=database) bing_search_service = BingSearchService(config=config.bing) if config.bing else None - conversation_service = ConversationService(config, database, transcription_service, notification_service, bing_search_service) + speaker_identification_service = SpeakerIdentificationServiceFactory.get_service(config=config) if config.speaker_identification else None + conversation_service = ConversationService(config, database, transcription_service, notification_service, bing_search_service, speaker_identification_service) # Create server app app = FastAPI() diff --git a/owl/services/conversation/conversation_service.py b/owl/services/conversation/conversation_service.py index d7c9997d..dc055829 100644 --- a/owl/services/conversation/conversation_service.py +++ b/owl/services/conversation/conversation_service.py @@ -6,7 +6,7 @@ from ..stt.asynchronous.abstract_async_transcription_service import AbstractAsyncTranscriptionService from ..conversation.transcript_summarizer import TranscriptionSummarizer -from ...database.crud import create_transcription, create_conversation, find_most_common_location, create_capture_file_segment_file_ref, update_conversation_state, get_conversation_by_conversation_uuid, get_capturing_conversation_by_capture_uuid +from ...database.crud import create_transcription, create_conversation, find_most_common_location, create_capture_file_segment_file_ref, update_conversation_state, get_conversation_by_conversation_uuid, get_capturing_conversation_by_capture_uuid, get_persons from ...database.database import Database from ...core.config import Configuration from ...models.schemas import Transcription, Conversation, ConversationState, Capture, CaptureSegment, TranscriptionRead, ConversationRead, SuggestedLink @@ -15,13 +15,14 @@ logger = logging.getLogger(__name__) class ConversationService: - def __init__(self, config: Configuration, database: Database, transcription_service: AbstractAsyncTranscriptionService, notification_service, bing_search_service=None): + def __init__(self, config: Configuration, database: Database, transcription_service: AbstractAsyncTranscriptionService, notification_service, bing_search_service=None, speaker_identification_service=None): self._config = config self._database = database self._transcription_service = transcription_service self._notification_service = notification_service self._summarizer = TranscriptionSummarizer(config) self._bing_search_service = bing_search_service + self._speaker_identification_service = speaker_identification_service async def create_conversation(self, conversation_uuid: str, start_time: datetime, capture_file: Capture) -> Conversation: with next(self._database.get_db()) as db: @@ -109,6 +110,15 @@ async def process_conversation_from_audio(self, conversation_uuid: str, voice_sa serialized_payload = deleted_data.json() await self._notification_service.send_notification("Empty Conversation", "An empty conversation was deleted", "delete_conversation", payload=serialized_payload) return None, None + + if self._speaker_identification_service: + persons = get_persons(db) + if not persons: + logger.info("No persons found in the database. Skipping speaker identification.") + else: + start_time = time.time() + transcription = await self._speaker_identification_service.identify_speakers(transcription, persons) + logger.info(f"Speaker identification complete in {time.time() - start_time:.2f} seconds") for utterance in transcription.utterances: utterance.spoken_at = conversation_start_time + timedelta(seconds=utterance.start) diff --git a/owl/services/stt/speaker_identification/abstract_speaker_identification_service.py b/owl/services/stt/speaker_identification/abstract_speaker_identification_service.py index 713e3ae0..58e7f462 100644 --- a/owl/services/stt/speaker_identification/abstract_speaker_identification_service.py +++ b/owl/services/stt/speaker_identification/abstract_speaker_identification_service.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from ....models.schemas import Transcript +from ....models.schemas import Transcription, Person +from typing import List class AbstractSpeakerIdentificationService(ABC): @abstractmethod - async def identifiy_speakers(self, transcript: Transcript, persons) -> Transcript: + async def identify_speakers(self, transcript: Transcription, persons: List[Person]) -> Transcription: pass \ No newline at end of file diff --git a/owl/services/stt/speaker_identification/speaker_identification_service_factory.py b/owl/services/stt/speaker_identification/speaker_identification_service_factory.py new file mode 100644 index 00000000..6044e5c7 --- /dev/null +++ b/owl/services/stt/speaker_identification/speaker_identification_service_factory.py @@ -0,0 +1,20 @@ +from .speech_brain_speaker_identification_service import SpeechBrainIdentificationService +import logging + +logger = logging.getLogger(__name__) + +class SpeakerIdentificationServiceFactory: + _instances = {} + + @staticmethod + def get_service(config): + service_type = config.speaker_identification.provider + if service_type not in SpeakerIdentificationServiceFactory._instances: + logger.info(f"Creating new {service_type} speaker identification service") + if service_type == "speech_brain": + SpeakerIdentificationServiceFactory._instances[service_type] = SpeechBrainIdentificationService(config.speech_brain) + else: + raise ValueError(f"Unknown speaker identification service type: {service_type}") + + return SpeakerIdentificationServiceFactory._instances[service_type] + diff --git a/owl/services/stt/speaker_identification/speech_brain_speaker_identification_service.py b/owl/services/stt/speaker_identification/speech_brain_speaker_identification_service.py new file mode 100644 index 00000000..e8916567 --- /dev/null +++ b/owl/services/stt/speaker_identification/speech_brain_speaker_identification_service.py @@ -0,0 +1,17 @@ + +from .abstract_speaker_identification_service import AbstractSpeakerIdentificationService +from ....models.schemas import Transcription, Person +from typing import List +import logging + +logger = logging.getLogger(__name__) + +class SpeechBrainIdentificationService(AbstractSpeakerIdentificationService): + def __init__(self, config): + self._config = config + + async def identify_speakers(self, transcript: Transcription, persons: List[Person]) -> Transcription: + # stub implementation. just set the first person in the list as the speaker for all utterances + for utterance in transcript.utterances: + utterance.person = persons[0] + return transcript \ No newline at end of file