Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Add adapter #1545

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
34 changes: 33 additions & 1 deletion scripts/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,36 @@ here are some results with their hyperparameters
| CoLA | Matthew Corr. | 2e-5 | 32 | 7800 | 10 | 59.23 | https://tensorboard.dev/experiment/33euRGh9SrW3p15JWgILnw/ |
| RTE | Accuracy | 2e-5 | 32 | 1800 | 10 | 69.67 | https://tensorboard.dev/experiment/XjTxr5anRrC1LMukLJJQ3g/|
| MRPC | Accuracy/F1 | 3e-5 | 32 | 7800 | 5 | 85.38/87.31 | https://tensorboard.dev/experiment/jEJFq2XXQ8SvCxt6eKIjwg/ |
| MNLI | Accuracy(m/mm) | 2e-5 | 48 | 7800 | 5 | 84.90/85.10 | https://tensorboard.dev/experiment/CZQlOBedRQeTZwn5o5fbKQ/ |
| MNLI | Accuracy(m/mm) | 2e-5 | 48 | 7800 | 4 | 84.90/85.10 | https://tensorboard.dev/experiment/CZQlOBedRQeTZwn5o5fbKQ/ |


## different method
We also offer different finetune method to save time and space. So now we offer two different method:
bias-finetune() and adapter-finetune. To use them, you can directly add an augment "method" like:
```bash
python train_classification.py \
--model_name google_en_uncased_bert_base \
--method adapter \
--task_name mrpc \
--lr 4.5e-4\
--model_name google_en_cased_bert_base \
--batch_size 32 \
--do_train \
--do_eval \
--seed 7800 \
--epochs 10 \
--optimizer adamw \
--train_dir glue/mrpc/train.parquet \
--eval_dir glue/mrpc/dev.parquet \
--gpus 1
```
And here are some result of different method(the blank means we can't find proper hyperparameter until now)

| task Name | metirc | full | bias-finetune | adapter |
|-----------|-------------|-------------|-------------|-------------|
| SST | Accuracy | 93.23 | | 93.46 |
| STS | Pearson Corr. | 89.26 | 89.30 | 89.70 |
| CoLA | Matthew Corr. | 59.23 | | 61.20 |
| RTE | Accuracy | 69.67 | 69.31 | 70.75 |
| MRPC | Accuracy/F1 | 85.38/87.31 | 85.29/88.63 | 87.74/91.39|
| MNLI | Accuracy(m/mm) | 84.90/85.10 |
91 changes: 77 additions & 14 deletions scripts/classification/train_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import json
import random
import pandas as pd
import mxnet.numpy_extension as _mx_npx
import os
import json
import logging
import time
import argparse
Expand Down Expand Up @@ -92,13 +94,36 @@ def parse_args():
help='the path to training dataset')
parser.add_argument('--warmup_ratio', type=float, default=0.1,
help='Ratio of warmup steps in the learning rate scheduler.')
parser.add_argument('--method', type=str, default='full', choices=['full', 'bias', 'adapter', 'last_layer'],
help='different finetune method')


args = parser.parse_args()
return args


def change_adapter_cfg(cfg, task):
adapter_config = {
'location_0':{
'adapter_fusion':False,
'pre_operator':False,
'task_names':[task.task_name],
task.task_name:{'type':'Basic','units':64, 'activation':'gelu'}},
'location_1':{
'adapter_fusion':False,
'pre_operator':False,
'task_names':[task.task_name],
task.task_name:{'type':'Basic','units':64, 'activation':'gelu'}}
}
cfg.defrost()
cfg.MODEL.use_adapter = True
cfg.MODEL.adapter_config = json.dumps(adapter_config)
cfg.freeze()
return cfg

def get_network(model_name,
ctx_l,
method='full',
checkpoint_path=None,
backbone_path=None,
task=None):
Expand All @@ -109,13 +134,15 @@ def get_network(model_name,
use_segmentation = 'roberta' not in model_name and 'xlmr' not in model_name
Model, cfg, tokenizer, download_params_path, _ = \
get_backbone(model_name, load_backbone=not backbone_path)

if method == 'adapter':
cfg = change_adapter_cfg(cfg, task)
backbone = Model.from_cfg(cfg)
# Load local backbone parameters if backbone_path provided.
# Otherwise, download backbone parameters from gluon zoo.

backbone_params_path = backbone_path if backbone_path else download_params_path
if checkpoint_path is None:
backbone.load_parameters(backbone_params_path, ignore_extra=True,
backbone.load_parameters(backbone_params_path, ignore_extra=True, allow_missing=(args.method != 'full'),
ctx=ctx_l, cast_dtype=True)
num_params, num_fixed_params \
= count_parameters(deduplicate_param_dict(backbone.collect_params()))
Expand Down Expand Up @@ -219,20 +246,23 @@ def train(args):
#random seed
set_seed(args.seed)
level = logging.INFO
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
detail_dir = os.path.join(args.output_dir, args.task_name)
if not os.path.exists(detail_dir):
os.mkdir(detail_dir)
logging_config(detail_dir,
name='train_{}_{}_'.format(args.task_name, args.model_name) + str(rank), # avoid race
name='train_{}_{}_{}_'.format(args.task_name, args.model_name, args.method) + str(rank), # avoid race
level=level,
console=(local_rank == 0))
logging.info(args)
cfg, tokenizer, classify_net, use_segmentation = \
get_network(args.model_name, ctx_l,
get_network(args.model_name, ctx_l, args.method,
args.param_checkpoint,
args.backbone_path,
task)


logging.info('Prepare training data')
train_data, _ = get_task_data(args, task, tokenizer, segment='train')
train_batchify = bf.Group(bf.Group(bf.Pad(), bf.Pad(), bf.Stack()),
Expand All @@ -253,6 +283,26 @@ def train(args):
sampler=sampler)


if args.method == 'full':
target_params_name = classify_net.collect_params().keys()
elif args.method == 'bias':
target_params_name = [key
for key in classify_net.collect_params() if
key.endswith('bias') or key.endswith('beta') or 'out_proj' in key]
elif args.method == 'adapter':
target_params_name = [key
for key in classify_net.collect_params() if
'adapter' in key or 'out_proj' in key]
elif args.method == 'last_layer':
target_params_name = [key
for key in classify_net.collect_params() if
'out_proj' in key]
for name in classify_net.collect_params():
if name not in target_params_name:
classify_net.collect_params()[name].grad_req = 'null'

target_params = {name:classify_net.collect_params()[name] for name in target_params_name}


param_dict = classify_net.collect_params()
# Do not apply weight decay to all the LayerNorm and bias
Expand All @@ -269,7 +319,7 @@ def train(args):
if local_rank == 0:
writer = SummaryWriter(logdir=os.path.join(args.output_dir,
args.task_name + '_tensorboard_' +
str(args.lr) + '_' + str(args.epochs)))
str(args.lr) + '_' + str(args.epochs) + '_' + str(args.method)))
if args.comm_backend == 'horovod':
# Horovod: fetch and broadcast parameters
hvd.broadcast_parameters(param_dict, root_rank=0)
Expand All @@ -290,10 +340,12 @@ def train(args):
optimizer_params = {'learning_rate': args.lr,
'wd': args.wd,
'lr_scheduler': lr_scheduler}


if args.comm_backend == 'horovod':
trainer = hvd.DistributedTrainer(param_dict, args.optimizer, optimizer_params)
trainer = hvd.DistributedTrainer(target_params, args.optimizer, optimizer_params)
else:
trainer = mx.gluon.Trainer(classify_net.collect_params(),
trainer = mx.gluon.Trainer(target_params,
'adamw',
optimizer_params)

Expand Down Expand Up @@ -376,16 +428,22 @@ def train(args):
log_gnorm = 0
log_step = 0
if local_rank == 0 and (i == max_update - 1 or i%(max_update//args.epochs) == 0 and i>0):
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
args.task_name,
(i + 1))
ckpt_name = '{}_{}_{}_{}.params'.format(args.model_name,
args.task_name,
(i + 1),
args.method)

tmp_params = classify_net._collect_params_with_prefix()
params_saved = os.path.join(detail_dir, ckpt_name)
classify_net.save_parameters(params_saved)
arg_dict = {key: tmp_params[key]._reduce() for key in target_params}
_mx_npx.savez(params_saved, **arg_dict)
logging.info('Params saved in: {}'.format(params_saved))
for metric in metrics:
metric.reset()

end_time = time.time()
logging.info('Total costs:{}'.format(end_time - start_time))



def evaluate(args):
Expand All @@ -410,19 +468,24 @@ def evaluate(args):
str(ctx_l)))

cfg, tokenizer, classify_net, use_segmentation = \
get_network(args.model_name, ctx_l,
get_network(args.model_name, ctx_l, args.method,
args.param_checkpoint,
args.backbone_path,
task)

candidate_ckpt = []
detail_dir = os.path.join(args.output_dir, args.task_name)
for name in os.listdir(detail_dir):
if name.endswith('.params') and args.task_name in name and args.model_name in name:
if name.endswith(args.method + '.params') and args.task_name in name and args.model_name in name:
candidate_ckpt.append(os.path.join(detail_dir, name))
candidate_ckpt.sort(reverse=False)
best_ckpt = {}
metrics = task.metric
def evaluate_by_ckpt(ckpt_name, best_ckpt):
classify_net.load_parameters(ckpt_name, ctx=ctx_l, cast_dtype=True)
loaded = _mx_npx.load(ckpt_name)
full_dict = {'params': loaded, 'filename': ckpt_name}
classify_net.load_dict(full_dict, ctx_l, allow_missing=True,
ignore_extra=True, cast_dtype=True)
logging.info('Prepare dev data')

dev_data, label = get_task_data(args, task, tokenizer, segment='eval')
Expand Down
Loading