Skip to content

Commit

Permalink
Cache YouTube transcripts in the DB
Browse files Browse the repository at this point in the history
  • Loading branch information
seanh committed Jul 31, 2023
1 parent 31db832 commit a4e77cc
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 11 deletions.
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@
import httpretty
import pytest
from h_matchers import Any
from pytest_factoryboy import register
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from tests.factories import TranscriptFactory
from tests.factories.factoryboy_sqlalchemy_session import (
clear_factoryboy_sqlalchemy_session,
set_factoryboy_sqlalchemy_session,
)
from via.db import Base

# Each factory has to be registered with pytest_factoryboy.
register(TranscriptFactory)


@pytest.fixture
def pyramid_settings():
Expand Down
1 change: 1 addition & 0 deletions tests/factories/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from tests.factories.transcript import TranscriptFactory
29 changes: 29 additions & 0 deletions tests/factories/transcript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from factory import Sequence
from factory.alchemy import SQLAlchemyModelFactory

from via.models import Transcript


class TranscriptFactory(SQLAlchemyModelFactory):
class Meta:
model = Transcript

video_id = Sequence(lambda n: f"video_id_{n}")
transcript_id = Sequence(lambda n: f"transcript_id_{n}")
transcript = [
{
"text": "[Music]",
"start": 0.0,
"duration": 7.52,
},
{
"text": "how many of you remember the first time",
"start": 5.6,
"duration": 4.72,
},
{
"text": "you saw a playstation 1 game if you were",
"start": 7.52,
"duration": 4.72,
},
]
56 changes: 47 additions & 9 deletions tests/unit/via/services/youtube_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from unittest.mock import sentinel

import pytest
from h_matchers import Any
from requests import Response
from sqlalchemy import select

from via.models import Transcript
from via.services.youtube import YouTubeDataAPIError, YouTubeService, factory


Expand All @@ -17,9 +20,10 @@ class TestYouTubeService:
(True, sentinel.api_key, True),
],
)
def test_enabled(self, enabled, api_key, expected):
def test_enabled(self, db_session, enabled, api_key, expected):
assert (
YouTubeService(
db_session=db_session,
enabled=enabled,
api_key=api_key,
http_service=sentinel.http_service,
Expand Down Expand Up @@ -79,11 +83,37 @@ def test_get_video_title_raises_YouTubeDataAPIError(self, svc, http_service):

assert exc_info.value.__cause__ == http_service.get.side_effect

def test_get_transcript(self, YouTubeTranscriptApi, svc):
transcript = svc.get_transcript(sentinel.video_id)
def test_get_transcript(self, db_session, svc, YouTubeTranscriptApi):
YouTubeTranscriptApi.get_transcript.return_value = [
{"text": "foo", "start": 0.0, "duration": 1.0},
{"text": "bar", "start": 1.0, "duration": 2.0},
]

YouTubeTranscriptApi.get_transcript.assert_called_once_with(sentinel.video_id)
assert transcript == YouTubeTranscriptApi.get_transcript.return_value
returned_transcript = svc.get_transcript("test_video_id")

YouTubeTranscriptApi.get_transcript.assert_called_once_with(
"test_video_id", languages=("en",)
)
assert returned_transcript == YouTubeTranscriptApi.get_transcript.return_value
# It should have cached the transcript in the DB.
assert db_session.scalars(select(Transcript)).all() == [
Any.instance_of(Transcript).with_attrs(
{
"video_id": "test_video_id",
"transcript": YouTubeTranscriptApi.get_transcript.return_value,
}
)
]

@pytest.mark.usefixtures("db_session")
@pytest.mark.parametrize("transcript__transcript_id", ["en"])
def test_get_transcript_returns_cached_transcripts(
self, transcript, svc, YouTubeTranscriptApi
):
returned_transcript = svc.get_transcript(transcript.video_id)

YouTubeTranscriptApi.get_transcript.assert_not_called()
assert returned_transcript == transcript.transcript

@pytest.mark.parametrize(
"video_id,expected_url",
Expand All @@ -100,18 +130,26 @@ def test_canonical_video_url(self, video_id, expected_url, svc):
assert expected_url == svc.canonical_video_url(video_id)

@pytest.fixture
def svc(self, http_service):
def svc(self, db_session, http_service):
return YouTubeService(
enabled=True, api_key=sentinel.api_key, http_service=http_service
db_session=db_session,
enabled=True,
api_key=sentinel.api_key,
http_service=http_service,
)


class TestFactory:
def test_it(self, YouTubeService, youtube_service, http_service, pyramid_request):
def test_it(
self, YouTubeService, youtube_service, pyramid_request, http_service, db_session
):
returned = factory(sentinel.context, pyramid_request)

YouTubeService.assert_called_once_with(
enabled=True, api_key="test_youtube_api_key", http_service=http_service
db_session=db_session,
enabled=pyramid_request.registry.settings["youtube_transcripts"],
api_key="test_youtube_api_key",
http_service=http_service,
)
assert returned == youtube_service

Expand Down
3 changes: 3 additions & 0 deletions via/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from via.models.transcript import Transcript


def includeme(_config): # pragma: no cover
pass
20 changes: 20 additions & 0 deletions via/models/_mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from datetime import datetime

from sqlalchemy import func
from sqlalchemy.orm import Mapped, MappedAsDataclass, mapped_column


class CreatedUpdatedMixin(MappedAsDataclass):
created: Mapped[datetime] = mapped_column(
init=False,
repr=False,
server_default=func.now(), # pylint:disable=not-callable
sort_order=-10,
)
updated: Mapped[datetime] = mapped_column(
init=False,
repr=False,
server_default=func.now(), # pylint:disable=not-callable
onupdate=func.now(), # pylint:disable=not-callable
sort_order=-10,
)
16 changes: 16 additions & 0 deletions via/models/transcript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column

from via.db import Base
from via.models._mixins import CreatedUpdatedMixin


class Transcript(CreatedUpdatedMixin, Base):
__tablename__ = "transcript"
__table_args__ = (UniqueConstraint("video_id", "transcript_id"),)

id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True, init=False)
video_id: Mapped[str]
transcript_id: Mapped[str]
transcript: Mapped[list] = mapped_column(JSONB, repr=False)
36 changes: 34 additions & 2 deletions via/services/youtube.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from urllib.parse import parse_qs, quote_plus, urlparse

from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from youtube_transcript_api import YouTubeTranscriptApi

from via.models import Transcript
from via.services.http import HTTPService


Expand All @@ -10,7 +13,10 @@ class YouTubeDataAPIError(Exception):


class YouTubeService:
def __init__(self, enabled: bool, api_key: str, http_service: HTTPService):
def __init__(
self, db_session, enabled: bool, api_key: str, http_service: HTTPService
):
self._db = db_session
self._enabled = enabled
self._api_key = api_key
self._http_service = http_service
Expand Down Expand Up @@ -82,11 +88,37 @@ def get_transcript(self, video_id):
:raise Exception: this method might raise any type of exception that
YouTubeTranscriptApi raises
"""
return YouTubeTranscriptApi.get_transcript(video_id)
transcript_id = language_code = "en"

try:
transcript = (
self._db.scalars(
select(Transcript).where(
Transcript.video_id == video_id,
Transcript.transcript_id == transcript_id,
)
)
.one()
.transcript
)
except NoResultFound:
transcript = YouTubeTranscriptApi.get_transcript(
video_id, languages=(language_code,)
)
self._db.add(
Transcript(
video_id=video_id,
transcript_id=transcript_id,
transcript=transcript,
)
)

return transcript


def factory(_context, request):
return YouTubeService(
db_session=request.db,
enabled=request.registry.settings["youtube_transcripts"],
api_key=request.registry.settings["youtube_api_key"],
http_service=request.find_service(HTTPService),
Expand Down

0 comments on commit a4e77cc

Please sign in to comment.