diff --git a/tests/conftest.py b/tests/conftest.py index a1f5cd4a..e1d6372a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker -from tests.factories import TranscriptFactory, VideoFactory +from tests.factories import TranscriptFactory, TranscriptInfoFactory, VideoFactory from tests.factories.factoryboy_sqlalchemy_session import ( clear_factoryboy_sqlalchemy_session, set_factoryboy_sqlalchemy_session, @@ -16,6 +16,7 @@ # Each factory has to be registered with pytest_factoryboy. register(TranscriptFactory) +register(TranscriptInfoFactory) register(VideoFactory) diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py index 0d81dcd0..2ff9cfd2 100644 --- a/tests/factories/__init__.py +++ b/tests/factories/__init__.py @@ -1,2 +1,3 @@ from tests.factories.transcript import TranscriptFactory +from tests.factories.transcript_info import TranscriptInfoFactory from tests.factories.video import VideoFactory diff --git a/tests/factories/transcript_info.py b/tests/factories/transcript_info.py new file mode 100644 index 00000000..6c9ffb28 --- /dev/null +++ b/tests/factories/transcript_info.py @@ -0,0 +1,13 @@ +from factory import Factory + +from via.services.youtube_transcript import TranscriptInfo + + +class TranscriptInfoFactory(Factory): + class Meta: + model = TranscriptInfo + + language_code = "en-us" + name = "English (United States)" + url = "https://example.com/api/timedtext?v=foo" + autogenerated = False diff --git a/tests/unit/services.py b/tests/unit/services.py index 00211d41..26bffe11 100644 --- a/tests/unit/services.py +++ b/tests/unit/services.py @@ -11,6 +11,7 @@ URLDetailsService, ViaClientService, YouTubeService, + YouTubeTranscriptService, ) @@ -69,3 +70,8 @@ def youtube_service(mock_service): youtube_service.get_video_id.return_value = None return youtube_service + + +@pytest.fixture +def youtube_transcript_service(mock_service): + return mock_service(YouTubeTranscriptService) diff --git a/tests/unit/via/services/youtube_test.py b/tests/unit/via/services/youtube_test.py index 7b9fa981..8fca1c7c 100644 --- a/tests/unit/via/services/youtube_test.py +++ b/tests/unit/via/services/youtube_test.py @@ -28,6 +28,7 @@ def test_enabled(self, db_session, enabled, api_key, expected): enabled=enabled, api_key=api_key, http_service=sentinel.http_service, + youtube_transcript_service=sentinel.youtube_transcript_service, ).enabled == expected ) @@ -94,42 +95,50 @@ def test_get_video_title_raises_YouTubeDataAPIError(self, svc, http_service): assert exc_info.value.__cause__ == http_service.get.side_effect - def test_get_transcript(self, db_session, svc, YouTubeTranscriptApi): - YouTubeTranscriptApi.get_transcript.return_value = [ - {"text": "foo", "start": 0.0, "duration": 1.0}, - {"text": "bar", "start": 1.0, "duration": 2.0}, - ] + def test_get_transcript_when_none_saved( + self, db_session, svc, youtube_transcript_service, transcript_info + ): + video_id = "test_video_id" + youtube_transcript_service.pick_default_transcript.return_value = ( + transcript_info + ) + youtube_transcript_service.get_transcript.return_value = "test_transcript" - returned_transcript = svc.get_transcript("test_video_id") + returned = svc.get_transcript(video_id) - YouTubeTranscriptApi.get_transcript.assert_called_once_with( - "test_video_id", languages=("en",) + # It gets the transcript from YouTubeTranscriptService. + youtube_transcript_service.get_transcript_infos.assert_called_once_with( + video_id + ) + youtube_transcript_service.pick_default_transcript.assert_called_once_with( + youtube_transcript_service.get_transcript_infos.return_value ) - assert returned_transcript == YouTubeTranscriptApi.get_transcript.return_value - # It should have cached the transcript in the DB. + youtube_transcript_service.get_transcript.assert_called_once_with( + transcript_info + ) + # It saves the transcript in the DB for next time. assert db_session.scalars(select(Transcript)).all() == [ Any.instance_of(Transcript).with_attrs( { - "video_id": "test_video_id", - "transcript": YouTubeTranscriptApi.get_transcript.return_value, + "video_id": video_id, + "transcript_id": transcript_info.id, + "transcript": "test_transcript", } ) ] + # It returns the transcript. + assert returned == "test_transcript" - @pytest.mark.usefixtures("db_session") - def test_get_transcript_returns_cached_transcripts( - self, transcript, svc, YouTubeTranscriptApi + def test_get_transcript_when_one_saved( + self, svc, transcript, youtube_transcript_service ): - returned_transcript = svc.get_transcript(transcript.video_id) + returned = svc.get_transcript(transcript.video_id) - YouTubeTranscriptApi.get_transcript.assert_not_called() - assert returned_transcript == transcript.transcript + # It returns the saved transcript from the DB without calling YouTubeService. + youtube_transcript_service.get_transcript.assert_not_called() + assert returned == transcript.transcript - @pytest.mark.usefixtures("db_session") - def test_get_transcript_returns_oldest_cached_transcript( - self, transcript_factory, svc - ): - """If there are multiple cached transcripts get_transcript() returns the oldest one.""" + def test_get_transcript_when_multiple_saved(self, svc, transcript_factory): oldest_transcript, newer_transcript = transcript_factory.create_batch( 2, video_id="video_id" ) @@ -138,9 +147,9 @@ def test_get_transcript_returns_oldest_cached_transcript( newer_transcript.created = datetime(2023, 8, 12) newer_transcript.transcript = "newest_transcript" - returned_transcript = svc.get_transcript("video_id") + returned = svc.get_transcript("video_id") - assert returned_transcript == "oldest_transcript" + assert returned == oldest_transcript.transcript @pytest.mark.parametrize( "video_id,expected_url", @@ -157,18 +166,25 @@ def test_canonical_video_url(self, video_id, expected_url, svc): assert expected_url == svc.canonical_video_url(video_id) @pytest.fixture - def svc(self, db_session, http_service): + def svc(self, db_session, http_service, youtube_transcript_service): return YouTubeService( db_session=db_session, enabled=True, api_key=sentinel.api_key, http_service=http_service, + youtube_transcript_service=youtube_transcript_service, ) class TestFactory: def test_it( - self, YouTubeService, youtube_service, pyramid_request, http_service, db_session + self, + YouTubeService, + youtube_service, + pyramid_request, + http_service, + db_session, + youtube_transcript_service, ): returned = factory(sentinel.context, pyramid_request) @@ -177,6 +193,7 @@ def test_it( enabled=pyramid_request.registry.settings["youtube_transcripts"], api_key="test_youtube_api_key", http_service=http_service, + youtube_transcript_service=youtube_transcript_service, ) assert returned == youtube_service @@ -187,8 +204,3 @@ def YouTubeService(self, patch): @pytest.fixture def youtube_service(self, YouTubeService): return YouTubeService.return_value - - -@pytest.fixture(autouse=True) -def YouTubeTranscriptApi(patch): - return patch("via.services.youtube.YouTubeTranscriptApi") diff --git a/tests/unit/via/services/youtube_transcript_test.py b/tests/unit/via/services/youtube_transcript_test.py new file mode 100644 index 00000000..1476dbdc --- /dev/null +++ b/tests/unit/via/services/youtube_transcript_test.py @@ -0,0 +1,207 @@ +import json +from io import BytesIO +from unittest.mock import sentinel + +import pytest +from h_matchers import Any +from requests import Response + +from tests.factories import TranscriptInfoFactory +from via.services.youtube_transcript import ( + TranscriptInfo, + YouTubeTranscriptService, + factory, +) + + +class TestTranscriptInfo: + @pytest.mark.parametrize( + "transcript_info,expected_id", + [ + ( + TranscriptInfo( + "en-us", + "English (United States)", + "https://example.com/transcript", + autogenerated=False, + ), + "en-us..RW5nbGlzaCAoVW5pdGVkIFN0YXRlcyk=", + ), + ( + TranscriptInfo( + "en", + "English", + "https://example.com/transcript", + autogenerated=True, + ), + "en.a.RW5nbGlzaA==", + ), + ], + ) + def test_id(self, transcript_info, expected_id): + assert transcript_info.id == expected_id + + +class TestYouTubeTranscriptService: + def test_get_transcript_infos(self, svc, http_service): + # The JSON response body from the YouTube API. + response_json = { + "captions": { + "playerCaptionsTracklistRenderer": { + "captionTracks": [ + { + "languageCode": "en", + "name": {"simpleText": "English"}, + "baseUrl": "https://example.com/transcript_1", + }, + { + "languageCode": "en-us", + "name": {"simpleText": "English (United States)"}, + "baseUrl": "https://example.com/transcript_2", + }, + ] + } + } + } + response = http_service.post.return_value = Response() + response.raw = BytesIO(json.dumps(response_json).encode("utf-8")) + + transcript_infos = svc.get_transcript_infos("test_video_id") + + caption_tracks = response_json["captions"]["playerCaptionsTracklistRenderer"][ + "captionTracks" + ] + assert transcript_infos == [ + Any.instance_of(TranscriptInfo).with_attrs( + { + "language_code": caption_tracks[0]["languageCode"], + "autogenerated": False, + "name": caption_tracks[0]["name"]["simpleText"], + "url": caption_tracks[0]["baseUrl"], + } + ), + Any.instance_of(TranscriptInfo).with_attrs( + { + "language_code": caption_tracks[1]["languageCode"], + "autogenerated": False, + "name": caption_tracks[1]["name"]["simpleText"], + "url": caption_tracks[1]["baseUrl"], + } + ), + ] + + @pytest.mark.parametrize( + "transcript_infos,expected_default_transcript_index", + [ + ( + [ + TranscriptInfoFactory(language_code="en", name="English"), + TranscriptInfoFactory( + language_code="en-us", name="English (United States)" + ), + ], + 0, + ), + ( + [ + TranscriptInfoFactory( + language_code="en-us", name="English (United States)" + ), + TranscriptInfoFactory(language_code="en", name="English - DTVCC1"), + ], + 0, + ), + ( + [ + TranscriptInfoFactory(language_code="en", name="English - Foo"), + TranscriptInfoFactory( + language_code="en-us", name="English (United States) - Foo" + ), + ], + 0, + ), + ( + [ + TranscriptInfoFactory( + language_code="en-us", name="English (United States) - Foo" + ), + TranscriptInfoFactory( + language_code="en", name="English", autogenerated=True + ), + ], + 0, + ), + ( + [ + TranscriptInfoFactory( + language_code="en", name="English", autogenerated=True + ), + TranscriptInfoFactory( + language_code="en-us", + name="English (United States)", + autogenerated=True, + ), + ], + 0, + ), + ( + [ + TranscriptInfoFactory(language_code="fr", name="French"), + TranscriptInfoFactory( + language_code="en", name="English", autogenerated=True + ), + ], + 1, + ), + ( + [ + TranscriptInfoFactory(language_code="fr", name="French"), + TranscriptInfoFactory(language_code="de", name="Deutsch"), + ], + 0, + ), + ], + ) + def test_pick_default_transcript( + self, svc, transcript_infos, expected_default_transcript_index + ): + assert ( + svc.pick_default_transcript(transcript_infos) + == transcript_infos[expected_default_transcript_index] + ) + + def test_get_transcript(self, svc, transcript_info, http_service): + http_service.get.return_value.text = """ + + Hey there guys, + Lichen' subscribe + + <font color="#A0AAB4">Buy my merch!</font> + + + """ + + transcript = svc.get_transcript(transcript_info) + http_service.get.assert_called_once_with(transcript_info.url) + + assert transcript == [ + {"duration": 1.387, "start": 0.21, "text": "Hey there guys,"}, + {"duration": 0.0, "start": 1.597, "text": "Lichen' subscribe"}, + {"duration": 2.063, "start": 4.327, "text": "Buy my merch!"}, + ] + + @pytest.fixture + def svc(self, http_service): + return YouTubeTranscriptService(http_service) + + +class TestFactory: + def test_factory(self, YouTubeTranscriptService, http_service, pyramid_request): + svc = factory(sentinel.context, pyramid_request) + + YouTubeTranscriptService.assert_called_once_with(http_service=http_service) + assert svc == YouTubeTranscriptService.return_value + + @pytest.fixture + def YouTubeTranscriptService(self, patch): + return patch("via.services.youtube_transcript.YouTubeTranscriptService") diff --git a/via/services/__init__.py b/via/services/__init__.py index ab267007..41b12877 100644 --- a/via/services/__init__.py +++ b/via/services/__init__.py @@ -12,6 +12,7 @@ from via.services.url_details import URLDetailsService from via.services.via_client import ViaClientService from via.services.youtube import YouTubeService +from via.services.youtube_transcript import YouTubeTranscriptService def includeme(config): # pragma: no cover @@ -37,6 +38,9 @@ def includeme(config): # pragma: no cover config.register_service_factory( "via.services.youtube.factory", iface=YouTubeService ) + config.register_service_factory( + "via.services.youtube_transcript.factory", iface=YouTubeTranscriptService + ) config.register_service_factory( "via.services.url_details.factory", iface=URLDetailsService diff --git a/via/services/youtube.py b/via/services/youtube.py index a23e26c9..6c730fec 100644 --- a/via/services/youtube.py +++ b/via/services/youtube.py @@ -1,10 +1,10 @@ from urllib.parse import parse_qs, quote_plus, urlparse from sqlalchemy import select -from youtube_transcript_api import YouTubeTranscriptApi from via.models import Transcript, Video from via.services.http import HTTPService +from via.services.youtube_transcript import YouTubeTranscriptService class YouTubeDataAPIError(Exception): @@ -12,13 +12,19 @@ class YouTubeDataAPIError(Exception): class YouTubeService: - def __init__( - self, db_session, enabled: bool, api_key: str, http_service: HTTPService + def __init__( # pylint:disable=too-many-arguments + self, + db_session, + enabled: bool, + api_key: str, + http_service: HTTPService, + youtube_transcript_service: YouTubeTranscriptService, ): self._db = db_session self._enabled = enabled self._api_key = api_key self._http_service = http_service + self._transcript_svc = youtube_transcript_service @property def enabled(self): @@ -101,10 +107,18 @@ def get_transcript(self, video_id): ).first(): return transcript.transcript - transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=("en",)) + transcript_infos = self._transcript_svc.get_transcript_infos(video_id) + transcript_info = self._transcript_svc.pick_default_transcript(transcript_infos) + transcript = self._transcript_svc.get_transcript(transcript_info) + self._db.add( - Transcript(video_id=video_id, transcript_id="en", transcript=transcript) + Transcript( + video_id=video_id, + transcript_id=transcript_info.id, + transcript=transcript, + ) ) + return transcript @@ -114,4 +128,5 @@ def factory(_context, request): enabled=request.registry.settings["youtube_transcripts"], api_key=request.registry.settings["youtube_api_key"], http_service=request.find_service(HTTPService), + youtube_transcript_service=request.find_service(YouTubeTranscriptService), ) diff --git a/via/services/youtube_transcript.py b/via/services/youtube_transcript.py new file mode 100644 index 00000000..bc00c322 --- /dev/null +++ b/via/services/youtube_transcript.py @@ -0,0 +1,135 @@ +import re +from base64 import b64encode +from dataclasses import dataclass +from typing import Dict, List +from xml.etree import ElementTree + +from via.services.http import HTTPService + + +@dataclass +class TranscriptInfo: + language_code: str + name: str + url: str + autogenerated: bool + + @property + def id(self) -> str: # pylint:disable=invalid-name + """Return a unique ID for this transcript.""" + name = b64encode(self.name.encode("utf-8")).decode("utf-8") + + return ".".join( + part or "" + for part in [ + self.language_code, + "a" if self.autogenerated else None, + name, + ] + ).rstrip(".") + + +DEFAULT_TRANSCRIPT_PREFERENCES = ( + { + "language_code": re.compile("en"), + "name": re.compile("English"), + "autogenerated": False, + }, + { + "language_code": re.compile("en-.*"), + "name": re.compile(r"English \(.*\)"), + "autogenerated": False, + }, + { + "language_code": re.compile("en"), + "name": re.compile(".*"), + "autogenerated": False, + }, + { + "language_code": re.compile("en-.*"), + "name": re.compile(".*"), + "autogenerated": False, + }, + { + "language_code": re.compile("en"), + "name": re.compile(".*"), + "autogenerated": True, + }, + { + "language_code": re.compile("en-"), + "name": re.compile(".*"), + "autogenerated": True, + }, +) + + +class YouTubeTranscriptService: + def __init__(self, http_service: HTTPService): + self._http_service = http_service + + def get_transcript_infos(self, video_id: str) -> List[TranscriptInfo]: + response = self._http_service.post( + "https://youtubei.googleapis.com/youtubei/v1/player", + json={ + "context": { + "client": { + "hl": "en", + "clientName": "WEB", + "clientVersion": "2.20210721.00.00", + } + }, + "videoId": video_id, + }, + ) + json = response.json() + dicts = json["captions"]["playerCaptionsTracklistRenderer"]["captionTracks"] + + return [ + TranscriptInfo( + language_code=caption_track_dict["languageCode"].lower(), + autogenerated=caption_track_dict.get("kind", None) == "asr", + name=caption_track_dict["name"]["simpleText"], + url=caption_track_dict["baseUrl"], + ) + for caption_track_dict in dicts + ] + + def pick_default_transcript( + self, transcript_infos: List[TranscriptInfo] + ) -> TranscriptInfo: + def matches(transcript_info: TranscriptInfo, preference: dict): + for key in ("language_code", "name"): + if not preference[key].fullmatch(getattr(transcript_info, key)): + return False + + return preference["autogenerated"] == transcript_info.autogenerated + + for preference in DEFAULT_TRANSCRIPT_PREFERENCES: + for transcript_info in transcript_infos: + if matches(transcript_info, preference): + return transcript_info + + return transcript_infos[0] + + def get_transcript(self, transcript_info: TranscriptInfo) -> List[Dict]: + response = self._http_service.get(transcript_info.url) + xml_elements = ElementTree.fromstring(response.text) + + def strip_html(xml_string): + return "".join( + ElementTree.fromstring(f"{xml_string}").itertext() + ).strip() + + return [ + { + "text": strip_html(xml_element.text), + "start": float(xml_element.attrib["start"]), + "duration": float(xml_element.attrib.get("dur", "0.0")), + } + for xml_element in xml_elements + if xml_element.text is not None + ] + + +def factory(_context, request): + return YouTubeTranscriptService(http_service=request.find_service(HTTPService))