Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for both 3 and 4 channels images. #57

Merged
merged 2 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion infer_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@
parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.")
parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down Expand Up @@ -435,6 +437,7 @@ def main():
center_crop=True if not args.no_center_crop and not args.random_crop else False,
flip=not args.no_flip,
random_crop=args.random_crop if args.random_crop else False,
alpha_channel=False if args.channels == 3 else True,
)

if args.input_image and not args.input_folder:
Expand Down
28 changes: 25 additions & 3 deletions muse_maskgit_pytorch/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,23 @@ def __init__(
stream=False,
using_taming=False,
random_crop=False,
alpha_channel=True,
):
super().__init__()
self.dataset = dataset
self.image_column = image_column
self.stream = stream
transform_list = [
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Lambda(
lambda img: img.convert("RGBA")
if img.mode != "RGBA" and alpha_channel
else img
if img.mode == "RGB" and not alpha_channel
else img.convert("RGB")
),
T.Resize(image_size),
]

if flip:
transform_list.append(T.RandomHorizontalFlip())
if center_crop and not random_crop:
Expand Down Expand Up @@ -199,7 +207,15 @@ def __getitem__(self, index):

class LocalTextImageDataset(Dataset):
def __init__(
self, path, image_size, tokenizer, flip=True, center_crop=True, using_taming=False, random_crop=False
self,
path,
image_size,
tokenizer,
flip=True,
center_crop=True,
using_taming=False,
random_crop=False,
alpha_channel=False,
):
super().__init__()
self.tokenizer = tokenizer
Expand Down Expand Up @@ -229,7 +245,13 @@ def __init__(
self.caption_pair.append(captions)

transform_list = [
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Lambda(
lambda img: img.convert("RGBA")
if img.mode != "RGBA" and alpha_channel
else img
if img.mode == "RGB" and not alpha_channel
else img.convert("RGB")
),
T.Resize(image_size),
]
if flip:
Expand Down
9 changes: 5 additions & 4 deletions muse_maskgit_pytorch/vqgan_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(self, x):

# discriminator
class Discriminator(nn.Module):
def __init__(self, dims, channels=3, groups=16, init_kernel_size=5):
def __init__(self, dims, channels=4, groups=16, init_kernel_size=5):
super().__init__()
dim_pairs = zip(dims[:-1], dims[1:])

Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self,
dim: int,
*,
channels=3,
channels=4,
layers=4,
layer_mults=None,
num_resnet_blocks=1,
Expand Down Expand Up @@ -337,7 +337,7 @@ def __init__(
*,
dim: int,
accelerator: Accelerator = None,
channels=3,
channels=4,
layers=4,
l2_recon_loss=False,
use_hinge_loss=True,
Expand Down Expand Up @@ -407,7 +407,8 @@ def vgg(self):
if exists(self._vgg):
return self._vgg

vgg = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
vgg = torchvision.models.vgg16(pretrained=True)
vgg.features[0] = nn.Conv2d(self.channels, 64, kernel_size=3, stride=1, padding=1)
vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
self._vgg = vgg.to(self.device)
return self._vgg
Expand Down
5 changes: 4 additions & 1 deletion train_muse_maskgit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@
help="Image Size.",
)
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down Expand Up @@ -812,6 +814,7 @@ def main():
flip=False if args.no_flip else True,
using_taming=False if not args.taming_model_path else True,
random_crop=args.random_crop if args.random_crop else False,
alpha_channel=False if args.channels == 3 else True,
)
elif args.link:
if not args.dataset_name:
Expand Down
5 changes: 4 additions & 1 deletion train_muse_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@
)
parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
parser.add_argument("--channels", type=int, default=3, help="Number of channels for the VAE.")
parser.add_argument(
"--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
)
parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
parser.add_argument(
Expand Down Expand Up @@ -563,6 +565,7 @@ def main():
flip=not args.no_flip,
stream=args.streaming,
random_crop=args.random_crop,
alpha_channel=False if args.channels == 3 else True,
)
# dataloader

Expand Down