From 27836132c97b3324b4ed5969ac9fd08751fbc8af Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Dec 2024 23:27:36 -0700 Subject: [PATCH] chore: Make query stage / shuffle code easier to understand (#54) --- datafusion_ray/context.py | 7 +++-- src/planner.rs | 2 +- src/query_stage.rs | 42 ++++++++++++-------------- src/shuffle/codec.rs | 2 +- src/shuffle/writer.rs | 10 +++---- testdata/expected-plans/q1.txt | 2 +- testdata/expected-plans/q10.txt | 2 +- testdata/expected-plans/q11.txt | 2 +- testdata/expected-plans/q12.txt | 2 +- testdata/expected-plans/q13.txt | 2 +- testdata/expected-plans/q16.txt | 2 +- testdata/expected-plans/q18.txt | 2 +- testdata/expected-plans/q2.txt | 2 +- testdata/expected-plans/q20.txt | 2 +- testdata/expected-plans/q21.txt | 2 +- testdata/expected-plans/q22.txt | 2 +- testdata/expected-plans/q3.txt | 2 +- testdata/expected-plans/q4.txt | 2 +- testdata/expected-plans/q5.txt | 2 +- testdata/expected-plans/q7.txt | 2 +- testdata/expected-plans/q8.txt | 2 +- testdata/expected-plans/q9.txt | 2 +- tests/test_context.py | 52 ++++++++++++++++----------------- 23 files changed, 72 insertions(+), 77 deletions(-) diff --git a/datafusion_ray/context.py b/datafusion_ray/context.py index 0070220..8d354ff 100644 --- a/datafusion_ray/context.py +++ b/datafusion_ray/context.py @@ -50,7 +50,7 @@ def execute_query_stage( # if the query stage has a single output partition then we need to execute for the output # partition, otherwise we need to execute in parallel for each input partition - concurrency = stage.get_input_partition_count() + concurrency = stage.get_execution_partition_count() output_partitions_count = stage.get_output_partition_count() if output_partitions_count == 1: # reduce stage @@ -159,5 +159,6 @@ def plan(self, execution_plan: Any) -> List[pa.RecordBatch]: ) _, partitions = ray.get(future) # assert len(partitions) == 1, len(partitions) - result_set = ray.get(partitions[0]) - return result_set + record_batches = ray.get(partitions[0]) + # filter out empty batches + return [batch for batch in record_batches if batch.num_rows > 0] diff --git a/src/planner.rs b/src/planner.rs index 954d8e2..c1e7b41 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -399,7 +399,7 @@ mod test { let query_stage = graph.query_stages.get(&id).unwrap(); output.push_str(&format!( "Query Stage #{id} ({} -> {}):\n{}\n", - query_stage.get_input_partition_count(), + query_stage.get_execution_partition_count(), query_stage.get_output_partition_count(), displayable(query_stage.plan.as_ref()).indent(false) )); diff --git a/src/query_stage.rs b/src/query_stage.rs index 05c090b..a5c9a08 100644 --- a/src/query_stage.rs +++ b/src/query_stage.rs @@ -16,7 +16,7 @@ // under the License. use crate::context::serialize_execution_plan; -use crate::shuffle::{ShuffleCodec, ShuffleReaderExec}; +use crate::shuffle::{ShuffleCodec, ShuffleReaderExec, ShuffleWriterExec}; use datafusion::error::Result; use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, Partitioning}; use datafusion::prelude::SessionContext; @@ -60,8 +60,8 @@ impl PyQueryStage { self.stage.get_child_stage_ids() } - pub fn get_input_partition_count(&self) -> usize { - self.stage.get_input_partition_count() + pub fn get_execution_partition_count(&self) -> usize { + self.stage.get_execution_partition_count() } pub fn get_output_partition_count(&self) -> usize { @@ -75,16 +75,6 @@ pub struct QueryStage { pub plan: Arc, } -fn _get_output_partition_count(plan: &dyn ExecutionPlan) -> usize { - // UnknownPartitioning and HashPartitioning with empty expressions will - // both return 1 partition. - match plan.properties().output_partitioning() { - Partitioning::UnknownPartitioning(_) => 1, - Partitioning::Hash(expr, _) if expr.is_empty() => 1, - p => p.partition_count(), - } -} - impl QueryStage { pub fn new(id: usize, plan: Arc) -> Self { Self { id, plan } @@ -96,21 +86,27 @@ impl QueryStage { ids } - /// Get the input partition count. This is the same as the number of concurrent tasks - /// when we schedule this query stage for execution - pub fn get_input_partition_count(&self) -> usize { - if self.plan.children().is_empty() { - // leaf node (file scan) - self.plan.output_partitioning().partition_count() + /// Get the number of partitions that can be executed in parallel + pub fn get_execution_partition_count(&self) -> usize { + if let Some(shuffle) = self.plan.as_any().downcast_ref::() { + // use the partitioning of the input to the shuffle write because we are + // really executing that and then using the shuffle writer to repartition + // the output + shuffle.input_plan.output_partitioning().partition_count() } else { - self.plan.children()[0] - .output_partitioning() - .partition_count() + // for any other plan, use its output partitioning + self.plan.output_partitioning().partition_count() } } pub fn get_output_partition_count(&self) -> usize { - _get_output_partition_count(self.plan.as_ref()) + // UnknownPartitioning and HashPartitioning with empty expressions will + // both return 1 partition. + match self.plan.properties().output_partitioning() { + Partitioning::UnknownPartitioning(_) => 1, + Partitioning::Hash(expr, _) if expr.is_empty() => 1, + p => p.partition_count(), + } } } diff --git a/src/shuffle/codec.rs b/src/shuffle/codec.rs index 79af0b8..0420428 100644 --- a/src/shuffle/codec.rs +++ b/src/shuffle/codec.rs @@ -102,7 +102,7 @@ impl PhysicalExtensionCodec for ShuffleCodec { }; PlanType::ShuffleReader(reader) } else if let Some(writer) = node.as_any().downcast_ref::() { - let plan = PhysicalPlanNode::try_from_physical_plan(writer.plan.clone(), self)?; + let plan = PhysicalPlanNode::try_from_physical_plan(writer.input_plan.clone(), self)?; let partitioning = encode_partitioning_scheme(writer.properties().output_partitioning())?; let writer = ShuffleWriterExecNode { diff --git a/src/shuffle/writer.rs b/src/shuffle/writer.rs index 069f99d..0e0f984 100644 --- a/src/shuffle/writer.rs +++ b/src/shuffle/writer.rs @@ -47,7 +47,7 @@ use std::sync::Arc; #[derive(Debug)] pub struct ShuffleWriterExec { pub stage_id: usize, - pub(crate) plan: Arc, + pub(crate) input_plan: Arc, /// Output partitioning properties: PlanProperties, /// Directory to write shuffle files from @@ -84,7 +84,7 @@ impl ShuffleWriterExec { Self { stage_id, - plan, + input_plan: plan, properties, shuffle_dir: shuffle_dir.to_string(), metrics: ExecutionPlanMetricsSet::new(), @@ -98,11 +98,11 @@ impl ExecutionPlan for ShuffleWriterExec { } fn schema(&self) -> SchemaRef { - self.plan.schema() + self.input_plan.schema() } fn children(&self) -> Vec<&Arc> { - vec![&self.plan] + vec![&self.input_plan] } fn with_new_children( @@ -122,7 +122,7 @@ impl ExecutionPlan for ShuffleWriterExec { self.stage_id ); - let mut stream = self.plan.execute(input_partition, context)?; + let mut stream = self.input_plan.execute(input_partition, context)?; let write_time = MetricBuilder::new(&self.metrics).subset_time("write_time", input_partition); let repart_time = diff --git a/testdata/expected-plans/q1.txt b/testdata/expected-plans/q1.txt index 282d5da..6f78394 100644 --- a/testdata/expected-plans/q1.txt +++ b/testdata/expected-plans/q1.txt @@ -42,7 +42,7 @@ ShuffleWriterExec(stage_id=1, output_partitioning=Hash([Column { name: "l_return CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=0, input_partitioning=Hash([Column { name: "l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }], 2)) -Query Stage #2 (2 -> 1): +Query Stage #2 (1 -> 1): SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST, l_linestatus@1 ASC NULLS LAST] ShuffleReaderExec(stage_id=1, input_partitioning=Hash([Column { name: "l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }], 2)) diff --git a/testdata/expected-plans/q10.txt b/testdata/expected-plans/q10.txt index 046f69e..3825561 100644 --- a/testdata/expected-plans/q10.txt +++ b/testdata/expected-plans/q10.txt @@ -117,7 +117,7 @@ ShuffleWriterExec(stage_id=7, output_partitioning=Hash([Column { name: "c_custke CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { name: "c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column { name: "c_acctbal", index: 2 }, Column { name: "c_phone", index: 3 }, Column { name: "n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column { name: "c_comment", index: 6 }], 2)) -Query Stage #8 (2 -> 1): +Query Stage #8 (1 -> 1): SortPreservingMergeExec: [revenue@2 DESC], fetch=20 ShuffleReaderExec(stage_id=7, input_partitioning=Hash([Column { name: "c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column { name: "c_acctbal", index: 3 }, Column { name: "c_phone", index: 6 }, Column { name: "n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column { name: "c_comment", index: 7 }], 2)) diff --git a/testdata/expected-plans/q11.txt b/testdata/expected-plans/q11.txt index 74f74d7..2972d52 100644 --- a/testdata/expected-plans/q11.txt +++ b/testdata/expected-plans/q11.txt @@ -167,7 +167,7 @@ ShuffleWriterExec(stage_id=10, output_partitioning=Hash([Column { name: "ps_part CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=9, input_partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 2)) -Query Stage #11 (2 -> 1): +Query Stage #11 (1 -> 1): SortPreservingMergeExec: [value@1 DESC] ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name: "ps_partkey", index: 0 }], 2)) diff --git a/testdata/expected-plans/q12.txt b/testdata/expected-plans/q12.txt index c7ae269..4cf0596 100644 --- a/testdata/expected-plans/q12.txt +++ b/testdata/expected-plans/q12.txt @@ -65,7 +65,7 @@ ShuffleWriterExec(stage_id=3, output_partitioning=Hash([Column { name: "l_shipmo CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column { name: "l_shipmode", index: 0 }], 2)) -Query Stage #4 (2 -> 1): +Query Stage #4 (1 -> 1): SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST] ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: "l_shipmode", index: 0 }], 2)) diff --git a/testdata/expected-plans/q13.txt b/testdata/expected-plans/q13.txt index 366db12..da7e93a 100644 --- a/testdata/expected-plans/q13.txt +++ b/testdata/expected-plans/q13.txt @@ -70,7 +70,7 @@ ShuffleWriterExec(stage_id=3, output_partitioning=Hash([Column { name: "c_count" CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column { name: "c_count", index: 0 }], 2)) -Query Stage #4 (2 -> 1): +Query Stage #4 (1 -> 1): SortPreservingMergeExec: [custdist@1 DESC, c_count@0 DESC] ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: "c_count", index: 0 }], 2)) diff --git a/testdata/expected-plans/q16.txt b/testdata/expected-plans/q16.txt index 24ecb18..b26e9a4 100644 --- a/testdata/expected-plans/q16.txt +++ b/testdata/expected-plans/q16.txt @@ -107,7 +107,7 @@ ShuffleWriterExec(stage_id=6, output_partitioning=Hash([Column { name: "p_brand" CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column { name: "p_brand", index: 0 }, Column { name: "p_type", index: 1 }, Column { name: "p_size", index: 2 }], 2)) -Query Stage #7 (2 -> 1): +Query Stage #7 (1 -> 1): SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST, p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST] ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { name: "p_brand", index: 0 }, Column { name: "p_type", index: 1 }, Column { name: "p_size", index: 2 }], 2)) diff --git a/testdata/expected-plans/q18.txt b/testdata/expected-plans/q18.txt index 30179d0..a5d28e8 100644 --- a/testdata/expected-plans/q18.txt +++ b/testdata/expected-plans/q18.txt @@ -104,7 +104,7 @@ ShuffleWriterExec(stage_id=6, output_partitioning=Hash([Column { name: "c_name", CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column { name: "c_name", index: 0 }, Column { name: "c_custkey", index: 1 }, Column { name: "o_orderkey", index: 2 }, Column { name: "o_orderdate", index: 3 }, Column { name: "o_totalprice", index: 4 }], 2)) -Query Stage #7 (2 -> 1): +Query Stage #7 (1 -> 1): SortPreservingMergeExec: [o_totalprice@4 DESC, o_orderdate@3 ASC NULLS LAST], fetch=100 ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { name: "c_name", index: 0 }, Column { name: "c_custkey", index: 1 }, Column { name: "o_orderkey", index: 2 }, Column { name: "o_orderdate", index: 3 }, Column { name: "o_totalprice", index: 4 }], 2)) diff --git a/testdata/expected-plans/q2.txt b/testdata/expected-plans/q2.txt index bc0713c..9778441 100644 --- a/testdata/expected-plans/q2.txt +++ b/testdata/expected-plans/q2.txt @@ -252,7 +252,7 @@ ShuffleWriterExec(stage_id=17, output_partitioning=Hash([Column { name: "p_partk CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=16, input_partitioning=Hash([Column { name: "ps_partkey", index: 1 }, Column { name: "min(partsupp.ps_supplycost)", index: 0 }], 2)) -Query Stage #18 (2 -> 1): +Query Stage #18 (1 -> 1): SortPreservingMergeExec: [s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1 ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], fetch=100 ShuffleReaderExec(stage_id=17, input_partitioning=Hash([Column { name: "p_partkey", index: 3 }], 2)) diff --git a/testdata/expected-plans/q20.txt b/testdata/expected-plans/q20.txt index 13b21c8..e1bc54c 100644 --- a/testdata/expected-plans/q20.txt +++ b/testdata/expected-plans/q20.txt @@ -142,7 +142,7 @@ ShuffleWriterExec(stage_id=8, output_partitioning=Hash([], 2)) CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=7, input_partitioning=Hash([Column { name: "ps_suppkey", index: 0 }], 2)) -Query Stage #9 (2 -> 1): +Query Stage #9 (1 -> 1): SortPreservingMergeExec: [s_name@0 ASC NULLS LAST] ShuffleReaderExec(stage_id=8, input_partitioning=Hash([], 2)) diff --git a/testdata/expected-plans/q21.txt b/testdata/expected-plans/q21.txt index b88bccc..8d6798f 100644 --- a/testdata/expected-plans/q21.txt +++ b/testdata/expected-plans/q21.txt @@ -172,7 +172,7 @@ ShuffleWriterExec(stage_id=10, output_partitioning=Hash([Column { name: "s_name" CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=9, input_partitioning=Hash([Column { name: "s_name", index: 0 }], 2)) -Query Stage #11 (2 -> 1): +Query Stage #11 (1 -> 1): SortPreservingMergeExec: [numwait@1 DESC, s_name@0 ASC NULLS LAST], fetch=100 ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name: "s_name", index: 0 }], 2)) diff --git a/testdata/expected-plans/q22.txt b/testdata/expected-plans/q22.txt index da693fb..7ad4ae1 100644 --- a/testdata/expected-plans/q22.txt +++ b/testdata/expected-plans/q22.txt @@ -91,7 +91,7 @@ ShuffleWriterExec(stage_id=4, output_partitioning=Hash([Column { name: "cntrycod CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: "cntrycode", index: 0 }], 2)) -Query Stage #5 (2 -> 1): +Query Stage #5 (1 -> 1): SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST] ShuffleReaderExec(stage_id=4, input_partitioning=Hash([Column { name: "cntrycode", index: 0 }], 2)) diff --git a/testdata/expected-plans/q3.txt b/testdata/expected-plans/q3.txt index f9039d3..3af2ea0 100644 --- a/testdata/expected-plans/q3.txt +++ b/testdata/expected-plans/q3.txt @@ -97,7 +97,7 @@ ShuffleWriterExec(stage_id=5, output_partitioning=Hash([Column { name: "l_orderk CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=4, input_partitioning=Hash([Column { name: "l_orderkey", index: 0 }, Column { name: "o_orderdate", index: 1 }, Column { name: "o_shippriority", index: 2 }], 2)) -Query Stage #6 (2 -> 1): +Query Stage #6 (1 -> 1): SortPreservingMergeExec: [revenue@1 DESC, o_orderdate@2 ASC NULLS LAST], fetch=10 ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column { name: "l_orderkey", index: 0 }, Column { name: "o_orderdate", index: 2 }, Column { name: "o_shippriority", index: 3 }], 2)) diff --git a/testdata/expected-plans/q4.txt b/testdata/expected-plans/q4.txt index 20460e4..2504483 100644 --- a/testdata/expected-plans/q4.txt +++ b/testdata/expected-plans/q4.txt @@ -70,7 +70,7 @@ ShuffleWriterExec(stage_id=3, output_partitioning=Hash([Column { name: "o_orderp CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column { name: "o_orderpriority", index: 0 }], 2)) -Query Stage #4 (2 -> 1): +Query Stage #4 (1 -> 1): SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST] ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name: "o_orderpriority", index: 0 }], 2)) diff --git a/testdata/expected-plans/q5.txt b/testdata/expected-plans/q5.txt index 2bacb27..3e66ddb 100644 --- a/testdata/expected-plans/q5.txt +++ b/testdata/expected-plans/q5.txt @@ -167,7 +167,7 @@ ShuffleWriterExec(stage_id=11, output_partitioning=Hash([Column { name: "n_name" CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name: "n_name", index: 0 }], 2)) -Query Stage #12 (2 -> 1): +Query Stage #12 (1 -> 1): SortPreservingMergeExec: [revenue@1 DESC] ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name: "n_name", index: 0 }], 2)) diff --git a/testdata/expected-plans/q7.txt b/testdata/expected-plans/q7.txt index 43bc031..9321b1b 100644 --- a/testdata/expected-plans/q7.txt +++ b/testdata/expected-plans/q7.txt @@ -176,7 +176,7 @@ ShuffleWriterExec(stage_id=11, output_partitioning=Hash([Column { name: "supp_na CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name: "supp_nation", index: 0 }, Column { name: "cust_nation", index: 1 }, Column { name: "l_year", index: 2 }], 2)) -Query Stage #12 (2 -> 1): +Query Stage #12 (1 -> 1): SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC NULLS LAST, l_year@2 ASC NULLS LAST] ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name: "supp_nation", index: 0 }, Column { name: "cust_nation", index: 1 }, Column { name: "l_year", index: 2 }], 2)) diff --git a/testdata/expected-plans/q8.txt b/testdata/expected-plans/q8.txt index e9f5b91..c7ec1ec 100644 --- a/testdata/expected-plans/q8.txt +++ b/testdata/expected-plans/q8.txt @@ -230,7 +230,7 @@ ShuffleWriterExec(stage_id=15, output_partitioning=Hash([Column { name: "o_year" CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=14, input_partitioning=Hash([Column { name: "o_year", index: 0 }], 2)) -Query Stage #16 (2 -> 1): +Query Stage #16 (1 -> 1): SortPreservingMergeExec: [o_year@0 ASC NULLS LAST] ShuffleReaderExec(stage_id=15, input_partitioning=Hash([Column { name: "o_year", index: 0 }], 2)) diff --git a/testdata/expected-plans/q9.txt b/testdata/expected-plans/q9.txt index 2c713b3..fa087f1 100644 --- a/testdata/expected-plans/q9.txt +++ b/testdata/expected-plans/q9.txt @@ -166,7 +166,7 @@ ShuffleWriterExec(stage_id=11, output_partitioning=Hash([Column { name: "nation" CoalesceBatchesExec: target_batch_size=8192 ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name: "nation", index: 0 }, Column { name: "o_year", index: 1 }], 2)) -Query Stage #12 (2 -> 1): +Query Stage #12 (1 -> 1): SortPreservingMergeExec: [nation@0 ASC NULLS LAST, o_year@1 DESC] ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name: "nation", index: 0 }, Column { name: "o_year", index: 1 }], 2)) diff --git a/tests/test_context.py b/tests/test_context.py index ecc3324..602f761 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -17,42 +17,42 @@ from datafusion_ray.context import DatafusionRayContext from datafusion import SessionContext, SessionConfig, RuntimeConfig, col, lit, functions as F +import pytest +@pytest.fixture +def df_ctx(): + """Fixture to create a DataFusion context.""" + # used fixed partition count so that tests are deterministic on different environments + config = SessionConfig().with_target_partitions(4) + return SessionContext(config=config) -def test_basic_query_succeed(): - df_ctx = SessionContext() - ctx = DatafusionRayContext(df_ctx) +@pytest.fixture +def ctx(df_ctx): + """Fixture to create a Datafusion Ray context.""" + return DatafusionRayContext(df_ctx) + +def test_basic_query_succeed(df_ctx, ctx): df_ctx.register_csv("tips", "examples/tips.csv", has_header=True) - # TODO why does this return a single batch and not a list of batches? record_batches = ctx.sql("SELECT * FROM tips") - assert record_batches[0].num_rows == 244 + assert len(record_batches) <= 4 + num_rows = sum(batch.num_rows for batch in record_batches) + assert num_rows == 244 -def test_aggregate_csv(): - df_ctx = SessionContext() - ctx = DatafusionRayContext(df_ctx) +def test_aggregate_csv(df_ctx, ctx): df_ctx.register_csv("tips", "examples/tips.csv", has_header=True) record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker") - assert isinstance(record_batches, list) - # TODO why does this return many empty batches? - num_rows = 0 - for record_batch in record_batches: - num_rows += record_batch.num_rows + assert len(record_batches) <= 4 + num_rows = sum(batch.num_rows for batch in record_batches) assert num_rows == 4 -def test_aggregate_parquet(): - df_ctx = SessionContext() - ctx = DatafusionRayContext(df_ctx) +def test_aggregate_parquet(df_ctx, ctx): df_ctx.register_parquet("tips", "examples/tips.parquet") record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by sex, smoker") - # TODO why does this return many empty batches? - num_rows = 0 - for record_batch in record_batches: - num_rows += record_batch.num_rows + assert len(record_batches) <= 4 + num_rows = sum(batch.num_rows for batch in record_batches) assert num_rows == 4 -def test_aggregate_parquet_dataframe(): - df_ctx = SessionContext() - ray_ctx = DatafusionRayContext(df_ctx) +def test_aggregate_parquet_dataframe(df_ctx, ctx): df = df_ctx.read_parquet(f"examples/tips.parquet") df = ( df.aggregate( @@ -62,12 +62,10 @@ def test_aggregate_parquet_dataframe(): .filter(col("day") != lit("Dinner")) .aggregate([col("sex"), col("smoker")], [F.avg(col("tip_pct")).alias("avg_pct")]) ) - ray_results = ray_ctx.plan(df.execution_plan()) + ray_results = ctx.plan(df.execution_plan()) df_ctx.create_dataframe([ray_results]).show() -def test_no_result_query(): - df_ctx = SessionContext() - ctx = DatafusionRayContext(df_ctx) +def test_no_result_query(df_ctx, ctx): df_ctx.register_csv("tips", "examples/tips.csv", has_header=True) ctx.sql("CREATE VIEW tips_view AS SELECT * FROM tips")