Skip to content

Commit

Permalink
Fix lm_eval_harness for GPT models (bigscience-workshop#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
conglongli authored Nov 15, 2023
1 parent 155ce98 commit 37050b8
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 17 deletions.
3 changes: 2 additions & 1 deletion examples_deepspeed/MoE/ds_evalharness.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ TASKS="lambada"
VOCAB_FILE=/data/Megatron-LM/data/gpt2-vocab.json
MERGE_FILE=/data/Megatron-LM/data/gpt2-merges.txt

export HF_DATASETS_OFFLINE=1
# export HF_DATASETS_OFFLINE=1

# Dummy arguments to make megatron happy. No need to configure them.
# The reason we don't need to configure them and many other arguments is
Expand All @@ -53,6 +53,7 @@ CMD="../../tasks/eval_harness/evaluate.py \
--no-load-rng \
--inference \
--disable-moe-token-dropping \
--tokenizer-type GPT2BPETokenizer \
--adaptive_seq_len\
--eval_fp32\
--task_list $TASKS\
Expand Down
8 changes: 4 additions & 4 deletions examples_deepspeed/MoE/readme_evalharness.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ This particular setup uses the normal deepspeed checkpoint and requires no conve
On login console with external network

Get lm-eval harness (https://github.com/EleutherAI/lm-evaluation-harness) and `best-download==0.0.7` needed to download some tasks.
Below package version numbers are what we tested that work.
```
(maybe need pip install --upgrade pip)
pip install best-download==0.0.7
pip install lm-eval
(previously we used "pip install git+https://github.com/EleutherAI/lm-evaluation-harness" to install, but later found the command above has less dependency issues)
pip install best-download==0.0.7 lm-eval==0.2.0 datasets==1.15.1 transformers==4.20.1 huggingface-hub==0.8.1
```

2. Pre-download needed datasets
Expand All @@ -33,7 +32,8 @@ Then install datasets for the tasks:
```
python ../../tasks/eval_harness/download.py --task_list hellaswag,lambada,triviaqa,webqs,winogrande,piqa,arc_challenge,arc_easy,openbookqa,race,boolq,cb,copa,rte,wic,wsc,multirc,record,anli_r1,anli_r2,anli_r3,wikitext,logiqa,mathqa,mc_taco,mrpc,prost,pubmedqa,qnli,qqp,sciq,sst,wnli
```
and make sure that `export HF_DATASETS_OFFLINE=1`

Previously we set `export HF_DATASETS_OFFLINE=1` to make the dataset offline after the above manual download. But somehow now this could trigger error on some kind of online verification for some of the datasets, so it's recommended to only set offline mode when necessary.

<!-- If there are things like custom tokenizers, pre-download those too, e.g.:
Expand Down
5 changes: 3 additions & 2 deletions examples_deepspeed/compression/ds_evalharness.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This is an example zero-shot eval script. Please first read the readme_evalharness.md under the same directory.
# This is an example zero-shot eval script. Please first read the readme_evalharness.md under the ../MoE directory.

# CHECKPOINT_PATH=/blob/users/minjiaz/compression_library/checkpoint/125M10L_Compression_Test_INT8_64gpu_lr6e-5_tokens5.25B_nocl_alpha-no_pp/global_step2000/
# CHECKPOINT_PATH=/blob/users/conglli/project/gpt3_with_pile/checkpoint/gpt3-with-pile-0.125B-lr-2.4e-3-minlr-6.0e-5-bs-2048-gpus-64-zero-0-mp-1-pp-1-no_pp-cl-startseqlen-72-step-27638-token-60B/global_step71000/
Expand Down Expand Up @@ -31,7 +31,7 @@ TASKS="lambada,wikitext"
VOCAB_FILE=/blob/data/the_pile_public_merged_nopreprocessing/gpt2-vocab.json
MERGE_FILE=/blob/data/the_pile_public_merged_nopreprocessing/gpt2-merges.txt

export HF_DATASETS_OFFLINE=1
# export HF_DATASETS_OFFLINE=1

# Dummy arguments to make megatron happy. No need to configure them.
# The reason we don't need to configure them and many other arguments is
Expand All @@ -56,6 +56,7 @@ CMD="../../tasks/eval_harness/evaluate.py \
--no-load-rng \
--inference \
--disable-moe-token-dropping \
--tokenizer-type GPT2BPETokenizer \
--adaptive_seq_len\
--eval_fp32\
--task_list $TASKS\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ if [ ! -f "$merge_file" ]; then
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt
fi

export HF_DATASETS_OFFLINE=1
# export HF_DATASETS_OFFLINE=1

dir2=$(dirname "$checkpoint_path")
dirname=$(basename "$dir2")/$(basename "$checkpoint_path")
Expand Down Expand Up @@ -58,6 +58,7 @@ command="../../../../tasks/eval_harness/evaluate.py \
--no-load-rng \
--inference \
--disable-moe-token-dropping \
--tokenizer-type GPT2BPETokenizer \
--adaptive_seq_len \
--eval_fp32 \
--num_fewshot ${num_fewshot} \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ num_fewshot=0
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
cuda_id=-1
total_mem=$(nvidia-smi --query-gpu=memory.total --format=csv -i 0 | grep -Eo [0-9]+)
total_mem=$(( ${total_mem}*99/100 )) # somehow there could exist tiny (4MB or so) gpu memory leak

## Code below only works when you run each evalharness task on a single GPU.
## For multi-GPU evalharness, check Megatron-DeepSpeed/blob/main/examples_deepspeed/MoE/ds_evalharness.sh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ batch_size=16
num_gpus=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
cuda_id=-1
total_mem=$(nvidia-smi --query-gpu=memory.total --format=csv -i 0 | grep -Eo [0-9]+)
total_mem=$(( ${total_mem}*99/100 )) # somehow there could exist tiny (4MB or so) gpu memory leak

## Code below only works when you run each evalharness task on a single GPU.
## For multi-GPU evalharness, check Megatron-DeepSpeed/blob/main/examples_deepspeed/MoE/ds_evalharness.sh
Expand Down
22 changes: 13 additions & 9 deletions tasks/eval_harness/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.core.enums import ModelType
from megatron.core import mpu
from megatron.training import setup_model_and_optimizer, get_model
from megatron.mpu.mappings import gather_from_tensor_model_parallel_region
from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region

from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
Expand Down Expand Up @@ -222,8 +223,7 @@ def _model_call(self, inps):
a_output, *other_losses = self.model(tokens,
position_ids,
attention_mask,
tokentype_ids=None,
forward_method_parallel_output=False)
tokentype_ids=None)
output.append(a_output)

if output is not None:
Expand Down Expand Up @@ -320,7 +320,7 @@ def load_ds_checkpoint_and_setup_megatron(extra_args_provider):
# avoid printing the arguments, since they will later be overridden.
_print_args = megatron.arguments._print_args
megatron.arguments._print_args = lambda *_args, **kwarg: None
args = _parse_args(extra_args_provider)
args = parse_args(extra_args_provider=extra_args_provider)

ds_checkpoint = DeepSpeedCheckpoint(args.load,
tp_degree=args.tensor_model_parallel_size,
Expand All @@ -340,20 +340,24 @@ def load_ds_checkpoint_and_setup_megatron(extra_args_provider):
cp_args.bf16 = False
cp_args.params_dtype = torch.float32

cp_args.tokenizer_type = 'GPT2BPETokenizer'

override_args(args, cp_args, skip_keys, skip_if_specified)

# stop megatron from reparsing the arguments.
megatron.global_vars._parse_args = lambda *_args, **kwarg: args
megatron.arguments.parse_args = lambda *_args, **kwarg: args
megatron.global_vars._ensure_var_is_not_initialized = lambda *_args, **kwarg: None
megatron.global_vars._GLOBAL_ARGS = args

initialize_megatron()
initialize_megatron(extra_args_provider=extra_args_provider)
megatron.global_vars._GLOBAL_ARGS = args
torch.distributed.barrier()

# Initializing megatron will update eg. tokenizer size. Override again.
override_args(args, cp_args, skip_keys, skip_if_specified)

# print final arguments.
_print_args(args)
_print_args("eval_harness arguments", args)
if args.deepspeed:

# Hack #3:
Expand All @@ -369,7 +373,7 @@ def load_ds_checkpoint_and_setup_megatron(extra_args_provider):

cp_path = args.load
args.load = None
model, _, _ = setup_model_and_optimizer(model_provider)
model, _, _ = setup_model_and_optimizer(model_provider, ModelType.encoder_or_decoder)
model = model[0]
zero_enabled = model._config.zero_enabled
model._config.zero_enabled = False
Expand Down Expand Up @@ -399,7 +403,7 @@ def tasks_args(parser):
group.add_argument('--eval_fp32', default = False, action='store_true', help='Should the evaluation run in fp32')
return parser

from megatron.global_vars import _parse_args
from megatron.arguments import parse_args

def main():
start = time.time()
Expand Down

0 comments on commit 37050b8

Please sign in to comment.