diff --git a/datafusion_ray/context.py b/datafusion_ray/context.py index 5cbaf86..954d665 100644 --- a/datafusion_ray/context.py +++ b/datafusion_ray/context.py @@ -92,80 +92,6 @@ def _get_worker_inputs( return futures -@ray.remote(num_cpus=0) -def execute_query_stage( - query_stages: list[QueryStage], - stage_id: int, -) -> tuple[int, list[ray.ObjectRef]]: - """ - Execute a query stage on the workers. - - Returns the stage ID, and a list of futures for the output partitions of the query stage. - """ - stage = QueryStage(stage_id, query_stages[stage_id]) - - # execute child stages first - child_futures = [] - for child_id in stage.get_child_stage_ids(): - child_futures.append(execute_query_stage.remote(query_stages, child_id)) - - # if the query stage has a single output partition then we need to execute for the output - # partition, otherwise we need to execute in parallel for each input partition - concurrency = stage.get_input_partition_count() - output_partitions_count = stage.get_output_partition_count() - if output_partitions_count == 1: - # reduce stage - print("Forcing reduce stage concurrency from {} to 1".format(concurrency)) - concurrency = 1 - - print( - "Scheduling query stage #{} with {} input partitions and {} output partitions".format( - stage.id(), concurrency, output_partitions_count - ) - ) - - # A list of (stage ID, list of futures) for each child stage - # Each list is a 2-D array of (input partitions, output partitions). - child_outputs = ray.get(child_futures) - - def _get_worker_inputs( - part: int, - ) -> tuple[list[tuple[int, int, int]], list[ray.ObjectRef]]: - ids = [] - futures = [] - for child_stage_id, child_futures in child_outputs: - for i, lst in enumerate(child_futures): - if isinstance(lst, list): - for j, f in enumerate(lst): - if concurrency == 1 or j == part: - # If concurrency is 1, pass in all shuffle partitions. Otherwise, - # only pass in the partitions that match the current worker partition. - ids.append((child_stage_id, i, j)) - futures.append(f) - elif concurrency == 1 or part == 0: - ids.append((child_stage_id, i, 0)) - futures.append(lst) - return ids, futures - - # schedule the actual execution workers - plan_bytes = stage.get_execution_plan_bytes() - futures = [] - opt = {} - # TODO not sure why we had this but my Ray cluster could not find suitable resource - # until I commented this out - # opt["resources"] = {"worker": 1e-3} - opt["num_returns"] = output_partitions_count - for part in range(concurrency): - ids, inputs = _get_worker_inputs(part) - futures.append( - execute_query_partition.options(**opt).remote( - stage_id, plan_bytes, part, ids, *inputs - ) - ) - - return stage_id, futures - - @ray.remote def execute_query_partition( stage_id: int,