Skip to content

Commit

Permalink
fix(python): Raise suitable error when invalid column passed to `get_…
Browse files Browse the repository at this point in the history
…column_index` (#17868)
  • Loading branch information
alexander-beedie authored Jul 26, 2024
1 parent 10ea973 commit 43bf944
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 3 deletions.
2 changes: 2 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4754,6 +4754,8 @@ def get_column_index(self, name: str) -> int:
... )
>>> df.get_column_index("ham")
2
>>> df.get_column_index("sandwich") # doctest: +SKIP
ColumnNotFoundError: sandwich
"""
return self._df.get_column_index(name)

Expand Down
7 changes: 5 additions & 2 deletions py-polars/src/dataframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,11 @@ impl PyDataFrame {
}
}

pub fn get_column_index(&self, name: &str) -> Option<usize> {
self.df.get_column_index(name)
pub fn get_column_index(&self, name: &str) -> PyResult<usize> {
Ok(self
.df
.try_get_column_index(name)
.map_err(PyPolarsErr::from)?)
}

pub fn get_column(&self, name: &str) -> PyResult<PySeries> {
Expand Down
11 changes: 10 additions & 1 deletion py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2832,7 +2832,6 @@ def test_from_records_u64_12329() -> None:

def test_negative_slice_12642() -> None:
df = pl.DataFrame({"x": range(5)})

assert_frame_equal(df.slice(-2, 1), df.tail(2).head(1))


Expand All @@ -2841,3 +2840,13 @@ def test_iter_columns() -> None:
iter_columns = df.iter_columns()
assert_series_equal(next(iter_columns), pl.Series("a", [1, 1, 2]))
assert_series_equal(next(iter_columns), pl.Series("b", [4, 5, 6]))


def test_get_column_index() -> None:
df = pl.DataFrame({"actual": [1001], "expected": [1000]})

assert df.get_column_index("actual") == 0
assert df.get_column_index("expected") == 1

with pytest.raises(ColumnNotFoundError, match="missing"):
df.get_column_index("missing")

0 comments on commit 43bf944

Please sign in to comment.