diff --git a/keras/trainers/data_adapters/generator_data_adapter.py b/keras/trainers/data_adapters/generator_data_adapter.py index fe3c8766edf..1fbdc37326a 100644 --- a/keras/trainers/data_adapters/generator_data_adapter.py +++ b/keras/trainers/data_adapters/generator_data_adapter.py @@ -1,5 +1,6 @@ import itertools +import numpy as np import tree from keras.trainers.data_adapters.data_adapter import DataAdapter @@ -38,7 +39,12 @@ def get_tensor_spec(x): ) shape = list(shape) shape[0] = None # The batch size is not guaranteed to be static. - return tf.TensorSpec(shape=shape, dtype=x.dtype.name) + if isinstance(x, tf.RaggedTensor): + return tf.RaggedTensorSpec(shape=shape, dtype=x.dtype.name) + if isinstance(x, tf.SparseTensor) or is_scipy_sparse(x): + return tf.SparseTensorSpec(shape=shape, dtype=x.dtype.name) + else: + return tf.TensorSpec(shape=shape, dtype=x.dtype.name) self._output_signature = tree.map_structure(get_tensor_spec, data) @@ -49,10 +55,20 @@ def get_numpy_iterator(self): def get_tf_dataset(self): from keras.utils.module_utils import tensorflow as tf + def convert_to_tf(batch): + if is_scipy_sparse(batch): + batch = scipy_sparse_to_tf_sparse(batch) + return batch + + def get_tf_iterator(): + for batch in self.generator: + batch = tree.map_structure(convert_to_tf, batch) + yield batch + if self._output_signature is None: self._set_tf_output_signature() ds = tf.data.Dataset.from_generator( - self.get_numpy_iterator, + get_tf_iterator, output_signature=self._output_signature, ) ds = ds.prefetch(tf.data.AUTOTUNE) @@ -70,3 +86,21 @@ def batch_size(self): def peek_and_restore(generator): element = next(generator) return element, itertools.chain([element], generator) + + +def is_scipy_sparse(x): + return x.__class__.__module__.startswith("scipy.sparse") and hasattr( + x, "tocoo" + ) + + +def scipy_sparse_to_tf_sparse(x): + from keras.utils.module_utils import tensorflow as tf + + sparse_coo = x.tocoo() + row, col = sparse_coo.row, sparse_coo.col + data, shape = sparse_coo.data, sparse_coo.shape + indices = np.concatenate( + (np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1 + ) + return tf.SparseTensor(indices, data, shape) diff --git a/keras/trainers/data_adapters/generator_data_adapter_test.py b/keras/trainers/data_adapters/generator_data_adapter_test.py index 8537c74c97b..aeb4c375bdb 100644 --- a/keras/trainers/data_adapters/generator_data_adapter_test.py +++ b/keras/trainers/data_adapters/generator_data_adapter_test.py @@ -1,6 +1,7 @@ import math import numpy as np +import scipy import tensorflow as tf from absl.testing import parameterized @@ -88,3 +89,49 @@ def test_basic_flow(self, use_sample_weight): for i in range(by.shape[0]): sample_order.append(by[i, 0]) self.assertAllClose(sample_order, list(range(64))) + + def test_tf_sparse_tensors(self): + def generate_tf(): + for i in range(4): + x = tf.SparseTensor( + indices=[[0, 0], [1, 2]], + values=[1.0, 2.0], + dense_shape=(2, 4), + ) + y = tf.SparseTensor( + indices=[[0, 0], [1, 1]], + values=[3.0, 4.0], + dense_shape=(2, 2), + ) + yield x, y + + adapter = generator_data_adapter.GeneratorDataAdapter(generate_tf()) + ds = adapter.get_tf_dataset() + for batch in ds: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, tf.SparseTensor) + self.assertIsInstance(by, tf.SparseTensor) + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + + def test_scipy_sparse_tensors(self): + def generate_scipy(): + for i in range(4): + x = scipy.sparse.coo_matrix( + ([1.0, 2.0], ([0, 1], [0, 2])), shape=[2, 4] + ) + y = scipy.sparse.coo_matrix( + ([3.0, 4.0], ([0, 1], [0, 1])), shape=[2, 2] + ) + yield x, y + + adapter = generator_data_adapter.GeneratorDataAdapter(generate_scipy()) + ds = adapter.get_tf_dataset() + for batch in ds: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, tf.SparseTensor) + self.assertIsInstance(by, tf.SparseTensor) + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) diff --git a/keras/trainers/data_adapters/tf_dataset_adapter.py b/keras/trainers/data_adapters/tf_dataset_adapter.py index 5690e2c5b5b..abfa8ded51a 100644 --- a/keras/trainers/data_adapters/tf_dataset_adapter.py +++ b/keras/trainers/data_adapters/tf_dataset_adapter.py @@ -36,8 +36,15 @@ def __init__(self, dataset, class_weight=None, distribution=None): self._dataset = dataset def get_numpy_iterator(self): + from keras.utils.module_utils import tensorflow as tf + + def convert_to_numpy(x): + if isinstance(x, tf.SparseTensor): + x = tf.sparse.to_dense(x) + return x.numpy() + for batch in self._dataset: - yield tree.map_structure(lambda x: x.numpy(), batch) + yield tree.map_structure(convert_to_numpy, batch) def get_tf_dataset(self): return self._dataset diff --git a/keras/trainers/data_adapters/tf_dataset_adapter_test.py b/keras/trainers/data_adapters/tf_dataset_adapter_test.py index 4aa22233e6b..d1dd94d1dc9 100644 --- a/keras/trainers/data_adapters/tf_dataset_adapter_test.py +++ b/keras/trainers/data_adapters/tf_dataset_adapter_test.py @@ -252,3 +252,30 @@ def test_distribute_dataset(self): else: self.assertEqual(tuple(bx.shape), (2, 4)) self.assertEqual(tuple(by.shape), (2, 2)) + + def test_tf_sparse_tensors(self): + x = tf.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 4) + ) + y = tf.SparseTensor( + indices=[[0, 0], [1, 1]], values=[3.0, 4.0], dense_shape=(2, 2) + ) + base_ds = tf.data.Dataset.from_tensors((x, y)) + adapter = tf_dataset_adapter.TFDatasetAdapter(base_ds) + + gen = adapter.get_numpy_iterator() + for batch in gen: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, np.ndarray) + self.assertIsInstance(by, np.ndarray) + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2)) + ds = adapter.get_tf_dataset() + for batch in ds: + self.assertEqual(len(batch), 2) + bx, by = batch + self.assertIsInstance(bx, tf.SparseTensor) + self.assertIsInstance(by, tf.SparseTensor) + self.assertEqual(bx.shape, (2, 4)) + self.assertEqual(by.shape, (2, 2))