Skip to content

Commit

Permalink
Ensure types match before calling member?
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Nov 28, 2023
1 parent 680be40 commit 01b78a5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
13 changes: 11 additions & 2 deletions lib/explorer/series.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3901,6 +3901,7 @@ defmodule Explorer.Series do
do: {:f, 64}

defp cast_to_ordered_series(:date, %Date{}), do: :date
defp cast_to_ordered_series(:time, %Time{}), do: :time

defp cast_to_ordered_series({:datetime, _}, %NaiveDateTime{}),
do: {:datetime, :microsecond}
Expand All @@ -3909,6 +3910,9 @@ defmodule Explorer.Series do
when is_integer(value),
do: :integer

defp cast_to_ordered_series({:duration, _}, %Explorer.Duration{}),
do: :duration

defp cast_to_ordered_series(_dtype, _value),
do: nil

Expand Down Expand Up @@ -5470,8 +5474,13 @@ defmodule Explorer.Series do
"""
@doc type: :list_wise
@spec member?(Series.t(), Explorer.Backend.Series.valid_types()) :: Series.t()
def member?(%Series{dtype: {:list, _}} = series, value),
do: apply_series(series, :member?, [value])
def member?(%Series{dtype: {:list, dtype}} = series, value) do
if cast_to_comparable_series(dtype, value) do
apply_series(series, :member?, [value])
else
dtype_mismatch_error("member?/2", series, value)
end
end

def member?(%Series{dtype: dtype}, _value),
do: dtype_error("member?/2", dtype, [{:list, :_}])
Expand Down
1 change: 1 addition & 0 deletions test/explorer/series_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4483,6 +4483,7 @@ defmodule Explorer.SeriesTest do
series = Series.from_list([[1.0], [1.0, 2.0]])

assert series |> Series.member?(2.0) |> Series.to_list() == [false, true]
assert series |> Series.member?(2) |> Series.to_list() == [false, true]
end

test "works with booleans" do
Expand Down

0 comments on commit 01b78a5

Please sign in to comment.