Skip to content

Commit

Permalink
Revision Release
Browse files Browse the repository at this point in the history
  • Loading branch information
sethbassetti committed May 31, 2024
1 parent 3f32ebb commit 61654af
Show file tree
Hide file tree
Showing 11 changed files with 4,020 additions and 84 deletions.
10 changes: 6 additions & 4 deletions src/create_quantiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ def main(cfg: DictConfig):
)

# Extract the xarray dataset and denormalize it
xr_ds = dataset.xr_data.map(denorm).sel(time=slice(START_YEAR, END_YEAR))
xr_ds = dataset.xr_data.map(denorm).sel(time=slice(START_YEAR, END_YEAR)).compute()
breakpoint()
# Group by each day of the year
groups = xr_ds.groupby("time.dayofyear")

# Compute the quantiles
quantiles = xr_ds.load().quantile(cfg.quantile, dim="time").drop_vars("quantile")
quantiles = groups.quantile(q=[0.9, 0.95, 0.99, 0.999], dim="time")

# Save the quantiles
save_name = f"{cfg.var}_{int(cfg.quantile * 100)}.nc"
save_name = f"{cfg.var}_quantiles.nc"
save_path = os.path.join(cfg.paths.quantile_dir, cfg.esm, save_name)

# Delete the file if it already exists (avoids permission denied errors)
Expand Down
702 changes: 702 additions & 0 deletions src/custom_diffusers/configuration_utils.py

Large diffs are not rendered by default.

1,160 changes: 1,160 additions & 0 deletions src/custom_diffusers/dpmsolver_multistep.py

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions src/data/climate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import xarray as xr
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
import dask

# Constants for the minimum and maximum of our datasets
MIN_MAX_CONSTANTS = {"tas": (-85.0, 60.0), "pr": (0.0, 6.0)}
Expand All @@ -17,12 +18,12 @@

# Normalization and Inverse Normalization functions
NORM_FN = {
"tas": lambda x: x / 20,
"pr": lambda x: np.log(1 + x),
"tas": lambda x: (x - 4.5) / 21.0,
"pr": lambda x: np.cbrt(x),
}
DENORM_FN = {
"tas": lambda x: x * 20,
"pr": lambda x: np.exp(x) - 1,
"tas": lambda x: x * 21.0 + 4.5,
"pr": lambda x: x**3,
}

# These functions transform the range of the data to [-1, 1]
Expand Down Expand Up @@ -75,9 +76,11 @@ def __init__(
data_dir: str,
scenario: str,
vars: list[str],
spatial_resolution=None,
):
self.seq_len = seq_len
self.realizations = realizations
self.spatial_resolution = spatial_resolution

self.data_dir = os.path.join(data_dir, esm, scenario)

Expand All @@ -98,6 +101,7 @@ def estimate_num_batches(self, batch_size: int) -> int:

def load_data(self, realization: str):
"""Loads the data from the specified paths and returns it as an xarray Dataset."""

realization_dir = os.path.join(self.data_dir, realization, "*.nc")

# Open up the dataset and make sure it's sorted by time
Expand All @@ -108,6 +112,11 @@ def load_data(self, realization: str):

# Apply preprocessing and normalization
self.xr_data = dataset.map(preprocess).map(normalize)

if self.spatial_resolution is not None:
with dask.config.set(**{'array.slicing.split_large_chunks' : False}):
self.xr_data = self.xr_data.coarsen(lon=3, lat=2).mean()

self.tensor_data = self.convert_xarray_to_tensor(self.xr_data)

def convert_xarray_to_tensor(self, ds: xr.Dataset) -> torch.Tensor:
Expand Down
103 changes: 68 additions & 35 deletions src/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from diffusers import DDPMScheduler
import xarray as xr
from tqdm import tqdm
import pandas as pd

# Local imports
from data.climate_dataset import ClimateDataset
Expand All @@ -26,6 +27,17 @@
realization_dict = {"gen": "r2", "val": "r2", "test": "r1"}


def get_starting_index(directory: str) -> int:
"""Goes through a directory of files named "member_i.nc" and returns the next available index."""
files = os.listdir(directory)
indices = [
int(file.split("_")[1].split(".")[0])
for file in files
if file.startswith("member")
]
return max(indices) + 1 if indices else 0


def create_batches(
xr_ds: xr.Dataset,
dataset: ClimateDataset,
Expand Down Expand Up @@ -92,6 +104,11 @@ def main(config: DictConfig) -> None:
assert config.load_path, "Must specify a load path"
assert os.path.isfile(config.load_path), "Invalid load path"

# Make sure num samples is 1 if gen mode is not gen
assert (
config.samples_per == 1 or config.gen_mode == "gen"
), "Number of samples must be 1 for val and test"

# Initialize all necessary objects
accelerator = Accelerator(**config.accelerator, even_batches=False)

Expand All @@ -101,7 +118,8 @@ def main(config: DictConfig) -> None:
scenario=config.scenario,
data_dir=config.paths.data_dir,
realizations=[realization_dict[config.gen_mode]],
vars=[config.variable],
vars=config.variables,
spatial_resolution=config.spatial_resolution
)
scheduler: DDPMScheduler = instantiate(config.scheduler)
scheduler.set_timesteps(config.sample_steps)
Expand All @@ -110,8 +128,10 @@ def main(config: DictConfig) -> None:
# Load the model from the checkpoint
chkpt: Checkpoint = torch.load(config.load_path, map_location="cpu")
model = chkpt["EMA"].eval()
model = model.to(accelerator.device)
else:
model = None

# Grab the Xarray dataset from the dataset object
xr_ds = dataset.xr_data.load()

Expand All @@ -128,43 +148,56 @@ def main(config: DictConfig) -> None:
# Prepare the model and dataloader for distributed training
model, dataloader = accelerator.prepare(model, dataloader)

gen_samples = []

for tensor_batch, coords in tqdm(
dataloader, disable=not accelerator.is_main_process
):
if model is not None:
gen_months = generate_samples(
tensor_batch,
scheduler=scheduler,
sample_steps=config.sample_steps,
model=model,
disable=not accelerator.is_main_process,
for i in tqdm(range(config.samples_per)):
gen_samples = []

for tensor_batch, coords in tqdm(
dataloader, disable=not accelerator.is_main_process
):
tensor_batch = tensor_batch.to(accelerator.device)
if model is not None:
gen_months = generate_samples(
tensor_batch,
scheduler=scheduler,
sample_steps=config.sample_steps,
model=model,
disable=True,
)
else:
gen_months = tensor_batch

for i in range(len(gen_months)):
gen_samples.append(
dataset.convert_tensor_to_xarray(gen_months[i], coords=coords[i])
)

gen_samples = accelerator.gather_for_metrics(gen_samples)
gen_samples = xr.concat(gen_samples, "time").drop_vars("height").sortby("time")

if accelerator.is_main_process:

# If we are generating multiple samples, create a directory for them
save_name = f"{config.gen_mode}_{config.save_name + '_' if config.save_name is not None else ''}{'_'.join(config.variables)}_{config.start_year}-{config.end_year}.nc"
save_path = os.path.join(
config.paths.save_dir, config.esm, config.scenario, save_name
)
else:
gen_months = tensor_batch
if config.gen_mode == "gen" and config.samples_per > 1:
save_dir = save_path.strip(".nc")
if not os.path.isdir(save_dir):
os.mkdir(save_dir)

for i in range(len(gen_months)):
gen_samples.append(
dataset.convert_tensor_to_xarray(gen_months[i], coords=coords[i])
)
mem_index = get_starting_index(save_dir)
save_path = os.path.join(save_dir, f"member_{mem_index}.nc")

else:
# Delete the file if it already exists (avoids permission denied errors)
if os.path.isfile(save_path):
os.remove(save_path)

# Save the generated samples
gen_samples.to_netcdf(save_path)

gen_samples = accelerator.gather_for_metrics(gen_samples)
gen_samples = xr.concat(gen_samples, "time").drop_vars("height").sortby("time")

if accelerator.is_main_process:
# Construct the save path
save_name = f"{config.gen_mode}_{config.save_name + '_' if config.save_name is not None else ''}{config.variable}_{config.start_year}-{config.end_year}.nc"
save_path = os.path.join(
config.paths.save_dir, config.esm, config.scenario, save_name
)
# Delete the file if it already exists (avoids permission denied errors)
if os.path.isfile(save_path):
os.remove(save_path)
# Save the generated samples
gen_samples.to_netcdf(save_path)

os.chmod(save_path, 0o770)
os.chmod(save_path, 0o770)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 61654af

Please sign in to comment.