Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speaker identification WIP #46

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions alembic/versions/b6aff0a993d7_add_person_and_voicesamples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Add person and voicesamples

Revision ID: b6aff0a993d7
Revises: 33bddba74d25
Create Date: 2024-03-01 08:56:55.205553

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa
import sqlmodel


# revision identifiers, used by Alembic.
revision: str = 'b6aff0a993d7'
down_revision: Union[str, None] = '33bddba74d25'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Use batch operations to support SQLite ALTER TABLE for adding constraints
with op.batch_alter_table('utterance', schema=None) as batch_op:
batch_op.add_column(sa.Column('person_id', sa.Integer(), nullable=True))
batch_op.create_foreign_key('fk_utterance_person', 'person', ['person_id'], ['id'])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to store the vector embedding of the voice here? This way, we would be able to

  • Show distinct speakers in the UI even without having the Persons in the DB
  • On creating a new person, easily find all the instances in the past when that person spoke by fetching all the utterances with similar enough voice embeddings


op.create_table('person',
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('first_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('last_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('voicesample',
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('filepath', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('speaker_embeddings', sa.JSON(), nullable=True),
sa.Column('person_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['person_id'], ['person.id'], name='fk_voicesample_person'),
sa.PrimaryKeyConstraint('id')
)

def downgrade() -> None:
# Use batch operations for dropping column with SQLite
with op.batch_alter_table('utterance', schema=None) as batch_op:
batch_op.drop_constraint('fk_utterance_person', type_='foreignkey')
batch_op.drop_column('person_id')

# Commands for dropping tables remain unchanged
op.drop_table('voicesample')
op.drop_table('person')
15 changes: 14 additions & 1 deletion clients/ios/Owl/Models/Transcript.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ struct CaptureFile: Codable {
}
}


struct CaptureFileSegment: Codable {
var id: Int
var filePath: String
Expand All @@ -34,6 +33,7 @@ struct CaptureFileSegment: Codable {
case sourceCapture = "source_capture"
}
}

struct Transcription: Codable {
var id: Int
var model: String
Expand All @@ -53,6 +53,7 @@ struct Utterance: Codable {
var end: Double?
var text: String?
var speaker: String?
var person: Person?
// var words: [Word]
}

Expand All @@ -70,3 +71,15 @@ struct Word: Codable {
case utteranceId = "utterance_id"
}
}

struct Person: Codable {
var id: Int
var firstName: String
var lastName: String?

enum CodingKeys: String, CodingKey {
case id
case firstName = "first_name"
case lastName = "last_name"
}
}
2 changes: 1 addition & 1 deletion clients/ios/Owl/Views/ConversationDetailView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct ConversationDetailView: View {
if !transcription.utterances.isEmpty {
ForEach(transcription.utterances, id: \.id) { utterance in
HStack {
Text("\(utterance.speaker ?? "Unknown"):")
Text("\(utterance.person?.firstName ?? utterance.speaker ?? "Unknown"):")
.fontWeight(.semibold)
.foregroundColor(.blue)

Expand Down
37 changes: 37 additions & 0 deletions owl/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import subprocess
from alembic import command
from alembic.config import Config
from ..database.database import Database
from ..database.crud import create_person, create_voice_sample
from ..models.schemas import Person, VoiceSample

import click
from rich.console import Console
Expand Down Expand Up @@ -202,6 +205,40 @@ def create_migration(config: Configuration, message: str):

console.log(f"[bold green]Migration script generated with message: '{message}'")

####################################################################################################
# Persons
####################################################################################################

@cli.command()
@add_options(_config_options)
@click.option('--first-name', required=True, help='First name of the person')
@click.option('--last-name', required=True, help='Last name of the person')
@click.option('--voice-sample-path', required=True, help='Path to the voice sample file')
def enroll_speaker(config: Configuration, first_name: str, last_name: str, voice_sample_path: str):
"""Enroll a new person with a voice sample."""
console = Console()
console.log("[bold green]Enrolling speaker...")

database = Database(config.database)
with next(database.get_db()) as db:
person = create_person(db, Person(first_name=first_name, last_name=last_name))
sample_directory = config.speaker_identification.voice_sample_directory
sample_directory = os.path.join(sample_directory, str(person.id))
os.makedirs(sample_directory, exist_ok=True)

filename = os.path.basename(voice_sample_path)
extension = os.path.splitext(filename)[1]

sample_file_path = os.path.join(sample_directory, f"{uuid.uuid1().hex}.{extension[1:]}")

with next(database.get_db()) as db:
voice_sample = create_voice_sample(db, VoiceSample(person_id=person.id, filepath=sample_file_path))
with open(voice_sample_path, "rb") as f:
with open(sample_file_path, "wb") as f2:
f2.write(f.read())

console.log(f"[bold green]Enrolled new person: '{person.id} ({voice_sample.id})'")

####################################################################################################
# Server
####################################################################################################
Expand Down
11 changes: 10 additions & 1 deletion owl/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@ class StreamingTranscriptionConfiguration(BaseModel):
class AsyncTranscriptionConfiguration(BaseModel):
provider: str

class SpeakerIdentificationConfiguration(BaseModel):
provider: str
voice_sample_directory: Optional[str] = None

class SpeechBrainConfiguration(BaseModel):
threshold: float # todo actual config

class DatabaseConfiguration(BaseModel):
url: str

Expand Down Expand Up @@ -104,4 +111,6 @@ def load_config_yaml(cls, config_filepath: str) -> 'Configuration':
conversation_endpointing: ConversationEndpointingConfiguration
notification: NotificationConfiguration
udp: UDPConfiguration
bing: BingConfiguration | None = None
bing: BingConfiguration | None = None
speaker_identification: SpeakerIdentificationConfiguration | None = None
speech_brain: SpeechBrainConfiguration | None = None
17 changes: 16 additions & 1 deletion owl/database/crud.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sqlmodel import SQLModel, Session, select
from ..models.schemas import Transcription, Conversation, Utterance, Location, CaptureSegment, Capture, ConversationState
from ..models.schemas import Transcription, Conversation, Utterance, Location, CaptureSegment, Capture, ConversationState, Person, VoiceSample
from typing import List, Optional
from sqlalchemy.orm import joinedload, selectinload
from sqlalchemy import desc, func, or_
Expand All @@ -8,6 +8,21 @@

logger = logging.getLogger(__name__)

def create_person(db: Session, person: Person) -> Person:
db.add(person)
db.commit()
db.refresh(person)
return person

def create_voice_sample(db: Session, voice_sample: VoiceSample) -> VoiceSample:
db.add(voice_sample)
db.commit()
db.refresh(voice_sample)
return voice_sample

def get_persons(db: Session) -> List[Person]:
return db.query(Person).all()

def create_utterance(db: Session, utterance: Utterance) -> Utterance:
db.add(utterance)
db.commit()
Expand Down
28 changes: 27 additions & 1 deletion owl/models/schemas.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Optional
from sqlmodel import SQLModel, Field, Relationship
from sqlmodel import SQLModel, Field, Relationship, Column, JSON
from datetime import datetime, timezone
from pydantic import BaseModel
from enum import Enum
Expand Down Expand Up @@ -36,6 +36,8 @@ class Utterance(CreatedAtMixin, table=True):
transcription: "Transcription" = Relationship(back_populates="utterances")

words: List[Word] = Relationship(back_populates="utterance", sa_relationship_kwargs={"cascade": "all, delete-orphan"})
person_id: Optional[int] = Field(default=None, foreign_key="person.id")
person: Optional["Person"] = Relationship(back_populates="utterances")

class Transcription(CreatedAtMixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
Expand Down Expand Up @@ -106,6 +108,19 @@ class CaptureSegment(CreatedAtMixin, table=True):

conversation: Optional[Conversation] = Relationship(back_populates="capture_segment_file")

class Person(CreatedAtMixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
first_name: str
last_name: str
voice_samples: List["VoiceSample"] = Relationship(back_populates="person")
utterances: List[Utterance] = Relationship(back_populates="person")

class VoiceSample(CreatedAtMixin, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
filepath: str = Field(...)
speaker_embeddings: dict = Field(default={}, sa_column=Column(JSON))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the key is embedding model can we name this explicitly speaker_embeddings_by_model?

person_id: Optional[int] = Field(default=None, foreign_key="person.id")
person: Optional["Person"] = Relationship(back_populates="voice_samples")

# API Response Models
# https://sqlmodel.tiangolo.com/tutorial/fastapi/relationships/#dont-include-all-the-data
Expand All @@ -121,13 +136,24 @@ class WordRead(BaseModel):
class Config:
from_attributes=True

class PersonRead(BaseModel):
id: Optional[int]
first_name: str
last_name: Optional[str]

class Config:
from_attributes=True


class UtteranceRead(BaseModel):
id: Optional[int]
start: Optional[float]
end: Optional[float]
spoken_at: Optional[datetime]
text: Optional[str]
speaker: Optional[str]
person: Optional[PersonRead] = None

class Config:
from_attributes=True
json_encoders = {
Expand Down
7 changes: 7 additions & 0 deletions owl/sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ udp:
host: '0.0.0.0'
port: 8001

speaker_identification:
provider: speech_brain
voice_sample_directory: voice_samples

speech_brain:
threshold: 0.2

# To enable web search
# bing:
# subscription_key: your_bing_subscription_service_key
4 changes: 3 additions & 1 deletion owl/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..services import LLMService, CaptureService, ConversationService, NotificationService, BingSearchService
from ..database.database import Database
from ..services.stt.asynchronous.async_transcription_service_factory import AsyncTranscriptionServiceFactory
from ..services.stt.speaker_identification.speaker_identification_service_factory import SpeakerIdentificationServiceFactory
from .task import Task
import logging
import asyncio
Expand Down Expand Up @@ -83,7 +84,8 @@ def create_server_app(config: Configuration) -> FastAPI:
notification_service = NotificationService(config.notification)
capture_service = CaptureService(config=config, database=database)
bing_search_service = BingSearchService(config=config.bing) if config.bing else None
conversation_service = ConversationService(config, database, transcription_service, notification_service, bing_search_service)
speaker_identification_service = SpeakerIdentificationServiceFactory.get_service(config=config) if config.speaker_identification else None
conversation_service = ConversationService(config, database, transcription_service, notification_service, bing_search_service, speaker_identification_service)

# Create server app
app = FastAPI()
Expand Down
14 changes: 12 additions & 2 deletions owl/services/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..stt.asynchronous.abstract_async_transcription_service import AbstractAsyncTranscriptionService
from ..conversation.transcript_summarizer import TranscriptionSummarizer
from ...database.crud import create_transcription, create_conversation, find_most_common_location, create_capture_file_segment_file_ref, update_conversation_state, get_conversation_by_conversation_uuid, get_capturing_conversation_by_capture_uuid
from ...database.crud import create_transcription, create_conversation, find_most_common_location, create_capture_file_segment_file_ref, update_conversation_state, get_conversation_by_conversation_uuid, get_capturing_conversation_by_capture_uuid, get_persons
from ...database.database import Database
from ...core.config import Configuration
from ...models.schemas import Transcription, Conversation, ConversationState, Capture, CaptureSegment, TranscriptionRead, ConversationRead, SuggestedLink
Expand All @@ -15,13 +15,14 @@
logger = logging.getLogger(__name__)

class ConversationService:
def __init__(self, config: Configuration, database: Database, transcription_service: AbstractAsyncTranscriptionService, notification_service, bing_search_service=None):
def __init__(self, config: Configuration, database: Database, transcription_service: AbstractAsyncTranscriptionService, notification_service, bing_search_service=None, speaker_identification_service=None):
self._config = config
self._database = database
self._transcription_service = transcription_service
self._notification_service = notification_service
self._summarizer = TranscriptionSummarizer(config)
self._bing_search_service = bing_search_service
self._speaker_identification_service = speaker_identification_service

async def create_conversation(self, conversation_uuid: str, start_time: datetime, capture_file: Capture) -> Conversation:
with next(self._database.get_db()) as db:
Expand Down Expand Up @@ -109,6 +110,15 @@ async def process_conversation_from_audio(self, conversation_uuid: str, voice_sa
serialized_payload = deleted_data.json()
await self._notification_service.send_notification("Empty Conversation", "An empty conversation was deleted", "delete_conversation", payload=serialized_payload)
return None, None

if self._speaker_identification_service:
persons = get_persons(db)
if not persons:
logger.info("No persons found in the database. Skipping speaker identification.")
else:
start_time = time.time()
transcription = await self._speaker_identification_service.identify_speakers(transcription, persons)
logger.info(f"Speaker identification complete in {time.time() - start_time:.2f} seconds")

for utterance in transcription.utterances:
utterance.spoken_at = conversation_start_time + timedelta(seconds=utterance.start)
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from abc import ABC, abstractmethod
from ....models.schemas import Transcription, Person
from typing import List

class AbstractSpeakerIdentificationService(ABC):

@abstractmethod
async def identify_speakers(self, transcript: Transcription, persons: List[Person]) -> Transcription:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .speech_brain_speaker_identification_service import SpeechBrainIdentificationService
import logging

logger = logging.getLogger(__name__)

class SpeakerIdentificationServiceFactory:
_instances = {}

@staticmethod
def get_service(config):
service_type = config.speaker_identification.provider
if service_type not in SpeakerIdentificationServiceFactory._instances:
logger.info(f"Creating new {service_type} speaker identification service")
if service_type == "speech_brain":
SpeakerIdentificationServiceFactory._instances[service_type] = SpeechBrainIdentificationService(config.speech_brain)
else:
raise ValueError(f"Unknown speaker identification service type: {service_type}")

return SpeakerIdentificationServiceFactory._instances[service_type]

Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

from .abstract_speaker_identification_service import AbstractSpeakerIdentificationService
from ....models.schemas import Transcription, Person
from typing import List
import logging

logger = logging.getLogger(__name__)

class SpeechBrainIdentificationService(AbstractSpeakerIdentificationService):
def __init__(self, config):
self._config = config

async def identify_speakers(self, transcript: Transcription, persons: List[Person]) -> Transcription:
# stub implementation. just set the first person in the list as the speaker for all utterances
for utterance in transcript.utterances:
utterance.person = persons[0]
return transcript