-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlosses.py
66 lines (50 loc) · 2.05 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# %%
import torch
import torch.nn as nn
# for diffusion, we use torch.nn.SmoothL1Loss
def poisson_loss(output, target):
# Assuming output is log rate (for numerical stability), convert to rate
rate = torch.exp(output)
loss = torch.mean(rate - target * output) # Simplified negative log likelihood
return loss
def neg_log_likelihood(output, target):
# output has gone through a softplus
loss = nn.PoissonNLLLoss(log_input=False, full=True, reduction="none")
return loss(output, target).sum() / output.size(0)
def latent_regularizer(z, cfg):
""" regualarizer that penalizes the squared difference between latents at neighbouring time steps. This returns sum of squared differences, NOT mean.
Args:
z: [B, C, L]
cfg: OmegaConf object
Returns:
loss: scalar (torch.Tensor)
"""
l2_reg = torch.sum(z**2)
k = cfg.training.get("td_k", 5) # number of time differences
temporal_difference_loss = 0
z_diff = 0
for i in range(1, k + 1):
z_diff += ((z[:, :, :-i] - z[:, :, i:]) ** 2 * (1 / (1 + i))).sum()
temporal_difference_loss = z_diff # gp prior-like loss
# it later gets scaled by latent_beta which then only affects l2_reg
# so the temporal difference loss is multiplied only by training.latent_td_beta
return l2_reg + temporal_difference_loss / (cfg.training.latent_beta) * (
cfg.training.latent_td_beta
)
# %%
if __name__ == "__main__":
import lovely_tensors
lovely_tensors.monkey_patch()
from omegaconf import OmegaConf
cfg = OmegaConf.create({"training": {"latent_beta": 1.0, "latent_td_beta": 1.0}})
# with sharp latents
z = torch.randn(4, 8, 1000)
loss = latent_regularizer(z, cfg) / z.numel()
print(loss)
# with smooth latents
# upscale last dim to 1000 using bilinear interpolation
z = torch.nn.functional.interpolate(z[:, :, :50], size=1000, mode="linear", align_corners=False)
loss = latent_regularizer(z, cfg) / z.numel()
print(loss)
# loss will be higher for sharp latents, e.g. 4.0 vs 0.8
# %%