Skip to content

Commit

Permalink
muse-maskgit: more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
neggles committed May 2, 2023
1 parent a55ec95 commit 39f843b
Showing 1 changed file with 44 additions and 55 deletions.
99 changes: 44 additions & 55 deletions muse_maskgit_pytorch/muse_maskgit_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import math
from functools import partial
from os import PathLike
from pathlib import Path
from random import random
from typing import Callable, List, Optional, Union

import torch
import torch.nn.functional as F
import torchvision.transforms as T
from accelerate import Accelerator
from beartype import beartype
from einops import rearrange, repeat
from torch import einsum, nn
Expand All @@ -17,15 +19,14 @@
from .vqgan_vae import VQGanVAE
from .vqgan_vae_taming import VQGanVAETaming

# helpers


# helpers
def exists(val):
return val is not None


def default(val, d):
return val if exists(val) else d
return val if val is not None else d


def eval_decorator(fn):
Expand All @@ -44,8 +45,6 @@ def l2norm(t):


# tensor helpers


def get_mask_subset_prob(mask, prob, min_mask=0):
batch, seq, device = *mask.shape, mask.device
num_to_mask = (mask.sum(dim=-1, keepdim=True) * prob).clamp(min=min_mask)
Expand All @@ -63,8 +62,6 @@ def get_mask_subset_prob(mask, prob, min_mask=0):


# classes


class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
Expand Down Expand Up @@ -119,8 +116,7 @@ def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8):
def forward(self, x, context=None, context_mask=None):
assert not (exists(context) ^ self.cross_attend)

h, is_cross_attn = self.heads, exists(context)

h = self.heads
x = self.norm(x)

kv_input = context if self.cross_attend else x
Expand Down Expand Up @@ -185,21 +181,19 @@ def forward(self, x, context=None, context_mask=None):


# transformer - it's all we need


class Transformer(nn.Module):
def __init__(
self,
*,
num_tokens,
dim,
seq_len,
dim_out=None,
t5_name=DEFAULT_T5_NAME,
self_cond=False,
add_mask_id=False,
cache_path=None,
**kwargs
num_tokens: int,
dim: int,
seq_len: int,
dim_out: Optional[int] = None,
t5_name: str = DEFAULT_T5_NAME,
self_cond: bool = False,
add_mask_id: bool = False,
cache_path: PathLike = None,
**kwargs,
):
super().__init__()
self.dim = dim
Expand Down Expand Up @@ -230,7 +224,6 @@ def __init__(
)

# optional self conditioning

self.self_cond = self_cond
self.self_cond_to_init_embed = FeedForward(dim)

Expand All @@ -256,7 +249,7 @@ def forward_with_neg_prompt(
neg_text_embed: torch.Tensor,
cond_scale=3.0,
return_embed=False,
**kwargs
**kwargs,
):
neg_logits = self.forward(*args, neg_text_embed=neg_text_embed, cond_drop_prob=0.0, **kwargs)
pos_logits, embed = self.forward(
Expand Down Expand Up @@ -343,8 +336,6 @@ def forward(


# self critic wrapper


class SelfCritic(nn.Module):
def __init__(self, net):
super().__init__()
Expand All @@ -371,23 +362,21 @@ def forward(self, x, *args, labels=None, **kwargs):


# specialized transformers


class MaskGitTransformer(Transformer):
def __init__(self, *args, **kwargs):
assert "add_mask_id" not in kwargs
if kwargs.pop("add_mask_id", True) is not True:
raise ValueError("MaskGitTransformer does not accept add_mask_id argument")
super().__init__(*args, add_mask_id=True, **kwargs)


class TokenCritic(Transformer):
def __init__(self, *args, **kwargs):
assert "dim_out" not in kwargs
if kwargs.pop("dim_out", 1) != 1:
raise ValueError("TokenCritic does not accept dim_out argument")
super().__init__(*args, dim_out=1, **kwargs)


# classifier free guidance functions


def uniform(shape, min=0, max=1, device=None):
return torch.zeros(shape, device=device).float().uniform_(0, 1)

Expand Down Expand Up @@ -441,50 +430,46 @@ def __init__(
self,
image_size,
transformer: MaskGitTransformer,
accelerator: Optional[Accelerator] = None,
noise_schedule: Callable = cosine_schedule,
token_critic: Optional[TokenCritic] = None,
self_token_critic=False,
self_token_critic: bool = False,
vae: Optional[Union[VQGanVAE, VQGanVAETaming]] = None,
cond_vae: Optional[Union[VQGanVAE, VQGanVAETaming]] = None,
cond_image_size=None,
cond_drop_prob=0.5,
self_cond_prob=0.9,
no_mask_token_prob=0.0,
critic_loss_weight=1.0,
cond_image_size: Optional[int] = None,
cond_drop_prob: float = 0.5,
self_cond_prob: float = 0.9,
no_mask_token_prob: float = 0.0,
critic_loss_weight: float = 1.0,
):
super().__init__()
self.vae = vae.copy_for_eval() if exists(vae) else None
self.accelerator = accelerator

self.vae = vae.copy_for_eval() if vae is not None else None

if exists(cond_vae):
if cond_vae is not None:
if cond_image_size is None:
raise ValueError("cond_image_size must be specified if conditioning")
self.cond_vae = cond_vae.eval()
else:
self.cond_vae = self.vae

assert not (
exists(cond_vae) and not exists(cond_image_size)
), "cond_image_size must be specified if conditioning"

self.image_size = image_size
self.cond_image_size = cond_image_size
self.resize_image_for_cond_image = exists(cond_image_size)

self.cond_drop_prob = cond_drop_prob

self.transformer = transformer
self.self_cond = transformer.self_cond
assert (
self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens
), "transformer num_tokens must be set to be equal to the vae codebook size"
if not self.vae.codebook_size == self.cond_vae.codebook_size == transformer.num_tokens:
raise ValueError("transformer num_tokens must be set to be equal to the vae codebook size")

self.mask_id = transformer.mask_id
self.noise_schedule = noise_schedule

assert not (self_token_critic and exists(token_critic))
self.token_critic = token_critic

if self_token_critic:
self.token_critic = SelfCritic(transformer)

if token_critic and self_token_critic:
raise ValueError("cannot have both self_token_critic and token_critic")
self.token_critic = SelfCritic(transformer) if self_token_critic else token_critic
self.critic_loss_weight = critic_loss_weight

# self conditioning
Expand All @@ -495,12 +480,16 @@ def __init__(
self.no_mask_token_prob = no_mask_token_prob

def save(self, path):
self.accelerator.save(self.state_dict(), path)
if self.accelerator:
self.accelerator.save(self.state_dict(), path)
else:
torch.save(self.state_dict(), path)

def load(self, path):
path = Path(path)
assert path.exists()
state_dict = torch.load(str(path))
if not path.exists() and path.is_file():
raise ValueError(f"cannot find file {path} (does not exist or is not a file)")
state_dict = torch.load(str(path), map_location="cpu")
self.load_state_dict(state_dict)

@torch.no_grad()
Expand Down

0 comments on commit 39f843b

Please sign in to comment.