Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alignment passthrough #26

Merged
merged 38 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
4ea23bb
wip
XapaJIaMnu Apr 3, 2023
a144ae2
Initial implementation of noise augmenters
XapaJIaMnu Apr 3, 2023
be4ebe5
Simplify code a bit
jelmervdl Apr 4, 2023
7dcc2a4
Fix tests
jelmervdl Apr 4, 2023
7a1189f
Fix possible bug in test
jelmervdl Apr 4, 2023
6da73ff
Add specific tests for the three modes
jelmervdl Apr 4, 2023
55b443c
Add alignment info to the simple tests
jelmervdl Apr 4, 2023
42b850e
Make placeholder modifier produce (corrected) alignment pairs
jelmervdl Apr 4, 2023
831d564
Make sure `get_placeholding_candidates` returns the original instances
jelmervdl Apr 5, 2023
2e3d29f
update tests
jelmervdl Apr 5, 2023
38e54de
Merge branch 'main' into alignment-passthrough
jelmervdl Jul 6, 2023
723760e
Attempt to improve the alignment fix-up
jelmervdl Jul 6, 2023
17733d4
Fix unit tests
jelmervdl Jul 7, 2023
5c768e4
Implement retokenize modifier
jelmervdl Jul 24, 2023
e0adec6
Merge remote-tracking branch 'origin/main' into alignment-passthrough
jelmervdl Jul 25, 2023
294a18d
Let PlaceholderModifier use Retokenizer implementation for now
jelmervdl Jul 27, 2023
d8b1b10
Add unittest for spm retokenize in placeholders
jelmervdl Jul 27, 2023
704bd65
Add test to confirm that even when no placeholder is added, retokeniz…
jelmervdl Jul 27, 2023
38a3cae
Efficiency: don't bother calculating candidates if prob = 0.
jelmervdl Jul 27, 2023
0c4868f
Add tests covering spaces tokenizer
jelmervdl Jul 27, 2023
aab72a4
Document the `spm_vocab` option of the `Tags` modifier
jelmervdl Jul 27, 2023
973906a
Be nicer about issues with the alignment info
jelmervdl Jul 28, 2023
6b4abe0
Explain the `StopIteration` bit
jelmervdl Jul 28, 2023
c200c9c
Remove unreachable else
jelmervdl Jul 28, 2023
126587d
Remove debug code
jelmervdl Jul 28, 2023
106d832
Document and rename methods
jelmervdl Jul 28, 2023
b9ad9f6
Skip trainer backtrace test for now
jelmervdl Jul 28, 2023
b822e8c
Only print alignment info when spm_vocab is passed in
jelmervdl Aug 7, 2023
ef3c780
Make `retokenize` a little less `O(n^2)`
jelmervdl Aug 7, 2023
7069872
Replace placeholder-specific end-to-end tests with specific test for …
jelmervdl Aug 7, 2023
6b62198
Use `Path` in type signature of modifiers to resolve relative paths
jelmervdl Aug 9, 2023
7a80d2c
Rewrite end-to-end tests
jelmervdl Aug 9, 2023
9603208
Rewrite DatasetReader to not always produce n+1 lines
jelmervdl Aug 9, 2023
a2248ad
Add option for batch size
jelmervdl Aug 9, 2023
2f72e76
Add some comments to the tests
jelmervdl Aug 9, 2023
4779dd6
Fix missing sentencepiece dependency
jelmervdl Aug 9, 2023
a17af46
Fix other pyproject.toml entries while we're at it
jelmervdl Aug 9, 2023
2479f09
Make trainer skip lines that can't be processed by modifier
jelmervdl Aug 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/opustrainer/alignments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@ def parse_alignments(pairs:str, src_tokens:Optional[TokenList]=None, trg_tokens:
]

if src_tokens is not None and trg_tokens is not None:
invalid_pairs = [
pair
for pair in pairs
if pair.src < 0 or pair.src >= len(src_tokens)
or pair.trg < 0 or pair.trg >= len(trg_tokens)
]
if invalid_pairs:
raise ValueError('Out-of-bound alignment pairs: ' + ' '.join(map(repr, invalid_pairs)))
for pair in pairs:
if pair.src < 0 or pair.src >= len(src_tokens) \
or pair.trg < 0 or pair.trg >= len(trg_tokens):
raise ValueError('Out-of-bound alignment pairs')

return pairs

Expand Down
8 changes: 4 additions & 4 deletions src/opustrainer/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def get_log_level(name: str) -> int:
logging.log(logging.WARNING, f"unknown log level level used: {name} assuming warning...")
return logging.WARNING

def log(msg: str, loglevel: str = "INFO") -> None:
def log(msg: str, loglevel: str = "INFO", **kwargs) -> None:
level = get_log_level(loglevel)
logging.log(level, msg)
logging.log(level, msg, **kwargs)


@lru_cache(None)
def log_once(msg: str, loglevel: str = "INFO") -> None:
def log_once(msg: str, loglevel: str = "INFO", **kwargs) -> None:
"""A wrapper to log, to make sure that we only print things once"""
log(msg, loglevel)
log(msg, loglevel, **kwargs)


def setup_logger(outputfilename: Optional[str] = None, loglevel: str = "INFO", disable_stderr: bool=False) -> None:
Expand Down
12 changes: 3 additions & 9 deletions src/opustrainer/modifiers/placeholders.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,9 @@ def __call__(self, line:str) -> str:
target = trg.split()
alignments = []

# Try parsing alignments. If we fail, just treat this sentence pair as one with out any
# alignment info.
try:
alignments = parse_alignments(rest[0], source, target)
except IndexError:
logger.log_once(f"Encountered empty alignment field, ignoring alignment info for such lines", loglevel="WARNING")
except ValueError:
logger.log_once(f"Encountered invalid alignments, ignoring alignment info for such lines", loglevel="WARNING")

# Try parsing alignments. If we fail, the sentence will be thrown out
# by the trainer.
alignments = parse_alignments(rest[0], source, target)
candidate_offset = 0;

while self.probability > 0.0:
Expand Down
9 changes: 5 additions & 4 deletions src/opustrainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,12 +593,13 @@ def state(self) -> EpochTrackerState:

Out = TypeVar('Out')

def trace_map(fn: Callable[[In], Out], items: Iterable[In]) -> Iterable[Out]:
for n, item in enumerate(items):
def try_trace_map(fn: Callable[[In], Out], items: Iterable[In]) -> Iterable[Out]:
for item in items:
try:
yield fn(item)
except Exception as exc:
raise Exception(f'Exception while processing item {n}: {item!r}') from exc
logger.log(f'Exception while processing line, skipping: {item!r}', 'WARNING',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should ideally be log_once otherwise we would get spammed everytime we loop through the dataset. (I assume the exception kwards would be the same always)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be, but it might also not be hashable. I'll experiment.

exc_info=(type(exc), exc, exc.__traceback__.tb_next)) # skip fn(item) frame


class Trainer:
Expand Down Expand Up @@ -698,7 +699,7 @@ def run(self, *, batch_size:int=100) -> Iterable[List[str]]:
# Apply any modifiers to random lines in the batch, or sentence
# (Multiple modifiers can be applied to the same line!)
for modifier in modifiers:
batch = list(trace_map(lambda line: modifier(line.rstrip('\r\n')) + '\n', batch))
batch = list(try_trace_map(lambda line: modifier(line.rstrip('\r\n')) + '\n', batch))

if self.shuffle:
random.shuffle(batch)
Expand Down
19 changes: 6 additions & 13 deletions tests/test_placeholders.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,12 @@ def test_warn_if_tag_modifier_is_not_last(self):
"""))
self.assertRegex(logger_ctx.output[0], r"Tags modifier should to be the last modifier to be applied")

def test_warn_if_alignment_is_missing(self):
def test_exception_if_alignment_is_missing(self):
tagger = PlaceholderTagModifier()
with self.assertLogs(logger, level='WARNING') as logger_ctx:
self.assertEqual(
tagger('Hello world\tHallo welt\t'),
'Hello world\tHallo welt')
self.assertRegex(logger_ctx.output[0], r'empty alignment field')
with self.assertRaises(IndexError):
tagger('Hello world\tHallo welt\t')

def test_warn_if_alignment_is_missing(self):
def test_exception_if_alignment_is_invalid(self):
tagger = PlaceholderTagModifier()
with self.assertLogs(level='WARNING') as logger_ctx:
self.assertEqual(
tagger('Hello world\tHallo welt\t0-0 1-2'),
'Hello world\tHallo welt')
self.assertRegex(logger_ctx.output[0], r'invalid alignments')

with self.assertRaises(ValueError):
tagger('Hello world\tHallo welt\t0-0 1-2')
12 changes: 9 additions & 3 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from contextlib import closing
from textwrap import dedent
from io import StringIO
from itertools import chain

import yaml

Expand Down Expand Up @@ -329,7 +330,6 @@ def test_combined_stage_configuration(self):
curriculum = CurriculumLoader().load(config)
self.assertEqual([modifier.__class__.__name__ for modifier in curriculum.stages['start'].modifiers or []], ['UpperCaseModifier', 'TitleCaseModifier'])

@unittest.skip('`Tags` no longer raises an exception on invalid alignment pairs')
def test_modifier_error_line_context(self):
"""Test that when a modifier fails, we get context information about the line that failed"""
with tempfile.NamedTemporaryFile('w', encoding='utf-8') as fd:
Expand Down Expand Up @@ -357,5 +357,11 @@ def test_modifier_error_line_context(self):

trainer = Trainer(curriculum)

with self.assertRaisesRegex(Exception, "Exception while processing item 1:"):
list(trainer.run(batch_size=2))
with self.assertLogs(level='WARNING') as logger_ctx:
output = list(chain.from_iterable(trainer.run(batch_size=1)))
# Assert we skipped the line
self.assertEqual(len(output), 1)
# Assert that we got the general error message
self.assertRegex(logger_ctx.output[0], r'Exception while processing line, skipping:')
# Assert that we got the specific error as well
self.assertRegex(logger_ctx.output[0], r'ValueError: Out-of-bound alignment pairs')