Skip to content

Commit

Permalink
revert more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Nov 16, 2024
1 parent 6f796aa commit 0266eff
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 606 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ log = "0.4"
prost = "0.13"
pyo3 = { version = "0.22", features = ["extension-module", "abi3", "abi3-py38"] }
tokio = { version = "1.40", features = ["macros", "rt", "rt-multi-thread", "sync"] }
uuid = "1.11.0"

[build-dependencies]
prost-types = "0.13"
Expand Down
83 changes: 45 additions & 38 deletions datafusion_ray/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,31 @@
from datafusion import SessionContext


def schedule_execution(
graph: ExecutionGraph,
stage_id: int,
is_final_stage: bool,
) -> list[ray.ObjectRef]:
stage = graph.get_query_stage(stage_id)
@ray.remote(num_cpus=0)
def execute_query_stage(
query_stages: list[QueryStage],
stage_id: int
) -> tuple[int, list[ray.ObjectRef]]:
"""
Execute a query stage on the workers.
Returns the stage ID, and a list of futures for the output partitions of the query stage.
"""
stage = QueryStage(stage_id, query_stages[stage_id])

# execute child stages first
# A list of (stage ID, list of futures) for each child stage
# Each list is a 2-D array of (input partitions, output partitions).
child_outputs = []
child_futures = []
for child_id in stage.get_child_stage_ids():
child_outputs.append((child_id, schedule_execution(graph, child_id, False)))
# child_outputs.append((child_id, schedule_execution(graph, child_id)))
child_futures.append(
execute_query_stage.remote(query_stages, child_id)
)

# if the query stage has a single output partition then we need to execute for the output
# partition, otherwise we need to execute in parallel for each input partition
concurrency = stage.get_input_partition_count()
output_partitions_count = stage.get_output_partition_count()
if is_final_stage:
if output_partitions_count == 1:
# reduce stage
print("Forcing reduce stage concurrency from {} to 1".format(concurrency))
concurrency = 1

Expand All @@ -55,41 +63,34 @@ def schedule_execution(
)
)

# A list of (stage ID, list of futures) for each child stage
# Each list is a 2-D array of (input partitions, output partitions).
child_outputs = ray.get(child_futures)

def _get_worker_inputs(
part: int,
) -> tuple[list[tuple[int, int, int]], list[ray.ObjectRef]]:
ids = []
futures = []
for child_stage_id, child_futures in child_outputs:
for i, lst in enumerate(child_futures):
if isinstance(lst, list):
for j, f in enumerate(lst):
if concurrency == 1 or j == part:
# If concurrency is 1, pass in all shuffle partitions. Otherwise,
# only pass in the partitions that match the current worker partition.
ids.append((child_stage_id, i, j))
futures.append(f)
elif concurrency == 1 or part == 0:
ids.append((child_stage_id, i, 0))
futures.append(lst)
return ids, futures

# if we are using disk-based shuffle, wait until the child stages to finish
# writing the shuffle files to disk first.
ray.get([f for _, lst in child_outputs for f in lst])

# schedule the actual execution workers
plan_bytes = stage.get_execution_plan_bytes()
plan_bytes = datafusion_ray.serialize_execution_plan(stage.get_execution_plan())
futures = []
opt = {}
# TODO not sure why we had this but my Ray cluster could not find suitable resource
# until I commented this out
# opt["resources"] = {"worker": 1e-3}
opt["num_returns"] = output_partitions_count
for part in range(concurrency):
ids, inputs = _get_worker_inputs(part)
futures.append(
execute_query_partition.options(**opt).remote(
stage_id, plan_bytes, part, ids, *inputs
)
)
return futures

return stage_id, futures


@ray.remote
Expand Down Expand Up @@ -157,15 +158,21 @@ def sql(self, sql: str) -> pa.RecordBatch:

graph = self.ctx.plan(execution_plan)
final_stage_id = graph.get_final_query_stage().id()
partitions = schedule_execution(graph, final_stage_id, True)
# assert len(partitions) == 1, len(partitions)
result_set = ray.get(partitions[0])
return result_set

def plan(self, physical_plan: Any) -> pa.RecordBatch:
graph = self.ctx.plan(physical_plan)
final_stage_id = graph.get_final_query_stage().id()
partitions = schedule_execution(graph, final_stage_id, True)
# serialize the query stages and store in Ray object store
query_stages = [
datafusion_ray.serialize_execution_plan(
graph.get_query_stage(i).get_execution_plan()
)
for i in range(final_stage_id + 1)
]
# schedule execution
future = execute_query_stage.remote(
query_stages,
final_stage_id,
self.use_ray_shuffle,
)
_, partitions = ray.get(future)
# assert len(partitions) == 1, len(partitions)
result_set = ray.get(partitions[0])
return result_set
33 changes: 16 additions & 17 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
// under the License.

use crate::planner::{make_execution_graph, PyExecutionGraph};
use crate::shuffle::{RayShuffleReaderExec, ShuffleCodec};
use datafusion::arrow::pyarrow::FromPyArrow;
use crate::shuffle::{ShuffleCodec, ShuffleReaderExec};
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::{DataFusionError, Result};
Expand Down Expand Up @@ -175,7 +174,7 @@ fn _set_inputs_for_ray_shuffle_reader(
plan: Arc<dyn ExecutionPlan>,
input_partitions: &Bound<'_, PyList>,
) -> Result<()> {
if let Some(reader_exec) = plan.as_any().downcast_ref::<RayShuffleReaderExec>() {
if let Some(reader_exec) = plan.as_any().downcast_ref::<ShuffleReaderExec>() {
let exec_stage_id = reader_exec.stage_id;
// iterate over inputs, wrap in PyBytes and set as input objects
for item in input_partitions.iter() {
Expand All @@ -192,20 +191,20 @@ fn _set_inputs_for_ray_shuffle_reader(
if stage_id != exec_stage_id {
continue;
}
let part = pytuple
.get_item(1)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.downcast::<PyLong>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?
.extract::<usize>()
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
let batch = RecordBatch::from_pyarrow_bound(
&pytuple
.get_item(2)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?,
)
.map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
reader_exec.add_input_partition(part, batch)?;
// let part = pytuple
// .get_item(1)
// .map_err(|e| DataFusionError::Execution(format!("{}", e)))?
// .downcast::<PyLong>()
// .map_err(|e| DataFusionError::Execution(format!("{}", e)))?
// .extract::<usize>()
// .map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
// let batch = RecordBatch::from_pyarrow_bound(
// &pytuple
// .get_item(2)
// .map_err(|e| DataFusionError::Execution(format!("{}", e)))?,
// )
// .map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
//reader_exec.add_input_partition(part, batch)?;
}
} else {
for child in plan.children() {
Expand Down
20 changes: 17 additions & 3 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::query_stage::PyQueryStage;
use crate::query_stage::QueryStage;
use crate::shuffle::{RayShuffleReaderExec, RayShuffleWriterExec};
use crate::shuffle::{ShuffleReaderExec, ShuffleWriterExec};
use datafusion::error::Result;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::repartition::RepartitionExec;
Expand All @@ -29,6 +29,7 @@ use pyo3::prelude::*;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use uuid::Uuid;

#[pyclass(name = "ExecutionGraph", module = "datafusion_ray", subclass)]
pub struct PyExecutionGraph {
Expand Down Expand Up @@ -200,11 +201,15 @@ fn create_shuffle_exchange(
// introduce shuffle to produce one output partition
let stage_id = graph.next_id();

// create temp dir for stage shuffle files
let temp_dir = create_temp_dir(stage_id)?;

let shuffle_writer_input = plan.clone();
let shuffle_writer: Arc<dyn ExecutionPlan> = Arc::new(RayShuffleWriterExec::new(
let shuffle_writer: Arc<dyn ExecutionPlan> = Arc::new(ShuffleWriterExec::new(
stage_id,
shuffle_writer_input,
partitioning_scheme.clone(),
&temp_dir,
));

debug!(
Expand All @@ -214,13 +219,22 @@ fn create_shuffle_exchange(

let stage_id = graph.add_query_stage(stage_id, shuffle_writer);
// replace the plan with a shuffle reader
Ok(Arc::new(RayShuffleReaderExec::new(
Ok(Arc::new(ShuffleReaderExec::new(
stage_id,
plan.schema(),
partitioning_scheme,
&temp_dir,
)))
}

fn create_temp_dir(stage_id: usize) -> Result<String> {
let uuid = Uuid::new_v4();
let temp_dir = format!("/tmp/ray-sql-{uuid}-stage-{stage_id}");
debug!("Creating temp shuffle dir: {temp_dir}");
std::fs::create_dir(&temp_dir)?;
Ok(temp_dir)
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
4 changes: 2 additions & 2 deletions src/query_stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use crate::context::serialize_execution_plan;
use crate::shuffle::{RayShuffleReaderExec, ShuffleCodec};
use crate::shuffle::{ShuffleCodec, ShuffleReaderExec};
use datafusion::error::Result;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -111,7 +111,7 @@ impl QueryStage {
}

fn collect_child_stage_ids(plan: &dyn ExecutionPlan, ids: &mut Vec<usize>) {
if let Some(shuffle_reader) = plan.as_any().downcast_ref::<RayShuffleReaderExec>() {
if let Some(shuffle_reader) = plan.as_any().downcast_ref::<ShuffleReaderExec>() {
ids.push(shuffle_reader.stage_id);
} else {
for child_plan in plan.children() {
Expand Down
9 changes: 2 additions & 7 deletions src/shuffle/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,8 @@
// under the License.

use crate::protobuf::ray_sql_exec_node::PlanType;
use crate::protobuf::{
RaySqlExecNode, ShuffleReaderExecNode,
ShuffleWriterExecNode,
};
use crate::shuffle::{
ShuffleReaderExec, ShuffleWriterExec,
};
use crate::protobuf::{RaySqlExecNode, ShuffleReaderExecNode, ShuffleWriterExecNode};
use crate::shuffle::{ShuffleReaderExec, ShuffleWriterExec};
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{DataFusionError, Result};
use datafusion::execution::runtime_env::RuntimeEnv;
Expand Down
77 changes: 74 additions & 3 deletions src/shuffle/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,84 @@
// specific language governing permissions and limitations
// under the License.

use arrow::record_batch::RecordBatch;
use datafusion::arrow;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::Result;
use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
use futures::Stream;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::macros::support::thread_rng_n;

mod codec;
mod ray_shuffle;
mod reader;
mod writer;

pub use codec::ShuffleCodec;
pub use ray_shuffle::RayShuffleReaderExec;
pub use ray_shuffle::RayShuffleWriterExec;
pub use reader::ShuffleReaderExec;
pub use writer::ShuffleWriterExec;

/// CombinedRecordBatchStream can be used to combine a Vec of SendableRecordBatchStreams into one
pub struct CombinedRecordBatchStream {
/// Schema wrapped by Arc
schema: SchemaRef,
/// Stream entries
entries: Vec<SendableRecordBatchStream>,
}

impl CombinedRecordBatchStream {
/// Create an CombinedRecordBatchStream
pub fn new(schema: SchemaRef, entries: Vec<SendableRecordBatchStream>) -> Self {
Self { schema, entries }
}
}

impl RecordBatchStream for CombinedRecordBatchStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
}

impl Stream for CombinedRecordBatchStream {
type Item = Result<RecordBatch>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use Poll::*;

let start = thread_rng_n(self.entries.len() as u32) as usize;
let mut idx = start;

for _ in 0..self.entries.len() {
let stream = self.entries.get_mut(idx).unwrap();

match Pin::new(stream).poll_next(cx) {
Ready(Some(val)) => return Ready(Some(val)),
Ready(None) => {
// Remove the entry
self.entries.swap_remove(idx);

// Check if this was the last entry, if so the cursor needs
// to wrap
if idx == self.entries.len() {
idx = 0;
} else if idx < start && start <= self.entries.len() {
// The stream being swapped into the current index has
// already been polled, so skip it.
idx = idx.wrapping_add(1) % self.entries.len();
}
}
Pending => {
idx = idx.wrapping_add(1) % self.entries.len();
}
}
}

// If the map is empty, then the stream is complete.
if self.entries.is_empty() {
Ready(None)
} else {
Pending
}
}
}
Loading

0 comments on commit 0266eff

Please sign in to comment.