Skip to content

Commit

Permalink
Merge branch 'main' into migrate-to-uv
Browse files Browse the repository at this point in the history
  • Loading branch information
egparedes committed Jan 28, 2025
2 parents 5181601 + e378877 commit 3ad6f97
Show file tree
Hide file tree
Showing 14 changed files with 171 additions and 102 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/deploy-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
id-token: write
steps:
- name: Download wheel
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: gt4py-dist
path: dist
Expand All @@ -60,7 +60,7 @@ jobs:
id-token: write
steps:
- name: Download wheel
uses: actions/download-artifact@v3
uses: actions/download-artifact@v4
with:
name: gt4py-dist
path: dist
Expand Down
3 changes: 2 additions & 1 deletion src/gt4py/next/ffront/fbuiltins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import functools
import inspect
import math
import operator
from builtins import bool, float, int, tuple # noqa: A004 shadowing a Python built-in
from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast

Expand Down Expand Up @@ -203,7 +204,7 @@ def astype(
return core_defs.dtype(type_).scalar_type(value)


_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs}
_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs, "neg": operator.neg}
UNARY_MATH_NUMBER_BUILTIN_NAMES: Final = [*_UNARY_MATH_NUMBER_BUILTIN_IMPL.keys()]

_UNARY_MATH_FP_BUILTIN_IMPL: Final = {
Expand Down
12 changes: 6 additions & 6 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,12 @@ def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
return self._lower_and_map("not_", node.operand)

return self._lower_and_map(
node.op.value,
foast.Constant(value="0", type=dtype, location=node.location),
node.operand,
)
if node.op in [dialect_ast_enums.UnaryOperator.USUB]:
return self._lower_and_map("neg", node.operand)
if node.op in [dialect_ast_enums.UnaryOperator.UADD]:
return self.visit(node.operand)
else:
raise NotImplementedError(f"Unary operator '{node.op}' is not supported.")

def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall:
return self._lower_and_map(node.op.value, node.left, node.right)
Expand Down
7 changes: 6 additions & 1 deletion src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,11 @@ def trunc(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def neg(*args):
raise BackendNotSelectedError()


@builtin_dispatch
def isfinite(*args):
raise BackendNotSelectedError()
Expand Down Expand Up @@ -397,7 +402,7 @@ def bool(*args): # noqa: A001 [builtin-variable-shadowing]
raise BackendNotSelectedError()


UNARY_MATH_NUMBER_BUILTINS = {"abs"}
UNARY_MATH_NUMBER_BUILTINS = {"abs", "neg"}
UNARY_LOGICAL_BUILTINS = {"not_"}
UNARY_MATH_FP_BUILTINS = {
"sin",
Expand Down
8 changes: 8 additions & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ def not_(a):
return not a


@builtins.neg.register(EMBEDDED)
def neg(a):
if isinstance(a, Column):
return np.negative(a)
return np.negative(a)


@builtins.gamma.register(EMBEDDED)
def gamma(a):
gamma_ = np.vectorize(math.gamma)
Expand Down Expand Up @@ -538,6 +545,7 @@ def promote_scalars(val: CompositeOfScalarOrField):
"and_": operator.and_,
"or_": operator.or_,
"xor_": operator.xor,
"neg": operator.neg,
}
decorator = getattr(builtins, math_builtin_name).register(EMBEDDED)
impl: Callable
Expand Down
74 changes: 24 additions & 50 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ir_makers as im,
misc as ir_misc,
)
from gt4py.next.iterator.transforms import fixed_point_transformation
from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas, inline_lambda
from gt4py.next.iterator.type_system import inference as itir_type_inference
from gt4py.next.type_system import type_info, type_specifications as ts
Expand Down Expand Up @@ -86,8 +87,10 @@ def _is_trivial_or_tuple_thereof_expr(node: ir.Node) -> bool:
# go through all available transformation and apply them. However the final result here still
# reads a little convoluted and is also different to how we write other transformations. We
# should revisit the pattern here and try to find a more general mechanism.
@dataclasses.dataclass(frozen=True)
class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator):
@dataclasses.dataclass(frozen=True, kw_only=True)
class CollapseTuple(
fixed_point_transformation.FixedPointTransformation, eve.PreserveLocationVisitor
):
"""
Simplifies `make_tuple`, `tuple_get` calls.
Expand All @@ -98,7 +101,7 @@ class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator):
# TODO(tehrengruber): This Flag mechanism is a little low level. What we actually want
# is something like a pass manager, where for each pattern we have a corresponding
# transformation, etc.
class Flag(enum.Flag):
class Transformation(enum.Flag):
#: `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t`
COLLAPSE_MAKE_TUPLE_TUPLE_GET = enum.auto()
#: `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
Expand Down Expand Up @@ -137,12 +140,12 @@ class Flag(enum.Flag):
INLINE_TRIVIAL_LET = enum.auto()

@classmethod
def all(self) -> CollapseTuple.Flag:
def all(self) -> CollapseTuple.Transformation:
return functools.reduce(operator.or_, self.__members__.values())

uids: eve_utils.UIDGenerator
ignore_tuple_size: bool
flags: Flag = Flag.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]
enabled_transformations: Transformation = Transformation.all() # noqa: RUF009 [function-call-in-dataclass-default-argument]

PRESERVED_ANNEX_ATTRS = ("type",)

Expand All @@ -155,8 +158,8 @@ def apply(
remove_letified_make_tuple_elements: bool = True,
offset_provider_type: Optional[common.OffsetProviderType] = None,
within_stencil: Optional[bool] = None,
# manually passing flags is mostly for allowing separate testing of the modes
flags: Optional[Flag] = None,
# manually passing enabled transformations is mostly for allowing separate testing of the modes
enabled_transformations: Optional[Transformation] = None,
# allow sym references without a symbol declaration, mostly for testing
allow_undeclared_symbols: bool = False,
uids: Optional[eve_utils.UIDGenerator] = None,
Expand All @@ -174,7 +177,7 @@ def apply(
to remove left-overs from `LETIFY_MAKE_TUPLE_ELEMENTS` transformation.
`(λ(_tuple_el_1, _tuple_el_2) → {_tuple_el_1, _tuple_el_2})(1, 2)` -> {1, 2}`
"""
flags = flags or cls.flags
enabled_transformations = enabled_transformations or cls.enabled_transformations
offset_provider_type = offset_provider_type or {}
uids = uids or eve_utils.UIDGenerator()

Expand All @@ -194,7 +197,7 @@ def apply(

new_node = cls(
ignore_tuple_size=ignore_tuple_size,
flags=flags,
enabled_transformations=enabled_transformations,
uids=uids,
).visit(node, within_stencil=within_stencil)

Expand All @@ -210,45 +213,17 @@ def apply(

return new_node

def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node:
def visit(self, node, **kwargs):
if cpm.is_call_to(node, "as_fieldop"):
kwargs = {**kwargs, "within_stencil": True}

node = self.generic_visit(node, **kwargs)
return self.fp_transform(node, **kwargs)

def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node:
while True:
new_node = self.transform(node, **kwargs)
if new_node is None:
break
assert new_node != node
node = new_node
return node

def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
if not isinstance(node, ir.FunCall):
return None

for transformation in self.Flag:
if self.flags & transformation:
assert isinstance(transformation.name, str)
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert (
result is not node
) # transformation should have returned None, since nothing changed
itir_type_inference.reinfer(result)
return result
return None
return super().visit(node, **kwargs)

def transform_collapse_make_tuple_tuple_get(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
if node.fun == ir.SymRef(id="make_tuple") and all(
isinstance(arg, ir.FunCall) and arg.fun == ir.SymRef(id="tuple_get")
for arg in node.args
if cpm.is_call_to(node, "make_tuple") and all(
cpm.is_call_to(arg, "tuple_get") for arg in node.args
):
# `make_tuple(tuple_get(0, t), tuple_get(1, t), ..., tuple_get(N-1,t))` -> `t`
assert isinstance(node.args[0], ir.FunCall)
Expand All @@ -275,10 +250,9 @@ def transform_collapse_tuple_get_make_tuple(
self, node: ir.FunCall, **kwargs
) -> Optional[ir.Node]:
if (
node.fun == ir.SymRef(id="tuple_get")
and isinstance(node.args[1], ir.FunCall)
and node.args[1].fun == ir.SymRef(id="make_tuple")
cpm.is_call_to(node, "tuple_get")
and isinstance(node.args[0], ir.Literal)
and cpm.is_call_to(node.args[1], "make_tuple")
):
# `tuple_get(i, make_tuple(e_0, e_1, ..., e_i, ..., e_N))` -> `e_i`
assert type_info.is_integer(node.args[0].type)
Expand All @@ -291,7 +265,7 @@ def transform_collapse_tuple_get_make_tuple(
return None

def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if node.fun == ir.SymRef(id="tuple_get") and isinstance(node.args[0], ir.Literal):
if cpm.is_call_to(node, "tuple_get") and isinstance(node.args[0], ir.Literal):
# TODO(tehrengruber): extend to general symbols as long as the tail call in the let
# does not capture
# `tuple_get(i, let(...)(make_tuple()))` -> `let(...)(tuple_get(i, make_tuple()))`
Expand All @@ -314,8 +288,8 @@ def transform_propagate_tuple_get(self, node: ir.FunCall, **kwargs) -> Optional[
)
return None

def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
if node.fun == ir.SymRef(id="make_tuple"):
def transform_letify_make_tuple_elements(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
if cpm.is_call_to(node, "make_tuple"):
# `make_tuple(expr1, expr1)`
# -> `let((_tuple_el_1, expr1), (_tuple_el_2, expr2))(make_tuple(_tuple_el_1, _tuple_el_2))`
bound_vars: dict[ir.Sym, ir.Expr] = {}
Expand All @@ -334,7 +308,7 @@ def transform_letify_make_tuple_elements(self, node: ir.FunCall, **kwargs) -> Op
)
return None

def transform_inline_trivial_make_tuple(self, node: ir.FunCall, **kwargs) -> Optional[ir.Node]:
def transform_inline_trivial_make_tuple(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
if cpm.is_let(node):
# `let(tup, make_tuple(trivial_expr1, trivial_expr2))(foo(tup))`
# -> `foo(make_tuple(trivial_expr1, trivial_expr2))`
Expand All @@ -349,7 +323,7 @@ def transform_propagate_to_if_on_tuples(self, node: ir.FunCall, **kwargs) -> Opt
# in local-view for now. Revisit.
return None

if not cpm.is_call_to(node, "if_"):
if isinstance(node, ir.FunCall) and not cpm.is_call_to(node, "if_"):
# TODO(tehrengruber): Only inline if type of branch value is a tuple.
# Examples:
# `(if cond then {1, 2} else {3, 4})[0]` -> `if cond then {1, 2}[0] else {3, 4}[0]`
Expand Down Expand Up @@ -391,7 +365,7 @@ def transform_propagate_to_if_on_tuples_cps(
# `if True then {2, 1} else {4, 3}`. The examples in the comments below all refer to this
# tuple reordering example here.

if cpm.is_call_to(node, "if_"):
if not isinstance(node, ir.FunCall) or cpm.is_call_to(node, "if_"):
return None

# The first argument that is eligible also transforms all remaining args (They will be
Expand Down
67 changes: 67 additions & 0 deletions src/gt4py/next/iterator/transforms/fixed_point_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
import enum
from typing import ClassVar, Optional, Type

from gt4py import eve
from gt4py.next.iterator import ir
from gt4py.next.iterator.type_system import inference as itir_type_inference


@dataclasses.dataclass(frozen=True, kw_only=True)
class FixedPointTransformation(eve.NodeTranslator):
"""
Transformation pass that transforms until no transformation is applicable anymore.
"""

#: Enum of all transformation (names). The transformations need to be defined as methods
#: named `transform_<NAME>`.
Transformation: ClassVar[Type[enum.Flag]]

#: All transformations enabled in this instance, e.g. `Transformation.T1 & Transformation.T2`.
#: Usually the default value is chosen to be all transformations.
enabled_transformations: enum.Flag

def visit(self, node, **kwargs):
node = super().visit(node, **kwargs)
return self.fp_transform(node, **kwargs) if isinstance(node, ir.Node) else node

def fp_transform(self, node: ir.Node, **kwargs) -> ir.Node:
"""
Transform node until a fixed point is reached, e.g. no transformation is applicable anymore.
"""
while True:
new_node = self.transform(node, **kwargs)
if new_node is None:
break
assert new_node != node
node = new_node
return node

def transform(self, node: ir.Node, **kwargs) -> Optional[ir.Node]:
"""
Transform node once.
Execute transformations until one is applicable. As soon as a transformation occured
the function will return the transformed node. Note that the transformation itself
may call other transformations on child nodes again.
"""
for transformation in self.Transformation:
if self.enabled_transformations & transformation:
assert isinstance(transformation.name, str)
method = getattr(self, f"transform_{transformation.name.lower()}")
result = method(node, **kwargs)
if result is not None:
assert (
result is not node
) # transformation should have returned None, since nothing changed
itir_type_inference.reinfer(result)
return result
return None
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/transforms/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def apply_common_transforms(
# required in order to get rid of expressions without a domain (e.g. when a tuple element is never accessed)
ir = CollapseTuple.apply(
ir,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES,
uids=collapse_tuple_uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
Expand All @@ -98,7 +98,7 @@ def apply_common_transforms(
# is constant-folded the surrounding tuple_get calls can be removed.
inlined = CollapseTuple.apply(
inlined,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES,
uids=collapse_tuple_uids,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program
Expand Down Expand Up @@ -136,7 +136,7 @@ def apply_common_transforms(
ir,
ignore_tuple_size=True,
uids=collapse_tuple_uids,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES,
offset_provider_type=offset_provider_type,
) # type: ignore[assignment] # always an itir.Program

Expand Down Expand Up @@ -176,7 +176,7 @@ def apply_fieldview_transforms(
ir = InlineLambdas.apply(ir, opcount_preserving=True, force_inline_lambda_args=True)
ir = CollapseTuple.apply(
ir,
flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES,
enabled_transformations=~CollapseTuple.Transformation.PROPAGATE_TO_IF_ON_TUPLES,
offset_provider_type=common.offset_provider_to_type(offset_provider),
) # type: ignore[assignment] # type is still `itir.Program`
ir = inline_dynamic_shifts.InlineDynamicShifts.apply(
Expand Down
Loading

0 comments on commit 3ad6f97

Please sign in to comment.