Skip to content

Commit

Permalink
[xla] fix dynamic size propagation in while loops
Browse files Browse the repository at this point in the history
Some cases with heavily nested control flow did not correctly propagate dynamic
sizes.

PiperOrigin-RevId: 578944255
  • Loading branch information
tensorflower-gardener authored and copybara-github committed Nov 2, 2023
1 parent a28aa0a commit 98fee3e
Show file tree
Hide file tree
Showing 2 changed files with 409 additions and 0 deletions.
14 changes: 14 additions & 0 deletions xla/service/dynamic_dimension_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2308,6 +2308,20 @@ Status DynamicDimensionInferenceVisitor::HandleWhile(HloInstruction* hlo) {
}
HloInstruction* new_body_root = hlo->while_body()->AddInstruction(
HloInstruction::CreateTuple(new_root_operands));
for (int i = 0; i < original_tuple_count; ++i) {
TF_RETURN_IF_ERROR(ForEachDynamicDimension(
body_root,
[&](ShapeIndex index, int64_t dimension,
HloInstruction* dynamic_size) -> Status {
SetDynamicSize(new_body_root, index, dimension, dynamic_size);
if (index.empty() || index.front() != i) {
return OkStatus();
}
index.pop_front();
SetDynamicSize(new_root_operands[i], index, dimension, dynamic_size);
return OkStatus();
}));
}
hlo->while_body()->set_root_instruction(new_body_root);
MarkAsChanged();

Expand Down
Loading

0 comments on commit 98fee3e

Please sign in to comment.