Skip to content

Commit

Permalink
stop double allocating memory in shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Feb 26, 2025
1 parent 133df5c commit 276e1e2
Showing 1 changed file with 10 additions and 63 deletions.
73 changes: 10 additions & 63 deletions native/core/src/execution/shuffle/shuffle_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ impl ExecutionPlan for ShuffleWriterExec {
futures::stream::once(
external_shuffle(
input,
partition,
self.output_data_file.clone(),
self.output_index_file.clone(),
self.partitioning.clone(),
Expand All @@ -206,7 +205,6 @@ impl ExecutionPlan for ShuffleWriterExec {
#[allow(clippy::too_many_arguments)]
async fn external_shuffle(
mut input: SendableRecordBatchStream,
partition_id: usize,
output_data_file: String,
output_index_file: String,
partitioning: Partitioning,
Expand All @@ -217,7 +215,6 @@ async fn external_shuffle(
) -> Result<SendableRecordBatchStream> {
let schema = input.schema();
let mut repartitioner = ShuffleRepartitioner::try_new(
partition_id,
output_data_file,
output_index_file,
Arc::clone(&schema),
Expand Down Expand Up @@ -295,7 +292,6 @@ struct ShuffleRepartitioner {
num_output_partitions: usize,
runtime: Arc<RuntimeEnv>,
metrics: ShuffleRepartitionerMetrics,
reservation: MemoryReservation,
/// Hashes for each row in the current batch
hashes_buf: Vec<u32>,
/// Partition ids for each row in the current batch
Expand All @@ -307,7 +303,6 @@ struct ShuffleRepartitioner {
impl ShuffleRepartitioner {
#[allow(clippy::too_many_arguments)]
pub fn try_new(
partition_id: usize,
output_data_file: String,
output_index_file: String,
schema: SchemaRef,
Expand All @@ -319,9 +314,6 @@ impl ShuffleRepartitioner {
enable_fast_encoding: bool,
) -> Result<Self> {
let num_output_partitions = partitioning.partition_count();
let reservation = MemoryConsumer::new(format!("ShuffleRepartitioner[{}]", partition_id))
.with_can_spill(true)
.register(&runtime.memory_pool);

let mut hashes_buf = Vec::with_capacity(batch_size);
let mut partition_ids = Vec::with_capacity(batch_size);
Expand Down Expand Up @@ -353,7 +345,6 @@ impl ShuffleRepartitioner {
num_output_partitions,
runtime,
metrics,
reservation,
hashes_buf,
partition_ids,
batch_size,
Expand Down Expand Up @@ -473,41 +464,12 @@ impl ShuffleRepartitioner {
.enumerate()
.filter(|(_, (start, end))| start < end)
{
let mut mem_diff = self
.append_rows_to_partition(
input.columns(),
&shuffled_partition_ids[start..end],
partition_id,
)
.await?;

if mem_diff > 0 {
let mem_increase = mem_diff as usize;

let try_grow = {
let mut mempool_timer = self.metrics.mempool_time.timer();
let result = self.reservation.try_grow(mem_increase);
mempool_timer.stop();
result
};

if try_grow.is_err() {
self.spill().await?;
let mut mempool_timer = self.metrics.mempool_time.timer();
self.reservation.free();
self.reservation.try_grow(mem_increase)?;
mempool_timer.stop();
mem_diff = 0;
}
}

if mem_diff < 0 {
let mem_used = self.reservation.size();
let mem_decrease = mem_used.min(-mem_diff as usize);
let mut mempool_timer = self.metrics.mempool_time.timer();
self.reservation.shrink(mem_decrease);
mempool_timer.stop();
}
self.append_rows_to_partition(
input.columns(),
&shuffled_partition_ids[start..end],
partition_id,
)
.await?;
}
}
Partitioning::UnknownPartitioning(n) if *n == 1 => {
Expand Down Expand Up @@ -594,11 +556,6 @@ impl ShuffleRepartitioner {

write_time.stop();

let mut mempool_timer = self.metrics.mempool_time.timer();
let used = self.reservation.size();
self.reservation.shrink(used);
mempool_timer.stop();

elapsed_compute.stop();

// shuffle writer always has empty output
Expand All @@ -610,7 +567,10 @@ impl ShuffleRepartitioner {
}

fn used(&self) -> usize {
self.reservation.size()
self.buffered_partitions
.iter()
.map(|b| b.reservation.size())
.sum()
}

fn spilled_bytes(&self) -> usize {
Expand Down Expand Up @@ -641,7 +601,6 @@ impl ShuffleRepartitioner {
for p in &mut self.buffered_partitions {
spilled_bytes += p.spill(&self.runtime, &self.metrics)?;
}
self.reservation.free();

self.metrics.spill_count.add(1);
self.metrics.spilled_bytes.add(spilled_bytes);
Expand Down Expand Up @@ -675,10 +634,6 @@ impl ShuffleRepartitioner {
// spill partitions and retry.
self.spill().await?;

let mut mempool_timer = self.metrics.mempool_time.timer();
self.reservation.free();
mempool_timer.stop();

start_index = new_start;
let output = &mut self.buffered_partitions[partition_id];
output_ret = output.append_rows(columns, indices, start_index, &self.metrics);
Expand Down Expand Up @@ -1125,7 +1080,6 @@ mod test {
let runtime_env = create_runtime(memory_limit);
let metrics_set = ExecutionPlanMetricsSet::new();
let mut repartitioner = ShuffleRepartitioner::try_new(
0,
"/tmp/data.out".to_string(),
"/tmp/index.out".to_string(),
batch.schema(),
Expand All @@ -1145,11 +1099,6 @@ mod test {
assert!(repartitioner.buffered_partitions[0].spill_file.is_none());
assert!(repartitioner.buffered_partitions[1].spill_file.is_none());

// TODO: note that we are currently double counting the memory usage
// because we reserve the memory twice - once at the repartitioner level
// and then again in each PartitionBuffer
// https://github.com/apache/datafusion-comet/issues/1448
assert_eq!(212992, repartitioner.reservation.size());
assert_eq!(
106496,
repartitioner.buffered_partitions[0].reservation.size()
Expand All @@ -1166,14 +1115,12 @@ mod test {
assert!(repartitioner.buffered_partitions[1].spill_file.is_some());

// after spill, all reservations should be freed
assert_eq!(0, repartitioner.reservation.size());
assert_eq!(0, repartitioner.buffered_partitions[0].reservation.size());
assert_eq!(0, repartitioner.buffered_partitions[1].reservation.size());

// insert another batch after spilling
repartitioner.insert_batch(batch.clone()).await.unwrap();

assert_eq!(212992, repartitioner.reservation.size());
assert_eq!(
106496,
repartitioner.buffered_partitions[0].reservation.size()
Expand Down

0 comments on commit 276e1e2

Please sign in to comment.