Skip to content

Commit

Permalink
Replace youtube-transcript-api library
Browse files Browse the repository at this point in the history
  • Loading branch information
seanh committed Aug 15, 2023
1 parent f431a6d commit 44f7ca9
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 27 deletions.
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from tests.factories import TranscriptFactory, VideoFactory
from tests.factories import TranscriptFactory, TranscriptInfoFactory, VideoFactory
from tests.factories.factoryboy_sqlalchemy_session import (
clear_factoryboy_sqlalchemy_session,
set_factoryboy_sqlalchemy_session,
Expand All @@ -16,6 +16,7 @@

# Each factory has to be registered with pytest_factoryboy.
register(TranscriptFactory)
register(TranscriptInfoFactory)
register(VideoFactory)


Expand Down
1 change: 1 addition & 0 deletions tests/factories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from tests.factories.transcript import TranscriptFactory
from tests.factories.transcript_info import TranscriptInfoFactory
from tests.factories.video import VideoFactory
13 changes: 13 additions & 0 deletions tests/factories/transcript_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from factory import Factory, Sequence

from via.services.youtube_transcript import TranscriptInfo


class TranscriptInfoFactory(Factory):
class Meta:
model = TranscriptInfo

language_code = "en-us"
name = "English (United States)"
url = Sequence(lambda n: f"https://example.com/api/timedtext?v={n}")
autogenerated = False
6 changes: 6 additions & 0 deletions tests/unit/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
URLDetailsService,
ViaClientService,
YouTubeService,
YouTubeTranscriptService,
)


Expand Down Expand Up @@ -69,3 +70,8 @@ def youtube_service(mock_service):
youtube_service.get_video_id.return_value = None

return youtube_service


@pytest.fixture
def youtube_transcript_service(mock_service):
return mock_service(YouTubeTranscriptService)
54 changes: 33 additions & 21 deletions tests/unit/via/services/youtube_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_enabled(self, db_session, enabled, api_key, expected):
enabled=enabled,
api_key=api_key,
http_service=sentinel.http_service,
youtube_transcript_service=sentinel.youtube_transcript_service,
).enabled
== expected
)
Expand Down Expand Up @@ -94,40 +95,48 @@ def test_get_video_title_raises_YouTubeDataAPIError(self, svc, http_service):

assert exc_info.value.__cause__ == http_service.get.side_effect

def test_get_transcript(self, db_session, svc, YouTubeTranscriptApi):
YouTubeTranscriptApi.get_transcript.return_value = [
{"text": "foo", "start": 0.0, "duration": 1.0},
{"text": "bar", "start": 1.0, "duration": 2.0},
]
def test_get_transcript(
self, db_session, svc, youtube_transcript_service, transcript_info
):
youtube_transcript_service.pick_default_transcript.return_value = (
transcript_info
)
youtube_transcript_service.get_transcript.return_value = "test_transcript"

returned_transcript = svc.get_transcript("test_video_id")

YouTubeTranscriptApi.get_transcript.assert_called_once_with(
"test_video_id", languages=("en",)
# It gets the transcript from YouTubeTranscriptService.
youtube_transcript_service.get_transcript_infos.assert_called_once_with(
"test_video_id"
)
youtube_transcript_service.pick_default_transcript.assert_called_once_with(
youtube_transcript_service.get_transcript_infos.return_value
)
assert returned_transcript == YouTubeTranscriptApi.get_transcript.return_value
youtube_transcript_service.get_transcript.assert_called_once_with(
transcript_info
)
assert returned_transcript == "test_transcript"
# It should have cached the transcript in the DB.
assert db_session.scalars(select(Transcript)).all() == [
Any.instance_of(Transcript).with_attrs(
{
"video_id": "test_video_id",
"transcript": YouTubeTranscriptApi.get_transcript.return_value,
"transcript_id": transcript_info.id,
"transcript": "test_transcript",
}
)
]

@pytest.mark.usefixtures("db_session")
def test_get_transcript_returns_cached_transcripts(
self, transcript, svc, YouTubeTranscriptApi
self, svc, transcript, youtube_transcript_service
):
returned_transcript = svc.get_transcript(transcript.video_id)

YouTubeTranscriptApi.get_transcript.assert_not_called()
youtube_transcript_service.get_transcript.assert_not_called()
assert returned_transcript == transcript.transcript

@pytest.mark.usefixtures("db_session")
def test_get_transcript_returns_oldest_cached_transcript(
self, transcript_factory, svc
self, svc, transcript_factory
):
"""If there are multiple cached transcripts get_transcript() returns the oldest one."""
oldest_transcript, newer_transcript = transcript_factory.create_batch(
Expand Down Expand Up @@ -155,18 +164,25 @@ def test_canonical_video_url(self, video_id, expected_url, svc):
assert expected_url == svc.canonical_video_url(video_id)

@pytest.fixture
def svc(self, db_session, http_service):
def svc(self, db_session, http_service, youtube_transcript_service):
return YouTubeService(
db_session=db_session,
enabled=True,
api_key=sentinel.api_key,
http_service=http_service,
youtube_transcript_service=youtube_transcript_service,
)


class TestFactory:
def test_it(
self, YouTubeService, youtube_service, pyramid_request, http_service, db_session
self,
YouTubeService,
youtube_service,
pyramid_request,
http_service,
db_session,
youtube_transcript_service,
):
returned = factory(sentinel.context, pyramid_request)

Expand All @@ -175,6 +191,7 @@ def test_it(
enabled=pyramid_request.registry.settings["youtube_transcripts"],
api_key="test_youtube_api_key",
http_service=http_service,
youtube_transcript_service=youtube_transcript_service,
)
assert returned == youtube_service

Expand All @@ -185,8 +202,3 @@ def YouTubeService(self, patch):
@pytest.fixture
def youtube_service(self, YouTubeService):
return YouTubeService.return_value


@pytest.fixture(autouse=True)
def YouTubeTranscriptApi(patch):
return patch("via.services.youtube.YouTubeTranscriptApi")
207 changes: 207 additions & 0 deletions tests/unit/via/services/youtube_transcript_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import json
from io import BytesIO
from unittest.mock import sentinel

import pytest
from h_matchers import Any
from requests import Response

from tests.factories import TranscriptInfoFactory
from via.services.youtube_transcript import (
TranscriptInfo,
YouTubeTranscriptService,
factory,
)


class TestTranscriptInfo:
@pytest.mark.parametrize(
"transcript_info,expected_id",
[
(
TranscriptInfo(
"en-us",
"English (United States)",
"https://example.com/transcript",
autogenerated=False,
),
"en-us..RW5nbGlzaCAoVW5pdGVkIFN0YXRlcyk=",
),
(
TranscriptInfo(
"en",
"English",
"https://example.com/transcript",
autogenerated=True,
),
"en.a.RW5nbGlzaA==",
),
],
)
def test_id(self, transcript_info, expected_id):
assert transcript_info.id == expected_id


class TestYouTubeTranscriptService:
def test_get_transcript_infos(self, svc, http_service):
# The JSON response body from the YouTube API.
response_json = {
"captions": {
"playerCaptionsTracklistRenderer": {
"captionTracks": [
{
"languageCode": "en",
"name": {"simpleText": "English"},
"baseUrl": "https://example.com/transcript_1",
},
{
"languageCode": "en-us",
"name": {"simpleText": "English (United States)"},
"baseUrl": "https://example.com/transcript_2",
},
]
}
}
}
response = http_service.post.return_value = Response()
response.raw = BytesIO(json.dumps(response_json).encode("utf-8"))

transcript_infos = svc.get_transcript_infos("test_video_id")

caption_tracks = response_json["captions"]["playerCaptionsTracklistRenderer"][
"captionTracks"
]
assert transcript_infos == [
Any.instance_of(TranscriptInfo).with_attrs(
{
"language_code": caption_tracks[0]["languageCode"],
"autogenerated": False,
"name": caption_tracks[0]["name"]["simpleText"],
"url": caption_tracks[0]["baseUrl"],
}
),
Any.instance_of(TranscriptInfo).with_attrs(
{
"language_code": caption_tracks[1]["languageCode"],
"autogenerated": False,
"name": caption_tracks[1]["name"]["simpleText"],
"url": caption_tracks[1]["baseUrl"],
}
),
]

@pytest.mark.parametrize(
"transcript_infos,expected_default_transcript_index",
[
(
[
TranscriptInfoFactory(language_code="en", name="English"),
TranscriptInfoFactory(
language_code="en-us", name="English (United States)"
),
],
0,
),
(
[
TranscriptInfoFactory(
language_code="en-us", name="English (United States)"
),
TranscriptInfoFactory(language_code="en", name="English - DTVCC1"),
],
0,
),
(
[
TranscriptInfoFactory(language_code="en", name="English - Foo"),
TranscriptInfoFactory(
language_code="en-us", name="English (United States) - Foo"
),
],
0,
),
(
[
TranscriptInfoFactory(
language_code="en-us", name="English (United States) - Foo"
),
TranscriptInfoFactory(
language_code="en", name="English", autogenerated=True
),
],
0,
),
(
[
TranscriptInfoFactory(
language_code="en", name="English", autogenerated=True
),
TranscriptInfoFactory(
language_code="en-us",
name="English (United States)",
autogenerated=True,
),
],
0,
),
(
[
TranscriptInfoFactory(language_code="fr", name="French"),
TranscriptInfoFactory(
language_code="en", name="English", autogenerated=True
),
],
1,
),
(
[
TranscriptInfoFactory(language_code="fr", name="French"),
TranscriptInfoFactory(language_code="de", name="Deutsch"),
],
0,
),
],
)
def test_pick_default_transcript(
self, svc, transcript_infos, expected_default_transcript_index
):
assert (
svc.pick_default_transcript(transcript_infos)
== transcript_infos[expected_default_transcript_index]
)

def test_get_transcript(self, svc, transcript_info, http_service):
http_service.get.return_value.text = """
<transcript>
<text start="0.21" dur="1.387">Hey there guys,</text>
<text start="1.597">Lichen&#39; subscribe</text>
<text start="4.327" dur="2.063">
&lt;font color=&quot;#A0AAB4&quot;&gt;Buy my merch!&lt;/font&gt;
</text>
</transcript>
"""

transcript = svc.get_transcript(transcript_info)
http_service.get.assert_called_once_with(transcript_info.url)

assert transcript == [
{"duration": 1.387, "start": 0.21, "text": "Hey there guys,"},
{"duration": 0.0, "start": 1.597, "text": "Lichen' subscribe"},
{"duration": 2.063, "start": 4.327, "text": "Buy my merch!"},
]

@pytest.fixture
def svc(self, http_service):
return YouTubeTranscriptService(http_service)


class TestFactory:
def test_factory(self, YouTubeTranscriptService, http_service, pyramid_request):
svc = factory(sentinel.context, pyramid_request)

YouTubeTranscriptService.assert_called_once_with(http_service=http_service)
assert svc == YouTubeTranscriptService.return_value

@pytest.fixture
def YouTubeTranscriptService(self, patch):
return patch("via.services.youtube_transcript.YouTubeTranscriptService")
4 changes: 4 additions & 0 deletions via/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from via.services.url_details import URLDetailsService
from via.services.via_client import ViaClientService
from via.services.youtube import YouTubeService
from via.services.youtube_transcript import YouTubeTranscriptService


def includeme(config): # pragma: no cover
Expand All @@ -37,6 +38,9 @@ def includeme(config): # pragma: no cover
config.register_service_factory(
"via.services.youtube.factory", iface=YouTubeService
)
config.register_service_factory(
"via.services.youtube_transcript.factory", iface=YouTubeTranscriptService
)

config.register_service_factory(
"via.services.url_details.factory", iface=URLDetailsService
Expand Down
Loading

0 comments on commit 44f7ca9

Please sign in to comment.