diff --git a/numba_rvsdg/core/datastructures/ast_transforms.py b/numba_rvsdg/core/datastructures/ast_transforms.py index 89df096..c31a5cf 100644 --- a/numba_rvsdg/core/datastructures/ast_transforms.py +++ b/numba_rvsdg/core/datastructures/ast_transforms.py @@ -578,7 +578,7 @@ def function(a: int) -> None self.codegen(node.orelse) # Set jump_target of current block, whatever it may be. - self.current_block.set_jump_targets(exit_index) + self.seal_block(exit_index) # Create exit block and leave open for modification self.add_block(exit_index) diff --git a/numba_rvsdg/tests/test_ast_transforms.py b/numba_rvsdg/tests/test_ast_transforms.py index 3228f25..1200cc3 100644 --- a/numba_rvsdg/tests/test_ast_transforms.py +++ b/numba_rvsdg/tests/test_ast_transforms.py @@ -758,6 +758,95 @@ def function(a: int): } self.compare(function, expected, empty={"7"}) + def test_for_with_nested_for_else(self): + def function(a: bool) -> int: + c = 1 + for i in range(1): + for j in range(1): + if a: + c *= 3 + break # This break decides, if True skip continue. + else: + c *= 5 + continue # Causes break below to be skipped. + c *= 7 + break # Causes the else below to be skipped + else: + c *= 9 # Not breaking in inner loop leads here + return c + + self.assertEqual(function(True), 3 * 7) + self.assertEqual(function(False), 5 * 9) + expected = { + "0": { + "instructions": [ + "c = 1", + "__iterator_1__ = iter(range(1))", + "i = None", + ], + "jump_targets": ["1"], + "name": "0", + }, + "1": { + "instructions": [ + "__iter_last_1__ = i", + "i = next(__iterator_1__, '__sentinel__')", + "i != '__sentinel__'", + ], + "jump_targets": ["2", "3"], + "name": "1", + }, + "2": { + "instructions": [ + "__iterator_5__ = iter(range(1))", + "j = None", + ], + "jump_targets": ["5"], + "name": "2", + }, + "3": { + "instructions": ["i = __iter_last_1__", "c *= 9"], + "jump_targets": ["4"], + "name": "3", + }, + "4": { + "instructions": ["return c"], + "jump_targets": [], + "name": "4", + }, + "5": { + "instructions": [ + "__iter_last_5__ = j", + "j = next(__iterator_5__, '__sentinel__')", + "j != '__sentinel__'", + ], + "jump_targets": ["6", "7"], + "name": "5", + }, + "6": { + "instructions": ["a"], + "jump_targets": ["9", "5"], + "name": "6", + }, + "7": { + "instructions": ["j = __iter_last_5__", "c *= 5"], + "jump_targets": ["1"], + "name": "7", + }, + "8": { + "instructions": ["c *= 7"], + "jump_targets": ["4"], + "name": "8", + }, + "9": { + "instructions": ["c *= 3"], + "jump_targets": ["8"], + "name": "9", + }, + } + + self.compare(function, expected, empty={"11", "10"}) + def test_for_with_nested_else_return_break_and_continue(self): def function(a: int, b: int, c: int, d: int, e: int, f: int) -> int: for i in range(2):