Skip to content

Commit

Permalink
[fix] fix cross_bwdB_bwdW
Browse files Browse the repository at this point in the history
  • Loading branch information
duanjunwen committed Feb 28, 2025
1 parent 6977bf5 commit 59819ae
Showing 1 changed file with 119 additions and 118 deletions.
237 changes: 119 additions & 118 deletions colossalai/pipeline/schedule/dualpipe_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_pipe_first_b_w(self, stage_pipe: List[ScheduledNode], chunk: int = 0):
else:
stage_pipe_temp.append(node)
stage_pipe = stage_pipe_temp[::-1] # node from last fully B to ...

# print(f"stage_pipe {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in stage_pipe]}")
if chunk == 0:
# get first d
for node in stage_pipe:
Expand Down Expand Up @@ -1201,93 +1201,93 @@ def bwdB_step(pipeline_schedule: List[List[ScheduledNode]]):

########### Pipe_Stage 3.2 ###########
def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]):
for stage in range(0, self.n_stage // 2):
first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0)
# print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
u_queue_w, u_queue_b, d_queue_w = [], [], []
### 1.Get W nodes, then merge up/down W nodes ###
# get up W nodes: [first_u: mbs//2]
for _ in range(first_u, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
u_queue_w.append(
ScheduledNode(
type="W",
chunk=0,
stage=stage,
minibatch=_,
start_time=curr_time,
completion_time=curr_time + self.one_time_unit,
)
)
curr_time += self.one_time_unit
# get down W nodes: [first_d: mbs//2] Bwd W to W Queue
for _ in range(first_d, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
d_queue_w.append(
ScheduledNode(
type="W",
chunk=1,
stage=stage,
minibatch=_,
start_time=curr_time,
completion_time=curr_time + self.one_time_unit,
)
)
curr_time += self.one_time_unit
### 2.Get B nodes, then cross with W ###
for _ in range(last_u, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
u_queue_b.append(
ScheduledNode(
type="B",
chunk=0,
stage=stage,
minibatch=_ + 1,
start_time=curr_time,
completion_time=curr_time + self.one_time_unit,
)
)
curr_time += self.one_time_unit
# if stage % 2 == 0: u_queue_w first, then d_queue_w
if stage % 2 == 0:
w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
cut_idx = len(wb_nodes)
for _ in range(len(wb_nodes)):
if (
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
and wb_nodes[_].type == "B"
and wb_nodes[_].chunk == 0
):
cut_idx = _
break
wb_nodes = wb_nodes[: cut_idx + 1]
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# else: d_queue_w first, then u_queue_w
else:
w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
cut_idx = len(wb_nodes)
for _ in range(len(wb_nodes)):
if (
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
and wb_nodes[_].type == "B"
and wb_nodes[_].chunk == 0
):
cut_idx = _
break
wb_nodes = wb_nodes[: cut_idx + 1]
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# for stage in range(0, self.n_stage // 2):
# first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0)
# # print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
# u_queue_w, u_queue_b, d_queue_w = [], [], []
# ### 1.Get W nodes, then merge up/down W nodes ###
# # get up W nodes: [first_u: mbs//2]
# for _ in range(first_u, self.n_micro // 2):
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
# u_queue_w.append(
# ScheduledNode(
# type="W",
# chunk=0,
# stage=stage,
# minibatch=_,
# start_time=curr_time,
# completion_time=curr_time + self.one_time_unit,
# )
# )
# curr_time += self.one_time_unit
# # get down W nodes: [first_d: mbs//2] Bwd W to W Queue
# for _ in range(first_d, self.n_micro // 2):
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
# d_queue_w.append(
# ScheduledNode(
# type="W",
# chunk=1,
# stage=stage,
# minibatch=_,
# start_time=curr_time,
# completion_time=curr_time + self.one_time_unit,
# )
# )
# curr_time += self.one_time_unit
# ### 2.Get B nodes, then cross with W ###
# for _ in range(last_u, self.n_micro // 2):
# curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
# u_queue_b.append(
# ScheduledNode(
# type="B",
# chunk=0,
# stage=stage,
# minibatch=_ + 1,
# start_time=curr_time,
# completion_time=curr_time + self.one_time_unit,
# )
# )
# curr_time += self.one_time_unit
# # if stage % 2 == 0: u_queue_w first, then d_queue_w
# if stage % 2 == 0:
# w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
# wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
# cut_idx = len(wb_nodes)
# for _ in range(len(wb_nodes)):
# if (
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
# and wb_nodes[_].type == "B"
# and wb_nodes[_].chunk == 0
# ):
# cut_idx = _
# break
# wb_nodes = wb_nodes[: cut_idx + 1]
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# # else: d_queue_w first, then u_queue_w
# else:
# w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
# wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b)
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
# cut_idx = len(wb_nodes)
# for _ in range(len(wb_nodes)):
# if (
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
# and wb_nodes[_].type == "B"
# and wb_nodes[_].chunk == 0
# ):
# cut_idx = _
# break
# wb_nodes = wb_nodes[: cut_idx + 1]
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")

for stage in range(self.n_stage // 2, self.n_stage):
first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=1)
print(f"stage {stage} Down first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ")
d_queue_w, d_queue_b, u_queue_w = [], [], []
### 1.Get W nodes, then merge down/up W nodes ###
# get down W nodes: [first_d: mbs//2] chunk 1
for _ in range(self.n_micro // 2, first_d):
# get down W nodes: [last_d: mbs] chunk 1
for _ in range(last_d, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
d_queue_w.append(
ScheduledNode(
Expand All @@ -1300,10 +1300,11 @@ def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]):
)
)
curr_time += self.one_time_unit
# print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}")
# get up W nodes: [first_u: mbs//2] chunk 0
for _ in range(self.n_micro // 2, first_u):
for _ in range(first_u, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
d_queue_w.append(
u_queue_w.append(
ScheduledNode(
type="W",
chunk=0,
Expand All @@ -1315,9 +1316,9 @@ def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]):
)
curr_time += self.one_time_unit
### 2.Get B nodes, then cross with W ###
for _ in range(self.n_micro // 2, last_d):
for _ in range(last_d, self.n_micro // 2):
curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0
u_queue_b.append(
d_queue_b.append(
ScheduledNode(
type="B",
chunk=1,
Expand All @@ -1328,41 +1329,41 @@ def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]):
)
)
curr_time += self.one_time_unit
print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}")

# print(
# f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]} d_queue_b {[_.minibatch for _ in d_queue_b]} u_queue_w {[_.minibatch for _ in u_queue_w]}"
# )
if stage % 2 == 0:
w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
# clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B'
cut_idx = len(wb_nodes)
for _ in range(len(wb_nodes)):
if (
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
and wb_nodes[_].type == "B"
and wb_nodes[_].chunk == 1
):
cut_idx = _
break
wb_nodes = wb_nodes[: cut_idx + 1]
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# else: d_queue_w first, then u_queue_w
else:
w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
# clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
cut_idx = len(wb_nodes)
for _ in range(len(wb_nodes)):
if (
wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
and wb_nodes[_].type == "B"
and wb_nodes[_].chunk == 1
):
cut_idx = _
break
wb_nodes = wb_nodes[: cut_idx + 1]
# print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# if stage % 2 == 0:
# w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w)
# wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B'
# cut_idx = len(wb_nodes)
# for _ in range(len(wb_nodes)):
# if (
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
# and wb_nodes[_].type == "B"
# and wb_nodes[_].chunk == 1
# ):
# cut_idx = _
# break
# wb_nodes = wb_nodes[: cut_idx + 1]
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")
# # else: d_queue_w first, then u_queue_w
# else:
# w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w)
# wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b)
# # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B'
# cut_idx = len(wb_nodes)
# for _ in range(len(wb_nodes)):
# if (
# wb_nodes[_].minibatch == (self.n_micro // 2 - 1)
# and wb_nodes[_].type == "B"
# and wb_nodes[_].chunk == 1
# ):
# cut_idx = _
# break
# wb_nodes = wb_nodes[: cut_idx + 1]
# # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}")

########### Pipe_Stage 3.3 ###########
def bwdW_step(pipeline_schedule: List[List[ScheduledNode]]):
Expand Down

0 comments on commit 59819ae

Please sign in to comment.