Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Checkpoints Format #63

Merged
merged 14 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,15 +234,18 @@ Check out [Lean Copilot](https://github.com/lean-dojo/LeanCopilot) if you want t
1. Download and install [Miniconda Python 3](https://docs.conda.io/en/latest/miniconda.html) (Anaconda should also work).
2. Create the conda environment and install Python dependencies:
```bash
conda create --yes --name ReProver python=3.10 ipython numpy
conda create --yes --name ReProver python=3.11 ipython
conda activate ReProver
pip install torch --index-url https://download.pytorch.org/whl/cu121 # Depending on your CUDA version; see https://pytorch.org/.
pip install tqdm loguru deepspeed "pytorch-lightning[extra]" transformers tensorboard openai rank_bm25 lean-dojo
pip install torch # Depending on your CUDA version; see https://pytorch.org/.
pip install tqdm loguru deepspeed "pytorch-lightning[extra]" transformers wandb openai rank_bm25 lean-dojo vllm
pip install git+https://github.com/pytorch/torchtune
```
3. Prepend the repo's root to the `PYTHONPATH` environment variable.
4. Make sure `wget` and `tar` are available. Then, run `python scripts/download_data.py` to download [LeanDojo Benchmark 4](https://zenodo.org/doi/10.5281/zenodo.8040109). They will be saved to `./data`.
5. Satisfy the requirements of [LeanDojo](https://github.com/lean-dojo/LeanDojo#requirements).
6. Use [LeanDojo](https://github.com/lean-dojo/LeanDojo) to trace all repos in the datasets: `python scripts/trace_repos.py`. This step may take some time. Please refer to [LeanDojo's documentation](https://leandojo.readthedocs.io/en/latest/) if you encounter any issues.
7. Run `wandb login` to log in Weights & Biases.



## Premise Selection
Expand All @@ -256,28 +259,29 @@ The config files for our experiments are in [./retrieval/confs](./retrieval/conf

Run `python retrieval/main.py fit --help` to see how to use the training script. For example:
```bash
python retrieval/main.py fit --config retrieval/confs/cli_lean4_random.yaml # Train the retriever on the `random` split of LeanDojo Benchmark 4.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_novel_premises.yaml # Train the retriever on the `novel_premises` split of LeanDojo Benchmark 4.
mkdir logs # Create the directory for training logs.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_random.yaml --trainer.logger.name train_retriever_random --trainer.logger.save_dir logs/train_retriever_random # Train the retriever on the `random` split of LeanDojo Benchmark 4.
python retrieval/main.py fit --config retrieval/confs/cli_lean4_novel_premises.yaml --trainer.logger.name train_retriever_novel_premises --trainer.logger.save_dir logs/train_retriever_novel_premises # Train the retriever on the `novel_premises` split of LeanDojo Benchmark 4.
```
The training script saves hyperparameters, model checkpoints, and other information to `./lightning_logs/EXP_ID/`, where `EXP_ID` is an arbitrary experiment ID that will be printed by the training script.
Hyperparameters and model checkpoints are saved in `./logs/train_retriever_*`, and you can monitor the training process on Weights & Biases.


### Retrieving Premises for All Proof States

After the models are trained, run the following commands to retrieve premises for all proof states in the dataset.
```bash
python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT
python retrieval/main.py predict --config retrieval/confs/cli_lean4_novel_premises.yaml --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT
python retrieval/main.py predict --config retrieval/confs/cli_lean4_random.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_random --trainer.logger.save_dir logs/predict_retriever_random
python retrieval/main.py predict --config retrieval/confs/cli_lean4_novel_premises.yaml --ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --trainer.logger.name predict_retriever_novel_premises --trainer.logger.save_dir logs/predict_retriever_novel_premises
```
Retrieved premises are saved to `./lightning_logs/EXP_ID'/predictions.pickle`.
, where `PATH_TO_RETRIEVER_CHECKPOINT` is the model checkpoint produced in the previous step. Retrieved premises are saved to `./logs/predict_retriever_*/predictions.pickle`.


### Evaluating the Retrieved Premises

After predictions are saved, evaluate them using metrics such as R@1, R@10, and MRR.
```bash
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/random --preds-file PATH_TO_PREDICTIONS_PICKLE
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises --preds-file PATH_TO_PREDICTIONS_PICKLE
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/random --preds-file logs/predict_retriever_random/predictions.pickle
python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises --preds-file logs/predict_retriever_novel_premises/predictions.pickle
```


Expand All @@ -286,40 +290,51 @@ python retrieval/evaluate.py --data-path data/leandojo_benchmark_4/novel_premise

### Training the Tactic Generator

Similar to premise selection, you can run `python generator/main.py --help` and `python generator/main.py fit --help` to check the command line options.
Similar to premise selection, you can run `python generation/main.py --help` and `python generation/main.py fit --help` to check the command line options.

To train tactic generators without retrieval:
```bash
python generator/main.py fit --config generator/confs/cli_lean4_random.yaml # LeanDojo Benchmark 4, `random` split
python generator/main.py fit --config generator/confs/cli_lean4_novel_premises.yaml # LeanDojo Benchmark 4, `novel_premises` split
python generation/main.py fit --config generation/confs/cli_lean4_random.yaml --trainer.logger.name train_generator_random --trainer.logger.save_dir logs/train_generator_random # LeanDojo Benchmark 4, `random` split
python generation/main.py fit --config generation/confs/cli_lean4_novel_premises.yaml --trainer.logger.name train_generator_novel_premises --trainer.logger.save_dir logs/train_generator_novel_premises # LeanDojo Benchmark 4, `novel_premises` split
```
Hyperparameters and model checkpoints are saved in `./logs/train_generator_*`, and you can monitor the training process on Weights & Biases.

To train models augmented by retrieval, we need to provide a retriever checkpoint and its predictions on all proof states in the dataset:
```bash
python generator/main.py fit --config generator/confs/cli_lean4_random.yaml --model.ret_ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path PATH_TO_PREDICTIONS_PICKLE
python generator/main.py fit --config generator/confs/cli_lean4_novel_premises.yaml --model.ret_ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path PATH_TO_PREDICTIONS_PICKLE
python generation/main.py fit --config generation/confs/cli_lean4_random.yaml --model.ret_ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path logs/predict_retriever_random/predictions.pickle --trainer.logger.name train_reprover_random --trainer.logger.save_dir logs/train_reprover_random
python generation/main.py fit --config generation/confs/cli_lean4_novel_premises.yaml --model.ret_ckpt_path $PATH_TO_RETRIEVER_CHECKPOINT --data.preds_path logs/predict_retriever_novel_premises/predictions.pickle --trainer.logger.name train_reprover_novel_premises --trainer.logger.save_dir logs/train_reprover_novel_premises
```


### Theorem Proving Evaluation on LeanDojo Benchmark

After the tactic generator is trained, we combine it with best-first search to prove theorems by interacting with Lean.

For models without retrieval, run:
The evaluation script takes Hugging Face model checkpoints (either local or remote) as input. For remote models, you can simply use their names, e.g., [kaiyuy/leandojo-lean4-tacgen-byt5-small](https://huggingface.co/kaiyuy/leandojo-lean4-tacgen-byt5-small). For locally trained models, you first need to convert them from PyTorch Ligthning checkpoints to Hugging Face checkpoints:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python prover/evaluate.py --data-path data/leandojo_benchmark_4/novel_premises/ --ckpt_path PATH_TO_MODEL_CHECKPOINT --split test --num-workers 5 --num-gpus 1
python scripts/convert_checkpoint.py generator --src $PATH_TO_GENERATOR_CHECKPOINT --dst ./leandojo-lean4-tacgen-byt5-small
python scripts/convert_checkpoint.py retriever --src $PATH_TO_RETRIEVER_CHECKPOINT --dst ./leandojo-lean4-retriever-byt5-small
```
, where `PATH_TO_GENERATOR_CHECKPOINT` and `PATH_TO_RETRIEVER_CHECKPOINT` are PyTorch Ligthning checkpoints produced by the training script.


To evaluate the model without retrieval, run (using the `random` data split as example):
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --gen_ckpt_path ./leandojo-lean4-tacgen-byt5-small --split test --num-workers 5 --num-gpus 1
```
You may tweak `--num-workers` and `--num-gpus` to fit your hardware.


For models with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):
For the model with retrieval, first use the retriever to index the corpus (pre-computing the embeddings of all premises):
```bash
python retrieval/index.py --ckpt_path PATH_TO_RETRIEVER_CHECKPOINT --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path PATH_TO_INDEXED_CORPUS
# Do it separately for two data splits.
python retrieval/index.py --ckpt_path ./leandojo-lean4-retriever-byt5-small --corpus-path data/leandojo_benchmark_4/corpus.jsonl --output-path $PATH_TO_INDEXED_CORPUS
```
It saves the indexed corpurs as a pickle file to `PATH_TO_INDEXED_CORPUS`.

Then, run:
```bash
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --ckpt_path PATH_TO_REPROVER_CHECKPOINT --indexed-corpus-path PATH_TO_INDEXED_CORPUS --split test --num-cpus 8 --with-gpus
# Do it separately for two data splits.
python scripts/convert_checkpoint.py generator --src $PATH_TO_REPROVER_CHECKPOINT --dst ./leandojo-lean4-retriever-tacgen-byt5-small
python prover/evaluate.py --data-path data/leandojo_benchmark_4/random/ --gen_ckpt_path ./leandojo-lean4-retriever-tacgen-byt5-small --ret_ckpt_path ./leandojo-lean4-retriever-byt5-small --indexed-corpus-path $PATH_TO_INDEXED_CORPUS --split test --num-workers 5 --num-gpus 1
```


Expand Down
68 changes: 9 additions & 59 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pytorch_lightning.utilities.deepspeed import (
convert_zero_checkpoint_to_fp32_state_dict,
)
from transformers import get_cosine_schedule_with_warmup
from transformers import get_constant_schedule_with_warmup
from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam
from typing import Optional, List, Dict, Any, Tuple, Generator
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy
Expand Down Expand Up @@ -354,48 +354,14 @@ def get_all_pos_premises(annot_tac, corpus: Corpus) -> List[Premise]:
return list(all_pos_premises)


_SPACES_REGEX = re.compile(r"\s+", re.DOTALL)


def normalize_spaces(s: str) -> str:
"""Repalce any consecutive block of whitespace characters in ``s`` with a single whitespace."""
return _SPACES_REGEX.sub(" ", s).strip()


def format_tactic(annot_tac: str, provenances, normalize: bool) -> str:
"""Use full names for the all <a>...</a>."""
if normalize:
annot_tac = normalize_spaces(annot_tac)
if len(provenances) == 0:
return annot_tac

tac = ""
marks = list(re.finditer(r"<a>(?P<ident>.+?)</a>", annot_tac))

for i, (m, prov) in enumerate(zip_strict(marks, provenances)):
last_end = marks[i - 1].end() if i > 0 else 0
tac += annot_tac[last_end : m.start()] + "<a>" + prov["full_name"] + "</a>"

tac += annot_tac[marks[-1].end() :]
return tac


def format_state(s: str) -> str:
m = re.match(r"\d+ goals", s)
if m is not None:
return s[m.end() :].strip()
else:
return s


def format_augmented_state(
s: str, premises: List[Premise], max_len: int, p_drop: float
s: str, premises: List[Premise], max_len: Optional[int] = None, p_drop: float = 0.0
) -> str:
"""Format a state with retrieved premises and drop some of them with probability ``p_drop``."""
s = format_state(s)

aug_s = ""
length = 0
if max_len is None:
max_len = 9999999999999999999999
max_premises_len = max_len - len(bytes(s.encode("utf-8")))

for p in premises:
Expand Down Expand Up @@ -429,22 +395,7 @@ def get_optimizers(
logger.info("Optimizing with AdamW")
optimizer = torch.optim.AdamW(parameters, lr=lr)

if trainer.max_steps != -1:
max_steps = trainer.max_steps
else:
assert trainer.max_epochs is not None
max_steps = (
trainer.max_epochs
* len(trainer.datamodule.train_dataloader())
// trainer.accumulate_grad_batches
)

scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=max_steps,
)

scheduler = get_constant_schedule_with_warmup(optimizer, warmup_steps)
return {
"optimizer": optimizer,
"lr_scheduler": {
Expand All @@ -462,14 +413,13 @@ def _is_deepspeed_checkpoint(path: str):

def load_checkpoint(model_cls, ckpt_path: str, device, freeze: bool):
"""Handle DeepSpeed checkpoints in model loading."""
if not _is_deepspeed_checkpoint(ckpt_path):
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
else:
if _is_deepspeed_checkpoint(ckpt_path):
with tempfile.TemporaryDirectory() as dirname:
path = os.path.join(dirname, "lightning.cpkt")
convert_zero_checkpoint_to_fp32_state_dict(ckpt_path, path)
model = model_cls.load_from_checkpoint(path, strict=False)
model = model.to(device)
model = model_cls.load_from_checkpoint(path, strict=False).to(device)
else: # PyTorch Ligthning checkpoints
model = model_cls.load_from_checkpoint(ckpt_path, strict=False).to(device)
if freeze:
model.freeze()
return model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ trainer:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
gradient_clip_val: 1.0
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
name: null
save_dir: null
max_steps: 500000
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
Expand Down Expand Up @@ -46,12 +50,10 @@ model:
data:
data_path: data/leandojo_benchmark_4/novel_premises/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
keep_marks: true
preds_path: null
batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices
eval_batch_size: 64
max_inp_seq_len: 2300
max_oup_seq_len: 512
p_drop: 0.5
normalize_tactics: true
num_workers: 2
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ trainer:
stage: 2
offload_optimizer: false
cpu_checkpointing: false
gradient_clip_val: 1.0
logger:
class_path: pytorch_lightning.loggers.WandbLogger
init_args:
name: null
save_dir: null
max_steps: 500000
check_val_every_n_epoch: 1
num_sanity_val_steps: 0
Expand Down Expand Up @@ -46,12 +50,10 @@ model:
data:
data_path: data/leandojo_benchmark_4/random/
corpus_path: data/leandojo_benchmark_4/corpus.jsonl
keep_marks: true
preds_path: null
batch_size: 8 # effective_batch_size == batch_size * accumulate_grad_batches * devices
eval_batch_size: 64
max_inp_seq_len: 2300
max_oup_seq_len: 512
p_drop: 0.5
normalize_tactics: true
num_workers: 2
Loading
Loading