Skip to content

Commit

Permalink
Sup3rFixer layer adjusted to avoid casting to numpy array
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jan 23, 2025
1 parent c7a06b6 commit f60dc9f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
44 changes: 44 additions & 0 deletions phygnn/layers/custom_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,50 @@ def call(x, hi_res_adder):
return x + hi_res_adder


class Sup3rFixer(tf.keras.layers.Layer):
"""Layer to fix certain values for a sup3r model in the middle of a
super resolution forward pass. This is used to condition models on sparse
observation data."""

def __init__(self, name=None):
"""
Parameters
----------
name : str | None
Unique str identifier of the fixer layer. Usually the name of the
hi-resolution feature used in the fixing.
"""
super().__init__(name=name)

@staticmethod
def call(x, hi_res_fixer, feature_index):
"""Fixes hi-resolution data for the tensor x in the middle of a
sup3r resolution network.
Parameters
----------
x : tf.Tensor
Input tensor
hi_res_fixer : tf.Tensor | np.ndarray
This should be a 4D array for spatial enhancement model or 5D array
for a spatiotemporal enhancement model (obs, spatial_1, spatial_2,
(temporal), 1) that can be used to fix values of x.
feature_index : int
The index of the feature to fix. This assumes that x has the same
number of channels and indexing as the model output features.
Returns
-------
x : tf.Tensor
Output tensor with the hi_res_fixer used to fix values of x.
"""
mask = tf.math.is_nan(hi_res_fixer[..., 0])
arrs = [x[..., i] for i in range(x.shape[-1])]
arrs[feature_index] = tf.where(
mask, x[..., feature_index], hi_res_fixer[..., 0])
return tf.stack(arrs, axis=-1)


class Sup3rConcat(tf.keras.layers.Layer):
"""Layer to concatenate a high-resolution feature to a sup3r model in the
middle of a super resolution forward pass."""
Expand Down
33 changes: 27 additions & 6 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,27 @@
"""
import os
from tempfile import TemporaryDirectory

import numpy as np
import pytest
import tensorflow as tf

from phygnn import TfModel
from phygnn.layers.custom_layers import (
ExpandDims,
FlattenAxis,
FunctionalLayer,
GaussianAveragePooling2D,
GaussianNoiseAxis,
LogTransform,
SigLin,
SkipConnection,
SpatioTemporalExpansion,
Sup3rFixer,
TileLayer,
FunctionalLayer,
GaussianAveragePooling2D,
SigLin,
LogTransform,
UnitConversion,
)
from phygnn.layers.handlers import HiddenLayers, Layers
from phygnn import TfModel


@pytest.mark.parametrize(
Expand Down Expand Up @@ -224,7 +226,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll):
t_roll=t_roll)
n_filters = 2 * s_mult**2 * t_mult
shape = (1, 4, 4, 3, n_filters)
n = np.product(shape)
n = np.prod(shape)
x = np.arange(n).reshape(shape)
y = layer(x)
assert y.shape[0] == x.shape[0]
Expand Down Expand Up @@ -622,3 +624,22 @@ def test_unit_conversion():
# bad number of scalar values
layer = UnitConversion(adder=0, scalar=[100, 1, 1])
y = layer(x)


def test_fixer_layer():
"""Make sure ``Sup3rFixer`` layer works properly"""
x = np.random.uniform(0, 1, (1, 10, 10, 4)).astype(np.float32)
y = np.random.uniform(0, 1, (1, 10, 10, 1)).astype(np.float32)
mask = np.random.choice([False, True], (1, 10, 10), p=[0.1, 0.9])
y[mask] = np.nan

x = tf.convert_to_tensor(x)
y = tf.convert_to_tensor(y)

layer = Sup3rFixer()
out = layer(x, y, feature_index=0).numpy()

assert tf.reduce_any(tf.math.is_nan(y))
assert np.allclose(out[..., 0][~mask], y[..., 0][~mask])
assert x.shape == out.shape
assert not tf.reduce_any(tf.math.is_nan(out))

0 comments on commit f60dc9f

Please sign in to comment.