From 59819ae4ae9188fe4770ed6c44779e16283bfa7d Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 28 Feb 2025 17:32:53 +0800 Subject: [PATCH] [fix] fix cross_bwdB_bwdW --- .../pipeline/schedule/dualpipe_schedule.py | 237 +++++++++--------- 1 file changed, 119 insertions(+), 118 deletions(-) diff --git a/colossalai/pipeline/schedule/dualpipe_schedule.py b/colossalai/pipeline/schedule/dualpipe_schedule.py index 6da600c68141..c470e9fe1baa 100644 --- a/colossalai/pipeline/schedule/dualpipe_schedule.py +++ b/colossalai/pipeline/schedule/dualpipe_schedule.py @@ -99,7 +99,7 @@ def get_pipe_first_b_w(self, stage_pipe: List[ScheduledNode], chunk: int = 0): else: stage_pipe_temp.append(node) stage_pipe = stage_pipe_temp[::-1] # node from last fully B to ... - + # print(f"stage_pipe {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in stage_pipe]}") if chunk == 0: # get first d for node in stage_pipe: @@ -1201,93 +1201,93 @@ def bwdB_step(pipeline_schedule: List[List[ScheduledNode]]): ########### Pipe_Stage 3.2 ########### def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]): - for stage in range(0, self.n_stage // 2): - first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0) - # print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ") - u_queue_w, u_queue_b, d_queue_w = [], [], [] - ### 1.Get W nodes, then merge up/down W nodes ### - # get up W nodes: [first_u: mbs//2] - for _ in range(first_u, self.n_micro // 2): - curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 - u_queue_w.append( - ScheduledNode( - type="W", - chunk=0, - stage=stage, - minibatch=_, - start_time=curr_time, - completion_time=curr_time + self.one_time_unit, - ) - ) - curr_time += self.one_time_unit - # get down W nodes: [first_d: mbs//2] Bwd W to W Queue - for _ in range(first_d, self.n_micro // 2): - curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 - d_queue_w.append( - ScheduledNode( - type="W", - chunk=1, - stage=stage, - minibatch=_, - start_time=curr_time, - completion_time=curr_time + self.one_time_unit, - ) - ) - curr_time += self.one_time_unit - ### 2.Get B nodes, then cross with W ### - for _ in range(last_u, self.n_micro // 2): - curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 - u_queue_b.append( - ScheduledNode( - type="B", - chunk=0, - stage=stage, - minibatch=_ + 1, - start_time=curr_time, - completion_time=curr_time + self.one_time_unit, - ) - ) - curr_time += self.one_time_unit - # if stage % 2 == 0: u_queue_w first, then d_queue_w - if stage % 2 == 0: - w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w) - wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b) - # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' - cut_idx = len(wb_nodes) - for _ in range(len(wb_nodes)): - if ( - wb_nodes[_].minibatch == (self.n_micro // 2 - 1) - and wb_nodes[_].type == "B" - and wb_nodes[_].chunk == 0 - ): - cut_idx = _ - break - wb_nodes = wb_nodes[: cut_idx + 1] - # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") - # else: d_queue_w first, then u_queue_w - else: - w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w) - wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b) - # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' - cut_idx = len(wb_nodes) - for _ in range(len(wb_nodes)): - if ( - wb_nodes[_].minibatch == (self.n_micro // 2 - 1) - and wb_nodes[_].type == "B" - and wb_nodes[_].chunk == 0 - ): - cut_idx = _ - break - wb_nodes = wb_nodes[: cut_idx + 1] - # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + # for stage in range(0, self.n_stage // 2): + # first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0) + # # print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ") + # u_queue_w, u_queue_b, d_queue_w = [], [], [] + # ### 1.Get W nodes, then merge up/down W nodes ### + # # get up W nodes: [first_u: mbs//2] + # for _ in range(first_u, self.n_micro // 2): + # curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # u_queue_w.append( + # ScheduledNode( + # type="W", + # chunk=0, + # stage=stage, + # minibatch=_, + # start_time=curr_time, + # completion_time=curr_time + self.one_time_unit, + # ) + # ) + # curr_time += self.one_time_unit + # # get down W nodes: [first_d: mbs//2] Bwd W to W Queue + # for _ in range(first_d, self.n_micro // 2): + # curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # d_queue_w.append( + # ScheduledNode( + # type="W", + # chunk=1, + # stage=stage, + # minibatch=_, + # start_time=curr_time, + # completion_time=curr_time + self.one_time_unit, + # ) + # ) + # curr_time += self.one_time_unit + # ### 2.Get B nodes, then cross with W ### + # for _ in range(last_u, self.n_micro // 2): + # curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # u_queue_b.append( + # ScheduledNode( + # type="B", + # chunk=0, + # stage=stage, + # minibatch=_ + 1, + # start_time=curr_time, + # completion_time=curr_time + self.one_time_unit, + # ) + # ) + # curr_time += self.one_time_unit + # # if stage % 2 == 0: u_queue_w first, then d_queue_w + # if stage % 2 == 0: + # w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w) + # wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b) + # # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' + # cut_idx = len(wb_nodes) + # for _ in range(len(wb_nodes)): + # if ( + # wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + # and wb_nodes[_].type == "B" + # and wb_nodes[_].chunk == 0 + # ): + # cut_idx = _ + # break + # wb_nodes = wb_nodes[: cut_idx + 1] + # # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + # # else: d_queue_w first, then u_queue_w + # else: + # w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w) + # wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b) + # # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' + # cut_idx = len(wb_nodes) + # for _ in range(len(wb_nodes)): + # if ( + # wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + # and wb_nodes[_].type == "B" + # and wb_nodes[_].chunk == 0 + # ): + # cut_idx = _ + # break + # wb_nodes = wb_nodes[: cut_idx + 1] + # # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") for stage in range(self.n_stage // 2, self.n_stage): first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=1) print(f"stage {stage} Down first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ") d_queue_w, d_queue_b, u_queue_w = [], [], [] ### 1.Get W nodes, then merge down/up W nodes ### - # get down W nodes: [first_d: mbs//2] chunk 1 - for _ in range(self.n_micro // 2, first_d): + # get down W nodes: [last_d: mbs] chunk 1 + for _ in range(last_d, self.n_micro // 2): curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 d_queue_w.append( ScheduledNode( @@ -1300,10 +1300,11 @@ def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]): ) ) curr_time += self.one_time_unit + # print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}") # get up W nodes: [first_u: mbs//2] chunk 0 - for _ in range(self.n_micro // 2, first_u): + for _ in range(first_u, self.n_micro // 2): curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 - d_queue_w.append( + u_queue_w.append( ScheduledNode( type="W", chunk=0, @@ -1315,9 +1316,9 @@ def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]): ) curr_time += self.one_time_unit ### 2.Get B nodes, then cross with W ### - for _ in range(self.n_micro // 2, last_d): + for _ in range(last_d, self.n_micro // 2): curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 - u_queue_b.append( + d_queue_b.append( ScheduledNode( type="B", chunk=1, @@ -1328,41 +1329,41 @@ def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]): ) ) curr_time += self.one_time_unit - print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}") + # print( # f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]} d_queue_b {[_.minibatch for _ in d_queue_b]} u_queue_w {[_.minibatch for _ in u_queue_w]}" # ) - if stage % 2 == 0: - w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w) - wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b) - # clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B' - cut_idx = len(wb_nodes) - for _ in range(len(wb_nodes)): - if ( - wb_nodes[_].minibatch == (self.n_micro // 2 - 1) - and wb_nodes[_].type == "B" - and wb_nodes[_].chunk == 1 - ): - cut_idx = _ - break - wb_nodes = wb_nodes[: cut_idx + 1] - # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") - # else: d_queue_w first, then u_queue_w - else: - w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w) - wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b) - # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' - cut_idx = len(wb_nodes) - for _ in range(len(wb_nodes)): - if ( - wb_nodes[_].minibatch == (self.n_micro // 2 - 1) - and wb_nodes[_].type == "B" - and wb_nodes[_].chunk == 1 - ): - cut_idx = _ - break - wb_nodes = wb_nodes[: cut_idx + 1] - # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + # if stage % 2 == 0: + # w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w) + # wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b) + # # clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B' + # cut_idx = len(wb_nodes) + # for _ in range(len(wb_nodes)): + # if ( + # wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + # and wb_nodes[_].type == "B" + # and wb_nodes[_].chunk == 1 + # ): + # cut_idx = _ + # break + # wb_nodes = wb_nodes[: cut_idx + 1] + # # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + # # else: d_queue_w first, then u_queue_w + # else: + # w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w) + # wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b) + # # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' + # cut_idx = len(wb_nodes) + # for _ in range(len(wb_nodes)): + # if ( + # wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + # and wb_nodes[_].type == "B" + # and wb_nodes[_].chunk == 1 + # ): + # cut_idx = _ + # break + # wb_nodes = wb_nodes[: cut_idx + 1] + # # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") ########### Pipe_Stage 3.3 ########### def bwdW_step(pipeline_schedule: List[List[ScheduledNode]]):