diff --git a/phygnn/layers/custom_layers.py b/phygnn/layers/custom_layers.py index 712994a..d92c1d0 100644 --- a/phygnn/layers/custom_layers.py +++ b/phygnn/layers/custom_layers.py @@ -897,6 +897,50 @@ def call(x, hi_res_adder): return x + hi_res_adder +class Sup3rFixer(tf.keras.layers.Layer): + """Layer to fix certain values for a sup3r model in the middle of a + super resolution forward pass. This is used to condition models on sparse + observation data.""" + + def __init__(self, name=None): + """ + Parameters + ---------- + name : str | None + Unique str identifier of the fixer layer. Usually the name of the + hi-resolution feature used in the fixing. + """ + super().__init__(name=name) + + @staticmethod + def call(x, hi_res_fixer, feature_index): + """Fixes hi-resolution data for the tensor x in the middle of a + sup3r resolution network. + + Parameters + ---------- + x : tf.Tensor + Input tensor + hi_res_fixer : tf.Tensor | np.ndarray + This should be a 4D array for spatial enhancement model or 5D array + for a spatiotemporal enhancement model (obs, spatial_1, spatial_2, + (temporal), 1) that can be used to fix values of x. + feature_index : int + The index of the feature to fix. This assumes that x has the same + number of channels and indexing as the model output features. + + Returns + ------- + x : tf.Tensor + Output tensor with the hi_res_fixer used to fix values of x. + """ + mask = tf.math.is_nan(hi_res_fixer[..., 0]) + arrs = [x[..., i] for i in range(x.shape[-1])] + arrs[feature_index] = tf.where( + mask, x[..., feature_index], hi_res_fixer[..., 0]) + return tf.stack(arrs, axis=-1) + + class Sup3rConcat(tf.keras.layers.Layer): """Layer to concatenate a high-resolution feature to a sup3r model in the middle of a super resolution forward pass.""" diff --git a/tests/test_layers.py b/tests/test_layers.py index 10758de..064c948 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -3,25 +3,27 @@ """ import os from tempfile import TemporaryDirectory + import numpy as np import pytest import tensorflow as tf +from phygnn import TfModel from phygnn.layers.custom_layers import ( ExpandDims, FlattenAxis, + FunctionalLayer, + GaussianAveragePooling2D, GaussianNoiseAxis, + LogTransform, + SigLin, SkipConnection, SpatioTemporalExpansion, + Sup3rFixer, TileLayer, - FunctionalLayer, - GaussianAveragePooling2D, - SigLin, - LogTransform, UnitConversion, ) from phygnn.layers.handlers import HiddenLayers, Layers -from phygnn import TfModel @pytest.mark.parametrize( @@ -224,7 +226,7 @@ def test_temporal_depth_to_time(t_mult, s_mult, t_roll): t_roll=t_roll) n_filters = 2 * s_mult**2 * t_mult shape = (1, 4, 4, 3, n_filters) - n = np.product(shape) + n = np.prod(shape) x = np.arange(n).reshape(shape) y = layer(x) assert y.shape[0] == x.shape[0] @@ -622,3 +624,22 @@ def test_unit_conversion(): # bad number of scalar values layer = UnitConversion(adder=0, scalar=[100, 1, 1]) y = layer(x) + + +def test_fixer_layer(): + """Make sure ``Sup3rFixer`` layer works properly""" + x = np.random.uniform(0, 1, (1, 10, 10, 4)).astype(np.float32) + y = np.random.uniform(0, 1, (1, 10, 10, 1)).astype(np.float32) + mask = np.random.choice([False, True], (1, 10, 10), p=[0.1, 0.9]) + y[mask] = np.nan + + x = tf.convert_to_tensor(x) + y = tf.convert_to_tensor(y) + + layer = Sup3rFixer() + out = layer(x, y, feature_index=0).numpy() + + assert tf.reduce_any(tf.math.is_nan(y)) + assert np.allclose(out[..., 0][~mask], y[..., 0][~mask]) + assert x.shape == out.shape + assert not tf.reduce_any(tf.math.is_nan(out))