Skip to content

Commit

Permalink
final push
Browse files Browse the repository at this point in the history
  • Loading branch information
thevasudevgupta committed Apr 12, 2021
1 parent e1899b5 commit 45125ec
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 36 deletions.
77 changes: 43 additions & 34 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class TrainerConfig(DefaultArgs):
single_file: bool = False

src_lang: str = 'hi_IN'
max_length: int = 32
max_target_length: int = 32
max_length: int = 40
max_target_length: int = 40

tr_max_samples: int = -1
val_max_samples: int = -1
Expand Down Expand Up @@ -106,25 +106,27 @@ class TrainerConfig(DefaultArgs):
add_layer_norm_after=False))


iitb_hin = TrainerConfig(tgt_file='/data/parallel/IITB.en-hi.en',
src_file='/data/parallel/IITB.en-hi.hi',
iitb_hin = TrainerConfig(tgt_file='parallel/IITB.en-hi.en',
src_file='parallel/IITB.en-hi.hi',
src_lang="hi_IN",
max_length=32,
max_target_length=32,
base_dir="iitb_base_dir")

bhasha_hin = TrainerConfig(tgt_file="/data/pib-v1.3/en-hi/train.en",
src_file="/data/pib-v1.3/en-hi/train.hi",
bhasha_hin = TrainerConfig(tgt_file="pib-v1.3/en-hi/train.en",
src_file="pib-v1.3/en-hi/train.hi",
src_lang="hi_IN",
max_length=40,
max_target_length=40)
max_target_length=40,
test_size=0.03)

bhasha_guj = TrainerConfig(tgt_file="/data/pib-v1.3/en-gu/train.en",
src_file="/data/pib-v1.3/en-gu/train.gu",
bhasha_guj = TrainerConfig(tgt_file="pib-v1.3/en-gu/train.en",
src_file="pib-v1.3/en-gu/train.gu",
src_lang="gu_IN",
max_length=40,
max_target_length=40,
base_dir="guj_base_dir")
base_dir="guj_base_dir",
test_size=0.13)

config_adapt_sa_ffn = replace(bhasha_hin,
enc_ffn_adapter=True,
Expand Down Expand Up @@ -162,18 +164,6 @@ class TrainerConfig(DefaultArgs):
max_length=40,
max_target_length=40)

best_adapters_hin = replace(freeze_model_hin,
enc_self_attn_adapter=True,
dec_ffn_adapter=True,
enc_tok_embed_adapter=True,
dec_tok_embed_adapter=True,
save_adapter_path="adapter",
# load_adapter_path="adapter.pt",
base_dir="final-best-adapters-hindi",
wandb_run_name="final-best-adapters-hin",
finetuned_id="offnote-mbart-adapters-hin-eng")


freeze_model_guj = replace(bhasha_guj,
base_dir="tr_dec-ffn_enc-attn_embed_hin2000,400",
wandb_run_name="tr_dec-ffn_enc-attn_embed_hin2000,400",
Expand All @@ -193,6 +183,19 @@ class TrainerConfig(DefaultArgs):
max_target_length=40)


best_adapters_hin = replace(freeze_model_hin,
enc_self_attn_adapter=True,
dec_ffn_adapter=True,
enc_tok_embed_adapter=True,
dec_tok_embed_adapter=True,
save_adapter_path="adapter",
# load_adapter_path="adapter.pt",
base_dir="final-best-adapters-hindi",
wandb_run_name="final-best-adapters-hin",
finetuned_id="offnote-mbart-adapters-hin-eng",
lr=1e-3,
max_epochs=5)

best_adapters_guj = replace(freeze_model_guj,
enc_self_attn_adapter=True,
dec_ffn_adapter=True,
Expand All @@ -202,19 +205,25 @@ class TrainerConfig(DefaultArgs):
# load_adapter_path="adapter.pt",
base_dir="final-best-adapters-guj",
wandb_run_name="final-best-adapters-guj",
finetuned_id="offnote-mbart-adapters-guj-eng")

# run = replace(freeze_model,
# base_dir="embed-adapter_bhasha-hin0.1M,20K",
# wandb_run_name="embed-adapter_bhasha-hin0.1M,20K",
# enc_tok_embed_adapter=True,
# dec_tok_embed_adapter=True,
# # save_epoch_dir="epoch-wts",
# save_adapter_path="adapter.pt")

check = replace(bhasha_hin, tr_max_samples=100, val_max_samples=30, wandb_run_name="random")
finetuned_id="offnote-mbart-adapters-guj-eng",
lr=1e-3,
max_epochs=5)

full_train_guj = replace(bhasha_guj,
base_dir="mbart-bhasha-guj-eng",
wandb_run_name="mbart-bhasha-guj-eng",
finetuned_id="mbart-bhasha-guj-eng",
lr=5e-5,
max_epochs=3)

full_train_hin = replace(bhasha_hin,
base_dir="mbart-bhasha-hin-eng",
wandb_run_name="mbart-bhasha-hin-eng",
finetuned_id="mbart-bhasha-hin-eng",
lr=5e-5,
max_epochs=3)

# this is getting called in `main.py`
main = check
# main = full_train_hin
# if sweep is defined then these args will work like default
# and will be overwritten by wandb
2 changes: 2 additions & 0 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

# python train.py --config "best_adapters_guj"
# python train.py --config "best_adapters_hin"
# python train.py --config "full_train_guj"
# python train.py --config "full_train_hin"

if __name__ == '__main__':

Expand Down
Binary file not shown.
7 changes: 5 additions & 2 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@
if version.parse(torch.__version__) <= version.parse("1.4.1"):
SAVE_STATE_WARNING = ""
else:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING

try:
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
except:
pass

logger = logging.get_logger(__name__)


Expand Down

0 comments on commit 45125ec

Please sign in to comment.