-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
62 lines (52 loc) · 1.82 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gpytorch
import torch
import numpy as np
import gp_models
def get_lengthscales(kernel):
if isinstance(kernel, gpytorch.kernels.ScaleKernel):
return get_lengthscales(kernel.base_kernel)
elif kernel.has_lengthscale:
return kernel.lengthscale
elif isinstance(kernel, gp_models.GeneralizedProjectionKernel):
ls = []
for k in kernel.kernel.kernels:
ls_ = []
for kk in k.base_kernel.kernels:
ls_.append(kk.lengthscale.item())
ls.append(ls_)
return ls
else:
return None
def get_mixins(kernel):
if isinstance(kernel, gp_models.GeneralizedProjectionKernel):
mixins = []
for k in kernel.kernel.kernels:
mixins.append(k.outputscale.item())
return mixins
elif isinstance(kernel, gpytorch.kernels.ScaleKernel):
return get_mixins(kernel.base_kernel)
else:
return None
def get_outputscale(kernel):
if isinstance(kernel, gpytorch.kernels.ScaleKernel):
return kernel.outputscale
else:
return None
def format_for_str(num_or_list, decimals=3):
if isinstance(num_or_list, torch.Tensor):
num_or_list = num_or_list.tolist()
return format_for_str(num_or_list)
if isinstance(num_or_list, list):
return [format_for_str(n) for n in num_or_list]
elif isinstance(num_or_list, float):
return np.round(num_or_list, decimals)
else:
return ''
@torch.jit.script
def my_cdist(x1, x2):
"""from Jacob Gardner here https://github.com/pytorch/pytorch/issues/15253"""
x1_norm = x1.pow(2).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(2).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
res = res.clamp_min_(1e-30).sqrt_()
return res