Skip to content

Commit

Permalink
8bit Optimisers and resolved potential Env issues
Browse files Browse the repository at this point in the history
The commit to setup.py was to resolve dependency issues on my machine, as certain features are deprecated in later versions of lightning and torchmetrics, these versions still have these features but present deprecation warnings

Also, regarding 8bit, replacing nn.Linear() with it's 8bit equivalent seems to break tensors, so I've opted for merely implementing the optimisers
  • Loading branch information
korakoe committed Mar 11, 2023
1 parent de94178 commit faa1281
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 14 deletions.
14 changes: 12 additions & 2 deletions muse_maskgit_pytorch/trainers/maskgit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchvision.utils import make_grid, save_image
from PIL import Image
from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
import bitsandbytes as bnb

from einops import rearrange

Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
only_save_last_checkpoint=False,
optimizer="Lion",
weight_decay=0.0,
use_8bit_adam=False
):
super().__init__(
dataloader,
Expand Down Expand Up @@ -104,11 +106,19 @@ def __init__(

# optimizers
if optimizer == "Adam":
self.optim = Adam(transformer_parameters, lr=lr, weight_decay=weight_decay)
if use_8bit_adam:
self.optim = bnb.optim.Adam8bit(transformer_parameters, lr=lr)
else:
self.optim = Adam(transformer_parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "AdamW":
self.optim = AdamW(transformer_parameters, lr=lr, weight_decay=weight_decay)
if use_8bit_adam:
self.optim = bnb.optim.AdamW8bit(transformer_parameters, lr=lr)
else:
self.optim = AdamW(transformer_parameters, lr=lr, weight_decay=weight_decay)
elif optimizer == "Lion":
self.optim = Lion(transformer_parameters, lr=lr, weight_decay=weight_decay)
if use_8bit_adam:
print("8bit is not supported with the Lion optimiser, Using standard Lion instead.")
else:
print(f"{optimizer} optimizer not supported yet.")

Expand Down
21 changes: 16 additions & 5 deletions muse_maskgit_pytorch/trainers/vqvae_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid, save_image

import bitsandbytes as bnb

from muse_maskgit_pytorch.vqgan_vae import VQGanVAE

Expand Down Expand Up @@ -74,6 +74,7 @@ def __init__(
only_save_last_checkpoint=False,
optimizer='Adam',
weight_decay=0.0,
use_8bit_adam=False
):
super().__init__(
dataloader,
Expand Down Expand Up @@ -102,14 +103,24 @@ def __init__(

# optimizers
if optimizer == 'Adam':
self.optim = Adam(vae_parameters, lr=lr, weight_decay=weight_decay)
self.discr_optim = Adam(discr_parameters, lr=lr, weight_decay=weight_decay)
if use_8bit_adam:
self.optim = bnb.optim.Adam8bit(vae_parameters, lr=lr)
self.discr_optim = bnb.optim.Adam8bit(discr_parameters, lr=lr)
else:
self.optim = Adam(vae_parameters, lr=lr)
self.discr_optim = Adam(discr_parameters, lr=lr)
elif optimizer == 'AdamW':
self.optim = AdamW(vae_parameters, lr=lr, weight_decay=weight_decay)
self.discr_optim = AdamW(discr_parameters, lr=lr)
if use_8bit_adam:
self.optim = bnb.optim.AdamW8bit(vae_parameters, lr=lr)
self.discr_optim = bnb.optim.AdamW8bit(discr_parameters, lr=lr)
else:
self.optim = AdamW(vae_parameters, lr=lr)
self.discr_optim = AdamW(discr_parameters, lr=lr)
elif optimizer == 'Lion':
self.optim = Lion(vae_parameters, lr=lr, weight_decay=weight_decay)
self.discr_optim = Lion(discr_parameters, lr=lr, weight_decay=weight_decay)
if use_8bit_adam:
print("8bit is not supported with the Lion optimiser, Using standard Lion instead.")
else:
print(f"{optimizer} optimizer not supported yet.")

Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"pillow",
"sentencepiece",
"torch>=1.6",
"torchmetrics<0.8.0",
"pytorch-lightning<=1.7.7",
"taming-transformers>=0.0.1",
"transformers",
"torch>=1.6",
Expand Down
6 changes: 6 additions & 0 deletions train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def parse_args():
choices=["no", "fp16", "bf16"],
help="Precision to train on.",
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether to use the 8bit adam optimiser",
)
parser.add_argument(
"--results_dir",
type=str,
Expand Down Expand Up @@ -384,6 +389,7 @@ def main():
only_save_last_checkpoint=args.only_save_last_checkpoint,
optimizer=args.optimizer,
weight_decay=args.weight_decay,
use_8bit_adam=args.use_8bit_adam
)

trainer.train()
Expand Down
25 changes: 18 additions & 7 deletions train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

import argparse

import torch.nn as nn
import bitsandbytes as bnb
from accelerate import init_empty_weights

def parse_args():
# Create the parser
Expand Down Expand Up @@ -110,6 +113,11 @@ def parse_args():
choices=["no", "fp16", "bf16"],
help="Precision to train on.",
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether to use the 8bit adam optimiser",
)
parser.add_argument(
"--results_dir",
type=str,
Expand Down Expand Up @@ -188,15 +196,17 @@ def parse_args():
default=None,
help="Path to the last saved checkpoint. 'results/vae.steps.pt'",
)
parser.add_argument("--optimizer",type=str,
parser.add_argument(
"--optimizer",type=str,
default='Lion',
help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Adam",
)
parser.add_argument("--weight_decay", type=float,
default=0.0,
help="Optimizer weight_decay to use. Default: 0.0",
)

)
parser.add_argument(
"--weight_decay", type=float,
default=0.0,
help="Optimizer weight_decay to use. Default: 0.0",
)

# Parse the argument
return parser.parse_args()

Expand Down Expand Up @@ -288,6 +298,7 @@ def main():
validation_image_scale=args.validation_image_scale,
only_save_last_checkpoint=args.only_save_last_checkpoint,
optimizer=args.optimizer,
use_8bit_adam=args.use_8bit_adam
)

trainer.train()
Expand Down

0 comments on commit faa1281

Please sign in to comment.