Skip to content

Commit

Permalink
Add support for Tensorflow SparseTensors: data adapters. (#18719)
Browse files Browse the repository at this point in the history
`GeneratorDataAdapter` now accepts `tf.SparseTensor`s and `scipy.sparse` matrices.

`TFDatasetAdapter` now accepts `tf.SparseTensor`s.
  • Loading branch information
hertschuh authored Nov 1, 2023
1 parent 588d3c5 commit 13ef4a8
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 3 deletions.
38 changes: 36 additions & 2 deletions keras/trainers/data_adapters/generator_data_adapter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools

import numpy as np
import tree

from keras.trainers.data_adapters.data_adapter import DataAdapter
Expand Down Expand Up @@ -38,7 +39,12 @@ def get_tensor_spec(x):
)
shape = list(shape)
shape[0] = None # The batch size is not guaranteed to be static.
return tf.TensorSpec(shape=shape, dtype=x.dtype.name)
if isinstance(x, tf.RaggedTensor):
return tf.RaggedTensorSpec(shape=shape, dtype=x.dtype.name)
if isinstance(x, tf.SparseTensor) or is_scipy_sparse(x):
return tf.SparseTensorSpec(shape=shape, dtype=x.dtype.name)
else:
return tf.TensorSpec(shape=shape, dtype=x.dtype.name)

self._output_signature = tree.map_structure(get_tensor_spec, data)

Expand All @@ -49,10 +55,20 @@ def get_numpy_iterator(self):
def get_tf_dataset(self):
from keras.utils.module_utils import tensorflow as tf

def convert_to_tf(batch):
if is_scipy_sparse(batch):
batch = scipy_sparse_to_tf_sparse(batch)
return batch

def get_tf_iterator():
for batch in self.generator:
batch = tree.map_structure(convert_to_tf, batch)
yield batch

if self._output_signature is None:
self._set_tf_output_signature()
ds = tf.data.Dataset.from_generator(
self.get_numpy_iterator,
get_tf_iterator,
output_signature=self._output_signature,
)
ds = ds.prefetch(tf.data.AUTOTUNE)
Expand All @@ -70,3 +86,21 @@ def batch_size(self):
def peek_and_restore(generator):
element = next(generator)
return element, itertools.chain([element], generator)


def is_scipy_sparse(x):
return x.__class__.__module__.startswith("scipy.sparse") and hasattr(
x, "tocoo"
)


def scipy_sparse_to_tf_sparse(x):
from keras.utils.module_utils import tensorflow as tf

sparse_coo = x.tocoo()
row, col = sparse_coo.row, sparse_coo.col
data, shape = sparse_coo.data, sparse_coo.shape
indices = np.concatenate(
(np.expand_dims(row, axis=1), np.expand_dims(col, axis=1)), axis=1
)
return tf.SparseTensor(indices, data, shape)
47 changes: 47 additions & 0 deletions keras/trainers/data_adapters/generator_data_adapter_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math

import numpy as np
import scipy
import tensorflow as tf
from absl.testing import parameterized

Expand Down Expand Up @@ -88,3 +89,49 @@ def test_basic_flow(self, use_sample_weight):
for i in range(by.shape[0]):
sample_order.append(by[i, 0])
self.assertAllClose(sample_order, list(range(64)))

def test_tf_sparse_tensors(self):
def generate_tf():
for i in range(4):
x = tf.SparseTensor(
indices=[[0, 0], [1, 2]],
values=[1.0, 2.0],
dense_shape=(2, 4),
)
y = tf.SparseTensor(
indices=[[0, 0], [1, 1]],
values=[3.0, 4.0],
dense_shape=(2, 2),
)
yield x, y

adapter = generator_data_adapter.GeneratorDataAdapter(generate_tf())
ds = adapter.get_tf_dataset()
for batch in ds:
self.assertEqual(len(batch), 2)
bx, by = batch
self.assertIsInstance(bx, tf.SparseTensor)
self.assertIsInstance(by, tf.SparseTensor)
self.assertEqual(bx.shape, (2, 4))
self.assertEqual(by.shape, (2, 2))

def test_scipy_sparse_tensors(self):
def generate_scipy():
for i in range(4):
x = scipy.sparse.coo_matrix(
([1.0, 2.0], ([0, 1], [0, 2])), shape=[2, 4]
)
y = scipy.sparse.coo_matrix(
([3.0, 4.0], ([0, 1], [0, 1])), shape=[2, 2]
)
yield x, y

adapter = generator_data_adapter.GeneratorDataAdapter(generate_scipy())
ds = adapter.get_tf_dataset()
for batch in ds:
self.assertEqual(len(batch), 2)
bx, by = batch
self.assertIsInstance(bx, tf.SparseTensor)
self.assertIsInstance(by, tf.SparseTensor)
self.assertEqual(bx.shape, (2, 4))
self.assertEqual(by.shape, (2, 2))
9 changes: 8 additions & 1 deletion keras/trainers/data_adapters/tf_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,15 @@ def __init__(self, dataset, class_weight=None, distribution=None):
self._dataset = dataset

def get_numpy_iterator(self):
from keras.utils.module_utils import tensorflow as tf

def convert_to_numpy(x):
if isinstance(x, tf.SparseTensor):
x = tf.sparse.to_dense(x)
return x.numpy()

for batch in self._dataset:
yield tree.map_structure(lambda x: x.numpy(), batch)
yield tree.map_structure(convert_to_numpy, batch)

def get_tf_dataset(self):
return self._dataset
Expand Down
27 changes: 27 additions & 0 deletions keras/trainers/data_adapters/tf_dataset_adapter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,30 @@ def test_distribute_dataset(self):
else:
self.assertEqual(tuple(bx.shape), (2, 4))
self.assertEqual(tuple(by.shape), (2, 2))

def test_tf_sparse_tensors(self):
x = tf.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 4)
)
y = tf.SparseTensor(
indices=[[0, 0], [1, 1]], values=[3.0, 4.0], dense_shape=(2, 2)
)
base_ds = tf.data.Dataset.from_tensors((x, y))
adapter = tf_dataset_adapter.TFDatasetAdapter(base_ds)

gen = adapter.get_numpy_iterator()
for batch in gen:
self.assertEqual(len(batch), 2)
bx, by = batch
self.assertIsInstance(bx, np.ndarray)
self.assertIsInstance(by, np.ndarray)
self.assertEqual(bx.shape, (2, 4))
self.assertEqual(by.shape, (2, 2))
ds = adapter.get_tf_dataset()
for batch in ds:
self.assertEqual(len(batch), 2)
bx, by = batch
self.assertIsInstance(bx, tf.SparseTensor)
self.assertIsInstance(by, tf.SparseTensor)
self.assertEqual(bx.shape, (2, 4))
self.assertEqual(by.shape, (2, 2))

0 comments on commit 13ef4a8

Please sign in to comment.