ULLME is a flexible, plug-and-play implementation that enables bidirectional attention across various LLMs and supports a range of fine-tuning strategies to learn passage embeddings.
ULLME can be easily installed via one of the following methods:
git clone https://github.com/nlp-uoregon/ullme.git
cd ullme
pip install -e .
# if you using flash-attention-2 (this is the default for ullme)
pip install flash-attn --no-build-isolation
ULLME offers follwing main features:
ULLME can support enhancing HuggingFace models by adding support for bidirectional processing in decoder-only Large Language Models (LLMs), as well as sequence encoding and pooling operations.
from ullme.models import ULLME
model = ULLME(
model_name_or_path="mistralai/Mistral-7B-v0.1",
model_backbone_type="mistral",
lora_name="ullme-mistral",
loar_r=16,
lora_alpha=32,
)
input_sentence = "This a example sentence."
model_inputs = model.tokenizer(
[input_sentence],
return_tensors='pt'
)
model_output = model(
input_ids=model_inputs['input_ids'],
attention_mask=model_inputs['attention_mask'],
is_generate=False
)
The ULLME's returned model is a PyTorch object, providing users with the flexibility to integrate it into various frameworks or pipelines. By default the ULLME model uses the mean
pooling strategy. The is_generate
parameter plays a crucial role in controlling the attention mechanism: when set to False
, the model employs bidirectional attention, optimizing it for dense retrieval tasks, while True
reverts the model to causal attention, mimicking the standard Hugging Face Transformer model output.
Our ULLME framework supports multiple fine-tuning strategies
from ullme.trainer import GradCacheTrainer
trainer = GradCacheTrainer(
con_loss_type='NTXentLoss',
gen_loss_type='sigmoid', # 'sft'
use_kl_loss=True
)
trainer.fit_epoch(
model=model,
train_loader=train_dataloader,
)
ULLME enables efficient and effective CL. It comes equipped with a range of advanced features designed to enhance the CL process and optimize performance, such as GradCache, cross-devices contrastive loss computation, miners, ... Note that, ULLME enables CL by default.
ULLME not only supports Contrastive Learning (CL) but also enables Supervised Fine-Tuning (SFT) and provides a range of preference loss functions to further enhance model performance. The loss functions that can be easily selected through the gen_loss_type
argument, currently support sft
, sigmoid
(i.e., DPO), kto
, ipo
.
In ULLME, we also introduce a novel fine-tuning strategy, GRL, that explicitly aligns the model's understanding of relevance in both embedding and generation spaces through a Kullback-Leibler (KL) divergence loss. You can enbale this by set use_kl_loss=True
.
from ullme.models import WrappedULLME
from ullme.eval import eval_mteb_dataset
model = WrappedULLME(
model_name_or_path="mistralai/Mistral-7B-v0.1",
model_backbone_type="mistral",
lora_name="ullme-mistral",
loar_r=16,
lora_alpha=32,
model_checkpoint="path/to/your/checkpoint"
)
eval_result = eval_mteb_dataset(
model=model,
dataset_name='MSMARCO',
langs=['eng'],
)
>> {'eng': 35.8}
ULLME streamlines the evaluation process by integrating direct support for evaluating LLM-based text embedding models over MTEB. ULLME allows users to select specific datasets and language subsets for evaluation through parameters dataset_name
and langs
.
We publish three fine-tuned model using GRL on three popular LLMs: Meta-Llama-3-8B; Mistral-2-7B; Phi-1.5B
To finetune the Meta-Llama-3-8B model, run the following command:
python -m ullme.main \
--config_file scripts/configs/llama.yaml \
--nodes 4 \
--devices 1 \
--gc_chunk_size 4 \
--output_dir output/ullme-grl-llam3
To evaluate the model on the MTEB benchmark, run the following command:
python -m eval.eval_mteb \
--config_file scripts/configs/llama.yaml \
--output_dir output/ullme-grl-llam3/checkpoint.ckpt
This project has been supported by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA), via IARPA Contract No. 2022-22072200003 under the Human Interpretable Attribution of Text using Underlying Structure (HIATUS) Program.
If you have any questions about the code, feel free to open an issue on the GitHub repository.