Skip to content

Commit

Permalink
nnjai support (#129)
Browse files Browse the repository at this point in the history
* nnjai support

* ruff format

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* changes as per requested

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test_nnjai.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yuvraajnarula and pre-commit-ci[bot] authored Jan 22, 2025
1 parent 98b9248 commit a633936
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 0 deletions.
1 change: 1 addition & 0 deletions graph_weather/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main import for the complete models"""

from .data.nnjai_wrapp import AMSUDataset, collate_fn
from .models.analysis import GraphWeatherAssimilator
from .models.forecast import GraphWeatherForecaster
2 changes: 2 additions & 0 deletions graph_weather/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""Dataloaders and data processing utilities"""

from .nnjai_wrapp import AMSUDataset, collate_fn
97 changes: 97 additions & 0 deletions graph_weather/data/nnjai_wrapp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
A custom PyTorch Dataset implementation for AMSU datasets.
This script defines a custom PyTorch Dataset (`AMSUDataset`) for working with AMSU datasets.
The dataset is loaded via the nnja library's `DataCatalog` and filtered for specific times and
variables. Each data point consists of a timestamp, latitude, longitude, and associated metadata.
"""

import numpy as np
import torch
from torch.utils.data import Dataset

try:
from nnja import DataCatalog
except ImportError:
print(
"NNJA-AI library not installed. Please install with `pip install git+https://github.com/brightbandtech/nnja-ai.git`"
)


class AMSUDataset(Dataset):
"""A custom PyTorch Dataset for handling AMSU data.
This dataset retrieves observations and their metadata, filtered by the provided time and
variable descriptors.
"""

def __init__(self, dataset_name, time, primary_descriptors, additional_variables):
"""Initialize the AMSU dataset loader.
Args:
dataset_name: Name of the dataset to load.
time: Specific timestamp to filter the data.
primary_descriptors: List of primary descriptor variables to include (e.g.,
OBS_TIMESTAMP, LAT, LON).
additional_variables: List of additional variables to include in metadata.
"""
self.dataset_name = dataset_name
self.time = time
self.primary_descriptors = primary_descriptors
self.additional_variables = additional_variables

# Load data catalog and dataset
self.catalog = DataCatalog(skip_manifest=True)
self.dataset = self.catalog[self.dataset_name]
self.dataset.load_manifest()

self.dataset = self.dataset.sel(
time=self.time, variables=self.primary_descriptors + self.additional_variables
)
self.dataframe = self.dataset.load_dataset(engine="pandas")

for col in primary_descriptors:
if col not in self.dataframe.columns:
raise ValueError(f"The dataset must include a '{col}' column.")

self.metadata_columns = [
col for col in self.dataframe.columns if col not in self.primary_descriptors
]

def __len__(self):
"""Return the total number of samples in the dataset."""
return len(self.dataframe)

def __getitem__(self, index):
"""Return the observation and metadata for a given index.
Args:
index: Index of the observation to retrieve.
Returns:
A dictionary containing timestamp, latitude, longitude, and metadata.
"""
row = self.dataframe.iloc[index]
time = row["OBS_TIMESTAMP"].timestamp()
latitude = row["LAT"]
longitude = row["LON"]
metadata = np.array([row[col] for col in self.metadata_columns], dtype=np.float32)

return {
"timestamp": torch.tensor(time, dtype=torch.float32),
"latitude": torch.tensor(latitude, dtype=torch.float32),
"longitude": torch.tensor(longitude, dtype=torch.float32),
"metadata": torch.from_numpy(metadata),
}


def collate_fn(batch):
"""Custom collate function to handle batching of dictionary data.
Args:
batch: List of dictionaries from __getitem__
Returns:
Single dictionary with batched tensors
"""
return {key: torch.stack([item[key] for item in batch]) for key in batch[0].keys()}
134 changes: 134 additions & 0 deletions tests/test_nnjai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Tests for the nnjai_wrapp module in the graph_weather package.
This file contains unit tests for AMSUDataset and collate_fn functions.
"""

from datetime import datetime
from unittest.mock import MagicMock, patch

import pytest
import torch

from graph_weather.data.nnjai_wrapp import AMSUDataset, collate_fn


# Mock the DataCatalog to avoid actual data loading
@pytest.fixture
def mock_datacatalog():
"""
Fixture to mock the DataCatalog for unit tests to avoid actual data loading.
This mock provides a mock dataset with predefined columns and values.
"""
with patch("graph_weather.data.nnjai_wrapp.DataCatalog") as mock:
# Mock dataset structure
mock_df = MagicMock()
mock_df.columns = ["OBS_TIMESTAMP", "LAT", "LON", "TMBR_00001", "TMBR_00002"]

# Define a mock row
class MockRow:
def __getitem__(self, key):
data = {
"OBS_TIMESTAMP": datetime.now(),
"LAT": 45.0,
"LON": -120.0,
"TMBR_00001": 250.0,
"TMBR_00002": 260.0,
}
return data.get(key, None)

# Configure mock dataset
mock_row = MockRow()
mock_df.iloc = MagicMock()
mock_df.iloc.__getitem__.return_value = mock_row
mock_df.__len__.return_value = 100

mock_dataset = MagicMock()
mock_dataset.load_dataset.return_value = mock_df
mock_dataset.sel.return_value = mock_dataset
mock_dataset.load_manifest = MagicMock()

mock.return_value.__getitem__.return_value = mock_dataset
yield mock


def test_amsu_dataset(mock_datacatalog):
"""
Test the AMSUDataset class to ensure proper data loading and tensor structure.
This test validates the AMSUDataset class for its ability to load the dataset
correctly, check for the appropriate tensor properties, and ensure the keys
and data types match expectations.
"""
# Initialize dataset parameters
dataset_name = "amsua-1bamua-NC021023"
time = "2021-01-01 00Z"
primary_descriptors = ["OBS_TIMESTAMP", "LAT", "LON"]
additional_variables = ["TMBR_00001", "TMBR_00002"]

dataset = AMSUDataset(dataset_name, time, primary_descriptors, additional_variables)

# Test dataset length
assert len(dataset) > 0, "Dataset should not be empty."

item = dataset[0]
expected_keys = {"timestamp", "latitude", "longitude", "metadata"}
assert set(item.keys()) == expected_keys, "Dataset item keys are not as expected."

# Validate tensor properties
assert isinstance(item["timestamp"], torch.Tensor), "Timestamp should be a tensor."
assert item["timestamp"].dtype == torch.float32, "Timestamp should have dtype float32."
assert item["timestamp"].ndim == 0, "Timestamp should be a scalar tensor."

assert isinstance(item["latitude"], torch.Tensor), "Latitude should be a tensor."
assert item["latitude"].dtype == torch.float32, "Latitude should have dtype float32."
assert item["latitude"].ndim == 0, "Latitude should be a scalar tensor."

assert isinstance(item["longitude"], torch.Tensor), "Longitude should be a tensor."
assert item["longitude"].dtype == torch.float32, "Longitude should have dtype float32."
assert item["longitude"].ndim == 0, "Longitude should be a scalar tensor."

assert isinstance(item["metadata"], torch.Tensor), "Metadata should be a tensor."
assert item["metadata"].shape == (
len(additional_variables),
), f"Metadata shape mismatch. Expected ({len(additional_variables)},)."
assert item["metadata"].dtype == torch.float32, "Metadata should have dtype float32."


def test_collate_function():
"""
Test the collate_fn function to ensure proper batching of dataset items.
This test checks that the collate_fn properly batches the timestamp, latitude,
longitude, and metadata fields of the dataset, ensuring correct shapes and data types.
"""
# Mock a batch of items
batch_size = 4
metadata_size = 2
mock_batch = [
{
"timestamp": torch.tensor(datetime.now().timestamp(), dtype=torch.float32),
"latitude": torch.tensor(45.0, dtype=torch.float32),
"longitude": torch.tensor(-120.0, dtype=torch.float32),
"metadata": torch.randn(metadata_size, dtype=torch.float32),
}
for _ in range(batch_size)
]

# Collate the batch
batched = collate_fn(mock_batch)

# Validate batched shapes and types
assert batched["timestamp"].shape == (batch_size,), "Timestamp batch shape mismatch."
assert batched["latitude"].shape == (batch_size,), "Latitude batch shape mismatch."
assert batched["longitude"].shape == (batch_size,), "Longitude batch shape mismatch."
assert batched["metadata"].shape == (
batch_size,
metadata_size,
), "Metadata batch shape mismatch."

assert batched["timestamp"].dtype == torch.float32, "Timestamp dtype mismatch."
assert batched["latitude"].dtype == torch.float32, "Latitude dtype mismatch."
assert batched["longitude"].dtype == torch.float32, "Longitude dtype mismatch."
assert batched["metadata"].dtype == torch.float32, "Metadata dtype mismatch."

0 comments on commit a633936

Please sign in to comment.