This is a HuggingFace's 🤗 transformers
+ Lightning ⚡️ implementation of K-Nearest Neighbors Augmented Language Models, designed to be easy to read & understand, useful in research, and for experimenting with new kNN-based model ideas.
The implementation is originally based on the k-NN Transformers repository. I found the originally implementation difficult to work with, especially for distributed environments. I reimplemented the method and made it compatible with Lightning ⚡️, and allows parallelization along multiple nodes and GPUs, as well as training using DeepSpeed through Lightning ⚡️.
The repository currently implements k-nearest-neighbor language model (kNN-LM) (Khandelwal et al., ICLR'2020). Efforts to implement k-nearest-neighbor machine translation (kNN-MT) (Khandelwal et al., ICLR'2021) and Neuro-Symbolic Language Modeling with Automaton-augmented Retrieval (ICML'2022), as well as decoder-style architectures (GPT-based) is planned in the future.
There are 4 main files in knnlm/training
:
generate.py
Generates a.arrow
tokenized dataset from a.jsonl
file of input-output pairs.train.py
Trains the model on the generated dataset.store.py
Generates afaiss
Approximate Nearest Neighbor (ANN) index from the training set.eval.py
Evaluates the model with/without the index.
All of these steps is controlled by a single config file knnlm/configs/main.yaml
. Simply specify the path to your data (train_path
, val_path
), the path to save the store (store_dir
), the path to the checkpoint to your finetuned model (checkpoint
), and all of the other typical training params (base model name, training parameters). Then, you can run the code in the sequence above (generate, train, store, and eval).
- k-NN Transformers for the original code.
- k-NN LM for the original KNN-LM implementation.