-
Notifications
You must be signed in to change notification settings - Fork 0
/
latent_spaces.py
102 lines (83 loc) · 3.69 KB
/
latent_spaces.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Classes that combine spaces with specific probability densities."""
from typing import Callable, List
from spaces import Space
import torch
class LatentSpace:
"""Combines a topological space with a marginal and conditional density to sample from."""
def __init__(
self, space: Space, sample_marginal: Callable, sample_conditional: Callable
):
self.space = space
self._sample_marginal = sample_marginal
self._sample_conditional = sample_conditional
@property
def sample_conditional(self):
if self._sample_conditional is None:
raise RuntimeError("sample_conditional was not set")
return lambda *args, **kwargs: self._sample_conditional(
self.space, *args, **kwargs
)
@sample_conditional.setter
def sample_conditional(self, value: Callable):
assert callable(value)
self._sample_conditional = value
@property
def sample_marginal(self):
if self._sample_marginal is None:
raise RuntimeError("sample_marginal was not set")
return lambda *args, **kwargs: self._sample_marginal(
self.space, *args, **kwargs
)
@sample_marginal.setter
def sample_marginal(self, value: Callable):
assert callable(value)
self._sample_marginal = value
@property
def dim(self):
return self.space.dim
class ProductLatentSpace(LatentSpace):
"""A latent space which is the cartesian product of other latent spaces."""
def __init__(self, spaces: List[LatentSpace]):
self.spaces = spaces
def sample_conditional(self, means, params, size, **kwargs):
x = []
for i, s in enumerate(self.spaces):
if len(means.shape) == 1:
z_s = means[i]
else:
z_s = means[:, i]
x.append(s.sample_conditional(mean=z_s, params=params[i], size=size, **kwargs))
return torch.cat(x, -1)
def sample_marginal(self, means, params, size, **kwargs):
x = [s.sample_marginal(means[:,i], params[i], size=size, **kwargs) for i, s in enumerate(self.spaces)]
return torch.cat(x, -1)
def sample_marginal_causal(self, std, size, first_content, **kwargs):
x = [s.sample_marginal(torch.as_tensor([0.0]),torch.as_tensor([0.0]), size=size, **kwargs) for i, s in enumerate(self.spaces)]
final_x = []
for i, s in enumerate(self.spaces):
if i==1 and std[i] is not None:
if first_content:
final_x.append(s.sample_marginal(x[-4],std[i]))
else:
final_x.append(s.sample_marginal(x[-2],std[i]))
elif i==6 and std[i] is not None:
if first_content:
final_x.append(s.sample_marginal(x[1],std[i]))
else:
final_x.append(s.sample_marginal(x[-2],std[i]))
elif i==8 and std[i] is not None:
if first_content:
final_x.append(s.sample_marginal(x[-4],std[i]))
else:
final_x.append(s.sample_marginal(x[1],std[i]))
elif i in (0,2,3,4,5,7,9): final_x.append(x[i])
final_final_x = []
for i, s in enumerate(self.spaces):
if i==0 and std[i] is not None: final_final_x.append(s.sample_marginal(x[1],std[i]))
elif i==5 and std[i] is not None:final_final_x.append(s.sample_marginal(x[-4],std[i]))
elif i==7 and std[i] is not None:final_final_x.append(s.sample_marginal(x[-2],std[i]))
elif i in (1,2,3,4,6,8,9): final_x.append(x[i])
return torch.cat(final_final_x, -1)
@property
def dim(self):
return sum([s.dim for s in self.spaces])