Skip to content

Commit

Permalink
Make make_spark_converter supports creating converter from a saved …
Browse files Browse the repository at this point in the history
…dataframe path (#787)

Signed-off-by: Weichen Xu [email protected]

Make make_spark_converter supports creating converter from a saved dataframe path.
In this case, we can skip the step of materializing spark dataframe that might be slow.
  • Loading branch information
WeichenXu123 authored Jan 30, 2023
1 parent 42f4af9 commit d337fee
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 14 deletions.
1 change: 1 addition & 0 deletions docs/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Release notes

Release 0.12.2 (unreleased)
===========================
- `PR 787 <https://github.com/uber/petastorm/pull/787>`_: ``make_spark_converter`` supports creating converter from a saved Spark DataFrame path.


Release 0.12.1
Expand Down
54 changes: 41 additions & 13 deletions petastorm/spark/spark_dataset_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,22 @@ def _check_url(dir_url):
'Please prepend "file://" for local filesystem.'.format(dir_url))


def _normalize_databricks_dbfs_url(url, err_msg):
if not (
url.startswith("file:/dbfs/") or
url.startswith("file:///dbfs/") or
url.startswith("dbfs:///") or
(url.startswith("dbfs:/") and not url.startswith("dbfs://"))
):
raise ValueError(err_msg)
if url.startswith("dbfs:///"):
# convert it to a dbfs fuse path
url = "file:/dbfs/" + url[len("dbfs:///"):]
elif url.startswith("dbfs:/"):
url = "file:/dbfs/" + url[len("dbfs:/"):]
return url


def _check_parent_cache_dir_url(dir_url):
"""Check dir url whether is suitable to be used as parent cache directory."""
_check_url(dir_url)
Expand Down Expand Up @@ -485,7 +501,7 @@ def _cache_df_or_retrieve_cache_data_url(df, parent_cache_dir_url,
If not, cache the df into the cache_dir in parquet format and return the
cache file path.
Use atexit to delete the cache before the python interpreter exits.
:param df: A :class:`DataFrame` object.
:param df: A :class:`pyspark.sql.DataFrame` object.
:param parquet_row_group_size_bytes: An int denoting the number of bytes
in a parquet row group.
:param compression_codec: Specify compression codec.
Expand Down Expand Up @@ -666,7 +682,10 @@ def make_spark_converter(
`SparkDatasetConverter.delete`, and when the spark application exit,
it will try best effort to delete the materialized dataframe data.
:param df: The :class:`DataFrame` object to be converted.
:param df: The :class:`pyspark.sql.DataFrame` object to be converted,
or a string of path pointing to the directory that stores the dataframe data
as parquet format, on databricks runtime, the path must be a dbfs
fuse path like 'file:/dbfs/xxx' or a dbfs path like 'dbfs:/xxx'.
:param parquet_row_group_size_bytes: An int denoting the number of bytes
in a parquet row group when materializing the dataframe.
:param compression_codec: Specify compression codec.
Expand All @@ -683,26 +702,35 @@ def make_spark_converter(

parent_cache_dir_url = _get_parent_cache_dir_url()

# TODO: Improve default behavior to be automatically choosing the best way.
compression_codec = compression_codec or "uncompressed"
if isinstance(df, str):
dataset_dir_url = df
if 'DATABRICKS_RUNTIME_VERSION' in os.environ:
dataset_dir_url = _normalize_databricks_dbfs_url(
dataset_dir_url,
"On databricks runtime, if `df` argument is a string, it must be a dbfs "
"fuse path like 'file:/dbfs/xxx' or a dbfs path like 'dbfs:/xxx'."
)
else:
# TODO: Improve default behavior to be automatically choosing the best way.
compression_codec = compression_codec or "uncompressed"

if compression_codec.lower() not in \
['uncompressed', 'bzip2', 'gzip', 'lz4', 'snappy', 'deflate']:
raise RuntimeError(
"compression_codec should be None or one of the following values: "
"'uncompressed', 'bzip2', 'gzip', 'lz4', 'snappy', 'deflate'")
if compression_codec.lower() not in \
['uncompressed', 'bzip2', 'gzip', 'lz4', 'snappy', 'deflate']:
raise RuntimeError(
"compression_codec should be None or one of the following values: "
"'uncompressed', 'bzip2', 'gzip', 'lz4', 'snappy', 'deflate'")

dataset_cache_dir_url = _cache_df_or_retrieve_cache_data_url(
df, parent_cache_dir_url, parquet_row_group_size_bytes, compression_codec, dtype)
dataset_dir_url = _cache_df_or_retrieve_cache_data_url(
df, parent_cache_dir_url, parquet_row_group_size_bytes, compression_codec, dtype)

# TODO: improve this by read parquet file metadata to get count
# Currently spark can make sure to only read the minimal column
# so count will usually be fast.
spark = _get_spark_session()
spark_df = spark.read.parquet(dataset_cache_dir_url)
spark_df = spark.read.parquet(dataset_dir_url)

dataset_size = spark_df.count()
parquet_file_url_list = list(spark_df._jdf.inputFiles())
_check_dataset_file_median_size(parquet_file_url_list)

return SparkDatasetConverter(dataset_cache_dir_url, parquet_file_url_list, dataset_size)
return SparkDatasetConverter(dataset_dir_url, parquet_file_url_list, dataset_size)
31 changes: 30 additions & 1 deletion petastorm/tests/test_spark_dataset_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
_check_dataset_file_median_size, _check_parent_cache_dir_url,
_check_rank_and_size_consistent_with_horovod, _check_url,
_get_horovod_rank_and_size, _get_spark_session, _make_sub_dir_url,
register_delete_dir_handler, _wait_file_available)
register_delete_dir_handler, _wait_file_available,
_normalize_databricks_dbfs_url,
)

from unittest import mock

Expand Down Expand Up @@ -649,3 +651,30 @@ def map_fn(_):

with pytest.raises(py4j.protocol.Py4JJavaError):
spark_test_ctx.spark.sparkContext.parallelize(range(1), 1).map(map_fn).collect()


def test_make_spark_convert_from_saved_df_path(spark_test_ctx):
df1 = spark_test_ctx.spark.range(100, 101)
output_path = \
os.path.join(spark_test_ctx.tempdir, "test_make_spark_convert_from_saved_df_path")
df1.write.parquet("file:" + output_path)
converter1 = make_spark_converter(output_path)

def map_fn(_):
with converter1.make_torch_dataloader(num_epochs=1) as dataloader:
for batch in dataloader:
ret = batch["id"][0]
return ret

result = spark_test_ctx.spark.sparkContext.parallelize(range(1), 1) \
.map(map_fn).collect()[0]
assert result == 100


def test_normalize_databricks_dbfs_url():
assert _normalize_databricks_dbfs_url("dbfs:/xx/yy", "") == "file:/dbfs/xx/yy"
assert _normalize_databricks_dbfs_url("dbfs:///xx/yy", "") == "file:/dbfs/xx/yy"
assert _normalize_databricks_dbfs_url("file:/dbfs/xx/yy", "") == "file:/dbfs/xx/yy"
assert _normalize_databricks_dbfs_url("file:///dbfs/xx/yy", "") == "file:///dbfs/xx/yy"
with pytest.raises(ValueError):
_normalize_databricks_dbfs_url("dbfs://xx/yy", "")

0 comments on commit d337fee

Please sign in to comment.