-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into migrate-to-uv
- Loading branch information
Showing
14 changed files
with
171 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
67 changes: 67 additions & 0 deletions
67
src/gt4py/next/iterator/transforms/fixed_point_transformation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# GT4Py - GridTools Framework | ||
# | ||
# Copyright (c) 2014-2024, ETH Zurich | ||
# All rights reserved. | ||
# | ||
# Please, refer to the LICENSE file in the root directory. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import dataclasses | ||
import enum | ||
from typing import ClassVar, Optional, Type | ||
|
||
from gt4py import eve | ||
from gt4py.next.iterator import ir | ||
from gt4py.next.iterator.type_system import inference as itir_type_inference | ||
|
||
|
||
@dataclasses.dataclass(frozen=True, kw_only=True) | ||
class FixedPointTransformation(eve.NodeTranslator): | ||
""" | ||
Transformation pass that transforms until no transformation is applicable anymore. | ||
""" | ||
|
||
#: Enum of all transformation (names). The transformations need to be defined as methods | ||
#: named `transform_<NAME>`. | ||
Transformation: ClassVar[Type[enum.Flag]] | ||
|
||
#: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`. | ||
#: Usually the default value is chosen to be all transformations. | ||
enabled_transformations: enum.Flag | ||
|
||
def visit(self, node, **kwargs): | ||
node = super().visit(node, **kwargs) | ||
return self.fp_transform(node, **kwargs) if isinstance(node, ir.Node) else node | ||
|
||
def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node: | ||
""" | ||
Transform node until a fixed point is reached, e.g. no transformation is applicable anymore. | ||
""" | ||
while True: | ||
new_node = self.transform(node, **kwargs) | ||
if new_node is None: | ||
break | ||
assert new_node != node | ||
node = new_node | ||
return node | ||
|
||
def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]: | ||
""" | ||
Transform node once. | ||
Execute transformations until one is applicable. As soon as a transformation occured | ||
the function will return the transformed node. Note that the transformation itself | ||
may call other transformations on child nodes again. | ||
""" | ||
for transformation in self.Transformation: | ||
if self.enabled_transformations & transformation: | ||
assert isinstance(transformation.name, str) | ||
method = getattr(self, f"transform_{transformation.name.lower()}") | ||
result = method(node, **kwargs) | ||
if result is not None: | ||
assert ( | ||
result is not node | ||
) # transformation should have returned None, since nothing changed | ||
itir_type_inference.reinfer(result) | ||
return result | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.