diff --git a/tests/conftest.py b/tests/conftest.py
index a1f5cd4a..e1d6372a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,7 +7,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
-from tests.factories import TranscriptFactory, VideoFactory
+from tests.factories import TranscriptFactory, TranscriptInfoFactory, VideoFactory
from tests.factories.factoryboy_sqlalchemy_session import (
clear_factoryboy_sqlalchemy_session,
set_factoryboy_sqlalchemy_session,
@@ -16,6 +16,7 @@
# Each factory has to be registered with pytest_factoryboy.
register(TranscriptFactory)
+register(TranscriptInfoFactory)
register(VideoFactory)
diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py
index 0d81dcd0..2ff9cfd2 100644
--- a/tests/factories/__init__.py
+++ b/tests/factories/__init__.py
@@ -1,2 +1,3 @@
from tests.factories.transcript import TranscriptFactory
+from tests.factories.transcript_info import TranscriptInfoFactory
from tests.factories.video import VideoFactory
diff --git a/tests/factories/transcript_info.py b/tests/factories/transcript_info.py
new file mode 100644
index 00000000..6c9ffb28
--- /dev/null
+++ b/tests/factories/transcript_info.py
@@ -0,0 +1,13 @@
+from factory import Factory
+
+from via.services.youtube_transcript import TranscriptInfo
+
+
+class TranscriptInfoFactory(Factory):
+ class Meta:
+ model = TranscriptInfo
+
+ language_code = "en-us"
+ name = "English (United States)"
+ url = "https://example.com/api/timedtext?v=foo"
+ autogenerated = False
diff --git a/tests/unit/services.py b/tests/unit/services.py
index 00211d41..26bffe11 100644
--- a/tests/unit/services.py
+++ b/tests/unit/services.py
@@ -11,6 +11,7 @@
URLDetailsService,
ViaClientService,
YouTubeService,
+ YouTubeTranscriptService,
)
@@ -69,3 +70,8 @@ def youtube_service(mock_service):
youtube_service.get_video_id.return_value = None
return youtube_service
+
+
+@pytest.fixture
+def youtube_transcript_service(mock_service):
+ return mock_service(YouTubeTranscriptService)
diff --git a/tests/unit/via/services/youtube_test.py b/tests/unit/via/services/youtube_test.py
index 7b9fa981..8fca1c7c 100644
--- a/tests/unit/via/services/youtube_test.py
+++ b/tests/unit/via/services/youtube_test.py
@@ -28,6 +28,7 @@ def test_enabled(self, db_session, enabled, api_key, expected):
enabled=enabled,
api_key=api_key,
http_service=sentinel.http_service,
+ youtube_transcript_service=sentinel.youtube_transcript_service,
).enabled
== expected
)
@@ -94,42 +95,50 @@ def test_get_video_title_raises_YouTubeDataAPIError(self, svc, http_service):
assert exc_info.value.__cause__ == http_service.get.side_effect
- def test_get_transcript(self, db_session, svc, YouTubeTranscriptApi):
- YouTubeTranscriptApi.get_transcript.return_value = [
- {"text": "foo", "start": 0.0, "duration": 1.0},
- {"text": "bar", "start": 1.0, "duration": 2.0},
- ]
+ def test_get_transcript_when_none_saved(
+ self, db_session, svc, youtube_transcript_service, transcript_info
+ ):
+ video_id = "test_video_id"
+ youtube_transcript_service.pick_default_transcript.return_value = (
+ transcript_info
+ )
+ youtube_transcript_service.get_transcript.return_value = "test_transcript"
- returned_transcript = svc.get_transcript("test_video_id")
+ returned = svc.get_transcript(video_id)
- YouTubeTranscriptApi.get_transcript.assert_called_once_with(
- "test_video_id", languages=("en",)
+ # It gets the transcript from YouTubeTranscriptService.
+ youtube_transcript_service.get_transcript_infos.assert_called_once_with(
+ video_id
+ )
+ youtube_transcript_service.pick_default_transcript.assert_called_once_with(
+ youtube_transcript_service.get_transcript_infos.return_value
)
- assert returned_transcript == YouTubeTranscriptApi.get_transcript.return_value
- # It should have cached the transcript in the DB.
+ youtube_transcript_service.get_transcript.assert_called_once_with(
+ transcript_info
+ )
+ # It saves the transcript in the DB for next time.
assert db_session.scalars(select(Transcript)).all() == [
Any.instance_of(Transcript).with_attrs(
{
- "video_id": "test_video_id",
- "transcript": YouTubeTranscriptApi.get_transcript.return_value,
+ "video_id": video_id,
+ "transcript_id": transcript_info.id,
+ "transcript": "test_transcript",
}
)
]
+ # It returns the transcript.
+ assert returned == "test_transcript"
- @pytest.mark.usefixtures("db_session")
- def test_get_transcript_returns_cached_transcripts(
- self, transcript, svc, YouTubeTranscriptApi
+ def test_get_transcript_when_one_saved(
+ self, svc, transcript, youtube_transcript_service
):
- returned_transcript = svc.get_transcript(transcript.video_id)
+ returned = svc.get_transcript(transcript.video_id)
- YouTubeTranscriptApi.get_transcript.assert_not_called()
- assert returned_transcript == transcript.transcript
+ # It returns the saved transcript from the DB without calling YouTubeService.
+ youtube_transcript_service.get_transcript.assert_not_called()
+ assert returned == transcript.transcript
- @pytest.mark.usefixtures("db_session")
- def test_get_transcript_returns_oldest_cached_transcript(
- self, transcript_factory, svc
- ):
- """If there are multiple cached transcripts get_transcript() returns the oldest one."""
+ def test_get_transcript_when_multiple_saved(self, svc, transcript_factory):
oldest_transcript, newer_transcript = transcript_factory.create_batch(
2, video_id="video_id"
)
@@ -138,9 +147,9 @@ def test_get_transcript_returns_oldest_cached_transcript(
newer_transcript.created = datetime(2023, 8, 12)
newer_transcript.transcript = "newest_transcript"
- returned_transcript = svc.get_transcript("video_id")
+ returned = svc.get_transcript("video_id")
- assert returned_transcript == "oldest_transcript"
+ assert returned == oldest_transcript.transcript
@pytest.mark.parametrize(
"video_id,expected_url",
@@ -157,18 +166,25 @@ def test_canonical_video_url(self, video_id, expected_url, svc):
assert expected_url == svc.canonical_video_url(video_id)
@pytest.fixture
- def svc(self, db_session, http_service):
+ def svc(self, db_session, http_service, youtube_transcript_service):
return YouTubeService(
db_session=db_session,
enabled=True,
api_key=sentinel.api_key,
http_service=http_service,
+ youtube_transcript_service=youtube_transcript_service,
)
class TestFactory:
def test_it(
- self, YouTubeService, youtube_service, pyramid_request, http_service, db_session
+ self,
+ YouTubeService,
+ youtube_service,
+ pyramid_request,
+ http_service,
+ db_session,
+ youtube_transcript_service,
):
returned = factory(sentinel.context, pyramid_request)
@@ -177,6 +193,7 @@ def test_it(
enabled=pyramid_request.registry.settings["youtube_transcripts"],
api_key="test_youtube_api_key",
http_service=http_service,
+ youtube_transcript_service=youtube_transcript_service,
)
assert returned == youtube_service
@@ -187,8 +204,3 @@ def YouTubeService(self, patch):
@pytest.fixture
def youtube_service(self, YouTubeService):
return YouTubeService.return_value
-
-
-@pytest.fixture(autouse=True)
-def YouTubeTranscriptApi(patch):
- return patch("via.services.youtube.YouTubeTranscriptApi")
diff --git a/tests/unit/via/services/youtube_transcript_test.py b/tests/unit/via/services/youtube_transcript_test.py
new file mode 100644
index 00000000..1476dbdc
--- /dev/null
+++ b/tests/unit/via/services/youtube_transcript_test.py
@@ -0,0 +1,207 @@
+import json
+from io import BytesIO
+from unittest.mock import sentinel
+
+import pytest
+from h_matchers import Any
+from requests import Response
+
+from tests.factories import TranscriptInfoFactory
+from via.services.youtube_transcript import (
+ TranscriptInfo,
+ YouTubeTranscriptService,
+ factory,
+)
+
+
+class TestTranscriptInfo:
+ @pytest.mark.parametrize(
+ "transcript_info,expected_id",
+ [
+ (
+ TranscriptInfo(
+ "en-us",
+ "English (United States)",
+ "https://example.com/transcript",
+ autogenerated=False,
+ ),
+ "en-us..RW5nbGlzaCAoVW5pdGVkIFN0YXRlcyk=",
+ ),
+ (
+ TranscriptInfo(
+ "en",
+ "English",
+ "https://example.com/transcript",
+ autogenerated=True,
+ ),
+ "en.a.RW5nbGlzaA==",
+ ),
+ ],
+ )
+ def test_id(self, transcript_info, expected_id):
+ assert transcript_info.id == expected_id
+
+
+class TestYouTubeTranscriptService:
+ def test_get_transcript_infos(self, svc, http_service):
+ # The JSON response body from the YouTube API.
+ response_json = {
+ "captions": {
+ "playerCaptionsTracklistRenderer": {
+ "captionTracks": [
+ {
+ "languageCode": "en",
+ "name": {"simpleText": "English"},
+ "baseUrl": "https://example.com/transcript_1",
+ },
+ {
+ "languageCode": "en-us",
+ "name": {"simpleText": "English (United States)"},
+ "baseUrl": "https://example.com/transcript_2",
+ },
+ ]
+ }
+ }
+ }
+ response = http_service.post.return_value = Response()
+ response.raw = BytesIO(json.dumps(response_json).encode("utf-8"))
+
+ transcript_infos = svc.get_transcript_infos("test_video_id")
+
+ caption_tracks = response_json["captions"]["playerCaptionsTracklistRenderer"][
+ "captionTracks"
+ ]
+ assert transcript_infos == [
+ Any.instance_of(TranscriptInfo).with_attrs(
+ {
+ "language_code": caption_tracks[0]["languageCode"],
+ "autogenerated": False,
+ "name": caption_tracks[0]["name"]["simpleText"],
+ "url": caption_tracks[0]["baseUrl"],
+ }
+ ),
+ Any.instance_of(TranscriptInfo).with_attrs(
+ {
+ "language_code": caption_tracks[1]["languageCode"],
+ "autogenerated": False,
+ "name": caption_tracks[1]["name"]["simpleText"],
+ "url": caption_tracks[1]["baseUrl"],
+ }
+ ),
+ ]
+
+ @pytest.mark.parametrize(
+ "transcript_infos,expected_default_transcript_index",
+ [
+ (
+ [
+ TranscriptInfoFactory(language_code="en", name="English"),
+ TranscriptInfoFactory(
+ language_code="en-us", name="English (United States)"
+ ),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(
+ language_code="en-us", name="English (United States)"
+ ),
+ TranscriptInfoFactory(language_code="en", name="English - DTVCC1"),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(language_code="en", name="English - Foo"),
+ TranscriptInfoFactory(
+ language_code="en-us", name="English (United States) - Foo"
+ ),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(
+ language_code="en-us", name="English (United States) - Foo"
+ ),
+ TranscriptInfoFactory(
+ language_code="en", name="English", autogenerated=True
+ ),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(
+ language_code="en", name="English", autogenerated=True
+ ),
+ TranscriptInfoFactory(
+ language_code="en-us",
+ name="English (United States)",
+ autogenerated=True,
+ ),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(language_code="fr", name="French"),
+ TranscriptInfoFactory(
+ language_code="en", name="English", autogenerated=True
+ ),
+ ],
+ 1,
+ ),
+ (
+ [
+ TranscriptInfoFactory(language_code="fr", name="French"),
+ TranscriptInfoFactory(language_code="de", name="Deutsch"),
+ ],
+ 0,
+ ),
+ ],
+ )
+ def test_pick_default_transcript(
+ self, svc, transcript_infos, expected_default_transcript_index
+ ):
+ assert (
+ svc.pick_default_transcript(transcript_infos)
+ == transcript_infos[expected_default_transcript_index]
+ )
+
+ def test_get_transcript(self, svc, transcript_info, http_service):
+ http_service.get.return_value.text = """
+
+ Hey there guys,
+ Lichen' subscribe
+
+ <font color="#A0AAB4">Buy my merch!</font>
+
+
+ """
+
+ transcript = svc.get_transcript(transcript_info)
+ http_service.get.assert_called_once_with(transcript_info.url)
+
+ assert transcript == [
+ {"duration": 1.387, "start": 0.21, "text": "Hey there guys,"},
+ {"duration": 0.0, "start": 1.597, "text": "Lichen' subscribe"},
+ {"duration": 2.063, "start": 4.327, "text": "Buy my merch!"},
+ ]
+
+ @pytest.fixture
+ def svc(self, http_service):
+ return YouTubeTranscriptService(http_service)
+
+
+class TestFactory:
+ def test_factory(self, YouTubeTranscriptService, http_service, pyramid_request):
+ svc = factory(sentinel.context, pyramid_request)
+
+ YouTubeTranscriptService.assert_called_once_with(http_service=http_service)
+ assert svc == YouTubeTranscriptService.return_value
+
+ @pytest.fixture
+ def YouTubeTranscriptService(self, patch):
+ return patch("via.services.youtube_transcript.YouTubeTranscriptService")
diff --git a/via/services/__init__.py b/via/services/__init__.py
index ab267007..41b12877 100644
--- a/via/services/__init__.py
+++ b/via/services/__init__.py
@@ -12,6 +12,7 @@
from via.services.url_details import URLDetailsService
from via.services.via_client import ViaClientService
from via.services.youtube import YouTubeService
+from via.services.youtube_transcript import YouTubeTranscriptService
def includeme(config): # pragma: no cover
@@ -37,6 +38,9 @@ def includeme(config): # pragma: no cover
config.register_service_factory(
"via.services.youtube.factory", iface=YouTubeService
)
+ config.register_service_factory(
+ "via.services.youtube_transcript.factory", iface=YouTubeTranscriptService
+ )
config.register_service_factory(
"via.services.url_details.factory", iface=URLDetailsService
diff --git a/via/services/youtube.py b/via/services/youtube.py
index a23e26c9..6c730fec 100644
--- a/via/services/youtube.py
+++ b/via/services/youtube.py
@@ -1,10 +1,10 @@
from urllib.parse import parse_qs, quote_plus, urlparse
from sqlalchemy import select
-from youtube_transcript_api import YouTubeTranscriptApi
from via.models import Transcript, Video
from via.services.http import HTTPService
+from via.services.youtube_transcript import YouTubeTranscriptService
class YouTubeDataAPIError(Exception):
@@ -12,13 +12,19 @@ class YouTubeDataAPIError(Exception):
class YouTubeService:
- def __init__(
- self, db_session, enabled: bool, api_key: str, http_service: HTTPService
+ def __init__( # pylint:disable=too-many-arguments
+ self,
+ db_session,
+ enabled: bool,
+ api_key: str,
+ http_service: HTTPService,
+ youtube_transcript_service: YouTubeTranscriptService,
):
self._db = db_session
self._enabled = enabled
self._api_key = api_key
self._http_service = http_service
+ self._transcript_svc = youtube_transcript_service
@property
def enabled(self):
@@ -101,10 +107,18 @@ def get_transcript(self, video_id):
).first():
return transcript.transcript
- transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=("en",))
+ transcript_infos = self._transcript_svc.get_transcript_infos(video_id)
+ transcript_info = self._transcript_svc.pick_default_transcript(transcript_infos)
+ transcript = self._transcript_svc.get_transcript(transcript_info)
+
self._db.add(
- Transcript(video_id=video_id, transcript_id="en", transcript=transcript)
+ Transcript(
+ video_id=video_id,
+ transcript_id=transcript_info.id,
+ transcript=transcript,
+ )
)
+
return transcript
@@ -114,4 +128,5 @@ def factory(_context, request):
enabled=request.registry.settings["youtube_transcripts"],
api_key=request.registry.settings["youtube_api_key"],
http_service=request.find_service(HTTPService),
+ youtube_transcript_service=request.find_service(YouTubeTranscriptService),
)
diff --git a/via/services/youtube_transcript.py b/via/services/youtube_transcript.py
new file mode 100644
index 00000000..bc00c322
--- /dev/null
+++ b/via/services/youtube_transcript.py
@@ -0,0 +1,135 @@
+import re
+from base64 import b64encode
+from dataclasses import dataclass
+from typing import Dict, List
+from xml.etree import ElementTree
+
+from via.services.http import HTTPService
+
+
+@dataclass
+class TranscriptInfo:
+ language_code: str
+ name: str
+ url: str
+ autogenerated: bool
+
+ @property
+ def id(self) -> str: # pylint:disable=invalid-name
+ """Return a unique ID for this transcript."""
+ name = b64encode(self.name.encode("utf-8")).decode("utf-8")
+
+ return ".".join(
+ part or ""
+ for part in [
+ self.language_code,
+ "a" if self.autogenerated else None,
+ name,
+ ]
+ ).rstrip(".")
+
+
+DEFAULT_TRANSCRIPT_PREFERENCES = (
+ {
+ "language_code": re.compile("en"),
+ "name": re.compile("English"),
+ "autogenerated": False,
+ },
+ {
+ "language_code": re.compile("en-.*"),
+ "name": re.compile(r"English \(.*\)"),
+ "autogenerated": False,
+ },
+ {
+ "language_code": re.compile("en"),
+ "name": re.compile(".*"),
+ "autogenerated": False,
+ },
+ {
+ "language_code": re.compile("en-.*"),
+ "name": re.compile(".*"),
+ "autogenerated": False,
+ },
+ {
+ "language_code": re.compile("en"),
+ "name": re.compile(".*"),
+ "autogenerated": True,
+ },
+ {
+ "language_code": re.compile("en-"),
+ "name": re.compile(".*"),
+ "autogenerated": True,
+ },
+)
+
+
+class YouTubeTranscriptService:
+ def __init__(self, http_service: HTTPService):
+ self._http_service = http_service
+
+ def get_transcript_infos(self, video_id: str) -> List[TranscriptInfo]:
+ response = self._http_service.post(
+ "https://youtubei.googleapis.com/youtubei/v1/player",
+ json={
+ "context": {
+ "client": {
+ "hl": "en",
+ "clientName": "WEB",
+ "clientVersion": "2.20210721.00.00",
+ }
+ },
+ "videoId": video_id,
+ },
+ )
+ json = response.json()
+ dicts = json["captions"]["playerCaptionsTracklistRenderer"]["captionTracks"]
+
+ return [
+ TranscriptInfo(
+ language_code=caption_track_dict["languageCode"].lower(),
+ autogenerated=caption_track_dict.get("kind", None) == "asr",
+ name=caption_track_dict["name"]["simpleText"],
+ url=caption_track_dict["baseUrl"],
+ )
+ for caption_track_dict in dicts
+ ]
+
+ def pick_default_transcript(
+ self, transcript_infos: List[TranscriptInfo]
+ ) -> TranscriptInfo:
+ def matches(transcript_info: TranscriptInfo, preference: dict):
+ for key in ("language_code", "name"):
+ if not preference[key].fullmatch(getattr(transcript_info, key)):
+ return False
+
+ return preference["autogenerated"] == transcript_info.autogenerated
+
+ for preference in DEFAULT_TRANSCRIPT_PREFERENCES:
+ for transcript_info in transcript_infos:
+ if matches(transcript_info, preference):
+ return transcript_info
+
+ return transcript_infos[0]
+
+ def get_transcript(self, transcript_info: TranscriptInfo) -> List[Dict]:
+ response = self._http_service.get(transcript_info.url)
+ xml_elements = ElementTree.fromstring(response.text)
+
+ def strip_html(xml_string):
+ return "".join(
+ ElementTree.fromstring(f"{xml_string}").itertext()
+ ).strip()
+
+ return [
+ {
+ "text": strip_html(xml_element.text),
+ "start": float(xml_element.attrib["start"]),
+ "duration": float(xml_element.attrib.get("dur", "0.0")),
+ }
+ for xml_element in xml_elements
+ if xml_element.text is not None
+ ]
+
+
+def factory(_context, request):
+ return YouTubeTranscriptService(http_service=request.find_service(HTTPService))