diff --git a/tests/conftest.py b/tests/conftest.py index a1f5cd4a..e1d6372a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from tests.factories import TranscriptFactory, VideoFactory +from tests.factories import TranscriptFactory, TranscriptInfoFactory, VideoFactory from tests.factories.factoryboy_sqlalchemy_session import ( clear_factoryboy_sqlalchemy_session, set_factoryboy_sqlalchemy_session, @@ -16,6 +16,7 @@ # Each factory has to be registered with pytest_factoryboy. register(TranscriptFactory) +register(TranscriptInfoFactory) register(VideoFactory) diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py index 0d81dcd0..2ff9cfd2 100644 --- a/tests/factories/__init__.py +++ b/tests/factories/__init__.py @@ -1,2 +1,3 @@ from tests.factories.transcript import TranscriptFactory +from tests.factories.transcript_info import TranscriptInfoFactory from tests.factories.video import VideoFactory diff --git a/tests/factories/transcript_info.py b/tests/factories/transcript_info.py new file mode 100644 index 00000000..eaf0cd90 --- /dev/null +++ b/tests/factories/transcript_info.py @@ -0,0 +1,13 @@ +from factory import Factory, Sequence + +from via.services.youtube_transcript import TranscriptInfo + + +class TranscriptInfoFactory(Factory): + class Meta: + model = TranscriptInfo + + language_code = "en-us" + name = "English (United States)" + url = Sequence(lambda n: f"https://example.com/api/timedtext?v={n}") + autogenerated = False diff --git a/tests/unit/services.py b/tests/unit/services.py index 00211d41..26bffe11 100644 --- a/tests/unit/services.py +++ b/tests/unit/services.py @@ -11,6 +11,7 @@ URLDetailsService, ViaClientService, YouTubeService, + YouTubeTranscriptService, ) @@ -69,3 +70,8 @@ def youtube_service(mock_service): youtube_service.get_video_id.return_value = None return youtube_service + + +@pytest.fixture +def youtube_transcript_service(mock_service): + return mock_service(YouTubeTranscriptService) diff --git a/tests/unit/via/services/youtube_test.py b/tests/unit/via/services/youtube_test.py index 6d409aea..8f747dae 100644 --- a/tests/unit/via/services/youtube_test.py +++ b/tests/unit/via/services/youtube_test.py @@ -28,6 +28,7 @@ def test_enabled(self, db_session, enabled, api_key, expected): enabled=enabled, api_key=api_key, http_service=sentinel.http_service, + youtube_transcript_service=sentinel.youtube_transcript_service, ).enabled == expected ) @@ -94,38 +95,45 @@ def test_get_video_title_raises_YouTubeDataAPIError(self, svc, http_service): assert exc_info.value.__cause__ == http_service.get.side_effect - def test_get_transcript(self, db_session, svc, YouTubeTranscriptApi): - YouTubeTranscriptApi.get_transcript.return_value = [ - {"text": "foo", "start": 0.0, "duration": 1.0}, - {"text": "bar", "start": 1.0, "duration": 2.0}, - ] + def test_get_transcript( + self, db_session, svc, youtube_transcript_service, transcript_info + ): + youtube_transcript_service.pick_default_transcript.return_value = ( + transcript_info + ) + youtube_transcript_service.get_transcript.return_value = "test_transcript" returned_transcript = svc.get_transcript("test_video_id") - YouTubeTranscriptApi.get_transcript.assert_called_once_with( - "test_video_id", languages=("en",) + youtube_transcript_service.get_transcript_infos.assert_called_once_with( + "test_video_id" + ) + youtube_transcript_service.pick_default_transcript.assert_called_once_with( + youtube_transcript_service.get_transcript_infos.return_value ) - assert returned_transcript == YouTubeTranscriptApi.get_transcript.return_value + youtube_transcript_service.get_transcript.assert_called_once_with( + transcript_info + ) + assert returned_transcript == "test_transcript" # It should have cached the transcript in the DB. assert db_session.scalars(select(Transcript)).all() == [ Any.instance_of(Transcript).with_attrs( { "video_id": "test_video_id", - "transcript": YouTubeTranscriptApi.get_transcript.return_value, + "transcript_id": transcript_info.id, + "transcript": "test_transcript", } ) ] - @pytest.mark.usefixtures("db_session") def test_get_transcript_returns_cached_transcripts( - self, transcript, svc, YouTubeTranscriptApi + self, svc, transcript, youtube_transcript_service ): returned_transcript = svc.get_transcript(transcript.video_id) - YouTubeTranscriptApi.get_transcript.assert_not_called() + youtube_transcript_service.get_transcript.assert_not_called() assert returned_transcript == transcript.transcript - @pytest.mark.usefixtures("db_session") def test_get_transcript_returns_oldest_cached_transcript( self, transcript_factory, svc ): @@ -155,18 +163,25 @@ def test_canonical_video_url(self, video_id, expected_url, svc): assert expected_url == svc.canonical_video_url(video_id) @pytest.fixture - def svc(self, db_session, http_service): + def svc(self, db_session, http_service, youtube_transcript_service): return YouTubeService( db_session=db_session, enabled=True, api_key=sentinel.api_key, http_service=http_service, + youtube_transcript_service=youtube_transcript_service, ) class TestFactory: def test_it( - self, YouTubeService, youtube_service, pyramid_request, http_service, db_session + self, + YouTubeService, + youtube_service, + pyramid_request, + http_service, + db_session, + youtube_transcript_service, ): returned = factory(sentinel.context, pyramid_request) @@ -175,6 +190,7 @@ def test_it( enabled=pyramid_request.registry.settings["youtube_transcripts"], api_key="test_youtube_api_key", http_service=http_service, + youtube_transcript_service=youtube_transcript_service, ) assert returned == youtube_service @@ -185,8 +201,3 @@ def YouTubeService(self, patch): @pytest.fixture def youtube_service(self, YouTubeService): return YouTubeService.return_value - - -@pytest.fixture(autouse=True) -def YouTubeTranscriptApi(patch): - return patch("via.services.youtube.YouTubeTranscriptApi") diff --git a/tests/unit/via/services/youtube_transcript_test.py b/tests/unit/via/services/youtube_transcript_test.py new file mode 100644 index 00000000..4b39563e --- /dev/null +++ b/tests/unit/via/services/youtube_transcript_test.py @@ -0,0 +1,300 @@ +import json +from io import BytesIO +from json import JSONDecodeError +from unittest.mock import sentinel +from xml.etree import ElementTree + +import pytest +from h_matchers import Any +from requests import Response + +from tests.factories import TranscriptInfoFactory +from via.exceptions import UnhandledUpstreamException +from via.services.youtube_transcript import ( + TranscriptInfo, + YouTubeTranscriptService, + factory, +) + + +class TestTranscriptInfo: + @pytest.mark.parametrize( + "transcript_info,expected_id", + [ + ( + TranscriptInfoFactory(), + "en-us..RW5nbGlzaCAoVW5pdGVkIFN0YXRlcyk=", + ), + ( + TranscriptInfoFactory(autogenerated=True), + "en-us.a.RW5nbGlzaCAoVW5pdGVkIFN0YXRlcyk=", + ), + ], + ) + def test_id(self, transcript_info, expected_id): + assert transcript_info.id == expected_id + + +class TestYouTubeTranscriptService: + def test_get_transcript_infos(self, svc, http_service): + # The JSON response body from the YouTube API. + response_json = { + "captions": { + "playerCaptionsTracklistRenderer": { + "captionTracks": [ + { + "languageCode": "en", + "name": {"simpleText": "English"}, + "baseUrl": "https://example.com/transcript_1", + }, + { + "languageCode": "en-us", + "name": {"simpleText": "English (United States)"}, + "baseUrl": "https://example.com/transcript_2", + }, + ] + } + } + } + response = http_service.post.return_value = Response() + response.raw = BytesIO(json.dumps(response_json).encode("utf-8")) + + transcript_infos = svc.get_transcript_infos("test_video_id") + + http_service.post.assert_called_once_with( + "https://youtubei.googleapis.com/youtubei/v1/player", + json={ + "context": { + "client": { + "hl": "en", + "clientName": "WEB", + "clientVersion": "2.20210721.00.00", + } + }, + "videoId": "test_video_id", + }, + ) + caption_tracks = response_json["captions"]["playerCaptionsTracklistRenderer"][ + "captionTracks" + ] + assert transcript_infos == [ + Any.instance_of(TranscriptInfo).with_attrs( + { + "language_code": caption_tracks[0]["languageCode"], + "autogenerated": False, + "name": caption_tracks[0]["name"]["simpleText"], + "url": caption_tracks[0]["baseUrl"], + } + ), + Any.instance_of(TranscriptInfo).with_attrs( + { + "language_code": caption_tracks[1]["languageCode"], + "autogenerated": False, + "name": caption_tracks[1]["name"]["simpleText"], + "url": caption_tracks[1]["baseUrl"], + } + ), + ] + + def test_get_transcript_infos_error_response(self, svc, http_service): + # We get an error response from the YouTube API. + http_service.post.side_effect = UnhandledUpstreamException( + "Something went wrong" + ) + + with pytest.raises(UnhandledUpstreamException): + svc.get_transcript_infos("test_video_id") + + @pytest.mark.parametrize( + "response_body,exception_class", + [ + (b"foo", JSONDecodeError), # Not valid JSON. + (b"[]", TypeError), # Not a dict. + (b"{}", KeyError), # No "captions" key. + (b'{"captions": 23}', TypeError), # "captions" isn't a dict. + (b'{"captions": {}}', KeyError), # No "playerCaptionsTracklistRenderer". + # "playerCaptionsTracklistRenderer" isn't a dict. + ( + b'{"captions": {"playerCaptionsTracklistRenderer": 23}}', + TypeError, + ), + # No "captionTracks". + ( + b'{"captions": {"playerCaptionsTracklistRenderer": {}}}', + KeyError, + ), + # "captionTracks" isn't a list. + ( + b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": 23}}}', + TypeError, + ), + # No "languageCode". + ( + b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"name": {"simpleText": "English"}, "baseUrl": "https://example.com"}]}}}', + KeyError, + ), + # No "name". + ( + b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"languageCode": "en", "baseUrl": "https://example.com"}]}}}', + KeyError, + ), + # "name" isn't a dict. + ( + b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"languageCode": "en", "name": 23, "baseUrl": "https://example.com"}]}}}', + TypeError, + ), + # No "simpleText". + ( + b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"languageCode": "en", "name": {}, "baseUrl": "https://example.com"}]}}}', + KeyError, + ), + # No "baseUrl". + ( + b'{"captions": {"playerCaptionsTracklistRenderer": {"captionTracks": [{"languageCode": "en", "name": {"simpleText": "English"}}]}}}', + KeyError, + ), + ], + ) + def test_get_transcript_infos_unexpected_response( + self, svc, http_service, response_body, exception_class + ): + """It crashes if the response body isn't what we expect.""" + response = http_service.post.return_value = Response() + response.encoding = "utf-8" + response.raw = BytesIO(response_body) + + with pytest.raises(exception_class): + svc.get_transcript_infos("test_video_id") + + @pytest.mark.parametrize( + "transcript_infos,expected_default_transcript_index", + [ + ( + [ + TranscriptInfoFactory(language_code="en", name="English"), + TranscriptInfoFactory( + language_code="en-us", name="English (United States)" + ), + ], + 0, + ), + ( + [ + TranscriptInfoFactory( + language_code="en-us", name="English (United States)" + ), + TranscriptInfoFactory(language_code="en", name="English - DTVCC1"), + ], + 0, + ), + ( + [ + TranscriptInfoFactory(language_code="en", name="English - Foo"), + TranscriptInfoFactory( + language_code="en-us", name="English (United States) - Foo" + ), + ], + 0, + ), + ( + [ + TranscriptInfoFactory( + language_code="en-us", name="English (United States) - Foo" + ), + TranscriptInfoFactory( + language_code="en", name="English", autogenerated=True + ), + ], + 0, + ), + ( + [ + TranscriptInfoFactory( + language_code="en", name="English", autogenerated=True + ), + TranscriptInfoFactory( + language_code="en-us", + name="English (United States)", + autogenerated=True, + ), + ], + 0, + ), + ( + [ + TranscriptInfoFactory(language_code="fr", name="French"), + TranscriptInfoFactory( + language_code="en", name="English", autogenerated=True + ), + ], + 1, + ), + ( + [ + TranscriptInfoFactory(language_code="fr", name="French"), + TranscriptInfoFactory(language_code="de", name="Deutsch"), + ], + 0, + ), + ], + ) + def test_pick_default_transcript( + self, svc, transcript_infos, expected_default_transcript_index + ): + assert ( + svc.pick_default_transcript(transcript_infos) + == transcript_infos[expected_default_transcript_index] + ) + + def test_get_transcript(self, svc, transcript_info, http_service): + http_service.get.return_value.text = """ + + Hey there guys, + Lichen' subscribe + + <font color="#A0AAB4">Buy my merch!</font> + + + """ + + transcript = svc.get_transcript(transcript_info) + http_service.get.assert_called_once_with(transcript_info.url) + + assert transcript == [ + {"duration": 1.387, "start": 0.21, "text": "Hey there guys,"}, + {"duration": 0.0, "start": 1.597, "text": "Lichen' subscribe"}, + {"duration": 2.063, "start": 4.327, "text": "Buy my merch!"}, + ] + + def test_get_transcript_error_response(self, svc, transcript_info, http_service): + # We get an error response from the YouTube API. + http_service.get.side_effect = UnhandledUpstreamException( + "Something went wrong" + ) + + with pytest.raises(UnhandledUpstreamException): + svc.get_transcript(transcript_info) + + def test_get_transcript_unexpected_response( + self, svc, transcript_info, http_service + ): + http_service.get.return_value.text = "foo" + + with pytest.raises(ElementTree.ParseError): + svc.get_transcript(transcript_info) + + @pytest.fixture + def svc(self, http_service): + return YouTubeTranscriptService(http_service) + + +class TestFactory: + def test_factory(self, YouTubeTranscriptService, http_service, pyramid_request): + svc = factory(sentinel.context, pyramid_request) + + YouTubeTranscriptService.assert_called_once_with(http_service=http_service) + assert svc == YouTubeTranscriptService.return_value + + @pytest.fixture + def YouTubeTranscriptService(self, patch): + return patch("via.services.youtube_transcript.YouTubeTranscriptService") diff --git a/via/services/__init__.py b/via/services/__init__.py index ab267007..41b12877 100644 --- a/via/services/__init__.py +++ b/via/services/__init__.py @@ -12,6 +12,7 @@ from via.services.url_details import URLDetailsService from via.services.via_client import ViaClientService from via.services.youtube import YouTubeService +from via.services.youtube_transcript import YouTubeTranscriptService def includeme(config): # pragma: no cover @@ -37,6 +38,9 @@ def includeme(config): # pragma: no cover config.register_service_factory( "via.services.youtube.factory", iface=YouTubeService ) + config.register_service_factory( + "via.services.youtube_transcript.factory", iface=YouTubeTranscriptService + ) config.register_service_factory( "via.services.url_details.factory", iface=URLDetailsService diff --git a/via/services/youtube.py b/via/services/youtube.py index a23e26c9..6c730fec 100644 --- a/via/services/youtube.py +++ b/via/services/youtube.py @@ -1,10 +1,10 @@ from urllib.parse import parse_qs, quote_plus, urlparse from sqlalchemy import select -from youtube_transcript_api import YouTubeTranscriptApi from via.models import Transcript, Video from via.services.http import HTTPService +from via.services.youtube_transcript import YouTubeTranscriptService class YouTubeDataAPIError(Exception): @@ -12,13 +12,19 @@ class YouTubeDataAPIError(Exception): class YouTubeService: - def __init__( - self, db_session, enabled: bool, api_key: str, http_service: HTTPService + def __init__( # pylint:disable=too-many-arguments + self, + db_session, + enabled: bool, + api_key: str, + http_service: HTTPService, + youtube_transcript_service: YouTubeTranscriptService, ): self._db = db_session self._enabled = enabled self._api_key = api_key self._http_service = http_service + self._transcript_svc = youtube_transcript_service @property def enabled(self): @@ -101,10 +107,18 @@ def get_transcript(self, video_id): ).first(): return transcript.transcript - transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=("en",)) + transcript_infos = self._transcript_svc.get_transcript_infos(video_id) + transcript_info = self._transcript_svc.pick_default_transcript(transcript_infos) + transcript = self._transcript_svc.get_transcript(transcript_info) + self._db.add( - Transcript(video_id=video_id, transcript_id="en", transcript=transcript) + Transcript( + video_id=video_id, + transcript_id=transcript_info.id, + transcript=transcript, + ) ) + return transcript @@ -114,4 +128,5 @@ def factory(_context, request): enabled=request.registry.settings["youtube_transcripts"], api_key=request.registry.settings["youtube_api_key"], http_service=request.find_service(HTTPService), + youtube_transcript_service=request.find_service(YouTubeTranscriptService), ) diff --git a/via/services/youtube_transcript.py b/via/services/youtube_transcript.py new file mode 100644 index 00000000..9a69ccf5 --- /dev/null +++ b/via/services/youtube_transcript.py @@ -0,0 +1,148 @@ +import re +from base64 import b64encode +from dataclasses import dataclass +from typing import Dict, List +from xml.etree import ElementTree + +from via.services.http import HTTPService + + +@dataclass +class TranscriptInfo: + """Information about a YouTube video transcript.""" + + language_code: str + """The transcript's language code, e.g. "en-us".""" + + name: str + """The transcript's name, e.g. "English (United States)".""" + + url: str + """The transcript's download URL.""" + + autogenerated: bool + """Whether or not the transcript was autogenerated by speech recognition.""" + + @property + def id(self) -> str: # pylint:disable=invalid-name + """Return a unique ID for this transcript.""" + return ".".join( + [ + self.language_code, + "a" if self.autogenerated else "", + b64encode(self.name.encode("utf-8")).decode("utf-8"), + ] + ).rstrip(".") + + +#: The preferences that we apply when choosing a default transcript for a video. +DEFAULT_TRANSCRIPT_PREFERENCES = ( + { + "language_code": re.compile("en"), + "name": re.compile("English"), + "autogenerated": False, + }, + { + "language_code": re.compile("en-.*"), + "name": re.compile(r"English \(.*\)"), + "autogenerated": False, + }, + { + "language_code": re.compile("en"), + "name": re.compile(".*"), + "autogenerated": False, + }, + { + "language_code": re.compile("en-.*"), + "name": re.compile(".*"), + "autogenerated": False, + }, + { + "language_code": re.compile("en"), + "name": re.compile(".*"), + "autogenerated": True, + }, + { + "language_code": re.compile("en-"), + "name": re.compile(".*"), + "autogenerated": True, + }, +) + + +class YouTubeTranscriptService: + """A service for getting text transcripts of YouTube videos.""" + + def __init__(self, http_service: HTTPService): + self._http_service = http_service + + def get_transcript_infos(self, video_id: str) -> List[TranscriptInfo]: + """Return the list of available transcripts for `video_id`.""" + response = self._http_service.post( + "https://youtubei.googleapis.com/youtubei/v1/player", + json={ + "context": { + "client": { + "hl": "en", + "clientName": "WEB", + "clientVersion": "2.20210721.00.00", + } + }, + "videoId": video_id, + }, + ) + json = response.json() + dicts = json["captions"]["playerCaptionsTracklistRenderer"]["captionTracks"] + + return [ + TranscriptInfo( + language_code=caption_track_dict["languageCode"].lower(), + autogenerated=caption_track_dict.get("kind", None) == "asr", + name=caption_track_dict["name"]["simpleText"], + url=caption_track_dict["baseUrl"], + ) + for caption_track_dict in dicts + ] + + def pick_default_transcript( + self, transcript_infos: List[TranscriptInfo] + ) -> TranscriptInfo: + """Return a choice of default transcript from `transcript_infos`.""" + + def matches(transcript_info: TranscriptInfo, preference: dict): + for key in ("language_code", "name"): + if not preference[key].fullmatch(getattr(transcript_info, key)): + return False + + return preference["autogenerated"] == transcript_info.autogenerated + + for preference in DEFAULT_TRANSCRIPT_PREFERENCES: + for transcript_info in transcript_infos: + if matches(transcript_info, preference): + return transcript_info + + return transcript_infos[0] + + def get_transcript(self, transcript_info: TranscriptInfo) -> List[Dict]: + """Download and return the actual transcript text for `transcript_info`.""" + response = self._http_service.get(transcript_info.url) + xml_elements = ElementTree.fromstring(response.text) + + def strip_html(xml_string): + return "".join( + ElementTree.fromstring(f"{xml_string}").itertext() + ).strip() + + return [ + { + "text": strip_html(xml_element.text), + "start": float(xml_element.attrib["start"]), + "duration": float(xml_element.attrib.get("dur", "0.0")), + } + for xml_element in xml_elements + if xml_element.text is not None + ] + + +def factory(_context, request): + return YouTubeTranscriptService(http_service=request.find_service(HTTPService))