Skip to content

Commit

Permalink
utility for symbol safe replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 29, 2025
1 parent 9bbedf7 commit 5d4a71c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ def map_to_parent_sdfg(
lambda m: dace.sdfg.replace_properties_dict(outer_desc, m),
)
# Same applies to the symbols used as field origin (the domain range start)
outer_origin = [val.subs(symbol_mapping) for val in self.origin]
outer_origin = [
gtx_dace_utils.safe_replace_symbolic(val, symbol_mapping) for val in self.origin
]

outer_node = outer_sdfg_state.add_access(outer)
return FieldopData(outer_node, self.gt_type, tuple(outer_origin))
Expand Down
25 changes: 24 additions & 1 deletion src/gt4py/next/program_processors/runners/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

import re
from typing import Final, Literal
from typing import Final, Literal, Mapping, Union

import dace

Expand Down Expand Up @@ -100,3 +100,26 @@ def filter_connectivity_types(
for offset, conn in offset_provider_type.items()
if isinstance(conn, gtx_common.NeighborConnectivityType)
}


def safe_replace_symbolic(
val: dace.symbolic.SymbolicType,
symbol_mapping: Mapping[
Union[dace.symbolic.SymbolicType, str], Union[dace.symbolic.SymbolicType, str]
],
) -> dace.symbolic.SymbolicType:
"""
Replace free symbols in a dace symbolic expression, using `safe_replace()`
in order to avoid clashes in case the new symbol value is also a free symbol
in the original exoression.
Args:
val: The symbolic expression where to apply the replacement.
symbol_mapping: The mapping table for symbol replacement.
Returns:
A new symbolic expression as result of symbol replacement.
"""
x = [val]
dace.symbolic.safe_replace(symbol_mapping, lambda m, xx=x: xx.append(xx[-1].subs(m)))
return x[-1]
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

"""Test utility functions of the dace backend module."""

import pytest

from gt4py.next.program_processors.runners.dace import utils as gtx_dace_utils


dace = pytest.importorskip("dace")


def test_safe_replace_symbolic():
assert gtx_dace_utils.safe_replace_symbolic(
dace.symbolic.pystr_to_symbolic("x*x + y"), symbol_mapping={"x": "y", "y": "x"}
) == dace.symbolic.pystr_to_symbolic("y*y + x")

0 comments on commit 5d4a71c

Please sign in to comment.