-
-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* nnjai support * ruff format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes as per requested * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test_nnjai.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
98b9248
commit a633936
Showing
4 changed files
with
234 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""Main import for the complete models""" | ||
|
||
from .data.nnjai_wrapp import AMSUDataset, collate_fn | ||
from .models.analysis import GraphWeatherAssimilator | ||
from .models.forecast import GraphWeatherForecaster |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
"""Dataloaders and data processing utilities""" | ||
|
||
from .nnjai_wrapp import AMSUDataset, collate_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
A custom PyTorch Dataset implementation for AMSU datasets. | ||
This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets. | ||
The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and | ||
variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
try: | ||
from nnja import DataCatalog | ||
except ImportError: | ||
print( | ||
"NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`" | ||
) | ||
|
||
|
||
class AMSUDataset(Dataset): | ||
"""A custom PyTorch Dataset for handling AMSU data. | ||
This dataset retrieves observations and their metadata, filtered by the provided time and | ||
variable descriptors. | ||
""" | ||
|
||
def __init__(self, dataset_name, time, primary_descriptors, additional_variables): | ||
"""Initialize the AMSU dataset loader. | ||
Args: | ||
dataset_name: Name of the dataset to load. | ||
time: Specific timestamp to filter the data. | ||
primary_descriptors: List of primary descriptor variables to include (e.g., | ||
OBS_TIMESTAMP, LAT, LON). | ||
additional_variables: List of additional variables to include in metadata. | ||
""" | ||
self.dataset_name = dataset_name | ||
self.time = time | ||
self.primary_descriptors = primary_descriptors | ||
self.additional_variables = additional_variables | ||
|
||
# Load data catalog and dataset | ||
self.catalog = DataCatalog(skip_manifest=True) | ||
self.dataset = self.catalog[self.dataset_name] | ||
self.dataset.load_manifest() | ||
|
||
self.dataset = self.dataset.sel( | ||
time=self.time, variables=self.primary_descriptors + self.additional_variables | ||
) | ||
self.dataframe = self.dataset.load_dataset(engine="pandas") | ||
|
||
for col in primary_descriptors: | ||
if col not in self.dataframe.columns: | ||
raise ValueError(f"The dataset must include a '{col}' column.") | ||
|
||
self.metadata_columns = [ | ||
col for col in self.dataframe.columns if col not in self.primary_descriptors | ||
] | ||
|
||
def __len__(self): | ||
"""Return the total number of samples in the dataset.""" | ||
return len(self.dataframe) | ||
|
||
def __getitem__(self, index): | ||
"""Return the observation and metadata for a given index. | ||
Args: | ||
index: Index of the observation to retrieve. | ||
Returns: | ||
A dictionary containing timestamp, latitude, longitude, and metadata. | ||
""" | ||
row = self.dataframe.iloc[index] | ||
time = row["OBS_TIMESTAMP"].timestamp() | ||
latitude = row["LAT"] | ||
longitude = row["LON"] | ||
metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) | ||
|
||
return { | ||
"timestamp": torch.tensor(time, dtype=torch.float32), | ||
"latitude": torch.tensor(latitude, dtype=torch.float32), | ||
"longitude": torch.tensor(longitude, dtype=torch.float32), | ||
"metadata": torch.from_numpy(metadata), | ||
} | ||
|
||
|
||
def collate_fn(batch): | ||
"""Custom collate function to handle batching of dictionary data. | ||
Args: | ||
batch: List of dictionaries from __getitem__ | ||
Returns: | ||
Single dictionary with batched tensors | ||
""" | ||
return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
""" | ||
Tests for the nnjai_wrapp module in the graph_weather package. | ||
This file contains unit tests for AMSUDataset and collate_fn functions. | ||
""" | ||
|
||
from datetime import datetime | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
import torch | ||
|
||
from graph_weather.data.nnjai_wrapp import AMSUDataset, collate_fn | ||
|
||
|
||
# Mock the DataCatalog to avoid actual data loading | ||
@pytest.fixture | ||
def mock_datacatalog(): | ||
""" | ||
Fixture to mock the DataCatalog for unit tests to avoid actual data loading. | ||
This mock provides a mock dataset with predefined columns and values. | ||
""" | ||
with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock: | ||
# Mock dataset structure | ||
mock_df = MagicMock() | ||
mock_df.columns = ["OBS_TIMESTAMP", "LAT", "LON", "TMBR_00001", "TMBR_00002"] | ||
|
||
# Define a mock row | ||
class MockRow: | ||
def __getitem__(self, key): | ||
data = { | ||
"OBS_TIMESTAMP": datetime.now(), | ||
"LAT": 45.0, | ||
"LON": -120.0, | ||
"TMBR_00001": 250.0, | ||
"TMBR_00002": 260.0, | ||
} | ||
return data.get(key, None) | ||
|
||
# Configure mock dataset | ||
mock_row = MockRow() | ||
mock_df.iloc = MagicMock() | ||
mock_df.iloc.__getitem__.return_value = mock_row | ||
mock_df.__len__.return_value = 100 | ||
|
||
mock_dataset = MagicMock() | ||
mock_dataset.load_dataset.return_value = mock_df | ||
mock_dataset.sel.return_value = mock_dataset | ||
mock_dataset.load_manifest = MagicMock() | ||
|
||
mock.return_value.__getitem__.return_value = mock_dataset | ||
yield mock | ||
|
||
|
||
def test_amsu_dataset(mock_datacatalog): | ||
""" | ||
Test the AMSUDataset class to ensure proper data loading and tensor structure. | ||
This test validates the AMSUDataset class for its ability to load the dataset | ||
correctly, check for the appropriate tensor properties, and ensure the keys | ||
and data types match expectations. | ||
""" | ||
# Initialize dataset parameters | ||
dataset_name = "amsua-1bamua-NC021023" | ||
time = "2021-01-01 00Z" | ||
primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] | ||
additional_variables = ["TMBR_00001", "TMBR_00002"] | ||
|
||
dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables) | ||
|
||
# Test dataset length | ||
assert len(dataset) > 0, "Dataset should not be empty." | ||
|
||
item = dataset[0] | ||
expected_keys = {"timestamp", "latitude", "longitude", "metadata"} | ||
assert set(item.keys()) == expected_keys, "Dataset item keys are not as expected." | ||
|
||
# Validate tensor properties | ||
assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor." | ||
assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32." | ||
assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor." | ||
|
||
assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor." | ||
assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32." | ||
assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor." | ||
|
||
assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor." | ||
assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32." | ||
assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor." | ||
|
||
assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor." | ||
assert item["metadata"].shape == ( | ||
len(additional_variables), | ||
), f"Metadata shape mismatch. Expected ({len(additional_variables)},)." | ||
assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32." | ||
|
||
|
||
def test_collate_function(): | ||
""" | ||
Test the collate_fn function to ensure proper batching of dataset items. | ||
This test checks that the collate_fn properly batches the timestamp, latitude, | ||
longitude, and metadata fields of the dataset, ensuring correct shapes and data types. | ||
""" | ||
# Mock a batch of items | ||
batch_size = 4 | ||
metadata_size = 2 | ||
mock_batch = [ | ||
{ | ||
"timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32), | ||
"latitude": torch.tensor(45.0, dtype=torch.float32), | ||
"longitude": torch.tensor(-120.0, dtype=torch.float32), | ||
"metadata": torch.randn(metadata_size, dtype=torch.float32), | ||
} | ||
for _ in range(batch_size) | ||
] | ||
|
||
# Collate the batch | ||
batched = collate_fn(mock_batch) | ||
|
||
# Validate batched shapes and types | ||
assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch." | ||
assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch." | ||
assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch." | ||
assert batched["metadata"].shape == ( | ||
batch_size, | ||
metadata_size, | ||
), "Metadata batch shape mismatch." | ||
|
||
assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch." | ||
assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch." | ||
assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch." | ||
assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch." |