Skip to content

Commit

Permalink
Defensively verifiying that all children plans have the same count of…
Browse files Browse the repository at this point in the history
… output partitions
  • Loading branch information
edmondop committed Dec 14, 2024
1 parent 151a0e2 commit 2018ad4
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions src/query_stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use datafusion::prelude::SessionContext;
use datafusion_proto::bytes::physical_plan_from_bytes_with_extension_codec;
use pyo3::prelude::*;
use pyo3::types::PyBytes;
use std::collections::HashSet;
use std::sync::Arc;

#[pyclass(name = "QueryStage", module = "datafusion_ray", subclass)]
Expand Down Expand Up @@ -99,14 +100,23 @@ impl QueryStage {
/// Get the input partition count. This is the same as the number of concurrent tasks
/// when we schedule this query stage for execution
pub fn get_input_partition_count(&self) -> usize {
if self.plan.children().is_empty() {
// leaf node (file scan)
self.plan.output_partitioning().partition_count()
} else {
self.plan.children()[0]
.output_partitioning()
.partition_count()
let mut output_partition_counts = HashSet::new();

for child in self.plan.children() {
output_partition_counts.insert(child.output_partitioning().partition_count());
if output_partition_counts.len() > 1 {
panic!(
"Children plan of {:#?} have a distinct output partitioning partition count",
self.plan
);
}
}
// If this stage is a leaf node (file scan), it won't have children
// so we return the partition count of the plan itself
output_partition_counts
.into_iter()
.next()
.unwrap_or(self.plan.output_partitioning().partition_count())
}

pub fn get_output_partition_count(&self) -> usize {
Expand Down

0 comments on commit 2018ad4

Please sign in to comment.