From 72044a6842dd848b923bbbff66de10a92e35ecbd Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 2 Apr 2024 10:38:55 +0200 Subject: [PATCH] feat: remove hand-optimizations from queries --- queries/pandas/q1.py | 53 +++++++++++++++----------------------------- queries/pandas/q5.py | 13 ++++++----- queries/polars/q2.py | 2 +- 3 files changed, 27 insertions(+), 41 deletions(-) diff --git a/queries/pandas/q1.py b/queries/pandas/q1.py index c368fc6..6314542 100644 --- a/queries/pandas/q1.py +++ b/queries/pandas/q1.py @@ -18,51 +18,34 @@ def query() -> pd.DataFrame: nonlocal lineitem lineitem = lineitem() - lineitem_filtered = lineitem.loc[ - :, - [ - "l_quantity", - "l_extendedprice", - "l_discount", - "l_tax", - "l_returnflag", - "l_linestatus", - "l_shipdate", - "l_orderkey", - ], - ] - sel = lineitem_filtered.l_shipdate <= VAR1 - lineitem_filtered = lineitem_filtered[sel] - lineitem_filtered["sum_qty"] = lineitem_filtered.l_quantity - lineitem_filtered["sum_base_price"] = lineitem_filtered.l_extendedprice - lineitem_filtered["avg_qty"] = lineitem_filtered.l_quantity - lineitem_filtered["avg_price"] = lineitem_filtered.l_extendedprice - lineitem_filtered["sum_disc_price"] = lineitem_filtered.l_extendedprice * ( + sel = lineitem.l_shipdate <= VAR1 + lineitem_filtered = lineitem[sel] + + # This is lenient towards pandas as normally an optimizer should decide + # that this could be computed before the groupby aggregation. + # Other implementations don't enjoy this benefit. + lineitem_filtered["disc_price"] = lineitem_filtered.l_extendedprice * ( 1 - lineitem_filtered.l_discount ) - lineitem_filtered["sum_charge"] = ( + lineitem_filtered["charge"] = ( lineitem_filtered.l_extendedprice * (1 - lineitem_filtered.l_discount) * (1 + lineitem_filtered.l_tax) ) - lineitem_filtered["avg_disc"] = lineitem_filtered.l_discount - lineitem_filtered["count_order"] = lineitem_filtered.l_orderkey - gb = lineitem_filtered.groupby(["l_returnflag", "l_linestatus"]) + gb = lineitem_filtered.groupby(["l_returnflag", "l_linestatus"], as_index=False) total = gb.agg( - { - "sum_qty": "sum", - "sum_base_price": "sum", - "sum_disc_price": "sum", - "sum_charge": "sum", - "avg_qty": "mean", - "avg_price": "mean", - "avg_disc": "mean", - "count_order": "count", - } + sum_qty=pd.NamedAgg(column="l_quantity", aggfunc="sum"), + sum_base_price=pd.NamedAgg(column="l_extendedprice", aggfunc="sum"), + sum_disc_price=pd.NamedAgg(column="disc_price", aggfunc="sum"), + sum_charge=pd.NamedAgg(column="charge", aggfunc="sum"), + avg_qty=pd.NamedAgg(column="l_quantity", aggfunc="mean"), + avg_price=pd.NamedAgg(column="l_extendedprice", aggfunc="mean"), + avg_disc=pd.NamedAgg(column="l_discount", aggfunc="mean"), + count_order=pd.NamedAgg(column="l_orderkey", aggfunc="size"), ) - result_df = total.reset_index().sort_values(["l_returnflag", "l_linestatus"]) + result_df = total.sort_values(["l_returnflag", "l_linestatus"]) return result_df # type: ignore[no-any-return] diff --git a/queries/pandas/q5.py b/queries/pandas/q5.py index 1946d3d..a91c579 100644 --- a/queries/pandas/q5.py +++ b/queries/pandas/q5.py @@ -42,12 +42,9 @@ def query() -> pd.DataFrame: supplier_ds = supplier_ds() rsel = region_ds.r_name == "ASIA" - osel = (orders_ds.o_orderdate >= date1) & (orders_ds.o_orderdate < date2) - forders = orders_ds[osel] - fregion = region_ds[rsel] - jn1 = fregion.merge(nation_ds, left_on="r_regionkey", right_on="n_regionkey") + jn1 = region_ds.merge(nation_ds, left_on="r_regionkey", right_on="n_regionkey") jn2 = jn1.merge(customer_ds, left_on="n_nationkey", right_on="c_nationkey") - jn3 = jn2.merge(forders, left_on="c_custkey", right_on="o_custkey") + jn3 = jn2.merge(orders_ds, left_on="c_custkey", right_on="o_custkey") jn4 = jn3.merge(line_item_ds, left_on="o_orderkey", right_on="l_orderkey") jn5 = supplier_ds.merge( jn4, @@ -55,8 +52,14 @@ def query() -> pd.DataFrame: right_on=["l_suppkey", "n_nationkey"], ) jn5["revenue"] = jn5.l_extendedprice * (1.0 - jn5.l_discount) + jn5 = jn5[ + (jn5.o_orderdate >= date1) + & (jn5.o_orderdate < date2) + & (jn5.r_name == rsel) + ] gb = jn5.groupby("n_name", as_index=False)["revenue"].sum() result_df = gb.sort_values("revenue", ascending=False) + return result_df # type: ignore[no-any-return] utils.run_query(Q_NUM, query) diff --git a/queries/polars/q2.py b/queries/polars/q2.py index 4646ba7..4552d90 100644 --- a/queries/polars/q2.py +++ b/queries/polars/q2.py @@ -24,7 +24,7 @@ def q() -> None: .filter(pl.col("p_size") == var_1) .filter(pl.col("p_type").str.ends_with(var_2)) .filter(pl.col("r_name") == var_3) - ).cache() + ) final_cols = [ "s_acctbal",