Skip to content

Commit

Permalink
Add machine translation options to the YouTube model
Browse files Browse the repository at this point in the history
  • Loading branch information
Jon Betts committed Aug 2, 2023
1 parent a507b1e commit 1117652
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 6 deletions.
2 changes: 1 addition & 1 deletion tests/unit/via/services/youtube_api/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_get_transcript(self, client, http_session):

transcript = client.get_transcript(caption_track)

http_session.get.assert_called_once_with(url=caption_track.base_url)
http_session.get.assert_called_once_with(url=caption_track.url)
assert transcript == Transcript(
track=caption_track,
text=[
Expand Down
55 changes: 52 additions & 3 deletions tests/unit/via/services/youtube_api/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,20 @@ def test_from_v1_json(self, kind):
(CaptionTrack(language_code="en"), "en"),
(CaptionTrack(language_code="en", kind="asr"), "en.a"),
(CaptionTrack(language_code="en", name="Hello"), "en..SGVsbG8="),
# Let's try everything at once
(
CaptionTrack(language_code="en-gb", kind="asr", name="Name"),
"en-gb.a.TmFtZQ==",
CaptionTrack(language_code="en", translated_language_code="fr"),
"en...fr",
),
# This combination isn't actually possible, but let's try everything at
# once
(
CaptionTrack(
language_code="en-gb",
kind="asr",
name="Name",
translated_language_code="fr",
),
"en-gb.a.TmFtZQ==.fr",
),
),
)
Expand All @@ -53,6 +63,28 @@ def test_is_auto_generated(self):
caption_track.kind = None
assert not caption_track.is_auto_generated

@pytest.mark.parametrize(
"caption_track,url",
(
(
CaptionTrack("en", base_url="http://example.com?a=1"),
"http://example.com?a=1",
),
(
CaptionTrack(
"en",
base_url="http://example.com?a=1",
translated_language_code="fr",
),
"http://example.com?a=1&tlang=fr",
),
(CaptionTrack("en", base_url=None), None),
(CaptionTrack("en", base_url=None, translated_language_code="fr"), None),
),
)
def test_url(self, caption_track, url):
assert caption_track.url == url


class TestCaptions:
def test_from_v1_json(self, CaptionTrack):
Expand Down Expand Up @@ -122,6 +154,23 @@ def test_find_matching_track(self, preferences, expected_label):
else not caption_track
)

def test_find_matching_track_with_translation(self):
captions = Captions(tracks=[CaptionTrack("fr", label="plain_fr")])

caption_track = captions.find_matching_track(
[
CaptionTrack(
language_code=Any(),
name=Any(),
kind=Any(),
translated_language_code="de",
)
]
)

assert caption_track.label == "plain_fr"
assert caption_track.translated_language_code == "de"

@pytest.fixture
def CaptionTrack(self, patch):
return patch("via.services.youtube_api.models.CaptionTrack")
Expand Down
4 changes: 2 additions & 2 deletions via/services/youtube_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def get_transcript(self, caption_track: CaptionTrack) -> Transcript:
the value before returning it.
"""

if not caption_track.base_url:
if not caption_track.url:
raise ValueError("Cannot get a transcript without a URL")

response = self._http.get(url=caption_track.base_url)
response = self._http.get(url=caption_track.url)
xml_elements = ElementTree.fromstring(response.text)

return Transcript(
Expand Down
29 changes: 29 additions & 0 deletions via/services/youtube_api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class CaptionTrack:
kind: Optional[str] = None
"""Is this track automatically generated by audio to text AI?"""

translated_language_code: Optional[str] = None
"""Language to machine translate this into. Set this manually."""

label: Optional[str] = None
"""Human readable name (determined by language + name)."""

Expand Down Expand Up @@ -54,6 +57,7 @@ def id(self) -> str: # pylint: disable=invalid-name
self.language_code,
"a" if self.is_auto_generated else None,
name,
self.translated_language_code,
]
).rstrip(".")

Expand All @@ -63,6 +67,19 @@ def is_auto_generated(self) -> bool:

return self.kind == "asr"

@property
def url(self) -> Optional[str]:
"""Get the URL to download a transcript of this caption track."""
if not self.base_url:
return None

url = self.base_url

if self.translated_language_code:
url += f"&tlang={self.translated_language_code}"

return url


@dataclass
class Captions:
Expand Down Expand Up @@ -98,6 +115,10 @@ def find_matching_track(
* language_code
* name
* is_auto_generated / kind
* translation_language_code
For a match to happen, we must match the first three items, and be
translatable to the last if present.
Earlier items are higher priority.
Expand All @@ -123,6 +144,14 @@ def get_key(track: CaptionTrack):
if best_index is None or best_index > index:
best_index, best_caption_track = index, deepcopy(caption_track)

if best_index is None:
return None

if target_language := preferences[best_index].translated_language_code:
# Convert the track to a translated language if required, we've
# checked above this is ok.
best_caption_track.translated_language_code = target_language

return best_caption_track


Expand Down

0 comments on commit 1117652

Please sign in to comment.