Skip to content

Commit

Permalink
WIP: Support include_file_paths in read_csv
Browse files Browse the repository at this point in the history
  • Loading branch information
alonme committed Oct 19, 2024
1 parent f88bd6a commit 7096edc
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 3 deletions.
8 changes: 8 additions & 0 deletions crates/polars-io/src/csv/read/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub struct CsvReadOptions {
pub raise_if_empty: bool,
pub ignore_errors: bool,
pub fields_to_cast: Vec<Field>,
pub include_file_paths: Option<PlSmallStr>,
}

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -81,6 +82,7 @@ impl Default for CsvReadOptions {
raise_if_empty: true,
ignore_errors: false,
fields_to_cast: vec![],
include_file_paths: None,
}
}
}
Expand Down Expand Up @@ -222,6 +224,12 @@ impl CsvReadOptions {
self
}

/// Include the path of the source file(s) as a column with this name, or don't include.
pub fn with_include_file_paths(mut self, include_file_paths: Option<PlSmallStr>) -> Self {
self.include_file_paths = include_file_paths;
self
}

/// Continue with next batch when a ParserError is encountered.
pub fn with_ignore_errors(mut self, ignore_errors: bool) -> Self {
self.ignore_errors = ignore_errors;
Expand Down
27 changes: 27 additions & 0 deletions crates/polars-io/src/csv/read/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,33 @@ where
let mut csv_reader = self.core_reader()?;
let mut df = csv_reader.as_df()?;

if let Some(col) = &self.options.include_file_paths {
// TODO: fix this - handle "open-file" vs "in-mem" - see `to_include_path_name`
let name = self
.options
.path
.as_ref()
.and_then(|path| path.to_str())
.unwrap_or("not a file");

if df.get_column_index(col).is_some() {
polars_bail!(
Duplicate: r#"column name for file paths "{}" conflicts with column name from file"#,
col
);
}

// TODO: add safety comment
// SAFETY:
unsafe {
df.with_column_unchecked(Column::new_scalar(
col.clone(),
Scalar::new(DataType::String, AnyValue::StringOwned(name.into())),
df.height(),
));
}
}

// Important that this rechunk is never done in parallel.
// As that leads to great memory overhead.
if rechunk && df.n_chunks() > 1 {
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-python/src/dataframe/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl PyDataFrame {
skip_rows, projection, separator, rechunk, columns, encoding, n_threads, path,
overwrite_dtype, overwrite_dtype_slice, low_memory, comment_prefix, quote_char,
null_values, missing_utf8_is_empty_string, try_parse_dates, skip_rows_after_header,
row_index, eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma, schema)
row_index, eol_char, raise_if_empty, truncate_ragged_lines, decimal_comma, schema, include_file_paths)
)]
pub fn read_csv(
py: Python,
Expand Down Expand Up @@ -65,6 +65,7 @@ impl PyDataFrame {
truncate_ragged_lines: bool,
decimal_comma: bool,
schema: Option<Wrap<Schema>>,
include_file_paths: Option<String>,
) -> PyResult<Self> {
let null_values = null_values.map(|w| w.0);
let eol_char = eol_char.as_bytes()[0];
Expand Down Expand Up @@ -113,6 +114,7 @@ impl PyDataFrame {
.with_skip_rows_after_header(skip_rows_after_header)
.with_row_index(row_index)
.with_raise_if_empty(raise_if_empty)
.with_include_file_paths(include_file_paths.map(|x| x.into()))
.with_parse_options(
CsvParseOptions::default()
.with_separator(separator.as_bytes()[0])
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/io/csv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def read_csv(
truncate_ragged_lines: bool = False,
decimal_comma: bool = False,
glob: bool = True,
include_file_paths: str | None = None,
) -> DataFrame:
r"""
Read a CSV file into a DataFrame.
Expand Down Expand Up @@ -208,6 +209,8 @@ def read_csv(
Parse floats using a comma as the decimal separator instead of a period.
glob
Expand path given via globbing rules.
include_file_paths
Include the path of the source file(s) as a column with this name.
Returns
-------
Expand Down Expand Up @@ -486,6 +489,7 @@ def read_csv(
truncate_ragged_lines=truncate_ragged_lines,
decimal_comma=decimal_comma,
glob=glob,
include_file_paths=include_file_paths,
)

if columns:
Expand Down Expand Up @@ -532,6 +536,7 @@ def read_csv(
truncate_ragged_lines=truncate_ragged_lines,
decimal_comma=decimal_comma,
glob=glob,
include_file_paths=include_file_paths,
)

if new_columns:
Expand Down Expand Up @@ -570,6 +575,7 @@ def _read_csv_impl(
truncate_ragged_lines: bool = False,
decimal_comma: bool = False,
glob: bool = True,
include_file_paths: str | None = None,
) -> DataFrame:
path: str | None
if isinstance(source, (str, Path)):
Expand Down Expand Up @@ -634,6 +640,7 @@ def _read_csv_impl(
truncate_ragged_lines=truncate_ragged_lines,
decimal_comma=decimal_comma,
glob=glob,
include_file_paths=include_file_paths,
)
if columns is None:
return scan.collect()
Expand Down Expand Up @@ -678,6 +685,7 @@ def _read_csv_impl(
truncate_ragged_lines=truncate_ragged_lines,
decimal_comma=decimal_comma,
schema=schema,
include_file_paths=include_file_paths,
)
return wrap_df(pydf)

Expand Down
34 changes: 32 additions & 2 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import zlib
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal as D
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, TypedDict

Expand All @@ -22,8 +23,6 @@
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from pathlib import Path

from polars._typing import TimeUnit
from tests.unit.conftest import MemoryUsage

Expand Down Expand Up @@ -2299,3 +2298,34 @@ def test_read_csv_cast_unparsable_later(
df.write_csv(f)
f.seek(0)
assert df.equals(pl.read_csv(f, schema={"x": dtype}))


@pytest.mark.write_disk
@pytest.mark.parametrize(("number_of_files"), [1, 2])
def test_read_csv_include_file_name(tmp_path: Path, number_of_files: int) -> None:
tmp_path.mkdir(exist_ok=True)
dfs: list[pl.DataFrame] = []

for x in ["1", "2"][:number_of_files]:
path = Path(f"{tmp_path}/{x}.csv").absolute()
dfs.append(pl.DataFrame({"x": 10 * [x]}).with_columns(path=pl.lit(str(path))))
dfs[-1].drop("path").write_csv(path)

expected = pl.concat(dfs)
assert expected.columns == ["x", "path"]

if number_of_files == 1:
read_csv_path = f"{tmp_path}/1.csv"
else:
read_csv_path = f"{tmp_path}/*.csv"

with pytest.raises(
pl.exceptions.DuplicateError,
match=r'column name for file paths "x" conflicts with column name from file',
):
pl.read_csv(read_csv_path, include_file_paths="x")

res = pl.read_csv(
read_csv_path, include_file_paths="path", schema=expected.drop("path").schema
)
assert_frame_equal(res, expected)

0 comments on commit 7096edc

Please sign in to comment.