-
Hi, I'm trying out jax's sparse array support and found some surprising memory behavior for sparse sparse matrix multiplication. For matrix multiplication of two 10,000 x 10,000, 1% dense sparse matrices jax errors saying it can't allocate the required ~1Tb of memory. This seems like far too much memory to be requesting (scipy does this within 0.5gb), so I'm wondering if I've set this up correctly. Here's what I tried: from scipy import sparse
import numpy as np
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.experimental.sparse as jax_sparse
X = sparse.random(
10_000,
10_000,
density=0.01,
format='csr',
random_state=np.random.default_rng(),
dtype=np.float32
)
Y = sparse.random(
10_000,
10_000,
density=0.01,
format='csr',
random_state=np.random.default_rng(),
dtype=np.float32
)
X_jax = jax_sparse.BCOO.from_scipy_sparse(X)
Y_jax = jax_sparse.BCOO.from_scipy_sparse(Y)
X_jax @ Y_jax full traceback---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[11], line 1
----> 1 X_jax @ Y_jax
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/experimental/sparse/transform.py:465, in _sparsify_with_interpreter.<locals>.wrapped(*args, **params)
463 spenv = SparsifyEnv()
464 spvalues = arrays_to_spvalues(spenv, args)
--> 465 spvalues_out, out_tree = f_raw(spenv, *spvalues, **params)
466 out = spvalues_to_arrays(spenv, spvalues_out)
467 return tree_unflatten(out_tree, out)
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/experimental/sparse/transform.py:450, in sparsify_raw.<locals>.wrapped(spenv, *spvalues, **params)
448 wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(f, params), in_tree)
449 jaxpr, out_avals_flat, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals_flat)
--> 450 result = eval_sparse(jaxpr, consts, spvalues_flat, spenv)
451 if len(out_avals_flat) != len(result):
452 raise Exception("Internal: eval_sparse does not return expected number of arguments. "
453 "Got {result} for avals {out_avals_flat}")
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/experimental/sparse/transform.py:427, in eval_sparse(jaxpr, consts, spvalues, spenv)
425 if prim not in sparse_rules_bcoo:
426 _raise_unimplemented_primitive(prim)
--> 427 out = sparse_rules_bcoo[prim](spenv, *invals, **eqn.params)
428 else:
429 out_bufs = prim.bind(*(spenv.data(val) for val in invals), **eqn.params)
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/experimental/sparse/transform.py:538, in _standard_sparse_rule.<locals>._sparse_rule(spenv, *spvalues, **kwds)
537 def _sparse_rule(spenv, *spvalues, **kwds):
--> 538 result = sparse_op(*spvalues_to_arrays(spenv, spvalues), **kwds)
539 return arrays_to_spvalues(spenv, result if prim.multiple_results else [result])
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/experimental/sparse/bcoo.py:633, in bcoo_dot_general(***failed resolving arguments***)
630 if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
631 shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
632 dimension_numbers)
--> 633 bufs = _bcoo_spdot_general(lhs.data, lhs.indices, rhs.data, rhs.indices,
634 lhs_spinfo=lhs._info, rhs_spinfo=rhs._info,
635 dimension_numbers=dimension_numbers,
636 preferred_element_type=preferred_element_type)
637 return BCOO(bufs, shape=shape)
638 elif isinstance(lhs, BCOO):
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/experimental/sparse/bcoo.py:1094, in _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data, rhs_indices, lhs_spinfo, rhs_spinfo, dimension_numbers, preferred_element_type)
1090 cdims = (api_util._ensure_index_tuple(lhs_contract),
1091 api_util._ensure_index_tuple(rhs_contract))
1092 bdims = (api_util._ensure_index_tuple(lhs_batch),
1093 api_util._ensure_index_tuple(rhs_batch))
-> 1094 return bcoo_spdot_general_p.bind(lhs_data, lhs_indices, rhs_data, rhs_indices,
1095 lhs_spinfo=lhs_spinfo, rhs_spinfo=rhs_spinfo,
1096 dimension_numbers=(cdims, bdims),
1097 preferred_element_type=preferred_element_type)
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/_src/core.py:386, in Primitive.bind(self, *args, **params)
383 def bind(self, *args, **params):
384 assert (not config.jax_enable_checks or
385 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 386 return self.bind_with_trace(find_top_trace(args), args, params)
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/_src/core.py:389, in Primitive.bind_with_trace(self, trace, args, params)
388 def bind_with_trace(self, trace, args, params):
--> 389 out = trace.process_primitive(self, map(trace.full_raise, args), params)
390 return map(full_lower, out) if self.multiple_results else full_lower(out)
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/_src/core.py:821, in EvalTrace.process_primitive(self, primitive, tracers, params)
820 def process_primitive(self, primitive, tracers, params):
--> 821 return primitive.impl(*tracers, **params)
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/experimental/sparse/bcoo.py:1181, in _bcoo_spdot_general_impl(lhs_data, lhs_indices, rhs_data, rhs_indices, lhs_spinfo, rhs_spinfo, dimension_numbers, preferred_element_type)
1179 func = nfold_vmap(func, lhs.n_batch - len(lhs_batch), in_axes=(0, 0, None, None))
1180 func = nfold_vmap(func, len(lhs_batch))
-> 1181 return func(lhs_data, lhs_indices, rhs_data, rhs_indices)
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/experimental/sparse/bcoo.py:1124, in _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, lhs_spinfo, rhs_spinfo, lhs_contracting, rhs_contracting, out_nse)
1119 rhs_j = rhs_indices[:, jnp.array(remaining(range(rhs.n_sparse), rhs_contracting), dtype=int)]
1121 # TODO(jakevdp): can we do this more efficiently than using an outer product? Note that
1122 # jnp.isin() currently doesn't help much, because it also does all() over an outer
1123 # comparison.
-> 1124 overlap = (lhs_i[:, None] == rhs_i[None, :]).all(-1)
1125 lhs_fill_value = jnp.expand_dims(
1126 jnp.array([lhs_shape[d] for d in lhs_contracting], dtype=lhs_i.dtype),
1127 range(lhs_i.ndim - 1))
1128 rhs_fill_value = jnp.expand_dims(
1129 jnp.array([rhs_shape[d] for d in rhs_contracting], dtype=rhs_i.dtype),
1130 range(rhs_i.ndim - 1))
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py:256, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
254 args = (other, self) if swap else (self, other)
255 if isinstance(other, _accepted_binop_types):
--> 256 return binary_op(*args)
257 if isinstance(other, _rejected_binop_types):
258 raise TypeError(f"unsupported operand type(s) for {opchar}: "
259 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
[... skipping hidden 10 frame]
File /mnt/workspace/mambaforge/envs/jax/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py:1229, in ExecuteReplicated.__call__(self, *args)
1224 self._handle_token_bufs(
1225 results.disassemble_prefix_into_single_device_arrays(
1226 len(self.ordered_effects)),
1227 results.consume_token())
1228 else:
-> 1229 results = self.xla_executable.execute_sharded(input_bufs)
1230 if dispatch.needs_check_special():
1231 out_arrays = results.disassemble_into_single_device_arrays()
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 1000000000000 bytes. ---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[11], line 1
----> 1 X_jax @ Y_jax
...
XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 1000000000000 bytes. |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 16 replies
-
Yes, this is a known issue: sparse-sparse matmul uses I've actually thought about removing sparse-sparse matmul completely, because its performance tends to surprise people. What do you think? |
Beta Was this translation helpful? Give feedback.
-
Hi @jakevdp. Are there any developments on this? My takeaway from the above is that sparse matrix multiplication is not implemented efficiently. Is this still the case? Any efficient work arounds yet? |
Beta Was this translation helpful? Give feedback.
-
Curious how the |
Beta Was this translation helpful? Give feedback.
-
Hi @jakevdp . Will sparse-sparse matrix multiplication perform better when using GPU? If so, how is it compared with cuSPARSE? |
Beta Was this translation helpful? Give feedback.
Yes, this is a known issue: sparse-sparse matmul uses
nse_1 * nse_2
memory complexity. Unfortunately, there's no sparse matrix primitives in XLA, so it's hard to do much better than this in general. It's one of the reasons that these tools have not graduated from thejax.experimental
namespace.I've actually thought about removing sparse-sparse matmul completely, because its performance tends to surprise people. What do you think?