Skip to content

Commit

Permalink
feat: Support hf:// in read_(csv|ipc|ndjson) functions (pola-rs#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored and atigbadr committed Jul 23, 2024
1 parent fc2a873 commit ba19ae1
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 76 deletions.
12 changes: 4 additions & 8 deletions crates/polars-io/src/path_utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ pub fn expand_paths_hive(
if is_cloud || { cfg!(not(target_family = "windows")) && config::force_async() } {
#[cfg(feature = "cloud")]
{
use polars_utils::_limit_path_len_io_err;

use crate::cloud::object_path_from_string;

if first_path.starts_with("hf://") {
Expand Down Expand Up @@ -199,14 +201,8 @@ pub fn expand_paths_hive(
// indistinguishable from an empty directory.
let path = PathBuf::from(path);
if !path.is_dir() {
path.metadata().map_err(|err| {
let msg =
Some(format!("{}: {}", err, path.to_str().unwrap()).into());
PolarsError::IO {
error: err.into(),
msg,
}
})?;
path.metadata()
.map_err(|err| _limit_path_len_io_err(&path, err))?;
}
}

Expand Down
5 changes: 1 addition & 4 deletions crates/polars-io/src/utils/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,7 @@ pub(crate) fn update_row_counts3(dfs: &mut [DataFrame], heights: &[IdxSize], off
}

#[cfg(feature = "json")]
pub(crate) fn overwrite_schema(
schema: &mut Schema,
overwriting_schema: &Schema,
) -> PolarsResult<()> {
pub fn overwrite_schema(schema: &mut Schema, overwriting_schema: &Schema) -> PolarsResult<()> {
for (k, value) in overwriting_schema.iter() {
*schema.try_get_mut(k)? = value.clone();
}
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-lazy/src/scan/ndjson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct LazyJsonLineReader {
pub(crate) low_memory: bool,
pub(crate) rechunk: bool,
pub(crate) schema: Option<SchemaRef>,
pub(crate) schema_overwrite: Option<SchemaRef>,
pub(crate) row_index: Option<RowIndex>,
pub(crate) infer_schema_length: Option<NonZeroUsize>,
pub(crate) n_rows: Option<usize>,
Expand All @@ -38,6 +39,7 @@ impl LazyJsonLineReader {
low_memory: false,
rechunk: false,
schema: None,
schema_overwrite: None,
row_index: None,
infer_schema_length: NonZeroUsize::new(100),
ignore_errors: false,
Expand Down Expand Up @@ -82,6 +84,13 @@ impl LazyJsonLineReader {
self
}

/// Set the JSON file's schema
#[must_use]
pub fn with_schema_overwrite(mut self, schema_overwrite: Option<SchemaRef>) -> Self {
self.schema_overwrite = schema_overwrite;
self
}

/// Reduce memory usage at the expense of performance
#[must_use]
pub fn low_memory(mut self, toggle: bool) -> Self {
Expand Down Expand Up @@ -129,6 +138,7 @@ impl LazyFileListReader for LazyJsonLineReader {
low_memory: self.low_memory,
ignore_errors: self.ignore_errors,
schema: self.schema,
schema_overwrite: self.schema_overwrite,
};

let scan_type = FileScan::NDJson {
Expand Down
7 changes: 6 additions & 1 deletion crates/polars-plan/src/plans/conversion/scans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ pub(super) fn ndjson_file_info(
};
let mut reader = std::io::BufReader::new(f);

let (reader_schema, schema) = if let Some(schema) = ndjson_options.schema.take() {
let (mut reader_schema, schema) = if let Some(schema) = ndjson_options.schema.take() {
if file_options.row_index.is_none() {
(schema.clone(), schema.clone())
} else {
Expand All @@ -340,6 +340,11 @@ pub(super) fn ndjson_file_info(
prepare_schemas(schema, file_options.row_index.as_ref())
};

if let Some(overwriting_schema) = &ndjson_options.schema_overwrite {
let schema = Arc::make_mut(&mut reader_schema);
overwrite_schema(schema, overwriting_schema)?;
}

Ok(FileInfo::new(
schema,
Some(Either::Right(reader_schema)),
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,5 @@ pub struct NDJsonReadOptions {
pub low_memory: bool,
pub ignore_errors: bool,
pub schema: Option<SchemaRef>,
pub schema_overwrite: Option<SchemaRef>,
}
6 changes: 3 additions & 3 deletions crates/polars-utils/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::path::Path;

use polars_error::*;

fn map_err(path: &Path, err: io::Error) -> PolarsError {
pub fn _limit_path_len_io_err(path: &Path, err: io::Error) -> PolarsError {
let path = path.to_string_lossy();
let msg = if path.len() > 88 {
let truncated_path: String = path.chars().skip(path.len() - 88).collect();
Expand All @@ -19,12 +19,12 @@ pub fn open_file<P>(path: P) -> PolarsResult<File>
where
P: AsRef<Path>,
{
File::open(&path).map_err(|err| map_err(path.as_ref(), err))
File::open(&path).map_err(|err| _limit_path_len_io_err(path.as_ref(), err))
}

pub fn create_file<P>(path: P) -> PolarsResult<File>
where
P: AsRef<Path>,
{
File::create(&path).map_err(|err| map_err(path.as_ref(), err))
File::create(&path).map_err(|err| _limit_path_len_io_err(path.as_ref(), err))
}
97 changes: 82 additions & 15 deletions py-polars/polars/io/csv/functions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import contextlib
import os
from io import BytesIO, StringIO
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Callable, Mapping, Sequence

import polars._reexport as pl
import polars.functions as F
from polars._utils.deprecation import deprecate_renamed_parameter
from polars._utils.various import (
_process_null_values,
Expand Down Expand Up @@ -419,45 +421,110 @@ def read_csv(
if not infer_schema:
infer_schema_length = 0

with prepare_file_arg(
source,
encoding=encoding,
use_pyarrow=False,
raise_if_empty=raise_if_empty,
storage_options=storage_options,
) as data:
df = _read_csv_impl(
data,
# TODO: scan_csv doesn't support a "dtype slice" (i.e. list[DataType])
schema_overrides_is_list = isinstance(schema_overrides, Sequence)
encoding_supported_in_lazy = encoding in {"utf8", "utf8-lossy"}

if (
# Check that it is not a BytesIO object
isinstance(v := source, (str, Path))
) and (
# HuggingFace only for now ⊂( ◜◒◝ )⊃
str(v).startswith("hf://")
# Also dispatch on FORCE_ASYNC, so that this codepath gets run
# through by our test suite during CI.
or (
os.getenv("POLARS_FORCE_ASYNC") == "1"
and not schema_overrides_is_list
and encoding_supported_in_lazy
)
# TODO: We can't dispatch this for all paths due to a few reasons:
# * `scan_csv` does not support compressed files
# * The `storage_options` configuration keys are different between
# fsspec and object_store (would require a breaking change)
):
if schema_overrides_is_list:
msg = "passing a list to `schema_overrides` is unsupported for hf:// paths"
raise ValueError(msg)
if not encoding_supported_in_lazy:
msg = f"unsupported encoding {encoding} for hf:// paths"
raise ValueError(msg)

lf = _scan_csv_impl(
source, # type: ignore[arg-type]
has_header=has_header,
columns=columns if columns else projection,
separator=separator,
comment_prefix=comment_prefix,
quote_char=quote_char,
skip_rows=skip_rows,
schema_overrides=schema_overrides,
schema_overrides=schema_overrides, # type: ignore[arg-type]
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
try_parse_dates=try_parse_dates,
n_threads=n_threads,
infer_schema_length=infer_schema_length,
batch_size=batch_size,
n_rows=n_rows,
encoding=encoding if encoding == "utf8-lossy" else "utf8",
encoding=encoding, # type: ignore[arg-type]
low_memory=low_memory,
rechunk=rechunk,
skip_rows_after_header=skip_rows_after_header,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
sample_size=sample_size,
eol_char=eol_char,
raise_if_empty=raise_if_empty,
truncate_ragged_lines=truncate_ragged_lines,
decimal_comma=decimal_comma,
glob=glob,
)

if columns:
lf = lf.select(columns)
elif projection:
lf = lf.select(F.nth(projection))

df = lf.collect()

else:
with prepare_file_arg(
source,
encoding=encoding,
use_pyarrow=False,
raise_if_empty=raise_if_empty,
storage_options=storage_options,
) as data:
df = _read_csv_impl(
data,
has_header=has_header,
columns=columns if columns else projection,
separator=separator,
comment_prefix=comment_prefix,
quote_char=quote_char,
skip_rows=skip_rows,
schema_overrides=schema_overrides,
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
try_parse_dates=try_parse_dates,
n_threads=n_threads,
infer_schema_length=infer_schema_length,
batch_size=batch_size,
n_rows=n_rows,
encoding=encoding if encoding == "utf8-lossy" else "utf8",
low_memory=low_memory,
rechunk=rechunk,
skip_rows_after_header=skip_rows_after_header,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
sample_size=sample_size,
eol_char=eol_char,
raise_if_empty=raise_if_empty,
truncate_ragged_lines=truncate_ragged_lines,
decimal_comma=decimal_comma,
glob=glob,
)

if new_columns:
return _update_columns(df, new_columns)
return df
Expand Down
42 changes: 38 additions & 4 deletions py-polars/polars/io/ipc/functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import contextlib
import os
from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, Sequence

import polars._reexport as pl
import polars.functions as F
from polars._utils.deprecation import deprecate_renamed_parameter
from polars._utils.various import (
is_str_sequence,
Expand All @@ -29,8 +31,6 @@
from polars._typing import SchemaDict


@deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4")
@deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4")
def read_ipc(
source: str | Path | IO[bytes] | bytes,
*,
Expand Down Expand Up @@ -92,6 +92,42 @@ def read_ipc(
That means that you cannot write to the same filename.
E.g. `pl.read_ipc("my_file.arrow").write_ipc("my_file.arrow")` will fail.
"""
if (
# Check that it is not a BytesIO object
isinstance(v := source, (str, Path))
) and (
# HuggingFace only for now ⊂( ◜◒◝ )⊃
(is_hf := str(v).startswith("hf://"))
# Also dispatch on FORCE_ASYNC, so that this codepath gets run
# through by our test suite during CI.
or os.getenv("POLARS_FORCE_ASYNC") == "1"
# TODO: Dispatch all paths to `scan_ipc` - this will need a breaking
# change to the `storage_options` parameter.
):
if is_hf and use_pyarrow:
msg = "`use_pyarrow=True` is not supported for Hugging Face"
raise ValueError(msg)

lf = scan_ipc(
source, # type: ignore[arg-type]
n_rows=n_rows,
memory_map=memory_map,
storage_options=storage_options,
row_index_name=row_index_name,
row_index_offset=row_index_offset,
rechunk=rechunk,
)

if columns:
if isinstance(columns[0], int):
lf = lf.select(F.nth(columns)) # type: ignore[arg-type]
else:
lf = lf.select(columns)

df = lf.collect()

return df

if use_pyarrow and n_rows and not memory_map:
msg = "`n_rows` cannot be used with `use_pyarrow=True` and `memory_map=False`"
raise ValueError(msg)
Expand Down Expand Up @@ -305,8 +341,6 @@ def read_ipc_schema(source: str | Path | IO[bytes] | bytes) -> dict[str, DataTyp
return _read_ipc_schema(source)


@deprecate_renamed_parameter("row_count_name", "row_index_name", version="0.20.4")
@deprecate_renamed_parameter("row_count_offset", "row_index_offset", version="0.20.4")
def scan_ipc(
source: str | Path | list[str] | list[Path],
*,
Expand Down
Loading

0 comments on commit ba19ae1

Please sign in to comment.