Skip to content

Commit

Permalink
fix(api): ensure that round behavior is consistent across backends
Browse files Browse the repository at this point in the history
BREAKING CHANGE: The `round` method no longer accepts arbitrary integer expressions for the number of digits to round to; the value must be a Python integer value.
  • Loading branch information
cpcloud committed Oct 22, 2024
1 parent 3252b38 commit 45aef36
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def round(op, **kw):
arg = translate(op.arg, **kw)
typ = PolarsType.from_ibis(op.dtype)
digits = _literal_value(op.digits)
return arg.round(digits or 0).cast(typ)
return arg.round(digits).cast(typ)

Check warning on line 641 in ibis/backends/polars/compiler.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/polars/compiler.py#L641

Added line #L641 was not covered by tests


@translate.register(ops.Radians)
Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,5 +533,8 @@ def visit_ArrayFlatten(self, op, *, arg):
def visit_RandomUUID(self, op, **kw):
return self.f.anon.uuid()

def visit_Round(self, op, *, arg, digits):
return self.cast(self.f.round(arg, digits), op.dtype)


compiler = DataFusionCompiler()
3 changes: 3 additions & 0 deletions ibis/backends/sql/compilers/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,5 +196,8 @@ def visit_TimestampFromYMDHMS(
)
)

def visit_Round(self, op, *, arg, digits):
return self.cast(self.f.round(arg, digits), op.dtype)


compiler = DruidCompiler()
3 changes: 3 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,5 +712,8 @@ def visit_TableUnnest(
def visit_StringToTime(self, op, *, arg, format_str):
return self.cast(self.f.str_to_time(arg, format_str), to=dt.time)

def visit_Round(self, op, *, arg, digits):
return self.cast(self.f.round(arg, digits), op.dtype)


compiler = DuckDBCompiler()
3 changes: 3 additions & 0 deletions ibis/backends/sql/compilers/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,5 +537,8 @@ def visit_DateDelta(self, op, *, left, right, part):
)
return self.f._ibis_date_delta(left, right)

def visit_Round(self, op, *, arg, digits):
return self.cast(self.f.round(arg, digits), op.dtype)


compiler = SQLiteCompiler()
3 changes: 3 additions & 0 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,5 +692,8 @@ def visit_ArraySum(self, op, *, arg):
def visit_ArrayMean(self, op, *, arg):
return self.visit_ArraySumAgg(op, arg=arg, output=operator.truediv)

def visit_Round(self, op, *, arg, digits):
return self.cast(self.f.round(arg, digits), op.dtype)


compiler = TrinoCompiler()
22 changes: 22 additions & 0 deletions ibis/backends/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,3 +1516,25 @@ def test_bitwise_not_col(backend, alltypes, df):
result = expr.execute()
expected = ~df.int_col
backend.assert_series_equal(result, expected.rename("tmp"))


def test_column_round_is_integer(con):
t = ibis.memtable({"x": [1.2, 3.4]})
expr = t.x.round().cast(int)
result = con.execute(expr)

one, three = sorted(result.tolist())

assert one == 1
assert isinstance(one, int)

assert three == 3
assert isinstance(three, int)


def test_scalar_round_is_integer(con):
expr = ibis.literal(1.2).round().cast(int)
result = con.execute(expr)

assert result == 1
assert isinstance(result, int)
21 changes: 13 additions & 8 deletions ibis/expr/operations/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,24 @@ class Round(Value):
"""Round a value."""

arg: StrictNumeric
# TODO(kszucs): the default should be 0 instead of being None
digits: Optional[Integer] = None
digits: int = 0

shape = rlz.shape_like("arg")

@property
def dtype(self):
if self.arg.dtype.is_decimal():
return self.arg.dtype
elif self.digits is None:
return dt.int64
else:
return dt.double
digits = self.digits
arg_dtype = self.arg.dtype

if arg_dtype.is_decimal():
return arg_dtype.copy(scale=digits)

nullable = arg_dtype.nullable

if not digits:
return dt.int64.copy(nullable=nullable)

return dt.double.copy(nullable=nullable)


@public
Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __neg__(self) -> NumericValue:
"""
return self.negate()

def round(self, digits: int | IntegerValue | None = None) -> NumericValue:
def round(self, digits: int = 0) -> NumericValue:
"""Round values to an indicated number of decimal places.
Parameters
Expand Down
4 changes: 2 additions & 2 deletions ibis/tests/expr/test_sql_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,11 @@ def test_sign(functional_alltypes, lineitem):
def test_round(functional_alltypes, lineitem):
result = functional_alltypes.double_col.round()
assert isinstance(result, ir.IntegerColumn)
assert result.op().args[1] is None
assert result.op().args[1] == 0

result = functional_alltypes.double_col.round(2)
assert isinstance(result, ir.FloatingColumn)
assert result.op().args[1] == ibis.literal(2).op()
assert result.op().args[1] == 2

# Even integers are double (at least in Impala, check with other DB
# implementations)
Expand Down

0 comments on commit 45aef36

Please sign in to comment.