Skip to content

Commit

Permalink
dynamic context noising
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-q committed Nov 8, 2024
1 parent 83110fb commit abe07fa
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 166 deletions.
4 changes: 1 addition & 3 deletions attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def forward(self, x: torch.Tensor):

q, k, v = map(lambda t: t.contiguous(), (q, k, v))

x = F.scaled_dot_product_attention(
query=q, key=k, value=v, is_causal=self.is_causal
)
x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=self.is_causal)

x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
x = x.to(q.dtype)
Expand Down
82 changes: 24 additions & 58 deletions dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,13 @@ def __init__(
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten

self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x, random_sample=False):
B, C, H, W = x.shape
assert (
random_sample or (H == self.img_size[0] and W == self.img_size[1])
assert random_sample or (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
Expand All @@ -83,9 +81,7 @@ class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size, hidden_size, bias=True
), # hidden_size is diffusion model hidden size
nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
Expand All @@ -103,17 +99,13 @@ def timestep_embedding(t, dim, max_period=10000):
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=t.device
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding

def forward(self, t):
Expand All @@ -130,12 +122,8 @@ class FinalLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))

def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
Expand Down Expand Up @@ -173,9 +161,7 @@ def __init__(
act_layer=approx_gelu,
drop=0,
)
self.s_adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))

self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.t_attn = TemporalAxialAttention(
Expand All @@ -192,34 +178,24 @@ def __init__(
act_layer=approx_gelu,
drop=0,
)
self.t_adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))

def forward(self, x, c):
B, T, H, W, D = x.shape

# spatial block
s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = (
self.s_adaLN_modulation(c).chunk(6, dim=-1)
)
x = x + gate(
self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa
)
x = x + gate(
self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp
s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(
6, dim=-1
)
x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)

# temporal block
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = (
self.t_adaLN_modulation(c).chunk(6, dim=-1)
)
x = x + gate(
self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa
)
x = x + gate(
self.t_mlp(modulate(self.t_norm2(x), t_shift_mlp, t_scale_mlp)), t_gate_mlp
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(
6, dim=-1
)
x = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
x = x + gate(self.t_mlp(modulate(self.t_norm2(x), t_shift_mlp, t_scale_mlp)), t_gate_mlp)

return x

Expand Down Expand Up @@ -249,21 +225,13 @@ def __init__(
self.num_heads = num_heads
self.max_frames = max_frames

self.x_embedder = PatchEmbed(
input_h, input_w, patch_size, in_channels, hidden_size, flatten=False
)
self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
self.t_embedder = TimestepEmbedder(hidden_size)
frame_h, frame_w = self.x_embedder.grid_size

self.spatial_rotary_emb = RotaryEmbedding(
dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256
)
self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
self.external_cond = (
nn.Linear(external_cond_dim, hidden_size)
if external_cond_dim > 0
else nn.Identity()
)
self.external_cond = nn.Linear(external_cond_dim, hidden_size) if external_cond_dim > 0 else nn.Identity()

self.blocks = nn.ModuleList(
[
Expand Down Expand Up @@ -340,9 +308,7 @@ def forward(self, x, t, external_cond=None):

# add spatial embeddings
x = rearrange(x, "b t c h w -> (b t) c h w")
x = self.x_embedder(
x
) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
# restore shape
x = rearrange(x, "(b t) h w d -> b t h w d", t=T)
# embed noise steps
Expand Down
68 changes: 30 additions & 38 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@


def main(args):
torch.manual_seed(0)
torch.cuda.manual_seed(0)

# load DiT checkpoint
model = DiT_models["DiT-S/2"]()
print(f"loading Oasis-500M from oasis-ckpt={os.path.abspath(args.oasis_ckpt)}...")
Expand All @@ -42,14 +45,14 @@ def main(args):
vae = vae.to(device).eval()

# sampling params
B = 1
n_prompt_frames = args.n_prompt_frames
total_frames = args.num_frames
max_noise_level = 1000
ddim_noise_steps = args.ddim_steps
noise_range = torch.linspace(-1, max_noise_level - 1, ddim_noise_steps + 1)
noise_abs_max = 20
ctx_max_noise_idx = ddim_noise_steps // 10 * 3
stabilization_level = 15

# get prompt image/video
x = load_prompt(
Expand All @@ -58,30 +61,28 @@ def main(args):
n_prompt_frames=n_prompt_frames,
)
# get input action stream
actions = load_actions(args.actions_path, action_offset=args.video_offset)[
:, :total_frames
]
actions = load_actions(args.actions_path, action_offset=args.video_offset)[:, :total_frames]
# x = torch.load("xs_original_0.pt")[6:7]
# actions = torch.load("external_cond_0.pt")[6:7, :total_frames]
# actions[:, :1] = torch.zeros_like(actions[:, :1])

# sampling inputs
x = x.to(device)
actions = actions.to(device)

# vae encoding
B = x.shape[0]
H, W = x.shape[-2:]
scaling_factor = 0.07843137255
x = rearrange(x, "b t c h w -> (b t) c h w")
H, W = x.shape[-2:]
with torch.no_grad():
x = vae.encode(x * 2 - 1).mean * scaling_factor
x = rearrange(
x,
"(b t) (h w) c -> b t c h w",
t=n_prompt_frames,
h=H // vae.patch_size,
w=W // vae.patch_size,
)
with autocast("cuda", dtype=torch.half):
x = vae.encode(x * 2 - 1).mean * scaling_factor
x = rearrange(x, "(b t) (h w) c -> b t c h w", t=n_prompt_frames, h=H // vae.patch_size, w=W // vae.patch_size)
x = x[:, :n_prompt_frames]

# get alphas
betas = sigmoid_beta_schedule(max_noise_level).to(device)
betas = sigmoid_beta_schedule(max_noise_level).float().to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod = rearrange(alphas_cumprod, "T -> T 1 1 1")
Expand All @@ -96,15 +97,12 @@ def main(args):
for noise_idx in reversed(range(1, ddim_noise_steps + 1)):
# set up noise values
ctx_noise_idx = min(noise_idx, ctx_max_noise_idx)
t_ctx = torch.full(
(B, i), noise_range[ctx_noise_idx], dtype=torch.long, device=device
)
t = torch.full(
(B, 1), noise_range[noise_idx], dtype=torch.long, device=device
)
t_next = torch.full(
(B, 1), noise_range[noise_idx - 1], dtype=torch.long, device=device
)
t_ctx = torch.full((B, i), noise_range[ctx_noise_idx], dtype=torch.long, device=device)
# t_ctx = torch.full(
# (B, i), stabilization_level - 1, dtype=torch.long, device=device
# )
t = torch.full((B, 1), noise_range[noise_idx], dtype=torch.long, device=device)
t_next = torch.full((B, 1), noise_range[noise_idx - 1], dtype=torch.long, device=device)
t_next = torch.where(t_next < 0, t, t_next)
t = torch.cat([t_ctx, t], dim=1)
t_next = torch.cat([t_ctx, t_next], dim=1)
Expand All @@ -119,27 +117,23 @@ def main(args):
ctx_noise = torch.randn_like(x_curr[:, :-1])
ctx_noise = torch.clamp(ctx_noise, -noise_abs_max, +noise_abs_max)
x_curr[:, :-1] = (
alphas_cumprod[t[:, :-1]].sqrt() * x_curr[:, :-1]
+ (1 - alphas_cumprod[t[:, :-1]]).sqrt() * ctx_noise
alphas_cumprod[t[:, :-1]].sqrt() * x_curr[:, :-1] + (1 - alphas_cumprod[t[:, :-1]]).sqrt() * ctx_noise
)

# get model predictions
with torch.no_grad():
with autocast("cuda", dtype=torch.half):
v = model(x_curr, t, actions[:, start_frame : i + 1])

x_start = (
alphas_cumprod[t].sqrt() * x_curr - (1 - alphas_cumprod[t]).sqrt() * v
)
x_noise = ((1 / alphas_cumprod[t]).sqrt() * x_curr - x_start) / (
1 / alphas_cumprod[t] - 1
).sqrt()
x_start = alphas_cumprod[t].sqrt() * x_curr - (1 - alphas_cumprod[t]).sqrt() * v
x_noise = ((1 / alphas_cumprod[t]).sqrt() * x_curr - x_start) / (1 / alphas_cumprod[t] - 1).sqrt()

# get frame prediction
x_pred = (
alphas_cumprod[t_next].sqrt() * x_start
+ x_noise * (1 - alphas_cumprod[t_next]).sqrt()
)
alpha_next = alphas_cumprod[t_next]
alpha_next[:, :-1] = torch.ones_like(alpha_next[:, :-1])
if noise_idx == 1:
alpha_next[:, -1:] = torch.ones_like(alpha_next[:, -1:])
x_pred = alpha_next.sqrt() * x_start + x_noise * (1 - alpha_next).sqrt()
x[:, -1:] = x_pred[:, -1:]

# vae decoding
Expand Down Expand Up @@ -212,9 +206,7 @@ def main(args):
help="What framerate should be used to save the output?",
default=20,
)
parse.add_argument(
"--ddim-steps", type=int, help="How many DDIM steps?", default=50
)
parse.add_argument("--ddim-steps", type=int, help="How many DDIM steps?", default=50)

args = parse.parse_args()
print("inference args:")
Expand Down
Loading

0 comments on commit abe07fa

Please sign in to comment.