-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdistrib_distance.py
executable file
·68 lines (56 loc) · 3.35 KB
/
distrib_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
def rand_projections(dim, num_projections=1000):
projections = torch.randn((num_projections, dim))
projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True))
return projections
def rand_projections_diff_priv(dim, num_projections=1000, sigma_proj=1):
projections = torch.randn((num_projections, dim))*sigma_proj
projections = projections / torch.sqrt(torch.sum(projections ** 2, dim=1, keepdim=True))
return projections
def make_sample_size_equal(first_samples,second_samples):
nb_samples_1 = first_samples.shape[0]
nb_samples_2 = second_samples.shape[0]
if nb_samples_1 < nb_samples_2:
second_samples = second_samples[:nb_samples_1]
elif nb_samples_1 > nb_samples_2:
first_samples = first_samples[:nb_samples_2]
return first_samples, second_samples
def sliced_wasserstein_distance(first_samples,
second_samples,
num_projections=1000,
p=1,
device='cuda'):
first_samples, second_samples = make_sample_size_equal(first_samples, second_samples)
dim = second_samples.size(1)
projections = rand_projections(dim, num_projections).to(device)
first_projections = first_samples.matmul(projections.transpose(0, 1))
second_projections = (second_samples.matmul(projections.transpose(0, 1)))
wasserstein_distance = torch.abs((torch.sort(first_projections.transpose(0, 1), dim=1)[0] -
torch.sort(second_projections.transpose(0, 1), dim=1)[0]))
wasserstein_distance = torch.pow(torch.mean(torch.pow(wasserstein_distance, p), dim=1), 1. / p) # averaging the sorted distance
return torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p)
def sliced_wasserstein_distance_diff_priv(first_samples,
second_samples,
num_projections=1000,
p=1,
device='cuda',
sigma_proj=1,
sigma_noise=0
):
# first samples are the data to protect
# second samples are the data_fake
first_samples, second_samples = make_sample_size_equal(first_samples, second_samples)
dim = second_samples.size(1)
nb_sample = second_samples.size(0)
projections = rand_projections_diff_priv(dim, num_projections,sigma_proj)
projections = projections.to(device)
noise = torch.randn((nb_sample,num_projections))*sigma_noise
noise = noise.to(device)
noise2 = torch.randn((nb_sample,num_projections))*sigma_noise
noise2 = noise2.to(device)
first_projections = first_samples.matmul(projections.transpose(0, 1)) + noise
second_projections = (second_samples.matmul(projections.transpose(0, 1))) + noise2
wasserstein_distance = torch.abs((torch.sort(first_projections.transpose(0, 1), dim=1)[0] -
torch.sort(second_projections.transpose(0, 1), dim=1)[0]))
wasserstein_distance = torch.pow(torch.mean(torch.pow(wasserstein_distance, p), dim=1), 1. / p) # averaging the sorted distance
return torch.pow(torch.pow(wasserstein_distance, p).mean(), 1. / p) # averaging over the random direction