Skip to content

Commit

Permalink
Merge pull request #780 from NVIDIA/branch-24.10
Browse files Browse the repository at this point in the history
[auto-merge] branch-24.10 to branch-24.12 [skip ci] [bot]
  • Loading branch information
nvauto authored Nov 11, 2024
2 parents e5cdca3 + cc08a39 commit 3421199
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/src/spark_rapids_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def _train_udf(pdf_iter: Iterator[pd.DataFrame]) -> pd.DataFrame:
if concated_nnz > np.iinfo(np.int32).max:
logger.warn(
f"The number of non-zero values of a partition exceeds the int32 index dtype. \
cupyx csr_matrix currently does not promote the dtype to int64 when concatenated; \
cupyx csr_matrix currently does not support int64 indices (https://github.com/cupy/cupy/issues/3513); \
keeping as scipy csr_matrix to avoid overflow."
)
else:
Expand Down
2 changes: 1 addition & 1 deletion python/src/spark_rapids_ml/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,7 +1154,7 @@ def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
if concated_nnz > np.iinfo(np.int32).max:
logger.warn(
f"The number of non-zero values of a partition exceeds the int32 index dtype. \
cupyx csr_matrix currently does not promote the dtype to int64 when concatenated; \
cupyx csr_matrix currently does not support int64 indices (https://github.com/cupy/cupy/issues/3513); \
keeping as scipy csr_matrix to avoid overflow."
)
else:
Expand Down
21 changes: 21 additions & 0 deletions python/tests/test_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,24 @@ def test_params(tmp_path: str, default_params: bool) -> None:
def test_umap_model_persistence(
sparse_fit: bool, gpu_number: int, tmp_path: str
) -> None:
import pyspark
from cuml.datasets import make_blobs
from packaging import version

with CleanSparkSession() as spark:

n_rows = 5000
n_cols = 200

if sparse_fit:
if version.parse(pyspark.__version__) < version.parse("3.4.0"):
import logging

err_msg = "pyspark < 3.4 is detected. Cannot import pyspark `unwrap_udt` function for SparseVector. "
"The test case will be skipped. Please install pyspark>=3.4."
logging.info(err_msg)
return

data, input_raw_data = _load_sparse_binary_data(n_rows, n_cols, 30)
df = spark.createDataFrame(data, ["features"])
else:
Expand Down Expand Up @@ -429,6 +439,17 @@ def test_umap_chunking(
)

if sparse_fit:
import pyspark
from packaging import version

if version.parse(pyspark.__version__) < version.parse("3.4.0"):
import logging

err_msg = "pyspark < 3.4 is detected. Cannot import pyspark `unwrap_udt` function for SparseVector. "
"The test case will be skipped. Please install pyspark>=3.4."
logging.info(err_msg)
return

data, input_raw_data = _load_sparse_binary_data(n_rows, n_cols, 30)
df = spark.createDataFrame(data, ["features"])
nbytes = input_raw_data.data.nbytes
Expand Down

0 comments on commit 3421199

Please sign in to comment.