-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpower_spectrum.py
52 lines (40 loc) · 1.59 KB
/
power_spectrum.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import numpy as np
from metatensor.torch import TensorMap, Labels, TensorBlock
class PowerSpectrum(torch.nn.Module):
def __init__(self, l_max, all_species):
super(PowerSpectrum, self).__init__()
self.l_max = l_max
self.all_species = all_species
def forward(self, spex):
keys = []
blocks = []
for a_i in self.all_species:
ps_values_ai = []
for l in range(self.l_max+1):
cg = 1.0/np.sqrt(2*l+1)
block_ai_l = spex.block({"lam": l, "a_i": a_i})
c_ai_l = block_ai_l.values
# same as this:
# ps_ai_l = cg*torch.einsum("ima, imb -> iab", c_ai_l, c_ai_l)
# but faster:
ps_ai_l = cg*torch.sum(c_ai_l.unsqueeze(2)*c_ai_l.unsqueeze(3), dim=1)
ps_ai_l = ps_ai_l.reshape(c_ai_l.shape[0], c_ai_l.shape[2]**2)
ps_values_ai.append(ps_ai_l)
ps_values_ai = torch.concatenate(ps_values_ai, dim=-1)
block = TensorBlock(
values=ps_values_ai,
samples=block_ai_l.samples,
components=[],
properties=Labels.range("property", ps_values_ai.shape[-1])
)
keys.append([a_i])
blocks.append(block)
power_spectrum = TensorMap(
keys = Labels(
names = ("a_i",),
values = torch.tensor(keys, device=blocks[0].values.device), # .reshape((-1, 2)),
),
blocks = blocks
)
return power_spectrum