diff --git a/src/query_stage.rs b/src/query_stage.rs index 05c090b..f70b0c0 100644 --- a/src/query_stage.rs +++ b/src/query_stage.rs @@ -23,6 +23,7 @@ use datafusion::prelude::SessionContext; use datafusion_proto::bytes::physical_plan_from_bytes_with_extension_codec; use pyo3::prelude::*; use pyo3::types::PyBytes; +use std::collections::HashSet; use std::sync::Arc; #[pyclass(name = "QueryStage", module = "datafusion_ray", subclass)] @@ -99,14 +100,23 @@ impl QueryStage { /// Get the input partition count. This is the same as the number of concurrent tasks /// when we schedule this query stage for execution pub fn get_input_partition_count(&self) -> usize { - if self.plan.children().is_empty() { - // leaf node (file scan) - self.plan.output_partitioning().partition_count() - } else { - self.plan.children()[0] - .output_partitioning() - .partition_count() + let mut output_partition_counts = HashSet::new(); + + for child in self.plan.children() { + output_partition_counts.insert(child.output_partitioning().partition_count()); + if output_partition_counts.len() > 1 { + panic!( + "Children plan of {:#?} have a distinct output partitioning partition count", + self.plan + ); + } } + // If this stage is a leaf node (file scan), it won't have children + // so we return the partition count of the plan itself + output_partition_counts + .into_iter() + .next() + .unwrap_or(self.plan.output_partitioning().partition_count()) } pub fn get_output_partition_count(&self) -> usize {