From a6339368368c6cf3e4b731773b1fef4df458c25e Mon Sep 17 00:00:00 2001 From: Yuvraaj Narula <49155095+yuvraajnarula@users.noreply.github.com> Date: Wed, 22 Jan 2025 14:43:02 +0530 Subject: [PATCH] nnjai support (#129) * nnjai support * ruff format * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes as per requested * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test_nnjai.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- graph_weather/__init__.py | 1 + graph_weather/data/__init__.py | 2 + graph_weather/data/nnjai_wrapp.py | 97 +++++++++++++++++++++ tests/test_nnjai.py | 134 ++++++++++++++++++++++++++++++ 4 files changed, 234 insertions(+) create mode 100644 graph_weather/data/nnjai_wrapp.py create mode 100644 tests/test_nnjai.py diff --git a/graph_weather/__init__.py b/graph_weather/__init__.py index 1758fd49..74719deb 100644 --- a/graph_weather/__init__.py +++ b/graph_weather/__init__.py @@ -1,4 +1,5 @@ """Main import for the complete models""" +from .data.nnjai_wrapp import AMSUDataset, collate_fn from .models.analysis import GraphWeatherAssimilator from .models.forecast import GraphWeatherForecaster diff --git a/graph_weather/data/__init__.py b/graph_weather/data/__init__.py index 6eb48e01..be2e2805 100644 --- a/graph_weather/data/__init__.py +++ b/graph_weather/data/__init__.py @@ -1 +1,3 @@ """Dataloaders and data processing utilities""" + +from .nnjai_wrapp import AMSUDataset, collate_fn diff --git a/graph_weather/data/nnjai_wrapp.py b/graph_weather/data/nnjai_wrapp.py new file mode 100644 index 00000000..61c209d2 --- /dev/null +++ b/graph_weather/data/nnjai_wrapp.py @@ -0,0 +1,97 @@ +""" +A custom PyTorch Dataset implementation for AMSU datasets. + +This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets. +The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and +variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata. +""" + +import numpy as np +import torch +from torch.utils.data import Dataset + +try: + from nnja import DataCatalog +except ImportError: + print( + "NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`" + ) + + +class AMSUDataset(Dataset): + """A custom PyTorch Dataset for handling AMSU data. + + This dataset retrieves observations and their metadata, filtered by the provided time and + variable descriptors. + """ + + def __init__(self, dataset_name, time, primary_descriptors, additional_variables): + """Initialize the AMSU dataset loader. + + Args: + dataset_name: Name of the dataset to load. + time: Specific timestamp to filter the data. + primary_descriptors: List of primary descriptor variables to include (e.g., + OBS_TIMESTAMP, LAT, LON). + additional_variables: List of additional variables to include in metadata. + """ + self.dataset_name = dataset_name + self.time = time + self.primary_descriptors = primary_descriptors + self.additional_variables = additional_variables + + # Load data catalog and dataset + self.catalog = DataCatalog(skip_manifest=True) + self.dataset = self.catalog[self.dataset_name] + self.dataset.load_manifest() + + self.dataset = self.dataset.sel( + time=self.time, variables=self.primary_descriptors + self.additional_variables + ) + self.dataframe = self.dataset.load_dataset(engine="pandas") + + for col in primary_descriptors: + if col not in self.dataframe.columns: + raise ValueError(f"The dataset must include a '{col}' column.") + + self.metadata_columns = [ + col for col in self.dataframe.columns if col not in self.primary_descriptors + ] + + def __len__(self): + """Return the total number of samples in the dataset.""" + return len(self.dataframe) + + def __getitem__(self, index): + """Return the observation and metadata for a given index. + + Args: + index: Index of the observation to retrieve. + + Returns: + A dictionary containing timestamp, latitude, longitude, and metadata. + """ + row = self.dataframe.iloc[index] + time = row["OBS_TIMESTAMP"].timestamp() + latitude = row["LAT"] + longitude = row["LON"] + metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32) + + return { + "timestamp": torch.tensor(time, dtype=torch.float32), + "latitude": torch.tensor(latitude, dtype=torch.float32), + "longitude": torch.tensor(longitude, dtype=torch.float32), + "metadata": torch.from_numpy(metadata), + } + + +def collate_fn(batch): + """Custom collate function to handle batching of dictionary data. + + Args: + batch: List of dictionaries from __getitem__ + + Returns: + Single dictionary with batched tensors + """ + return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()} diff --git a/tests/test_nnjai.py b/tests/test_nnjai.py new file mode 100644 index 00000000..2cbf7508 --- /dev/null +++ b/tests/test_nnjai.py @@ -0,0 +1,134 @@ +""" +Tests for the nnjai_wrapp module in the graph_weather package. + +This file contains unit tests for AMSUDataset and collate_fn functions. +""" + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from graph_weather.data.nnjai_wrapp import AMSUDataset, collate_fn + + +# Mock the DataCatalog to avoid actual data loading +@pytest.fixture +def mock_datacatalog(): + """ + Fixture to mock the DataCatalog for unit tests to avoid actual data loading. + + This mock provides a mock dataset with predefined columns and values. + """ + with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock: + # Mock dataset structure + mock_df = MagicMock() + mock_df.columns = ["OBS_TIMESTAMP", "LAT", "LON", "TMBR_00001", "TMBR_00002"] + + # Define a mock row + class MockRow: + def __getitem__(self, key): + data = { + "OBS_TIMESTAMP": datetime.now(), + "LAT": 45.0, + "LON": -120.0, + "TMBR_00001": 250.0, + "TMBR_00002": 260.0, + } + return data.get(key, None) + + # Configure mock dataset + mock_row = MockRow() + mock_df.iloc = MagicMock() + mock_df.iloc.__getitem__.return_value = mock_row + mock_df.__len__.return_value = 100 + + mock_dataset = MagicMock() + mock_dataset.load_dataset.return_value = mock_df + mock_dataset.sel.return_value = mock_dataset + mock_dataset.load_manifest = MagicMock() + + mock.return_value.__getitem__.return_value = mock_dataset + yield mock + + +def test_amsu_dataset(mock_datacatalog): + """ + Test the AMSUDataset class to ensure proper data loading and tensor structure. + + This test validates the AMSUDataset class for its ability to load the dataset + correctly, check for the appropriate tensor properties, and ensure the keys + and data types match expectations. + """ + # Initialize dataset parameters + dataset_name = "amsua-1bamua-NC021023" + time = "2021-01-01 00Z" + primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"] + additional_variables = ["TMBR_00001", "TMBR_00002"] + + dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables) + + # Test dataset length + assert len(dataset) > 0, "Dataset should not be empty." + + item = dataset[0] + expected_keys = {"timestamp", "latitude", "longitude", "metadata"} + assert set(item.keys()) == expected_keys, "Dataset item keys are not as expected." + + # Validate tensor properties + assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor." + assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32." + assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor." + + assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor." + assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32." + assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor." + + assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor." + assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32." + assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor." + + assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor." + assert item["metadata"].shape == ( + len(additional_variables), + ), f"Metadata shape mismatch. Expected ({len(additional_variables)},)." + assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32." + + +def test_collate_function(): + """ + Test the collate_fn function to ensure proper batching of dataset items. + + This test checks that the collate_fn properly batches the timestamp, latitude, + longitude, and metadata fields of the dataset, ensuring correct shapes and data types. + """ + # Mock a batch of items + batch_size = 4 + metadata_size = 2 + mock_batch = [ + { + "timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32), + "latitude": torch.tensor(45.0, dtype=torch.float32), + "longitude": torch.tensor(-120.0, dtype=torch.float32), + "metadata": torch.randn(metadata_size, dtype=torch.float32), + } + for _ in range(batch_size) + ] + + # Collate the batch + batched = collate_fn(mock_batch) + + # Validate batched shapes and types + assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch." + assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch." + assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch." + assert batched["metadata"].shape == ( + batch_size, + metadata_size, + ), "Metadata batch shape mismatch." + + assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch." + assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch." + assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch." + assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch."