Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EO Development (draft pull request) #3

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
36007ea
Merge remote-tracking branch 'origin/development' into development-eo
annajungbluth Apr 25, 2024
0c2b26f
started eo training development
annajungbluth Apr 25, 2024
9993839
started testing training pipeline
annajungbluth Apr 25, 2024
3330469
wip - tested training pipeline
annajungbluth Apr 25, 2024
9071b24
made training pipeline run
annajungbluth Apr 26, 2024
d7f2c9d
moved geo dataset and editors into ITI repo
annajungbluth Apr 26, 2024
b6983b2
added new editor and modified training script
annajungbluth Apr 27, 2024
2879551
modified editor
annajungbluth Apr 27, 2024
12ee80e
tested callbacks
annajungbluth Apr 28, 2024
cdc4c7c
added normalisation steps to training script, and started writing a n…
lillif May 13, 2024
92c26d6
training script for miniset set up, mean std normalisation finished
lillif May 14, 2024
a18e68a
normalisation script finished, attempted training
lillif May 17, 2024
9d592bb
started hydra training file
annajungbluth May 19, 2024
6d24344
added normalization and fixed bugs in training script
annajungbluth May 20, 2024
1ddd8d5
merge with master
annajungbluth Oct 3, 2024
d7b31d0
fixed small merge bugs and added autoroot file
annajungbluth Oct 3, 2024
ad54e8e
Added file with dataset information
annajungbluth Oct 4, 2024
72c3f82
fixed goes metrics file
annajungbluth Oct 18, 2024
4533b53
updated summary files
annajungbluth Oct 20, 2024
2e72ef8
added new normalization routine and started first experiment
annajungbluth Oct 20, 2024
7a79a1e
reduced val data
annajungbluth Oct 20, 2024
d3a4ad4
optimized dataloader to reduce memory consumption
annajungbluth Oct 24, 2024
e2c5675
added normalization files for subset of data
annajungbluth Oct 31, 2024
625852a
debugging constant channels
annajungbluth Oct 31, 2024
8770404
started new experiment
annajungbluth Oct 31, 2024
5b98ee1
added seed to training script
annajungbluth Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added .project-root
Empty file.
54 changes: 54 additions & 0 deletions config/example-hydra-config/data.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
A_data:
A_path: null
A_train_dataset:
_target_: iti.data.geo_datasets.GeoDataset # TODO: make specific msg dataset?
data_dir: null
editors: null # TODO: hard code in dataset?
splits_dict:
train:
years: [2020]
months: [10]
days: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
load_coords: False
load_cloudmask: False
A_val_dataset:
_target_: iti.data.geo_datasets.GeoDataset # TODO: make specific msg dataset?
data_dir: null
editors: null # TODO: hard code in dataset?
splits_dict:
train:
years: [2020]
months: [10]
days: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
load_coords: False
load_cloudmask: False
A_plot_settings: null

B_data:
B_path: null
B_train_dataset:
_target_: iti.data.geo_datasets.GeoDataset # TODO: make specific goes dataset?
data_dir: null
editors: null # TODO: hard code in dataset?
splits_dict:
train:
years: [2020]
months: [10]
days: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
load_coords: False
load_cloudmask: False
B_val_dataset:
_target_: iti.data.geo_datasets.GeoDataset # TODO: make specific goes dataset?
data_dir: null
editors: null # TODO: hard code in dataset?
splits_dict:
train:
years: [2020]
months: [10]
days: [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
load_coords: False
load_cloudmask: False
B_plot_settings: null

num_workers: 4
iterations_per_epoch: 1000
9 changes: 9 additions & 0 deletions config/example-hydra-config/model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model:
__target__: null
input_dim_a: 11
input_dim_b: 16
upsampling: 0
discriminator_mode: CHANNELS
lambda_diversity: 0
norm: 'none'
use_batch_statistic: False
1 change: 1 addition & 0 deletions config/example-hydra-config/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
base_dir: /home/freischem/outputs/miniset/
6 changes: 6 additions & 0 deletions config/example-hydra-config/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
experiment_name: null
tags: null
wandb_entity: null
wandb_project: null
wandb_name: null
wandb_id: null
27 changes: 27 additions & 0 deletions config/msg_to_goes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
base_dir: /home/anna.jungbluth/outputs/msg-to-goes/
data:
A_path: /mnt/disks/eo-data/msg/
# converted_A_path:
B_path: /mnt/disks/eo-data/goes/
# converted_B_path:
num_workers: 4
iterations_per_epoch: 10
patch_size: (256, 256)
skip_constant_channels: False # If True, patches with constant channels will be skipped. Note, this massively slows down training.
model:
input_dim_a: 7 # 11
input_dim_b: 9 # 16
upsampling: 0
discriminator_mode: CHANNELS
lambda_diversity: 0
norm: 'none' # 'in_rs_aff'
use_batch_statistic: False
logging:
wandb_entity: itieo
wandb_project: msg-to-goes
wandb_name: msg-to-goes-infrared-channels
training:
epochs: 100
normalization: # TODO: Change to avoid absolute paths
A_norm_dir: /home/anna.jungbluth/InstrumentToInstrument/dataset/msg_2020_hourly_subset.csv
B_norm_dir: /home/anna.jungbluth/InstrumentToInstrument/dataset/goes_2020_hourly_subset.csv
8,689 changes: 8,689 additions & 0 deletions dataset/goes_2020_hourly.csv

Large diffs are not rendered by default.

4,345 changes: 4,345 additions & 0 deletions dataset/goes_2020_hourly_subset.csv

Large diffs are not rendered by default.

8,707 changes: 8,707 additions & 0 deletions dataset/msg_2020_hourly.csv

Large diffs are not rendered by default.

4,353 changes: 4,353 additions & 0 deletions dataset/msg_2020_hourly_subset.csv

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions itipy/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, data, model, plot_settings_A=None, plot_settings_B=None, plot

plot_settings = [*plot_settings_A, *plot_settings_B, *plot_settings_A]

super().__init__(data, model, path, plot_id, plot_settings, **kwargs)
super().__init__(data, model, plot_id, plot_settings, **kwargs)

def predict(self, x):
x_ab, x_aba = self.model.forwardABA(x)
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(self, data, model, plot_settings_A=None, plot_settings_B=None, plot

plot_settings = [*plot_settings_B, *plot_settings_A, *plot_settings_B]

super().__init__(data, model, path, plot_id, plot_settings, **kwargs)
super().__init__(data, model, plot_id, plot_settings, **kwargs)

def predict(self, x):
x_ba, x_bab = self.model.forwardBAB(x)
Expand Down Expand Up @@ -169,7 +169,7 @@ def __init__(self, data, model, plot_settings_A=None, plot_settings_B=None, plot

plot_settings = [*plot_settings_A, *plot_settings_B]

super().__init__(data, model, path, plot_id, plot_settings, **kwargs)
super().__init__(data, model, plot_id, plot_settings, **kwargs)

def predict(self, input_data):
x_ab = self.model.forwardAB(input_data)
Expand Down
149 changes: 149 additions & 0 deletions itipy/data/geo_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations
import collections
import collections.abc

#hyper needs the four following aliases to be done manually.
collections.Iterable = collections.abc.Iterable
collections.Mapping = collections.abc.Mapping
collections.MutableSet = collections.abc.MutableSet
collections.MutableMapping = collections.abc.MutableMapping

import logging
import torch
import numpy as np
import xarray as xr
from typing import List, Union, Dict
from loguru import logger

from itipy.data.editor import Editor
from itipy.data.geo_editor import CenterWeightedCropDatasetEditor
from itipy.data.dataset import BaseDataset
from itipy.data.geo_utils import get_split, get_list_filenames, _check_any_constant_channels, _check_all_constant_channels

class GeoDataset(BaseDataset):
def __init__(
self,
data_dir: List[str],
editors: List[Editor],
splits_dict: Dict,
ext: str="nc",
limit: int=None,
load_coords: bool=True,
load_cloudmask: bool=True,
patch_size: tuple[int, int] = (256, 256),
skip_constant_channels: bool = False, # Could be used for filtering out night time observations
**kwargs
):
"""
Initialize the GeoDataset class.

Args:
data_dir (List[str]): A list of directories containing the data files.
editors (List[Editor]): A list of editors for data preprocessing.
splits_dict (Dict, optional): A dictionary specifying the splits for the dataset. Defaults to None.
ext (str, optional): The file extension of the data files. Defaults to "nc".
limit (int, optional): The maximum number of files to load. Defaults to None.
load_coords (bool, optional): Whether to load the coordinates. Defaults to True.
load_cloudmask (bool, optional): Whether to load the cloud mask. Defaults to True.
patch_size (tuple[int, int], optional): The size of the patches to crop. Defaults to (256, 256).
skip_constant_channels (bool, optional): Whether to skip a patch is any channel is constant. Defaults to False.
**kwargs: Additional keyword arguments.

"""
self.data_dir = data_dir
self.editors = editors
self.splits_dict = splits_dict
self.ext = ext
self.limit = limit
self.load_coords = load_coords
self.load_cloudmask = load_cloudmask
self.patch_size = patch_size
self.skip_constant_channels = skip_constant_channels

self.files = self.get_files()

self.crop = CenterWeightedCropDatasetEditor(patch_shape=self.patch_size)

super().__init__(
data=self.files,
editors=self.editors,
ext=self.ext,
limit=self.limit,
**kwargs
)

def get_files(self):
# Get filenames from data_dir
files = get_list_filenames(data_path=self.data_dir, ext=self.ext)
# split files based on split criteria
files = get_split(files=files, split_dict=self.splits_dict)
return files

def __len__(self):
return len(self.files)

def getIndex(self, data_dict, idx):
# Attempt applying editors
try:
return self.convertData(data_dict)
except Exception as ex:
logging.error('Unable to convert %s: %s' % (self.files[idx], ex))
raise ex

def __getitem__(self, idx):
data_dict = {}

max_attempts = 20
attempts = 1

while attempts <= max_attempts:
if attempts == max_attempts:
raise Exception("Could not load data after %d attempts." % max_attempts)
# Load dataset
ds: xr.Dataset = xr.load_dataset(self.files[idx], engine="netcdf4")
# Crop data before computing
ds = self.crop(ds)
# Extract data
data = ds.Rad.compute().to_numpy()
# Check if all channels are constant -> Always performed
all_constant = _check_all_constant_channels(data)
# Check if any channel is constant -> Only relevant if skip_constant_channels is True
any_constant = _check_any_constant_channels(data)
if all_constant or (self.skip_constant_channels and any_constant):
# Retry loading data
logger.info("Found constant channels in %s. Attempting with other files." % self.files[idx])
idx = np.random.randint(0, len(self.files))
attempts += 1
else:
break

data_dict["data"] = data
del data # Delete data to reduce memory usage
# Extract wavelengths
wavelengths = ds.band_wavelength.compute().to_numpy()
data_dict["wavelengths"] = wavelengths
del wavelengths # Delete data to reduce memory usage

# Extract coordinates
if self.load_coords:
latitude = ds.latitude.compute().to_numpy()
longitude = ds.longitude.compute().to_numpy()
coords = np.stack([latitude, longitude], axis=0)
data_dict["coords"] = coords
del latitude, longitude # Delete data to reduce memory usage
del coords # Delete data to reduce memory usage

# Extract cloud mask
if self.load_cloudmask:
cloud_mask = ds.cloud_mask.compute().to_numpy()
data_dict["cloud_mask"] = cloud_mask
del cloud_mask # Delete data to reduce memory usage

# Delete dataset to reduce memory usage
del ds

# Apply editors
data, _ = self.getIndex(data_dict, idx)
return data


Loading