Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile + CUDA Graph optimization for bs=1 #272

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 90 additions & 33 deletions src/fairseq2/generation/beam_search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/generation/beam_search.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -154,10 +154,9 @@
def __call__(
self, prompt_seqs: Tensor, prompt_padding_mask: Optional[PaddingMask]
) -> SequenceGeneratorOutput:
op = _BeamSearchSequenceGeneratorOp(

Check failure on line 157 in src/fairseq2/generation/beam_search.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Missing positional argument "step_hooks" in call to "_BeamSearchSequenceGeneratorOp"
self.model,
prompt_seqs,

Check failure on line 158 in src/fairseq2/generation/beam_search.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 1 to "_BeamSearchSequenceGeneratorOp" has incompatible type "Tensor"; expected "DecoderModel"
prompt_padding_mask,

Check failure on line 159 in src/fairseq2/generation/beam_search.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Argument 2 to "_BeamSearchSequenceGeneratorOp" has incompatible type "Optional[PaddingMask]"; expected "Tensor"
self.algorithm,
self.beam_size,
self.min_gen_len,
Expand Down Expand Up @@ -295,9 +294,11 @@
source_padding_mask: Optional[PaddingMask],
prompt_seqs: Tensor,
prompt_padding_mask: Optional[PaddingMask],
compiled_text_decoder: Optional[list] = None,
model = None
) -> Seq2SeqGeneratorOutput:
# (P, S)
encoder_output, encoder_padding_mask = self.model.encode(
encoder_output, encoder_padding_mask = model.encode(
source_seqs, source_padding_mask
)

Expand All @@ -323,7 +324,7 @@
)

op = _BeamSearchSeq2SeqGeneratorOp(
self.model,
model,
encoder_output,
encoder_padding_mask,
prompt_seqs,
Expand All @@ -344,7 +345,19 @@
self._step_hooks,
)

hypotheses = op()
for layer in model.decoder.layers.drop_iter():
if compiled_text_decoder[0] is None:
# 1024 is hard-coded as the maximum sequence length for self-attention layers for the optimal performance. The number could be changed accordingly.
layer.self_attn.cache_k = torch.zeros((5, layer.self_attn.num_heads, 1024, layer.self_attn.head_dim), dtype=torch.half).cuda()
layer.self_attn.cache_v = torch.zeros((5, layer.self_attn.num_heads, 1024, layer.self_attn.head_dim), dtype=torch.half).cuda()
# 256 is hard-coded as the maximum sequence length for cross-attention layers for the optimal performance. The number could be changed accordingly.
layer.encoder_decoder_attn.cache_k = torch.zeros((5, layer.encoder_decoder_attn.num_heads, 256, layer.encoder_decoder_attn.head_dim), dtype=torch.half).cuda()
layer.encoder_decoder_attn.cache_v = torch.zeros((5, layer.encoder_decoder_attn.num_heads, 256, layer.encoder_decoder_attn.head_dim), dtype=torch.half).cuda()
layer.self_attn.kv_cache = False
layer.encoder_decoder_attn.kv_cache = False

Check failure on line 358 in src/fairseq2/generation/beam_search.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
hypotheses = op(compiled_text_decoder, model)


return Seq2SeqGeneratorOutput(hypotheses, encoder_output, encoder_padding_mask)

Expand Down Expand Up @@ -580,12 +593,34 @@

# Holds the sequences that have reached EOS.
self.output = [[] for _ in range(num_prompts)]

Check failure on line 596 in src/fairseq2/generation/beam_search.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
def params_for_incremental_gen(self, prev_pos : int, cur_pos : int, device : torch.device):
valid_seq_pos = torch.arange(prev_pos, cur_pos, device=device)

# 1024 is hard-coded as the maximum sequence length for the optimal performance. The number could be changed accordingly.
mask = torch.full(
(1, 1, 1, 1024), False, device=device
)
mask[:, :, :, :valid_seq_pos.item() + 1] = True
return mask, valid_seq_pos

Check failure on line 606 in src/fairseq2/generation/beam_search.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
def __call__(self, compiled_text_decoder = None, model = None) -> List[List[Hypothesis]]:
if compiled_text_decoder[0] is None:
compiled_text_decoder[0] = torch.compile(model.decoder.forward, mode='max-autotune')

def __call__(self) -> List[List[Hypothesis]]:
self._prepare_state()
self._prepare_state(model, compiled_text_decoder[0])

prev_pos = self.min_prompt_len-1
for self.step_nr in range(self.min_prompt_len, self.max_seq_len):
if not self._step():
cuda_graph_mask, valid_seq_pos = self.params_for_incremental_gen(
prev_pos, self.step_nr, self.seqs.device)

if compiled_text_decoder[1] is None and self.step_nr > self.min_prompt_len:
compiled_text_decoder[1] = torch.compile(model.decoder.forward2, mode='max-autotune')

output = self._step(cuda_graph_mask, valid_seq_pos, compiled_text_decoder[0] if self.step_nr==self.min_prompt_len else compiled_text_decoder[1], model)
prev_pos = self.step_nr
if not output:
break

# Sort the hypotheses by their scores before returning.
Expand All @@ -594,12 +629,12 @@

return self.output

def _prepare_state(self) -> None:
def _prepare_state(self, model, cuda_graph = None) -> None:
# Fast-forward to the first step that needs to be generated.
if self.min_prompt_len > 1:
self._prefill()
self._prefill(model, cuda_graph=cuda_graph)

def _prefill(self) -> None:
def _prefill(self, model, cuda_graph=None) -> None:
chunk_begin = 0

prefill_len = self.min_prompt_len
Expand All @@ -613,7 +648,15 @@

chunk_end = chunk_begin + chunk_size

model_output = self._decode(self.seqs[:, chunk_begin:chunk_end])
# 1024 is hard-coded as the maximum sequence length for the optimal performance. The number could be changed accordingly.
mask = torch.full(
(1, 1, 1, 1024), False, device=self.seqs.device
)
mask[:, :, :, chunk_begin:chunk_end] = True

valid_seq_pos = torch.arange(chunk_begin, chunk_end, device=self.seqs.device)

model_output = self._decode(self.seqs[:, chunk_begin:chunk_end], mask, valid_seq_pos, cuda_graph, model)

self.state_bag.increment_step_nr(chunk_size)

Expand Down Expand Up @@ -657,9 +700,9 @@
for hook in self.step_hooks.values():
hook(self.prompt_indices, seqs, step_scores, prefill=True)

def _step(self) -> bool:
def _step(self, cuda_graph_mask, valid_seq_pos, cuda_graph, model) -> bool:
# Generate the next step output.
model_output = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr])
model_output = self._decode(self.seqs[:, self.step_nr - 1 : self.step_nr], cuda_graph_mask, valid_seq_pos, cuda_graph, model)

self.state_bag.increment_step_nr()

Expand Down Expand Up @@ -740,7 +783,7 @@
# (N_new)
next_step = BeamStep.merge(beam_next_step_list)

self._reorder_state(next_step.seq_indices)
self._reorder_state(next_step.seq_indices, model)

# Record the current step.
self.seqs[:, self.step_nr] = next_step.vocab_indices
Expand Down Expand Up @@ -825,7 +868,7 @@
return next_step.first(self.beam_size)

@abstractmethod
def _decode(self, seqs: Tensor) -> SequenceModelOutput:
def _decode(self, seqs: Tensor, cuda_graph_mask: Tensor, valid_seq_pos: Tensor, cuda_graph, model) -> SequenceModelOutput:
...

def _finish_sequence(self, seq_idx: int, score: Tensor) -> bool:
Expand Down Expand Up @@ -873,8 +916,22 @@
# beam.
return len(hypotheses) == self.beam_size

def _reorder_state(self, new_order: Tensor) -> None:
self.state_bag.reorder(new_order)
def _reorder_state(self, new_order: Tensor, model=None) -> None:
cache_ks = []
cache_vs = []
for layer in model.decoder.layers.drop_iter():
cache_ks.append(layer.self_attn.cache_k)
cache_vs.append(layer.self_attn.cache_v)
cache_ks.append(layer.encoder_decoder_attn.cache_k)
cache_vs.append(layer.encoder_decoder_attn.cache_v)

@torch.compile(mode='max-autotune-no-cudagraphs')
def reorder(k, new_order):
for i in range(len(k)):
k[i].copy_(k[i].index_select(0, new_order))

reorder(cache_ks, new_order)
reorder(cache_vs, new_order)

# (N) -> (N - F)
if self.prompt_lens is not None:
Expand Down Expand Up @@ -937,17 +994,15 @@
step_hooks,
)

self.model = model

@override
def _decode(self, seqs: Tensor) -> SequenceModelOutput:
decoder_output, decoder_padding_mask = self.model.decode(
def _decode(self, seqs: Tensor, cuda_graph_mask: Tensor, valid_seq_pos: Tensor, cuda_graph, model) -> SequenceModelOutput:
decoder_output, decoder_padding_mask = model.decode(
seqs,
None, # We never use PAD in incremental decoding.
state_bag=self.state_bag,
)

return self.model.project(decoder_output, decoder_padding_mask)
return model.project(decoder_output, decoder_padding_mask)


class _BeamSearchSeq2SeqGeneratorOp(_BeamSearchSequenceGeneratorOpBase):
Expand Down Expand Up @@ -996,26 +1051,28 @@
step_processors,
step_hooks,
)

self.model = model
self.encoder_output = encoder_output
self.encoder_padding_mask = encoder_padding_mask

Check failure on line 1054 in src/fairseq2/generation/beam_search.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
# 256 is hard-coded as the maximum sequence length for cross-attention layers for the optimal performance. The number could be changed accordingly.
self.encoder_output = torch.cat((encoder_output, torch.zeros((encoder_output.shape[0], 256-encoder_output.shape[1], encoder_output.shape[2]), device=encoder_output.device, dtype=encoder_output.dtype)), 1)
self.encoder_padding_mask = PaddingMask(torch.tensor([encoder_output.shape[1]], device=encoder_output.device), batch_seq_len=256)

@override
def _decode(self, seqs: Tensor) -> SequenceModelOutput:
decoder_output, decoder_padding_mask = self.model.decode(
def _decode(self, seqs: Tensor, cuda_graph_mask: Tensor, valid_seq_pos: Tensor, cuda_graph = None, model = None) -> SequenceModelOutput:
decoder_output, decoder_padding_mask = model.decode(
seqs,
None, # We never use PAD in incremental decoding.
self.encoder_output,
self.encoder_padding_mask,
self.encoder_padding_mask.materialize(),
state_bag=self.state_bag,
cuda_graph_mask=cuda_graph_mask,
valid_seq_pos=valid_seq_pos,
compiled_decoder=cuda_graph,
)

return self.model.project(decoder_output, decoder_padding_mask)
return model.project(decoder_output, decoder_padding_mask)

@override
def _reorder_state(self, new_order: Tensor) -> None:
super()._reorder_state(new_order)
def _reorder_state(self, new_order: Tensor, model=None) -> None:
super()._reorder_state(new_order, model=model)

self.encoder_output = self.encoder_output.index_select(dim=0, index=new_order)

Expand Down
8 changes: 6 additions & 2 deletions src/fairseq2/generation/text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/generation/text.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -68,6 +68,8 @@
self,
source_seqs: Tensor,
source_padding_mask: Optional[PaddingMask],
compiled_text_decoder: Optional[list] = None,

Check failure on line 71 in src/fairseq2/generation/text.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Implicit generic "Any". Use "typing.List" and specify generic parameters
s2t_model_list: Optional[list] = None

Check failure on line 72 in src/fairseq2/generation/text.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Implicit generic "Any". Use "typing.List" and specify generic parameters
) -> Tuple[List[StringLike], Seq2SeqGeneratorOutput]:
"""A subclass should call this method for actual text conversion.

Expand All @@ -88,8 +90,8 @@
# (S) -> (N, S)
target_prefix_seqs = self.target_prefix_seq.expand(batch_size, -1)

generator_output = self.generator(

Check failure on line 93 in src/fairseq2/generation/text.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Unexpected keyword argument "compiled_text_decoder" for "__call__" of "Seq2SeqGenerator"

Check failure on line 93 in src/fairseq2/generation/text.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Unexpected keyword argument "model" for "__call__" of "Seq2SeqGenerator"
source_seqs, source_padding_mask, target_prefix_seqs, None
source_seqs, source_padding_mask, target_prefix_seqs, None, compiled_text_decoder=compiled_text_decoder, model=s2t_model_list[0]

Check failure on line 94 in src/fairseq2/generation/text.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Value of type "Optional[List[Any]]" is not indexable
)

texts: List[StringLike] = []
Expand Down Expand Up @@ -126,10 +128,12 @@

return texts[0], generator_output

def batch_convert(

Check failure on line 131 in src/fairseq2/generation/text.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation for one or more arguments
self,
source_seqs: Tensor,
source_padding_mask: Optional[PaddingMask],
compiled_text_decoder = None,
s2t_model_list: Optional[list] = None

Check failure on line 136 in src/fairseq2/generation/text.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Implicit generic "Any". Use "typing.List" and specify generic parameters
) -> Tuple[List[StringLike], Seq2SeqGeneratorOutput]:
"""
:param source_seqs:
Expand All @@ -149,7 +153,7 @@
"`source_seqs` must contain at least one element, but is empty instead."
)

return self._do_convert(source_seqs, source_padding_mask)
return self._do_convert(source_seqs, source_padding_mask, compiled_text_decoder=compiled_text_decoder, s2t_model_list=s2t_model_list)


@final
Expand Down
12 changes: 9 additions & 3 deletions src/fairseq2/nn/transformer/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/nn/transformer/attention.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -126,7 +126,10 @@
is_causal = False

if key_padding_mask is not None:
mask = key_padding_mask.materialize()
if isinstance(key_padding_mask, PaddingMask):
mask = key_padding_mask.materialize()
else:
mask = key_padding_mask

# (N, S_kv) -> (N, 1, 1, S_kv)
mask = mask[:, None, None, :]
Expand All @@ -151,10 +154,13 @@
mask = attn_mask.materialize()
elif attn_mask is not None:
# ([H], S, S_kv)
mask = attn_mask.materialize()
if isinstance(attn_mask, AttentionMask):
mask = attn_mask.materialize()
else:
mask = attn_mask
else:
mask = None

Check failure on line 163 in src/fairseq2/nn/transformer/attention.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
attn = F.scaled_dot_product_attention( # type: ignore[attr-defined]
seqs,
keys,
Expand Down
54 changes: 46 additions & 8 deletions src/fairseq2/nn/transformer/decoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/nn/transformer/decoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -213,6 +213,8 @@
encoder_padding_mask: Optional[PaddingMask] = None,
*,
state_bag: Optional[IncrementalStateBag] = None,
cuda_graph_mask: Optional[Tensor] = None,
valid_seq_pos: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[PaddingMask]]:
if self._layer_output_hooks and self.layers.drop_p > 0.0:
raise RuntimeError(
Expand All @@ -221,23 +223,58 @@

num_layers = len(self.layers)

if self.self_attn_mask_factory is None:
self_attn_mask = None
else:
self_attn_mask = self.self_attn_mask_factory(
seqs, keys=seqs, training=self.training, state_bag=state_bag
)

for layer_idx, layer in enumerate(self.layers.drop_iter()):
seqs, padding_mask = layer(
seqs,
padding_mask,
self_attn_mask,
cuda_graph_mask,
encoder_output,
encoder_padding_mask,
state_bag=state_bag,
valid_seq_pos=valid_seq_pos,
)
for hook in self._layer_output_hooks.values():
if not hook(layer_idx, seqs, padding_mask, num_layers):
break

if self.layer_norm is not None:
seqs = self.layer_norm(seqs)

return seqs, padding_mask


@finaloverride
def forward2(
self,
seqs: Tensor,
padding_mask: Optional[PaddingMask],
encoder_output: Optional[Tensor] = None,
encoder_padding_mask: Optional[PaddingMask] = None,
*,
state_bag: Optional[IncrementalStateBag] = None,
cuda_graph_mask: Optional[Tensor] = None,
valid_seq_pos: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[PaddingMask]]:
if self._layer_output_hooks and self.layers.drop_p > 0.0:
raise RuntimeError(
"The layer output hooks cannot be run when LayerDrop is enabled."
)

num_layers = len(self.layers)


for layer_idx, layer in enumerate(self.layers.drop_iter()):
seqs, padding_mask = layer.forward2(
seqs,
padding_mask,
# self_attn_mask,
cuda_graph_mask,
encoder_output,
encoder_padding_mask,
state_bag=state_bag,
valid_seq_pos=valid_seq_pos,
)

Check failure on line 277 in src/fairseq2/nn/transformer/decoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
for hook in self._layer_output_hooks.values():
if not hook(layer_idx, seqs, padding_mask, num_layers):
break
Expand All @@ -259,3 +296,4 @@
s = f"{s}, self_attn_mask_factory={self_attn_mask_factory}"

return f"{s}, norm_order={self.norm_order}"

Check failure on line 299 in src/fairseq2/nn/transformer/decoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace

Check failure on line 299 in src/fairseq2/nn/transformer/decoder.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

no newline at end of file
Loading
Loading