diff --git a/tests/conftest.py b/tests/conftest.py
index a1f5cd4a..e1d6372a 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,7 +7,7 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
-from tests.factories import TranscriptFactory, VideoFactory
+from tests.factories import TranscriptFactory, TranscriptInfoFactory, VideoFactory
from tests.factories.factoryboy_sqlalchemy_session import (
clear_factoryboy_sqlalchemy_session,
set_factoryboy_sqlalchemy_session,
@@ -16,6 +16,7 @@
# Each factory has to be registered with pytest_factoryboy.
register(TranscriptFactory)
+register(TranscriptInfoFactory)
register(VideoFactory)
diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py
index 0d81dcd0..2ff9cfd2 100644
--- a/tests/factories/__init__.py
+++ b/tests/factories/__init__.py
@@ -1,2 +1,3 @@
from tests.factories.transcript import TranscriptFactory
+from tests.factories.transcript_info import TranscriptInfoFactory
from tests.factories.video import VideoFactory
diff --git a/tests/factories/transcript_info.py b/tests/factories/transcript_info.py
new file mode 100644
index 00000000..eaf0cd90
--- /dev/null
+++ b/tests/factories/transcript_info.py
@@ -0,0 +1,13 @@
+from factory import Factory, Sequence
+
+from via.services.youtube_transcript import TranscriptInfo
+
+
+class TranscriptInfoFactory(Factory):
+ class Meta:
+ model = TranscriptInfo
+
+ language_code = "en-us"
+ name = "English (United States)"
+ url = Sequence(lambda n: f"https://example.com/api/timedtext?v={n}")
+ autogenerated = False
diff --git a/tests/unit/services.py b/tests/unit/services.py
index 00211d41..26bffe11 100644
--- a/tests/unit/services.py
+++ b/tests/unit/services.py
@@ -11,6 +11,7 @@
URLDetailsService,
ViaClientService,
YouTubeService,
+ YouTubeTranscriptService,
)
@@ -69,3 +70,8 @@ def youtube_service(mock_service):
youtube_service.get_video_id.return_value = None
return youtube_service
+
+
+@pytest.fixture
+def youtube_transcript_service(mock_service):
+ return mock_service(YouTubeTranscriptService)
diff --git a/tests/unit/via/services/youtube_test.py b/tests/unit/via/services/youtube_test.py
index 6d409aea..8f747dae 100644
--- a/tests/unit/via/services/youtube_test.py
+++ b/tests/unit/via/services/youtube_test.py
@@ -28,6 +28,7 @@ def test_enabled(self, db_session, enabled, api_key, expected):
enabled=enabled,
api_key=api_key,
http_service=sentinel.http_service,
+ youtube_transcript_service=sentinel.youtube_transcript_service,
).enabled
== expected
)
@@ -94,38 +95,45 @@ def test_get_video_title_raises_YouTubeDataAPIError(self, svc, http_service):
assert exc_info.value.__cause__ == http_service.get.side_effect
- def test_get_transcript(self, db_session, svc, YouTubeTranscriptApi):
- YouTubeTranscriptApi.get_transcript.return_value = [
- {"text": "foo", "start": 0.0, "duration": 1.0},
- {"text": "bar", "start": 1.0, "duration": 2.0},
- ]
+ def test_get_transcript(
+ self, db_session, svc, youtube_transcript_service, transcript_info
+ ):
+ youtube_transcript_service.pick_default_transcript.return_value = (
+ transcript_info
+ )
+ youtube_transcript_service.get_transcript.return_value = "test_transcript"
returned_transcript = svc.get_transcript("test_video_id")
- YouTubeTranscriptApi.get_transcript.assert_called_once_with(
- "test_video_id", languages=("en",)
+ youtube_transcript_service.get_transcript_infos.assert_called_once_with(
+ "test_video_id"
+ )
+ youtube_transcript_service.pick_default_transcript.assert_called_once_with(
+ youtube_transcript_service.get_transcript_infos.return_value
)
- assert returned_transcript == YouTubeTranscriptApi.get_transcript.return_value
+ youtube_transcript_service.get_transcript.assert_called_once_with(
+ transcript_info
+ )
+ assert returned_transcript == "test_transcript"
# It should have cached the transcript in the DB.
assert db_session.scalars(select(Transcript)).all() == [
Any.instance_of(Transcript).with_attrs(
{
"video_id": "test_video_id",
- "transcript": YouTubeTranscriptApi.get_transcript.return_value,
+ "transcript_id": transcript_info.id,
+ "transcript": "test_transcript",
}
)
]
- @pytest.mark.usefixtures("db_session")
def test_get_transcript_returns_cached_transcripts(
- self, transcript, svc, YouTubeTranscriptApi
+ self, svc, transcript, youtube_transcript_service
):
returned_transcript = svc.get_transcript(transcript.video_id)
- YouTubeTranscriptApi.get_transcript.assert_not_called()
+ youtube_transcript_service.get_transcript.assert_not_called()
assert returned_transcript == transcript.transcript
- @pytest.mark.usefixtures("db_session")
def test_get_transcript_returns_oldest_cached_transcript(
self, transcript_factory, svc
):
@@ -155,18 +163,25 @@ def test_canonical_video_url(self, video_id, expected_url, svc):
assert expected_url == svc.canonical_video_url(video_id)
@pytest.fixture
- def svc(self, db_session, http_service):
+ def svc(self, db_session, http_service, youtube_transcript_service):
return YouTubeService(
db_session=db_session,
enabled=True,
api_key=sentinel.api_key,
http_service=http_service,
+ youtube_transcript_service=youtube_transcript_service,
)
class TestFactory:
def test_it(
- self, YouTubeService, youtube_service, pyramid_request, http_service, db_session
+ self,
+ YouTubeService,
+ youtube_service,
+ pyramid_request,
+ http_service,
+ db_session,
+ youtube_transcript_service,
):
returned = factory(sentinel.context, pyramid_request)
@@ -175,6 +190,7 @@ def test_it(
enabled=pyramid_request.registry.settings["youtube_transcripts"],
api_key="test_youtube_api_key",
http_service=http_service,
+ youtube_transcript_service=youtube_transcript_service,
)
assert returned == youtube_service
@@ -185,8 +201,3 @@ def YouTubeService(self, patch):
@pytest.fixture
def youtube_service(self, YouTubeService):
return YouTubeService.return_value
-
-
-@pytest.fixture(autouse=True)
-def YouTubeTranscriptApi(patch):
- return patch("via.services.youtube.YouTubeTranscriptApi")
diff --git a/tests/unit/via/services/youtube_transcript_test.py b/tests/unit/via/services/youtube_transcript_test.py
new file mode 100644
index 00000000..4b39563e
--- /dev/null
+++ b/tests/unit/via/services/youtube_transcript_test.py
@@ -0,0 +1,300 @@
+import json
+from io import BytesIO
+from json import JSONDecodeError
+from unittest.mock import sentinel
+from xml.etree import ElementTree
+
+import pytest
+from h_matchers import Any
+from requests import Response
+
+from tests.factories import TranscriptInfoFactory
+from via.exceptions import UnhandledUpstreamException
+from via.services.youtube_transcript import (
+ TranscriptInfo,
+ YouTubeTranscriptService,
+ factory,
+)
+
+
+class TestTranscriptInfo:
+ @pytest.mark.parametrize(
+ "transcript_info,expected_id",
+ [
+ (
+ TranscriptInfoFactory(),
+ "en-us..RW5nbGlzaCAoVW5pdGVkIFN0YXRlcyk=",
+ ),
+ (
+ TranscriptInfoFactory(autogenerated=True),
+ "en-us.a.RW5nbGlzaCAoVW5pdGVkIFN0YXRlcyk=",
+ ),
+ ],
+ )
+ def test_id(self, transcript_info, expected_id):
+ assert transcript_info.id == expected_id
+
+
+class TestYouTubeTranscriptService:
+ def test_get_transcript_infos(self, svc, http_service):
+ # The JSON response body from the YouTube API.
+ response_json = {
+ "captions": {
+ "playerCaptionsTracklistRenderer": {
+ "captionTracks": [
+ {
+ "languageCode": "en",
+ "name": {"simpleText": "English"},
+ "baseUrl": "https://example.com/transcript_1",
+ },
+ {
+ "languageCode": "en-us",
+ "name": {"simpleText": "English (United States)"},
+ "baseUrl": "https://example.com/transcript_2",
+ },
+ ]
+ }
+ }
+ }
+ response = http_service.post.return_value = Response()
+ response.raw = BytesIO(json.dumps(response_json).encode("utf-8"))
+
+ transcript_infos = svc.get_transcript_infos("test_video_id")
+
+ http_service.post.assert_called_once_with(
+ "https://youtubei.googleapis.com/youtubei/v1/player",
+ json={
+ "context": {
+ "client": {
+ "hl": "en",
+ "clientName": "WEB",
+ "clientVersion": "2.20210721.00.00",
+ }
+ },
+ "videoId": "test_video_id",
+ },
+ )
+ caption_tracks = response_json["captions"]["playerCaptionsTracklistRenderer"][
+ "captionTracks"
+ ]
+ assert transcript_infos == [
+ Any.instance_of(TranscriptInfo).with_attrs(
+ {
+ "language_code": caption_tracks[0]["languageCode"],
+ "autogenerated": False,
+ "name": caption_tracks[0]["name"]["simpleText"],
+ "url": caption_tracks[0]["baseUrl"],
+ }
+ ),
+ Any.instance_of(TranscriptInfo).with_attrs(
+ {
+ "language_code": caption_tracks[1]["languageCode"],
+ "autogenerated": False,
+ "name": caption_tracks[1]["name"]["simpleText"],
+ "url": caption_tracks[1]["baseUrl"],
+ }
+ ),
+ ]
+
+ def test_get_transcript_infos_error_response(self, svc, http_service):
+ # We get an error response from the YouTube API.
+ http_service.post.side_effect = UnhandledUpstreamException(
+ "Something went wrong"
+ )
+
+ with pytest.raises(UnhandledUpstreamException):
+ svc.get_transcript_infos("test_video_id")
+
+ @pytest.mark.parametrize(
+ "response_body,exception_class",
+ [
+ (b"foo", JSONDecodeError), # Not valid JSON.
+ (b"[]", TypeError), # Not a dict.
+ (b"{}", KeyError), # No "captions" key.
+ (b'{"captions": 23}', TypeError), # "captions" isn't a dict.
+ (b'{"captions": {}}', KeyError), # No "playerCaptionsTracklistRenderer".
+ # "playerCaptionsTracklistRenderer" isn't a dict.
+ (
+ b'{"captions": {"playerCaptionsTracklistRenderer": 23}}',
+ TypeError,
+ ),
+ # No "captionTracks".
+ (
+ b'{"captions": {"playerCaptionsTracklistRenderer": {}}}',
+ KeyError,
+ ),
+ # "captionTracks" isn't a list.
+ (
+ b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": 23}}}',
+ TypeError,
+ ),
+ # No "languageCode".
+ (
+ b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"name": {"simpleText": "English"}, "baseUrl": "https://example.com"}]}}}',
+ KeyError,
+ ),
+ # No "name".
+ (
+ b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"languageCode": "en", "baseUrl": "https://example.com"}]}}}',
+ KeyError,
+ ),
+ # "name" isn't a dict.
+ (
+ b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"languageCode": "en", "name": 23, "baseUrl": "https://example.com"}]}}}',
+ TypeError,
+ ),
+ # No "simpleText".
+ (
+ b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"languageCode": "en", "name": {}, "baseUrl": "https://example.com"}]}}}',
+ KeyError,
+ ),
+ # No "baseUrl".
+ (
+ b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"languageCode": "en", "name": {"simpleText": "English"}}]}}}',
+ KeyError,
+ ),
+ ],
+ )
+ def test_get_transcript_infos_unexpected_response(
+ self, svc, http_service, response_body, exception_class
+ ):
+ """It crashes if the response body isn't what we expect."""
+ response = http_service.post.return_value = Response()
+ response.encoding = "utf-8"
+ response.raw = BytesIO(response_body)
+
+ with pytest.raises(exception_class):
+ svc.get_transcript_infos("test_video_id")
+
+ @pytest.mark.parametrize(
+ "transcript_infos,expected_default_transcript_index",
+ [
+ (
+ [
+ TranscriptInfoFactory(language_code="en", name="English"),
+ TranscriptInfoFactory(
+ language_code="en-us", name="English (United States)"
+ ),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(
+ language_code="en-us", name="English (United States)"
+ ),
+ TranscriptInfoFactory(language_code="en", name="English - DTVCC1"),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(language_code="en", name="English - Foo"),
+ TranscriptInfoFactory(
+ language_code="en-us", name="English (United States) - Foo"
+ ),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(
+ language_code="en-us", name="English (United States) - Foo"
+ ),
+ TranscriptInfoFactory(
+ language_code="en", name="English", autogenerated=True
+ ),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(
+ language_code="en", name="English", autogenerated=True
+ ),
+ TranscriptInfoFactory(
+ language_code="en-us",
+ name="English (United States)",
+ autogenerated=True,
+ ),
+ ],
+ 0,
+ ),
+ (
+ [
+ TranscriptInfoFactory(language_code="fr", name="French"),
+ TranscriptInfoFactory(
+ language_code="en", name="English", autogenerated=True
+ ),
+ ],
+ 1,
+ ),
+ (
+ [
+ TranscriptInfoFactory(language_code="fr", name="French"),
+ TranscriptInfoFactory(language_code="de", name="Deutsch"),
+ ],
+ 0,
+ ),
+ ],
+ )
+ def test_pick_default_transcript(
+ self, svc, transcript_infos, expected_default_transcript_index
+ ):
+ assert (
+ svc.pick_default_transcript(transcript_infos)
+ == transcript_infos[expected_default_transcript_index]
+ )
+
+ def test_get_transcript(self, svc, transcript_info, http_service):
+ http_service.get.return_value.text = """
+
+ Hey there guys,
+ Lichen' subscribe
+
+ <font color="#A0AAB4">Buy my merch!</font>
+
+
+ """
+
+ transcript = svc.get_transcript(transcript_info)
+ http_service.get.assert_called_once_with(transcript_info.url)
+
+ assert transcript == [
+ {"duration": 1.387, "start": 0.21, "text": "Hey there guys,"},
+ {"duration": 0.0, "start": 1.597, "text": "Lichen' subscribe"},
+ {"duration": 2.063, "start": 4.327, "text": "Buy my merch!"},
+ ]
+
+ def test_get_transcript_error_response(self, svc, transcript_info, http_service):
+ # We get an error response from the YouTube API.
+ http_service.get.side_effect = UnhandledUpstreamException(
+ "Something went wrong"
+ )
+
+ with pytest.raises(UnhandledUpstreamException):
+ svc.get_transcript(transcript_info)
+
+ def test_get_transcript_unexpected_response(
+ self, svc, transcript_info, http_service
+ ):
+ http_service.get.return_value.text = "foo"
+
+ with pytest.raises(ElementTree.ParseError):
+ svc.get_transcript(transcript_info)
+
+ @pytest.fixture
+ def svc(self, http_service):
+ return YouTubeTranscriptService(http_service)
+
+
+class TestFactory:
+ def test_factory(self, YouTubeTranscriptService, http_service, pyramid_request):
+ svc = factory(sentinel.context, pyramid_request)
+
+ YouTubeTranscriptService.assert_called_once_with(http_service=http_service)
+ assert svc == YouTubeTranscriptService.return_value
+
+ @pytest.fixture
+ def YouTubeTranscriptService(self, patch):
+ return patch("via.services.youtube_transcript.YouTubeTranscriptService")
diff --git a/via/services/__init__.py b/via/services/__init__.py
index ab267007..41b12877 100644
--- a/via/services/__init__.py
+++ b/via/services/__init__.py
@@ -12,6 +12,7 @@
from via.services.url_details import URLDetailsService
from via.services.via_client import ViaClientService
from via.services.youtube import YouTubeService
+from via.services.youtube_transcript import YouTubeTranscriptService
def includeme(config): # pragma: no cover
@@ -37,6 +38,9 @@ def includeme(config): # pragma: no cover
config.register_service_factory(
"via.services.youtube.factory", iface=YouTubeService
)
+ config.register_service_factory(
+ "via.services.youtube_transcript.factory", iface=YouTubeTranscriptService
+ )
config.register_service_factory(
"via.services.url_details.factory", iface=URLDetailsService
diff --git a/via/services/youtube.py b/via/services/youtube.py
index a23e26c9..6c730fec 100644
--- a/via/services/youtube.py
+++ b/via/services/youtube.py
@@ -1,10 +1,10 @@
from urllib.parse import parse_qs, quote_plus, urlparse
from sqlalchemy import select
-from youtube_transcript_api import YouTubeTranscriptApi
from via.models import Transcript, Video
from via.services.http import HTTPService
+from via.services.youtube_transcript import YouTubeTranscriptService
class YouTubeDataAPIError(Exception):
@@ -12,13 +12,19 @@ class YouTubeDataAPIError(Exception):
class YouTubeService:
- def __init__(
- self, db_session, enabled: bool, api_key: str, http_service: HTTPService
+ def __init__( # pylint:disable=too-many-arguments
+ self,
+ db_session,
+ enabled: bool,
+ api_key: str,
+ http_service: HTTPService,
+ youtube_transcript_service: YouTubeTranscriptService,
):
self._db = db_session
self._enabled = enabled
self._api_key = api_key
self._http_service = http_service
+ self._transcript_svc = youtube_transcript_service
@property
def enabled(self):
@@ -101,10 +107,18 @@ def get_transcript(self, video_id):
).first():
return transcript.transcript
- transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=("en",))
+ transcript_infos = self._transcript_svc.get_transcript_infos(video_id)
+ transcript_info = self._transcript_svc.pick_default_transcript(transcript_infos)
+ transcript = self._transcript_svc.get_transcript(transcript_info)
+
self._db.add(
- Transcript(video_id=video_id, transcript_id="en", transcript=transcript)
+ Transcript(
+ video_id=video_id,
+ transcript_id=transcript_info.id,
+ transcript=transcript,
+ )
)
+
return transcript
@@ -114,4 +128,5 @@ def factory(_context, request):
enabled=request.registry.settings["youtube_transcripts"],
api_key=request.registry.settings["youtube_api_key"],
http_service=request.find_service(HTTPService),
+ youtube_transcript_service=request.find_service(YouTubeTranscriptService),
)
diff --git a/via/services/youtube_transcript.py b/via/services/youtube_transcript.py
new file mode 100644
index 00000000..9a69ccf5
--- /dev/null
+++ b/via/services/youtube_transcript.py
@@ -0,0 +1,148 @@
+import re
+from base64 import b64encode
+from dataclasses import dataclass
+from typing import Dict, List
+from xml.etree import ElementTree
+
+from via.services.http import HTTPService
+
+
+@dataclass
+class TranscriptInfo:
+ """Information about a YouTube video transcript."""
+
+ language_code: str
+ """The transcript's language code, e.g. "en-us"."""
+
+ name: str
+ """The transcript's name, e.g. "English (United States)"."""
+
+ url: str
+ """The transcript's download URL."""
+
+ autogenerated: bool
+ """Whether or not the transcript was autogenerated by speech recognition."""
+
+ @property
+ def id(self) -> str: # pylint:disable=invalid-name
+ """Return a unique ID for this transcript."""
+ return ".".join(
+ [
+ self.language_code,
+ "a" if self.autogenerated else "",
+ b64encode(self.name.encode("utf-8")).decode("utf-8"),
+ ]
+ ).rstrip(".")
+
+
+#: The preferences that we apply when choosing a default transcript for a video.
+DEFAULT_TRANSCRIPT_PREFERENCES = (
+ {
+ "language_code": re.compile("en"),
+ "name": re.compile("English"),
+ "autogenerated": False,
+ },
+ {
+ "language_code": re.compile("en-.*"),
+ "name": re.compile(r"English \(.*\)"),
+ "autogenerated": False,
+ },
+ {
+ "language_code": re.compile("en"),
+ "name": re.compile(".*"),
+ "autogenerated": False,
+ },
+ {
+ "language_code": re.compile("en-.*"),
+ "name": re.compile(".*"),
+ "autogenerated": False,
+ },
+ {
+ "language_code": re.compile("en"),
+ "name": re.compile(".*"),
+ "autogenerated": True,
+ },
+ {
+ "language_code": re.compile("en-"),
+ "name": re.compile(".*"),
+ "autogenerated": True,
+ },
+)
+
+
+class YouTubeTranscriptService:
+ """A service for getting text transcripts of YouTube videos."""
+
+ def __init__(self, http_service: HTTPService):
+ self._http_service = http_service
+
+ def get_transcript_infos(self, video_id: str) -> List[TranscriptInfo]:
+ """Return the list of available transcripts for `video_id`."""
+ response = self._http_service.post(
+ "https://youtubei.googleapis.com/youtubei/v1/player",
+ json={
+ "context": {
+ "client": {
+ "hl": "en",
+ "clientName": "WEB",
+ "clientVersion": "2.20210721.00.00",
+ }
+ },
+ "videoId": video_id,
+ },
+ )
+ json = response.json()
+ dicts = json["captions"]["playerCaptionsTracklistRenderer"]["captionTracks"]
+
+ return [
+ TranscriptInfo(
+ language_code=caption_track_dict["languageCode"].lower(),
+ autogenerated=caption_track_dict.get("kind", None) == "asr",
+ name=caption_track_dict["name"]["simpleText"],
+ url=caption_track_dict["baseUrl"],
+ )
+ for caption_track_dict in dicts
+ ]
+
+ def pick_default_transcript(
+ self, transcript_infos: List[TranscriptInfo]
+ ) -> TranscriptInfo:
+ """Return a choice of default transcript from `transcript_infos`."""
+
+ def matches(transcript_info: TranscriptInfo, preference: dict):
+ for key in ("language_code", "name"):
+ if not preference[key].fullmatch(getattr(transcript_info, key)):
+ return False
+
+ return preference["autogenerated"] == transcript_info.autogenerated
+
+ for preference in DEFAULT_TRANSCRIPT_PREFERENCES:
+ for transcript_info in transcript_infos:
+ if matches(transcript_info, preference):
+ return transcript_info
+
+ return transcript_infos[0]
+
+ def get_transcript(self, transcript_info: TranscriptInfo) -> List[Dict]:
+ """Download and return the actual transcript text for `transcript_info`."""
+ response = self._http_service.get(transcript_info.url)
+ xml_elements = ElementTree.fromstring(response.text)
+
+ def strip_html(xml_string):
+ return "".join(
+ ElementTree.fromstring(f"{xml_string}").itertext()
+ ).strip()
+
+ return [
+ {
+ "text": strip_html(xml_element.text),
+ "start": float(xml_element.attrib["start"]),
+ "duration": float(xml_element.attrib.get("dur", "0.0")),
+ }
+ for xml_element in xml_elements
+ if xml_element.text is not None
+ ]
+
+
+def factory(_context, request):
+ return YouTubeTranscriptService(http_service=request.find_service(HTTPService))