diff --git a/t5/data/tasks.py b/t5/data/tasks.py index faeb423c..5ef43fab 100644 --- a/t5/data/tasks.py +++ b/t5/data/tasks.py @@ -223,6 +223,30 @@ metric_fns=[metrics.bleu], output_features=DEFAULT_OUTPUT_FEATURES) +ULM_V0_VOCAB_SPM = "/cns/mf-d/home/multipod-language-data/vocab/m4_meena_vocab_0304/spm.256k.model" +ULM_V0_VOCAB = t5.data.SentencePieceVocabulary(ULM_V0_VOCAB_SPM, extra_ids=100) +ULM_OUTPUT_FEATURES = { + "inputs": seqio.Feature( + vocabulary=ULM_V0_VOCAB, add_eos=True, required=False), + "targets": seqio.Feature( + vocabulary=ULM_V0_VOCAB, add_eos=True) +} + +TaskRegistry.add( + "wmt_t2t_ende_v003_ulm_v0_vocab", + source=seqio.TfdsDataSource(tfds_name="wmt_t2t_translate/de-en:1.0.0"), + preprocessors=[ + functools.partial( + preprocessors.translate, + source_language=b.language_pair[1], + target_language=b.language_pair[0]), + seqio.preprocessors.tokenize, + seqio.CacheDatasetPlaceholder(), + seqio.preprocessors.append_eos_after_trim, + ], + metric_fns=[metrics.bleu], + output_features=ULM_OUTPUT_FEATURES) + # ================================= SuperGlue ================================== for b in tfds.text.super_glue.SuperGlue.builder_configs.values(): # We use a simplified version of WSC, defined below