Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/mackelab/labproject into main
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelgloeckler committed Jun 6, 2024
2 parents cc142c9 + 742ae09 commit 7b5e3e1
Show file tree
Hide file tree
Showing 4 changed files with 478 additions and 0 deletions.
28 changes: 28 additions & 0 deletions configs/conf_mmd_main_scaling_experiment_different_kernels.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
exp_log_name: "MMD_main_scaling_kernel_experiment" # optional but recommended

# datasets to use
data: ["toy_2d" , "random", "random"]
augmentation: ['gauss', 'one_dim_shift', 'one_dim_shift',]

# number of samples and dimensions
n: [10000,10000,10000] #[10000, 10000, 10000] #samples Note that for main figure 10k
d: [2, 10, 1000] # dimensions

mmd_bandwidth: [1,5,10] # only the rbf has the bandwidth parameter

# sample size experiments
experiments: ["ScaleSampleSizeMMD", "ScaleSampleSizeMMDlinear","ScaleSampleSizeMMDenergy"]
sample_size: [50, 100, 200, 500, 1000, 2000, 3000, 4000]
runs: 5 # number of sample selection for errorbars

# dimensionality experiments
experiments_dim: ["ScaleDimMMD", "ScaleDimMMDlinear", "ScaleDimMMDenergy"]
dim_sizes: [5, 10, 50, 100, 500, 1000]
runs_dim: 5 # number of sample selection for errorbars

# seed for reproducibility
seed: 0

# for the reduced sample size experiments
#sample_size: [8, 10, 20, 50, 80]
#n: [500, 500, 500]
409 changes: 409 additions & 0 deletions docs/notebooks/mmd/MMD_scaling_experiment_different_kernels.ipynb

Large diffs are not rendered by default.

26 changes: 26 additions & 0 deletions labproject/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
gaussian_kl_divergence,
c2st_nn,
compute_rbf_mmd,
compute_linear_mmd,
compute_energy_mmd,
)
from labproject.plotting import plot_scaling_metric_dimensionality, plot_scaling_metric_sample_size
from labproject.metrics.gaussian_squared_wasserstein import gaussian_squared_w2_distance
Expand Down Expand Up @@ -115,6 +117,16 @@ def __init__(self, min_dim=2, **kwargs):
super().__init__("MMD", compute_rbf_mmd, **kwargs)


class ScaleDimMMDenergy(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("MMD", compute_energy_mmd, **kwargs)


class ScaleDimMMDlinear(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("MMD", compute_linear_mmd, **kwargs)


"""class ScaleDimMMD(ScaleDim):
def __init__(self, min_dim=2, **kwargs):
super().__init__("FID", compute_rbf_mmd, **kwargs)"""
Expand Down Expand Up @@ -229,6 +241,20 @@ def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
)


class ScaleSampleSizeMMDenergy(ScaleSampleSize):
def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
super().__init__(
"MMD", compute_energy_mmd, min_samples=min_samples, sample_sizes=sample_sizes, **kwargs
)


class ScaleSampleSizeMMDlinear(ScaleSampleSize):
def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
super().__init__(
"MMD", compute_linear_mmd, min_samples=min_samples, sample_sizes=sample_sizes, **kwargs
)


class ScaleSampleSizeFID(ScaleSampleSize):
def __init__(self, min_samples=3, sample_sizes=None, **kwargs):
super().__init__(
Expand Down
15 changes: 15 additions & 0 deletions labproject/metrics/MMD_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ def linear_kernel(x, y):
return x @ y.t()


def energy_kernel(x, y):
x_norm = torch.linalg.norm(x, dim=-1)
y_norm = torch.linalg.norm(y, dim=-1)
return x_norm[:, None] + y_norm[None, :] - torch.cdist(x, y)


def median_heuristic(x, y):
return torch.median(torch.cdist(x, y))

Expand Down Expand Up @@ -71,3 +77,12 @@ def compute_linear_mmd_naive(x, y):
def compute_linear_mmd(x, y):
delta = torch.mean(x, 0) - torch.mean(y, 0)
return torch.norm(delta, 2) ** 2


@register_metric("mmd_energy")
def compute_energy_mmd(x, y):
x_kernel = energy_kernel(x, x)
y_kernel = energy_kernel(y, y)
xy_kernel = energy_kernel(x, y)
mmd = torch.mean(x_kernel) + torch.mean(y_kernel) - 2 * torch.mean(xy_kernel)
return mmd

0 comments on commit 7b5e3e1

Please sign in to comment.