Skip to content

amzn/amazon-multi-token-completion

MTC: Multi Token Completion

This package provides code for the paper Simple and Effective Multi-Token Completion from Masked Language Models.

Paper: Simple and Effective Multi-Token Completion from Masked Language Models

Bibtex entry:

@inproceedings{DBLP:conf/eacl/KalinskyKLG23,
  author       = {Oren Kalinsky and
                  Guy Kushilevitz and
                  Alexander Libov and
                  Yoav Goldberg},
  title        = {Simple and Effective Multi-Token Completion from Masked Language Models},
  booktitle    = {Findings of the Association for Computational Linguistics: {EACL}
                  2023, Dubrovnik, Croatia, May 2-6, 2023},
  pages        = {2311--2324},
  publisher    = {Association for Computational Linguistics},
  year         = {2023},
  url          = {https://aclanthology.org/2023.findings-eacl.179},
  timestamp    = {Mon, 08 May 2023 14:38:37 +0200},
  biburl       = {https://dblp.org/rec/conf/eacl/KalinskyKLG23.bib},
  bibsource    = {dblp computer science bibliography, https://dblp.org}
}

Steps to run the code

Configuration

  • pip install -r requirements.txt
  • pip install transformers==4.5.1 ray[default]==1.3.0 torch==1.8.1
  • Update data_path in configuration.py to 's3://multi-token-completion'

Data preprocessing

To create the dataset you will first need to parse our released data by running mtc_model. For example, to create the wikipedia dataset run: mtc_model.py --dataset_name wiki_pub. This will create the preprocessed dataset under data/input_data_bert-base-cased_wiki_pub/ using bert-base-cased by default.

Training

  • RNN decoder - run mtc_model.py --input_path data/input_data_bert-base-cased_wiki_pub/
  • EMAT decoder - run matrix_plugin.py --input_path data/input_data_bert-base-cased_wiki_pub/

Testing

  • RNN decoder - run test.py --dataset_path data/input_data_bert-base-cased_wiki_pub/ --ckpt <CHECKPOINT_PATH> --all
  • EMAT decoder - run matrix_plugin.py --input_path data/input_data_bert-base-cased_wiki_pub/ --ckpt <CHECKPOINT_PATH> --test
  • T5 baseline - run T5_constrained_generation.py

Security

See CONTRIBUTING for more information.

License

This project is licensed under the Apache-2.0 License.

About

No description, website, or topics provided.

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages