diff --git a/python/src/spark_rapids_ml/core.py b/python/src/spark_rapids_ml/core.py index 644c88a7..455aa6e7 100644 --- a/python/src/spark_rapids_ml/core.py +++ b/python/src/spark_rapids_ml/core.py @@ -752,7 +752,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame: if concated_nnz > np.iinfo(np.int32).max: logger.warn( f"The number of non-zero values of a partition exceeds the int32 index dtype. \ - cupyx csr_matrix currently does not promote the dtype to int64 when concatenated; \ + cupyx csr_matrix currently does not support int64 indices (https://github.com/cupy/cupy/issues/3513); \ keeping as scipy csr_matrix to avoid overflow." ) else: diff --git a/python/src/spark_rapids_ml/umap.py b/python/src/spark_rapids_ml/umap.py index 2fc68498..2ce0abf5 100644 --- a/python/src/spark_rapids_ml/umap.py +++ b/python/src/spark_rapids_ml/umap.py @@ -1154,7 +1154,7 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]: if concated_nnz > np.iinfo(np.int32).max: logger.warn( f"The number of non-zero values of a partition exceeds the int32 index dtype. \ - cupyx csr_matrix currently does not promote the dtype to int64 when concatenated; \ + cupyx csr_matrix currently does not support int64 indices (https://github.com/cupy/cupy/issues/3513); \ keeping as scipy csr_matrix to avoid overflow." ) else: diff --git a/python/tests/test_umap.py b/python/tests/test_umap.py index a4cd0f17..2e93d8ab 100644 --- a/python/tests/test_umap.py +++ b/python/tests/test_umap.py @@ -373,7 +373,9 @@ def test_params(tmp_path: str, default_params: bool) -> None: def test_umap_model_persistence( sparse_fit: bool, gpu_number: int, tmp_path: str ) -> None: + import pyspark from cuml.datasets import make_blobs + from packaging import version with CleanSparkSession() as spark: @@ -381,6 +383,14 @@ def test_umap_model_persistence( n_cols = 200 if sparse_fit: + if version.parse(pyspark.__version__) < version.parse("3.4.0"): + import logging + + err_msg = "pyspark < 3.4 is detected. Cannot import pyspark `unwrap_udt` function for SparseVector. " + "The test case will be skipped. Please install pyspark>=3.4." + logging.info(err_msg) + return + data, input_raw_data = _load_sparse_binary_data(n_rows, n_cols, 30) df = spark.createDataFrame(data, ["features"]) else: @@ -429,6 +439,17 @@ def test_umap_chunking( ) if sparse_fit: + import pyspark + from packaging import version + + if version.parse(pyspark.__version__) < version.parse("3.4.0"): + import logging + + err_msg = "pyspark < 3.4 is detected. Cannot import pyspark `unwrap_udt` function for SparseVector. " + "The test case will be skipped. Please install pyspark>=3.4." + logging.info(err_msg) + return + data, input_raw_data = _load_sparse_binary_data(n_rows, n_cols, 30) df = spark.createDataFrame(data, ["features"]) nbytes = input_raw_data.data.nbytes