From f3dbbf3557ff3ed09dda206f7795d6412b92f0c7 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 21 Feb 2024 04:22:08 +0100 Subject: [PATCH] Address Polars deprecation warnings / modernize syntax (#76) --- polars_queries/q1.py | 38 +++++++++++++++++------------------ polars_queries/q10.py | 44 ++++++++++++++++++----------------------- polars_queries/q11.py | 2 +- polars_queries/q12.py | 20 +++++++++---------- polars_queries/q13.py | 6 +++--- polars_queries/q14.py | 2 +- polars_queries/q15.py | 4 ++-- polars_queries/q16.py | 4 ++-- polars_queries/q17.py | 2 +- polars_queries/q18.py | 18 ++++++++--------- polars_queries/q2.py | 10 +++------- polars_queries/q20.py | 2 +- polars_queries/q21.py | 2 +- polars_queries/q22.py | 8 +++----- polars_queries/q3.py | 14 ++++++------- polars_queries/q4.py | 6 +++--- polars_queries/q5.py | 2 +- polars_queries/q6.py | 2 +- polars_queries/q7.py | 4 ++-- polars_queries/q8.py | 14 +++++-------- polars_queries/q9.py | 18 +++++++---------- polars_queries/utils.py | 22 ++++++++++----------- pyproject.toml | 1 + 23 files changed, 109 insertions(+), 136 deletions(-) diff --git a/polars_queries/q1.py b/polars_queries/q1.py index adc0214..3a0bff7 100644 --- a/polars_queries/q1.py +++ b/polars_queries/q1.py @@ -12,28 +12,26 @@ def q(): q = utils.get_line_item_ds() q_final = ( q.filter(pl.col("l_shipdate") <= var_1) - .group_by(["l_returnflag", "l_linestatus"]) + .group_by("l_returnflag", "l_linestatus") .agg( - [ - pl.sum("l_quantity").alias("sum_qty"), - pl.sum("l_extendedprice").alias("sum_base_price"), - (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) - .sum() - .alias("sum_disc_price"), - ( - pl.col("l_extendedprice") - * (1.0 - pl.col("l_discount")) - * (1.0 + pl.col("l_tax")) - ) - .sum() - .alias("sum_charge"), - pl.mean("l_quantity").alias("avg_qty"), - pl.mean("l_extendedprice").alias("avg_price"), - pl.mean("l_discount").alias("avg_disc"), - pl.count().alias("count_order"), - ], + pl.sum("l_quantity").alias("sum_qty"), + pl.sum("l_extendedprice").alias("sum_base_price"), + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) + .sum() + .alias("sum_disc_price"), + ( + pl.col("l_extendedprice") + * (1.0 - pl.col("l_discount")) + * (1.0 + pl.col("l_tax")) + ) + .sum() + .alias("sum_charge"), + pl.mean("l_quantity").alias("avg_qty"), + pl.mean("l_extendedprice").alias("avg_price"), + pl.mean("l_discount").alias("avg_disc"), + pl.len().alias("count_order"), ) - .sort(["l_returnflag", "l_linestatus"]) + .sort("l_returnflag", "l_linestatus") ) utils.run_query(Q_NUM, q_final) diff --git a/polars_queries/q10.py b/polars_queries/q10.py index a1e26ec..781d27c 100644 --- a/polars_queries/q10.py +++ b/polars_queries/q10.py @@ -23,38 +23,32 @@ def q(): .filter(pl.col("o_orderdate").is_between(var_1, var_2, closed="left")) .filter(pl.col("l_returnflag") == "R") .group_by( - [ - "c_custkey", - "c_name", - "c_acctbal", - "c_phone", - "n_name", - "c_address", - "c_comment", - ] + "c_custkey", + "c_name", + "c_acctbal", + "c_phone", + "n_name", + "c_address", + "c_comment", ) .agg( - [ - (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) - .sum() - .round(2) - .alias("revenue") - ] + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) + .sum() + .round(2) + .alias("revenue") ) .with_columns( pl.col("c_address").str.strip_chars(), pl.col("c_comment").str.strip_chars() ) .select( - [ - "c_custkey", - "c_name", - "revenue", - "c_acctbal", - "n_name", - "c_address", - "c_phone", - "c_comment", - ] + "c_custkey", + "c_name", + "revenue", + "c_acctbal", + "n_name", + "c_address", + "c_phone", + "c_comment", ) .sort(by="revenue", descending=True) .limit(20) diff --git a/polars_queries/q11.py b/polars_queries/q11.py index 99d10fc..637769a 100644 --- a/polars_queries/q11.py +++ b/polars_queries/q11.py @@ -34,7 +34,7 @@ def q(): .with_columns(pl.lit(1).alias("lit")) .join(res_2, on="lit") .filter(pl.col("value") > pl.col("tmp")) - .select(["ps_partkey", "value"]) + .select("ps_partkey", "value") .sort("value", descending=True) ) diff --git a/polars_queries/q12.py b/polars_queries/q12.py index f47a6c7..816aff2 100644 --- a/polars_queries/q12.py +++ b/polars_queries/q12.py @@ -23,19 +23,17 @@ def q(): .filter(pl.col("l_shipdate") < pl.col("l_commitdate")) .filter(pl.col("l_receiptdate").is_between(var_3, var_4, closed="left")) .with_columns( - [ - pl.when(pl.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"])) - .then(1) - .otherwise(0) - .alias("high_line_count"), - pl.when(pl.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"]).not_()) - .then(1) - .otherwise(0) - .alias("low_line_count"), - ] + pl.when(pl.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"])) + .then(1) + .otherwise(0) + .alias("high_line_count"), + pl.when(pl.col("o_orderpriority").is_in(["1-URGENT", "2-HIGH"]).not_()) + .then(1) + .otherwise(0) + .alias("low_line_count"), ) .group_by("l_shipmode") - .agg([pl.col("high_line_count").sum(), pl.col("low_line_count").sum()]) + .agg(pl.col("high_line_count").sum(), pl.col("low_line_count").sum()) .sort("l_shipmode") ) diff --git a/polars_queries/q13.py b/polars_queries/q13.py index fc2fda1..75c7259 100644 --- a/polars_queries/q13.py +++ b/polars_queries/q13.py @@ -20,9 +20,9 @@ def q(): .group_by("c_custkey") .agg(pl.col("o_orderkey").count().alias("c_count")) .group_by("c_count") - .count() - .select([pl.col("c_count"), pl.col("count").alias("custdist")]) - .sort(["custdist", "c_count"], descending=[True, True]) + .len() + .select(pl.col("c_count"), pl.col("len").alias("custdist")) + .sort(by=["custdist", "c_count"], descending=[True, True]) ) utils.run_query(Q_NUM, q_final) diff --git a/polars_queries/q14.py b/polars_queries/q14.py index 985c0c8..7986085 100644 --- a/polars_queries/q14.py +++ b/polars_queries/q14.py @@ -21,7 +21,7 @@ def q(): ( 100.00 * pl.when(pl.col("p_type").str.contains("PROMO*")) - .then((pl.col("l_extendedprice") * (1 - pl.col("l_discount")))) + .then(pl.col("l_extendedprice") * (1 - pl.col("l_discount"))) .otherwise(0) .sum() / (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).sum() diff --git a/polars_queries/q15.py b/polars_queries/q15.py index d51ccba..4190033 100644 --- a/polars_queries/q15.py +++ b/polars_queries/q15.py @@ -24,14 +24,14 @@ def q(): .sum() .alias("total_revenue") ) - .select([pl.col("l_suppkey").alias("supplier_no"), pl.col("total_revenue")]) + .select(pl.col("l_suppkey").alias("supplier_no"), pl.col("total_revenue")) ) q_final = ( supplier_ds.join(revenue_ds, left_on="s_suppkey", right_on="supplier_no") .filter(pl.col("total_revenue") == pl.col("total_revenue").max()) .with_columns(pl.col("total_revenue").round(2)) - .select(["s_suppkey", "s_name", "s_address", "s_phone", "total_revenue"]) + .select("s_suppkey", "s_name", "s_address", "s_phone", "total_revenue") .sort("s_suppkey") ) diff --git a/polars_queries/q16.py b/polars_queries/q16.py index fdc30a3..30df55d 100644 --- a/polars_queries/q16.py +++ b/polars_queries/q16.py @@ -23,8 +23,8 @@ def q(): .filter(pl.col("p_size").is_in([49, 14, 23, 45, 19, 3, 36, 9])) .join(supplier_ds, left_on="ps_suppkey", right_on="s_suppkey", how="left") .filter(pl.col("ps_suppkey_right").is_null()) - .group_by(["p_brand", "p_type", "p_size"]) - .agg([pl.col("ps_suppkey").n_unique().alias("supplier_cnt")]) + .group_by("p_brand", "p_type", "p_size") + .agg(pl.col("ps_suppkey").n_unique().alias("supplier_cnt")) .sort( by=["supplier_cnt", "p_brand", "p_type", "p_size"], descending=[True, False, False, False], diff --git a/polars_queries/q17.py b/polars_queries/q17.py index 8884a18..fed39a5 100644 --- a/polars_queries/q17.py +++ b/polars_queries/q17.py @@ -21,7 +21,7 @@ def q(): q_final = ( res_1.group_by("p_partkey") .agg((0.2 * pl.col("l_quantity").mean()).alias("avg_quantity")) - .select([pl.col("p_partkey").alias("key"), pl.col("avg_quantity")]) + .select(pl.col("p_partkey").alias("key"), pl.col("avg_quantity")) .join(res_1, left_on="key", right_on="p_partkey") .filter(pl.col("l_quantity") < pl.col("avg_quantity")) .select((pl.col("l_extendedprice").sum() / 7.0).round(2).alias("avg_yearly")) diff --git a/polars_queries/q18.py b/polars_queries/q18.py index f70dddf..a970ab7 100644 --- a/polars_queries/q18.py +++ b/polars_queries/q18.py @@ -16,23 +16,21 @@ def q(): line_item_ds.group_by("l_orderkey") .agg(pl.col("l_quantity").sum().alias("sum_quantity")) .filter(pl.col("sum_quantity") > var_1) - .select([pl.col("l_orderkey").alias("key"), pl.col("sum_quantity")]) + .select(pl.col("l_orderkey").alias("key"), pl.col("sum_quantity")) .join(orders_ds, left_on="key", right_on="o_orderkey") .join(line_item_ds, left_on="key", right_on="l_orderkey") .join(customer_ds, left_on="o_custkey", right_on="c_custkey") .group_by("c_name", "o_custkey", "key", "o_orderdate", "o_totalprice") .agg(pl.col("l_quantity").sum().alias("col6")) .select( - [ - pl.col("c_name"), - pl.col("o_custkey").alias("c_custkey"), - pl.col("key").alias("o_orderkey"), - pl.col("o_orderdate").alias("o_orderdat"), - pl.col("o_totalprice"), - pl.col("col6"), - ] + pl.col("c_name"), + pl.col("o_custkey").alias("c_custkey"), + pl.col("key").alias("o_orderkey"), + pl.col("o_orderdate").alias("o_orderdat"), + pl.col("o_totalprice"), + pl.col("col6"), ) - .sort(["o_totalprice", "o_orderdat"], descending=[True, False]) + .sort(by=["o_totalprice", "o_orderdat"], descending=[True, False]) .limit(100) ) diff --git a/polars_queries/q2.py b/polars_queries/q2.py index 05f4a99..d374b7d 100644 --- a/polars_queries/q2.py +++ b/polars_queries/q2.py @@ -39,19 +39,15 @@ def q(): q_final = ( result_q1.group_by("p_partkey") - .agg(pl.min("ps_supplycost").alias("ps_supplycost")) - .join( - result_q1, - left_on=["p_partkey", "ps_supplycost"], - right_on=["p_partkey", "ps_supplycost"], - ) + .agg(pl.min("ps_supplycost")) + .join(result_q1, on=["p_partkey", "ps_supplycost"]) .select(final_cols) .sort( by=["s_acctbal", "n_name", "s_name", "p_partkey"], descending=[True, False, False, False], ) .limit(100) - .with_columns(pl.col(pl.datatypes.Utf8).str.strip_chars().name.keep()) + .with_columns(pl.col(pl.String).str.strip_chars().name.keep()) ) utils.run_query(Q_NUM, q_final) diff --git a/polars_queries/q20.py b/polars_queries/q20.py index 524f0e7..01ed4b6 100644 --- a/polars_queries/q20.py +++ b/polars_queries/q20.py @@ -42,7 +42,7 @@ def q(): .select(pl.col("ps_suppkey").unique()) .join(res_3, left_on="ps_suppkey", right_on="s_suppkey") .with_columns(pl.col("s_address").str.strip_chars()) - .select(["s_name", "s_address"]) + .select("s_name", "s_address") .sort("s_name") ) diff --git a/polars_queries/q21.py b/polars_queries/q21.py index 3cdc69e..2347983 100644 --- a/polars_queries/q21.py +++ b/polars_queries/q21.py @@ -34,7 +34,7 @@ def q(): .filter(pl.col("n_name") == var_1) .filter(pl.col("o_orderstatus") == "F") .group_by("s_name") - .agg(pl.count().alias("numwait")) + .agg(pl.len().alias("numwait")) .sort(by=["numwait", "s_name"], descending=[True, False]) .limit(100) ) diff --git a/polars_queries/q22.py b/polars_queries/q22.py index 9094623..184f5b7 100644 --- a/polars_queries/q22.py +++ b/polars_queries/q22.py @@ -12,7 +12,7 @@ def q(): res_1 = ( customer_ds.with_columns(pl.col("c_phone").str.slice(0, 2).alias("cntrycode")) .filter(pl.col("cntrycode").str.contains("13|31|23|29|30|18|17")) - .select(["c_acctbal", "c_custkey", "cntrycode"]) + .select("c_acctbal", "c_custkey", "cntrycode") ) res_2 = ( @@ -33,10 +33,8 @@ def q(): .filter(pl.col("c_acctbal") > pl.col("avg_acctbal")) .group_by("cntrycode") .agg( - [ - pl.col("c_acctbal").count().alias("numcust"), - pl.col("c_acctbal").sum().round(2).alias("totacctbal"), - ] + pl.col("c_acctbal").count().alias("numcust"), + pl.col("c_acctbal").sum().round(2).alias("totacctbal"), ) .sort("cntrycode") ) diff --git a/polars_queries/q3.py b/polars_queries/q3.py index 1efa95c..624abac 100644 --- a/polars_queries/q3.py +++ b/polars_queries/q3.py @@ -24,15 +24,13 @@ def q(): .with_columns( (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("revenue") ) - .group_by(["o_orderkey", "o_orderdate", "o_shippriority"]) - .agg([pl.sum("revenue")]) + .group_by("o_orderkey", "o_orderdate", "o_shippriority") + .agg(pl.sum("revenue")) .select( - [ - pl.col("o_orderkey").alias("l_orderkey"), - "revenue", - "o_orderdate", - "o_shippriority", - ] + pl.col("o_orderkey").alias("l_orderkey"), + "revenue", + "o_orderdate", + "o_shippriority", ) .sort(by=["revenue", "o_orderdate"], descending=[True, False]) .limit(10) diff --git a/polars_queries/q4.py b/polars_queries/q4.py index 2dbbde9..d53bf8b 100644 --- a/polars_queries/q4.py +++ b/polars_queries/q4.py @@ -20,9 +20,9 @@ def q(): .filter(pl.col("l_commitdate") < pl.col("l_receiptdate")) .unique(subset=["o_orderpriority", "l_orderkey"]) .group_by("o_orderpriority") - .agg(pl.count().alias("order_count")) - .sort(by="o_orderpriority") - .with_columns(pl.col("order_count").cast(pl.datatypes.Int64)) + .agg(pl.len().alias("order_count")) + .sort("o_orderpriority") + .with_columns(pl.col("order_count").cast(pl.Int64)) ) utils.run_query(Q_NUM, q_final) diff --git a/polars_queries/q5.py b/polars_queries/q5.py index 93c2ea6..abb4a21 100644 --- a/polars_queries/q5.py +++ b/polars_queries/q5.py @@ -35,7 +35,7 @@ def q(): (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("revenue") ) .group_by("n_name") - .agg([pl.sum("revenue")]) + .agg(pl.sum("revenue")) .sort(by="revenue", descending=True) ) diff --git a/polars_queries/q6.py b/polars_queries/q6.py index cf9d84c..9009811 100644 --- a/polars_queries/q6.py +++ b/polars_queries/q6.py @@ -23,7 +23,7 @@ def q(): .with_columns( (pl.col("l_extendedprice") * pl.col("l_discount")).alias("revenue") ) - .select(pl.sum("revenue").alias("revenue")) + .select(pl.sum("revenue")) ) utils.run_query(Q_NUM, q_final) diff --git a/polars_queries/q7.py b/polars_queries/q7.py index 44a45e2..4b8e7f3 100644 --- a/polars_queries/q7.py +++ b/polars_queries/q7.py @@ -47,8 +47,8 @@ def q(): (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("volume") ) .with_columns(pl.col("l_shipdate").dt.year().alias("l_year")) - .group_by(["supp_nation", "cust_nation", "l_year"]) - .agg([pl.sum("volume").alias("revenue")]) + .group_by("supp_nation", "cust_nation", "l_year") + .agg(pl.sum("volume").alias("revenue")) .sort(by=["supp_nation", "cust_nation", "l_year"]) ) diff --git a/polars_queries/q8.py b/polars_queries/q8.py index 8723b6a..3dc9e24 100644 --- a/polars_queries/q8.py +++ b/polars_queries/q8.py @@ -16,8 +16,8 @@ def q(): nation_ds = utils.get_nation_ds() region_ds = utils.get_region_ds() - n1 = nation_ds.select(["n_nationkey", "n_regionkey"]) - n2 = nation_ds.clone().select(["n_nationkey", "n_name"]) + n1 = nation_ds.select("n_nationkey", "n_regionkey") + n2 = nation_ds.clone().select("n_nationkey", "n_name") q_final = ( part_ds.join(line_item_ds, left_on="p_partkey", right_on="l_partkey") @@ -35,13 +35,9 @@ def q(): ) .filter(pl.col("p_type") == "ECONOMY ANODIZED STEEL") .select( - [ - pl.col("o_orderdate").dt.year().alias("o_year"), - (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias( - "volume" - ), - pl.col("n_name").alias("nation"), - ] + pl.col("o_orderdate").dt.year().alias("o_year"), + (pl.col("l_extendedprice") * (1 - pl.col("l_discount"))).alias("volume"), + pl.col("n_name").alias("nation"), ) .with_columns( pl.when(pl.col("nation") == "BRAZIL") diff --git a/polars_queries/q9.py b/polars_queries/q9.py index ce3dcc4..0f55248 100644 --- a/polars_queries/q9.py +++ b/polars_queries/q9.py @@ -1,5 +1,3 @@ -from datetime import datetime - import polars as pl from polars_queries import utils @@ -27,16 +25,14 @@ def q(): .join(nation_ds, left_on="s_nationkey", right_on="n_nationkey") .filter(pl.col("p_name").str.contains("green")) .select( - [ - pl.col("n_name").alias("nation"), - pl.col("o_orderdate").dt.year().alias("o_year"), - ( - pl.col("l_extendedprice") * (1 - pl.col("l_discount")) - - pl.col("ps_supplycost") * pl.col("l_quantity") - ).alias("amount"), - ] + pl.col("n_name").alias("nation"), + pl.col("o_orderdate").dt.year().alias("o_year"), + ( + pl.col("l_extendedprice") * (1 - pl.col("l_discount")) + - pl.col("ps_supplycost") * pl.col("l_quantity") + ).alias("amount"), ) - .group_by(["nation", "o_year"]) + .group_by("nation", "o_year") .agg(pl.sum("amount").round(2).alias("sum_profit")) .sort(by=["nation", "o_year"], descending=[False, True]) ) diff --git a/polars_queries/utils.py b/polars_queries/utils.py index ed6162a..926ce82 100644 --- a/polars_queries/utils.py +++ b/polars_queries/utils.py @@ -1,6 +1,6 @@ import os import timeit -from os.path import join +from pathlib import Path import polars as pl from linetimer import CodeTimer, linetimer @@ -20,7 +20,7 @@ STREAMING = bool(os.environ.get("STREAMING", False)) -def _scan_ds(path: str): +def _scan_ds(path: Path): path = f"{path}.{FILE_TYPE}" if FILE_TYPE == "parquet": scan = pl.scan_parquet(path) @@ -37,7 +37,7 @@ def _scan_ds(path: str): def get_query_answer( query: int, base_dir: str = ANSWERS_PARQUET_BASE_DIR ) -> pl.LazyFrame: - return pl.scan_parquet(join(base_dir, f"q{query}.parquet")) + return pl.scan_parquet(Path(base_dir) / f"q{query}.parquet") def test_results(q_num: int, result_df: pl.DataFrame): @@ -47,35 +47,35 @@ def test_results(q_num: int, result_df: pl.DataFrame): def get_line_item_ds(base_dir: str = DATASET_BASE_DIR) -> pl.LazyFrame: - return _scan_ds(join(base_dir, "lineitem")) + return _scan_ds(Path(base_dir) / "lineitem") def get_orders_ds(base_dir: str = DATASET_BASE_DIR) -> pl.LazyFrame: - return _scan_ds(join(base_dir, "orders")) + return _scan_ds(Path(base_dir) / "orders") def get_customer_ds(base_dir: str = DATASET_BASE_DIR) -> pl.LazyFrame: - return _scan_ds(join(base_dir, "customer")) + return _scan_ds(Path(base_dir) / "customer") def get_region_ds(base_dir: str = DATASET_BASE_DIR) -> pl.LazyFrame: - return _scan_ds(join(base_dir, "region")) + return _scan_ds(Path(base_dir) / "region") def get_nation_ds(base_dir: str = DATASET_BASE_DIR) -> pl.LazyFrame: - return _scan_ds(join(base_dir, "nation")) + return _scan_ds(Path(base_dir) / "nation") def get_supplier_ds(base_dir: str = DATASET_BASE_DIR) -> pl.LazyFrame: - return _scan_ds(join(base_dir, "supplier")) + return _scan_ds(Path(base_dir) / "supplier") def get_part_ds(base_dir: str = DATASET_BASE_DIR) -> pl.LazyFrame: - return _scan_ds(join(base_dir, "part")) + return _scan_ds(Path(base_dir) / "part") def get_part_supp_ds(base_dir: str = DATASET_BASE_DIR) -> pl.LazyFrame: - return _scan_ds(join(base_dir, "partsupp")) + return _scan_ds(Path(base_dir) / "partsupp") def run_query(q_num: int, lp: pl.LazyFrame): diff --git a/pyproject.toml b/pyproject.toml index e9e0170..e2ab17a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,6 @@ [tool.ruff] line-length = 88 +target-version = "py312" fix = true [tool.ruff.lint]