From d9218b63c37b678cc13f35adb135a7c679978778 Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 17 Dec 2024 13:29:17 +0100 Subject: [PATCH 01/33] This are the changes Edoardo implemented to fix some issues in the optimization pipeline when confronted with scans. --- .../transformations/__init__.py | 8 +- .../transformations/gpu_utils.py | 2 +- .../transformations/simplify.py | 5 +- .../dace_fieldview/transformations/strides.py | 132 ++++++++++++++++++ .../test_map_buffer_elimination.py | 93 +++++++++++- 5 files changed, 233 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 4f3efb19b0..439084674e 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -35,7 +35,11 @@ gt_simplify, gt_substitute_compiletime_symbols, ) -from .strides import gt_change_transient_strides +from .strides import ( + gt_change_transient_strides, + gt_map_strides_to_dst_nested_sdfg, + gt_map_strides_to_src_nested_sdfg, +) from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -59,6 +63,8 @@ "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_make_transients_persistent", + "gt_map_strides_to_dst_nested_sdfg", + "gt_map_strides_to_src_nested_sdfg", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 2cd3020180..7b14144ead 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -95,7 +95,7 @@ def gt_gpu_transformation( if try_removing_trivial_maps: # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on - # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # GPU. `sdfg.apply_gpu_transformations()` will wrap such Tasklets in a Map. So # we might end up with lots of these trivial Maps, each requiring a separate # kernel launch. To prevent this we will combine these trivial maps, if # possible, with their downstream maps. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 6b7bd1b6d5..1a132cacb2 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -950,7 +950,7 @@ def _perform_pointwise_test( def apply( self, - graph: dace.SDFGState | dace.SDFG, + graph: dace.SDFGState, sdfg: dace.SDFG, ) -> None: # Removal @@ -971,6 +971,9 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None + # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension + gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) + # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge( diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 4e254f2880..72a1916875 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,6 +6,9 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import functools +from typing import Iterable + import dace from dace import data as dace_data @@ -64,6 +67,13 @@ def _gt_change_transient_strides_non_recursive_impl( # we simply have to reverse the order. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) + for state in sdfg.states(): + for data_node in state.data_nodes(): + if data_node.data == top_level_transient: + for in_edge in state.in_edges(data_node): + gt_map_strides_to_src_nested_sdfg(sdfg, state, in_edge, data_node) + for out_edge in state.out_edges(data_node): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, out_edge, data_node) def _find_toplevel_transients( @@ -97,3 +107,125 @@ def _find_toplevel_transients( continue top_level_transients.add(data) return top_level_transients + + +def gt_map_strides_to_dst_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, +) -> None: + """Propagates the strides of the given data node to the nested SDFGs on the edge destination. + + This function will recursively visit the nested SDFGs connected to the given + data node and apply mapping from inner to outer strides. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + """ + if isinstance(edge.dst, dace.nodes.MapEntry): + # Find the destinaion of the edge entering the map entry node + map_entry_out_conn = edge.dst_conn.replace("IN_", "OUT_") + for edge_from_map_entry in state.out_edges_by_connector(edge.dst, map_entry_out_conn): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, edge_from_map_entry, outer_node) + return + + if not isinstance(edge.dst, dace.nodes.NestedSDFG): + return + + outer_strides = outer_node.desc(sdfg).strides + _gt_map_strides_to_nested_sdfg(edge.dst, edge.dst_conn, edge.data, outer_strides) + + for inner_state in edge.dst.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == edge.dst: + for inner_edge in inner_state.out_edges(inner_node): + gt_map_strides_to_dst_nested_sdfg(sdfg, state, inner_edge, inner_node) + + +def gt_map_strides_to_src_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, +) -> None: + """Propagates the strides of the given data node to the nested SDFGs on the edge source. + + This function will recursively visit the nested SDFGs connected to the given + data node and apply mapping from inner to outer strides. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + """ + if isinstance(edge.src, dace.nodes.MapExit): + # Find the source of the edge entering the map exit node + map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") + for edge_to_map_exit in state.in_edges_by_connector(edge.src, map_exit_in_conn): + gt_map_strides_to_src_nested_sdfg(sdfg, state, edge_to_map_exit, outer_node) + return + + if not isinstance(edge.src, dace.nodes.NestedSDFG): + return + + if isinstance(edge.src.sdfg.data(edge.src_conn), dace.data.Scalar): + return # no strides to propagate + + outer_strides = outer_node.desc(sdfg).strides + _gt_map_strides_to_nested_sdfg(edge.src, edge.src_conn, edge.data, outer_strides) + + for inner_state in edge.src.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == edge.src_conn: + for inner_edge in inner_state.in_edges(inner_node): + gt_map_strides_to_src_nested_sdfg(sdfg, state, inner_edge, inner_node) + + +def _gt_map_strides_to_nested_sdfg( + nsdfg_node: dace.nodes.NestedSDFG, + inner_data: str, + edge_data: dace.Memlet, + outer_strides: Iterable[int | dace.symbolic.SymExpr], +) -> None: + # We need to propagate the strides inside the nested SDFG on the global arrays + new_strides = tuple( + stride + for stride, to_map_size in zip( + outer_strides, + edge_data.subset.size(), + strict=True, + ) + if to_map_size != 1 + ) + inner_desc = nsdfg_node.sdfg.arrays[inner_data] + assert not inner_desc.transient + + if isinstance(inner_desc, dace.data.Scalar): + assert len(new_strides) == 0 + return + + assert isinstance(inner_desc, dace.data.Array) + if all(isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides): + for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): + nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride + else: + inner_desc.set_shape(inner_desc.shape, new_strides) + + new_strides_symbols: list[dace.symbol] = functools.reduce( + lambda acc, itm: (acc + list(itm.free_symbols)) # type: ignore[union-attr] + if dace.symbolic.issymbolic(itm) + else acc, + new_strides, + [], + ) + new_strides_free_symbols = { + sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols + } + for sym in new_strides_free_symbols: + nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) + nsdfg_node.symbol_mapping[sym.name] = sym diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index 1a4ce6d047..a98eac3c2c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -22,10 +22,6 @@ import dace -def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: - return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} - - def _make_test_sdfg( output_name: str = "G", input_name: str = "G", @@ -262,3 +258,92 @@ def test_map_buffer_elimination_not_apply(): validate_all=True, ) assert count == 0 + + +def test_map_buffer_elimination_with_nested_sdfgs(): + """ + After removing a transient connected to a nested SDFG node, ensure that the strides + are propagated to the arrays in nested SDFG. + """ + + stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] + + # top-level sdfg + sdfg = dace.SDFG(util.unique_name("map_buffer")) + inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) + out, out_desc = sdfg.add_array( + "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) + ) + tmp, _ = sdfg.add_temp_transient_like(out_desc) + state = sdfg.add_state() + tmp_node = state.add_access(tmp) + + nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) + out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) + tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) + state1 = nsdfg1.add_state() + tmp1_node = state1.add_access(tmp1) + + nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) + out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) + tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) + state2 = nsdfg2.add_state() + tmp2_node = state2.add_access(tmp2) + + state2.add_mapped_tasklet( + "broadcast2", + map_ranges={"__i": "0:10"}, + code="__oval = __ival + 1.0", + inputs={ + "__ival": dace.Memlet(f"{inp2}[__i]"), + }, + outputs={ + "__oval": dace.Memlet(f"{tmp2}[__i]"), + }, + output_nodes={tmp2_node}, + external_edges=True, + ) + state2.add_nedge(tmp2_node, state2.add_access(out2), dace.Memlet.from_array(out2, out2_desc)) + + nsdfg2_node = state1.add_nested_sdfg(nsdfg2, nsdfg1, inputs={"__inp"}, outputs={"__out"}) + me1, mx1 = state1.add_map("broadcast1", ndrange={"__i": "0:10"}) + state1.add_memlet_path( + state1.add_access(inp1), + me1, + nsdfg2_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp1, inp1_desc), + ) + state1.add_memlet_path( + nsdfg2_node, mx1, tmp1_node, src_conn="__out", memlet=dace.Memlet(f"{tmp1}[__i, 0:10]") + ) + state1.add_nedge(tmp1_node, state1.add_access(out1), dace.Memlet.from_array(out1, out1_desc)) + + nsdfg1_node = state.add_nested_sdfg(nsdfg1, sdfg, inputs={"__inp"}, outputs={"__out"}) + me, mx = state.add_map("broadcast", ndrange={"__i": "0:10"}) + state.add_memlet_path( + state.add_access(inp), + me, + nsdfg1_node, + dst_conn="__inp", + memlet=dace.Memlet.from_array(inp, inp_desc), + ) + state.add_memlet_path( + nsdfg1_node, mx, tmp_node, src_conn="__out", memlet=dace.Memlet(f"{tmp}[__i, 0:10, 0:10]") + ) + state.add_nedge(tmp_node, state.add_access(out), dace.Memlet.from_array(out, out_desc)) + + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 3 + assert out1_desc.strides == out_desc.strides[1:] + assert out2_desc.strides == out_desc.strides[2:] From 9d7e7225333a1ada28f0273a2495c88ed0fea6df Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 11:08:22 +0100 Subject: [PATCH 02/33] First rework. However the actuall modifier function is not modified yet. --- .../dace_fieldview/transformations/strides.py | 431 ++++++++++++++---- 1 file changed, 354 insertions(+), 77 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 72a1916875..196f7b3e74 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -7,10 +7,11 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -from typing import Iterable +from typing import Iterable, Optional import dace from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( transformations as gtx_transformations, @@ -57,93 +58,160 @@ def gt_change_transient_strides( def _gt_change_transient_strides_non_recursive_impl( sdfg: dace.SDFG, ) -> None: - """Essentially this function just changes the stride to FORTRAN order.""" - for top_level_transient in _find_toplevel_transients(sdfg, only_arrays=True): + """Essentially this function just changes the stride to FORTRAN order. + + Todo: + Make this function more intelligent to analyse the access pattern and then + figuring out the best order. + """ + + # NOTE: processing the transient here is enough. If we are inside a + # NestedSDFG then they were handled before on the level above us. + top_level_transients_and_their_accesses = _gt_find_toplevel_data_accesses( + sdfg=sdfg, + only_transients=True, + only_arrays=True, + ) + for top_level_transient, accesses in top_level_transients_and_their_accesses.items(): desc: dace_data.Array = sdfg.arrays[top_level_transient] + + # Setting the strides only make sense if we have more than two dimensions ndim = len(desc.shape) if ndim <= 1: continue + # We assume that everything is in C order initially, to get FORTRAN order # we simply have to reverse the order. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) - for state in sdfg.states(): - for data_node in state.data_nodes(): - if data_node.data == top_level_transient: - for in_edge in state.in_edges(data_node): - gt_map_strides_to_src_nested_sdfg(sdfg, state, in_edge, data_node) - for out_edge in state.out_edges(data_node): - gt_map_strides_to_dst_nested_sdfg(sdfg, state, out_edge, data_node) - -def _find_toplevel_transients( + # Now we have to propagate the changed strides. Because we already have + # collected all the AccessNodes we are using the + # `gt_propagate_strides_from_access_node()` function, but we have to + # create `processed_nsdfg` set already outside here. + # Furthermore, the same comment as above apply, we do not have to + # propagate the non-transients, because they either come from outside, + # or they were already handled in the levels above, where they were + # defined and then propagated down. + processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + for state, access_node in accesses: + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=access_node, + processed_nsdfgs=processed_nsdfgs, + ) + + +def gt_propagate_strides_of( sdfg: dace.SDFG, - only_arrays: bool = False, -) -> set[str]: - """Find all top level transients in the SDFG. + data_name: str, +) -> None: + """Propagates the strides of `data_name` within the whole SDFG. - The function will scan the SDFG, ignoring nested one, and return the - name of all transients that have an access node at the top level. - However, it will ignore access nodes that refers to registers. + This function will call `gt_propagate_strides_from_access_node()` for every + AccessNode that refers to `data_name`. It will also make sure that + a NestedSDFG is visited only once. + + Args: + sdfg: The SDFG on which we operate. + data_name: Name of the data descriptor that should be handled. """ - top_level_transients: set[str] = set() + + # Defining it here ensures that we will not enter an NestedSDFG multiple times. + processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + for state in sdfg.states(): - scope_dict = state.scope_dict() for dnode in state.data_nodes(): - data: str = dnode.data - if scope_dict[dnode] is not None: - if data in top_level_transients: - top_level_transients.remove(data) - continue - elif data in top_level_transients: - continue - elif gtx_transformations.util.is_view(dnode, sdfg): + if dnode.data != data_name: continue - desc: dace_data.Data = dnode.desc(sdfg) - - if not desc.transient: - continue - elif only_arrays and not isinstance(desc, dace_data.Array): - continue - top_level_transients.add(data) - return top_level_transients + gt_propagate_strides_from_access_node( + sdfg=sdfg, + state=state, + outer_node=dnode, + processed_nsdfgs=processed_nsdfgs, + ) -def gt_map_strides_to_dst_nested_sdfg( +def gt_propagate_strides_from_access_node( sdfg: dace.SDFG, state: dace.SDFGState, - edge: dace.sdfg.graph.Edge, - outer_node: dace.nodes.AccessNode, + outer_node: dace_nodes.AccessNode, + processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, ) -> None: - """Propagates the strides of the given data node to the nested SDFGs on the edge destination. + """Propagates the stride of `outer_node` along all adjacent edges of `outer_node`. + + The function will propagate the strides of the data descriptor `outer_node` + refers to along all adjacent edges of `outer_node`. If one of these edges + leads to a NestedSDFG then the function will modify the strides of data + descriptor within to match the strides on the outside. The function will then + recursively process NestedSDFG. - This function will recursively visit the nested SDFGs connected to the given - data node and apply mapping from inner to outer strides. + It is important that this function will only handle the NestedSDFGs that are + reachable from `outer_node`. To fully propagate the strides the + `gt_propagate_strides_of()` should be used. Args: sdfg: The SDFG to process. state: The state where the data node is used. edge: The edge that reads from the data node, the nested SDFG is expected as the destination. outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + propagate_along_dataflow: Determine the direction of propagation. If `True` the + function follows the dataflow. """ - if isinstance(edge.dst, dace.nodes.MapEntry): - # Find the destinaion of the edge entering the map entry node - map_entry_out_conn = edge.dst_conn.replace("IN_", "OUT_") - for edge_from_map_entry in state.out_edges_by_connector(edge.dst, map_entry_out_conn): - gt_map_strides_to_dst_nested_sdfg(sdfg, state, edge_from_map_entry, outer_node) - return + if processed_nsdfgs is None: + # For preventing the case that nested SDFGs are handled multiple time. + # TODO: It certainly happens if a node is input and output, but are there other cases? + processed_nsdfgs = set() + + for in_edge in state.in_edges(outer_node): + gt_map_strides_to_src_nested_sdfg( + sdfg=sdfg, + state=state, + edge=in_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ) + for out_edge in state.out_edges(outer_node): + gt_map_strides_to_dst_nested_sdfg( + sdfg=sdfg, + state=state, + edge=out_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + ) - if not isinstance(edge.dst, dace.nodes.NestedSDFG): - return - outer_strides = outer_node.desc(sdfg).strides - _gt_map_strides_to_nested_sdfg(edge.dst, edge.dst_conn, edge.data, outer_strides) +def gt_map_strides_to_dst_nested_sdfg( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.Edge, + outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, +) -> None: + """Propagates the strides of `outer_node` along `edge` along the dataflow. - for inner_state in edge.dst.sdfg.states(): - for inner_node in inner_state.data_nodes(): - if inner_node.data == edge.dst: - for inner_edge in inner_state.out_edges(inner_node): - gt_map_strides_to_dst_nested_sdfg(sdfg, state, inner_edge, inner_node) + For more information see the description of `_gt_map_strides_to_nested_sdfg_src_dst(). + However it is recommended to use `gt_propagate_strides_of()` directly. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that writes to the data node, the nested SDFG is expected as the source. + outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when + you know what your are doing. + """ + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=True, + ) def gt_map_strides_to_src_nested_sdfg( @@ -151,39 +219,165 @@ def gt_map_strides_to_src_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, ) -> None: - """Propagates the strides of the given data node to the nested SDFGs on the edge source. + """Propagates the strides of `outer_node` along `edge` against the dataflow. - This function will recursively visit the nested SDFGs connected to the given - data node and apply mapping from inner to outer strides. + For more information see the description of `_gt_map_strides_to_nested_sdfg_src_dst(). + However it is recommended to use `gt_propagate_strides_of()` directly. Args: sdfg: The SDFG to process. state: The state where the data node is used. edge: The edge that writes to the data node, the nested SDFG is expected as the source. outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when + you know what your are doing. """ - if isinstance(edge.src, dace.nodes.MapExit): - # Find the source of the edge entering the map exit node - map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") - for edge_to_map_exit in state.in_edges_by_connector(edge.src, map_exit_in_conn): - gt_map_strides_to_src_nested_sdfg(sdfg, state, edge_to_map_exit, outer_node) - return + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=False, + ) - if not isinstance(edge.src, dace.nodes.NestedSDFG): - return - if isinstance(edge.src.sdfg.data(edge.src_conn), dace.data.Scalar): - return # no strides to propagate +def _gt_map_strides_to_nested_sdfg_src_dst( + sdfg: dace.SDFG, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + outer_node: dace.nodes.AccessNode, + processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]], + propagate_along_dataflow: bool, +) -> None: + """Propagates the stride of `outer_node` along `edge`. + + The function will follow `edge`, the direction depends on the value of + `propagate_along_dataflow` and propagate the strides of `outer_node` + into every NestedSDFG that is reachable by following `edge`. + + When the function encounters a NestedSDFG it will determine the the data + descriptor `outer_node` refers on the inside of the NestedSDFG. + It will then replace the stride of the inner descriptor with the ones + of the outside. Afterwards it will recursively propagates the + stride inside the NestedSDFG. + During this propagation the function will follow any edges. + + If the function reaches a NestedSDFG that is listed inside `processed_nsdfgs` + then it will be skipped. NestedSDFGs that have been processed will be added + to the `processed_nsdfgs`. + + Args: + sdfg: The SDFG to process. + state: The state where the data node is used. + edge: The edge that reads from the data node, the nested SDFG is expected as the destination. + outer_node: The data node whose strides should be propagated. + processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + Only specify when you know what your are doing. + propagate_along_dataflow: Determine the direction of propagation. If `True` the + function follows the dataflow. + + Note: + A user should not use this function directly, instead `gt_propagate_strides_of()`, + `gt_map_strides_to_src_nested_sdfg()` (`propagate_along_dataflow == `False`) + or `gt_map_strides_to_dst_nested_sdfg()` (`propagate_along_dataflow == `True`) + should be used. + + Todo: + Try using `MemletTree` for the propagation. + """ + # If `processed_nsdfg` is `None` then this is the first call. We will now + # allocate the `set` and pass it as argument to all recursive calls, this + # ensures that the `set` is the same everywhere. + if processed_nsdfgs is None: + processed_nsdfgs = set() - outer_strides = outer_node.desc(sdfg).strides - _gt_map_strides_to_nested_sdfg(edge.src, edge.src_conn, edge.data, outer_strides) + if propagate_along_dataflow: + # Propagate along the dataflow or forward, so we are interested at the `dst` of the edge. + ScopeNode = dace_nodes.MapEntry - for inner_state in edge.src.sdfg.states(): - for inner_node in inner_state.data_nodes(): - if inner_node.data == edge.src_conn: - for inner_edge in inner_state.in_edges(inner_node): - gt_map_strides_to_src_nested_sdfg(sdfg, state, inner_edge, inner_node) + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.dst + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.dst_conn + + def next_edges_by_connector( + state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + if edge.dst_conn is None or not edge.dst_conn.startswith("IN_"): + return [] + return list(state.out_edges_by_connector(edge.dst, "OUT_" + edge.dst_conn[3:])) + + else: + # Propagate against the dataflow or backward, so we are interested at the `src` of the edge. + ScopeNode = dace_nodes.MapExit + + def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_nodes.Node: + return edge.src + + def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: + return edge.src_conn + + def next_edges_by_connector( + state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] + ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: + return list(state.in_edges_by_connector(edge.src, "IN_" + edge.src_conn[4:])) + + if isinstance(get_node(edge), ScopeNode): + for next_edge in next_edges_by_connector(state, edge): + _gt_map_strides_to_nested_sdfg_src_dst( + sdfg=sdfg, + state=state, + edge=next_edge, + outer_node=outer_node, + processed_nsdfgs=processed_nsdfgs, + propagate_along_dataflow=propagate_along_dataflow, + ) + + elif isinstance(get_node(edge), dace.nodes.NestedSDFG): + nsdfg_node = get_node(edge) + inner_data = get_inner_data(edge) + + if nsdfg_node in processed_nsdfgs: + # We have processed this nested SDFG already, so we have nothing to do. + return + + # Mark this nested SDFG as processed. + processed_nsdfgs.add(nsdfg_node) + + # Now set the stride of the data descriptor inside the nested SDFG to + # the ones it has outside. + _gt_map_strides_to_nested_sdfg( + nsdfg_node=nsdfg_node, + inner_data=inner_data, + edge_data=edge.data, + outer_strides=outer_node.desc(sdfg).strides, + ) + + # Because the function call above if not recursive we have now to scan the + # propagate the change into the nested SDFG. Using + # `_gt_find_toplevel_data_accesses()` is a bit overkill, but allows for a + # more uniform processing. + # TODO(phimuell): Instead of scanning every level for every data we modify + # we should scan the whole SDFG once and then reuse this information. + accesses_in_nested_sdfg = _gt_find_toplevel_data_accesses( + sdfg=nsdfg_node.sdfg, + only_transients=False, # Because on the nested levels they are globals. + only_arrays=True, + ) + for nested_state, nested_access in accesses_in_nested_sdfg.get(inner_data, list()): + # We have to use `gt_propagate_strides_of()` here because we have to + # handle its entirety. We could wait until the other branch processes + # the nested SDFG, but this might not work, so let's do it fully now. + gt_propagate_strides_from_access_node( + sdfg=nsdfg_node.sdfg, + state=nested_state, + outer_node=nested_access, + processed_nsdfgs=processed_nsdfgs, + ) def _gt_map_strides_to_nested_sdfg( @@ -192,6 +386,7 @@ def _gt_map_strides_to_nested_sdfg( edge_data: dace.Memlet, outer_strides: Iterable[int | dace.symbolic.SymExpr], ) -> None: + # TODO(phimuell/edopao): Refactor this function. # We need to propagate the strides inside the nested SDFG on the global arrays new_strides = tuple( stride @@ -214,6 +409,7 @@ def _gt_map_strides_to_nested_sdfg( for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride else: + assert len(inner_desc.shape) == len(new_strides) inner_desc.set_shape(inner_desc.shape, new_strides) new_strides_symbols: list[dace.symbol] = functools.reduce( @@ -229,3 +425,84 @@ def _gt_map_strides_to_nested_sdfg( for sym in new_strides_free_symbols: nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) nsdfg_node.symbol_mapping[sym.name] = sym + + +def _gt_find_toplevel_data_accesses( + sdfg: dace.SDFG, + only_transients: bool, + only_arrays: bool = False, +) -> dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]]: + """Find all data that is accessed on the top level. + + The function will scan the SDFG, ignoring nested one, and return the + name of all data (global and transient) that only have AccessNodes on + the top level. In data is found that has an AccessNode on both the top + level and in a nested scope and error is generated. + The function will ignore an access in the following cases: + - The AccessNode refers to data that is a register. + - The AccessNode refers to a View. + + Args: + sdfg: The SDFG to process. + only_transients: If `True` all non transients will be filtered out. + only_arrays: If `True`, defaults to `False`, only arrays are returned. + + Returns: + A `dict` that maps the name of a data container, that should be processed + to a list of tuples containing the state where the AccessNode was found + and the node. + """ + # List of data that is accessed on the top level and all its access node. + top_level_data: dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]] = dict() + + # List of all data that were found not on top level. + not_top_level_data: set[str] = set() + + for state in sdfg.states(): + scope_dict = state.scope_dict() + for dnode in state.data_nodes(): + data: str = dnode.data + if scope_dict[dnode] is not None: + # The node was not found on the top level. So we can ignore it. + # We also check if it was ever found on the top level, this should + # not happen, as everything should go through Maps. But some strange + # DaCe transformation might do it. + assert data in top_level_data, f"Found {data} on the top level and inside a scope." + not_top_level_data.add(data) + continue + + elif data in top_level_data: + # The data is already known to be in top level data, so we must add the + # AccessNode to the list of known nodes. But nothing else. + top_level_data[data].append((state, dnode)) + continue + + elif gtx_transformations.util.is_view(dnode, sdfg): + # The AccessNode refers to a View so we ignore it anyway + # TODO(phimuell/edopao): Should the function return them? + continue + + # We have found a new data node that is on the top node and is unknown. + assert ( + data not in not_top_level_data + ), f"Found {data} on the top level and inside a scope." + desc: dace_data.Data = dnode.desc(sdfg) + + # Check if we only accept arrays + if only_arrays and not isinstance(desc, dace_data.Array): + continue + + # For now we ignore registers. + # We do this because register are allocated on the stack, so the compiler + # has all information and should organize the best thing possible. + # TODO(phimuell): verify this. + elif desc.storage is not dace.StorageType.Register: + continue + + # We are only interested in transients + if only_transients and desc.transient: + continue + + # Now create the new entry in the list and record the AccessNode. + top_level_data[data] = [(state, dnode)] + return top_level_data From 1ddd6fee4a21b565e300a5a6ea3ad2161998a53f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 11:37:55 +0100 Subject: [PATCH 03/33] Updated some commenst. --- .../dace_fieldview/transformations/strides.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 196f7b3e74..363ffd6a93 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -39,8 +39,6 @@ def gt_change_transient_strides( Todo: - Implement the estimation correctly. - - Handle the case of nested SDFGs correctly; on the outside a transient, - but on the inside a non transient. """ # TODO(phimeull): Implement this function correctly. @@ -50,22 +48,32 @@ def gt_change_transient_strides( return sdfg for nsdfg in sdfg.all_sdfgs_recursive(): - # TODO(phimuell): Handle the case when transient goes into nested SDFG - # on the inside it is a non transient, so it is ignored. _gt_change_transient_strides_non_recursive_impl(nsdfg) def _gt_change_transient_strides_non_recursive_impl( sdfg: dace.SDFG, ) -> None: - """Essentially this function just changes the stride to FORTRAN order. + """Set optimal strides of all transients in the SDFG. + + The function will look for all top level transients, see `_gt_find_toplevel_data_accesses()` + and set their strides such that the access is optimal, see Note. The function + will also run `gt_propagate_strides_of()` to propagate the strides into nested SDFGs. + + This function should never be called directly but always through + `gt_change_transient_strides()`! + + Note: + Currently the function just reverses the strides of the data descriptor + it processes. Since DaCe generates `C` order by default this lead to + FORTRAN order, which is (for now) sufficient to optimize the memory + layout to GPU. Todo: Make this function more intelligent to analyse the access pattern and then figuring out the best order. """ - - # NOTE: processing the transient here is enough. If we are inside a + # NOTE: Processing the transient here is enough. If we are inside a # NestedSDFG then they were handled before on the level above us. top_level_transients_and_their_accesses = _gt_find_toplevel_data_accesses( sdfg=sdfg, From 95e0007022dabe56762936451e583ef092625bb3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 11:44:49 +0100 Subject: [PATCH 04/33] I want to ignore register, not only consider them. --- .../runners/dace_fieldview/transformations/strides.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 363ffd6a93..48d5f7620d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -504,7 +504,7 @@ def _gt_find_toplevel_data_accesses( # We do this because register are allocated on the stack, so the compiler # has all information and should organize the best thing possible. # TODO(phimuell): verify this. - elif desc.storage is not dace.StorageType.Register: + elif desc.storage is dace.StorageType.Register: continue # We are only interested in transients From f1b7a3ff851884cfdb154950c7ebc4fd8fc46d47 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 12:56:06 +0100 Subject: [PATCH 05/33] There was a missing `not` in the check. Which is funny then if you look at the last commit, the number of `not`s in this function was correct. --- .../runners/dace_fieldview/transformations/strides.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 48d5f7620d..61471de74b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -475,7 +475,9 @@ def _gt_find_toplevel_data_accesses( # We also check if it was ever found on the top level, this should # not happen, as everything should go through Maps. But some strange # DaCe transformation might do it. - assert data in top_level_data, f"Found {data} on the top level and inside a scope." + assert ( + data not in top_level_data + ), f"Found {data} on the top level and inside a scope." not_top_level_data.add(data) continue From 50ad620b97284fa3c78c53f92c5bd0e474308f98 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 12:56:49 +0100 Subject: [PATCH 06/33] Had to update the propagation, to also handle aliasing. It seems that we alsohave to handle alias. It makes thing a bit handler, instead of only looking at the NestedSDFG, we now look at the `(NameOfDataDescriptorInside, NestedSDFG)` pair. However, it still has some errors. --- .../dace_fieldview/transformations/strides.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 61471de74b..78a25c4407 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -from typing import Iterable, Optional +from typing import Iterable, Optional, TypeAlias import dace from dace import data as dace_data @@ -18,6 +18,19 @@ ) +PropagatedStrideRecord: TypeAlias = tuple[str, dace_nodes.NestedSDFG] +"""Record of a stride that has been propagated into a NestedSDFG. + +The type combines the NestedSDFG into which the strides were already propagated +and the data within that NestedSDFG to which we have propagated the data, +which is the connector name on the NestedSDFG. +We need the NestedSDFG because we have to know what was already processed, +however, we also need the name within because of aliasing, i.e. a data +descriptor on the outside could be mapped to multiple data descriptors +inside the NestedSDFG. +""" + + def gt_change_transient_strides( sdfg: dace.SDFG, gpu: bool, @@ -118,8 +131,8 @@ def gt_propagate_strides_of( """Propagates the strides of `data_name` within the whole SDFG. This function will call `gt_propagate_strides_from_access_node()` for every - AccessNode that refers to `data_name`. It will also make sure that - a NestedSDFG is visited only once. + AccessNode that refers to `data_name`. It will also make sure that a descriptor + inside a NestedSDFG is only processed once. Args: sdfg: The SDFG on which we operate. @@ -127,7 +140,7 @@ def gt_propagate_strides_of( """ # Defining it here ensures that we will not enter an NestedSDFG multiple times. - processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() + processed_nsdfgs: set[PropagatedStrideRecord] = set() for state in sdfg.states(): for dnode in state.data_nodes(): @@ -145,7 +158,7 @@ def gt_propagate_strides_from_access_node( sdfg: dace.SDFG, state: dace.SDFGState, outer_node: dace_nodes.AccessNode, - processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the stride of `outer_node` along all adjacent edges of `outer_node`. @@ -164,7 +177,7 @@ def gt_propagate_strides_from_access_node( state: The state where the data node is used. edge: The edge that reads from the data node, the nested SDFG is expected as the destination. outer_node: The data node whose strides should be propagated. - processed_nsdfgs: Set of Nested SDFG that were already processed and will be ignored. + processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. Only specify when you know what your are doing. propagate_along_dataflow: Determine the direction of propagation. If `True` the function follows the dataflow. @@ -197,7 +210,7 @@ def gt_map_strides_to_dst_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, - processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` along the dataflow. @@ -227,7 +240,7 @@ def gt_map_strides_to_src_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, - processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]] = None, + processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` against the dataflow. @@ -257,7 +270,7 @@ def _gt_map_strides_to_nested_sdfg_src_dst( state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], outer_node: dace.nodes.AccessNode, - processed_nsdfgs: Optional[set[dace_nodes.NestedSDFG]], + processed_nsdfgs: Optional[set[PropagatedStrideRecord]], propagate_along_dataflow: bool, ) -> None: """Propagates the stride of `outer_node` along `edge`. @@ -348,13 +361,14 @@ def next_edges_by_connector( elif isinstance(get_node(edge), dace.nodes.NestedSDFG): nsdfg_node = get_node(edge) inner_data = get_inner_data(edge) + process_record = (inner_data, nsdfg_node) - if nsdfg_node in processed_nsdfgs: - # We have processed this nested SDFG already, so we have nothing to do. + if process_record in processed_nsdfgs: + # We already handled this NestedSDFG and the inner data. return # Mark this nested SDFG as processed. - processed_nsdfgs.add(nsdfg_node) + processed_nsdfgs.add(process_record) # Now set the stride of the data descriptor inside the nested SDFG to # the ones it has outside. From 983022c3f80a6ccb2bd003dbc9c22bafaafc08b0 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 13:21:00 +0100 Subject: [PATCH 07/33] In the function for looking for top level accesses the `only_transients` flag was not implemented properly. --- .../dace_fieldview/transformations/strides.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 78a25c4407..e808422765 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -457,16 +457,19 @@ def _gt_find_toplevel_data_accesses( """Find all data that is accessed on the top level. The function will scan the SDFG, ignoring nested one, and return the - name of all data (global and transient) that only have AccessNodes on - the top level. In data is found that has an AccessNode on both the top - level and in a nested scope and error is generated. - The function will ignore an access in the following cases: + name of all data that only have AccessNodes on the top level. In data + is found that has an AccessNode on both the top level and in a nested + scope and error is generated. + By default the function will return transient and non transient data, + however, if `only_transients` is `True` then only transient data will + be returned. + Furthermore, the function will ignore an access in the following cases: - The AccessNode refers to data that is a register. - The AccessNode refers to a View. Args: sdfg: The SDFG to process. - only_transients: If `True` all non transients will be filtered out. + only_transients: If `True` only include transients. only_arrays: If `True`, defaults to `False`, only arrays are returned. Returns: @@ -524,7 +527,7 @@ def _gt_find_toplevel_data_accesses( continue # We are only interested in transients - if only_transients and desc.transient: + if only_transients and (not desc.transient): continue # Now create the new entry in the list and record the AccessNode. From e7b1afbf127a7f4a38df9492b147a015501a8a47 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 13:27:29 +0100 Subject: [PATCH 08/33] Small reminder of the future. --- .../runners/dace_fieldview/transformations/strides.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index e808422765..08cd08120a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -114,6 +114,7 @@ def _gt_change_transient_strides_non_recursive_impl( # propagate the non-transients, because they either come from outside, # or they were already handled in the levels above, where they were # defined and then propagated down. + # TODO(phimuell): Updated the functions such that only once scan is needed. processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() for state, access_node in accesses: gt_propagate_strides_from_access_node( From df7bd0ca993b59a0fcbde67a4ed194c24ac9b3e4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 14:47:31 +0100 Subject: [PATCH 09/33] Forgot to export the new SDFG stuff. --- .../runners/dace_fieldview/transformations/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 439084674e..0902bd665a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -39,6 +39,8 @@ gt_change_transient_strides, gt_map_strides_to_dst_nested_sdfg, gt_map_strides_to_src_nested_sdfg, + gt_propagate_strides_from_access_node, + gt_propagate_strides_of, ) from .util import gt_find_constant_arguments, gt_make_transients_persistent @@ -65,6 +67,8 @@ "gt_make_transients_persistent", "gt_map_strides_to_dst_nested_sdfg", "gt_map_strides_to_src_nested_sdfg", + "gt_propagate_strides_from_access_node", + "gt_propagate_strides_of", "gt_reduce_distributed_buffering", "gt_set_gpu_blocksize", "gt_set_iteration_order", From 363ab5942e4da737789554b57f5880cf80ef49ee Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 15:02:43 +0100 Subject: [PATCH 10/33] Had to update function for actuall renaming of the strides. Before the function had a special mode in which it performed the renaming through the `symbol_mapping`. However, this made testing a bit harder and so I decided that there should be a flag to disable this. --- .../dace_fieldview/transformations/strides.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 08cd08120a..e8eb25bd59 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -128,6 +128,7 @@ def _gt_change_transient_strides_non_recursive_impl( def gt_propagate_strides_of( sdfg: dace.SDFG, data_name: str, + ignore_symbol_mapping: bool = False, ) -> None: """Propagates the strides of `data_name` within the whole SDFG. @@ -138,6 +139,8 @@ def gt_propagate_strides_of( Args: sdfg: The SDFG on which we operate. data_name: Name of the data descriptor that should be handled. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. """ # Defining it here ensures that we will not enter an NestedSDFG multiple times. @@ -152,6 +155,7 @@ def gt_propagate_strides_of( state=state, outer_node=dnode, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -159,6 +163,7 @@ def gt_propagate_strides_from_access_node( sdfg: dace.SDFG, state: dace.SDFGState, outer_node: dace_nodes.AccessNode, + ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the stride of `outer_node` along all adjacent edges of `outer_node`. @@ -180,6 +185,8 @@ def gt_propagate_strides_from_access_node( outer_node: The data node whose strides should be propagated. processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. Only specify when you know what your are doing. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. propagate_along_dataflow: Determine the direction of propagation. If `True` the function follows the dataflow. """ @@ -195,6 +202,7 @@ def gt_propagate_strides_from_access_node( edge=in_edge, outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, ) for out_edge in state.out_edges(outer_node): gt_map_strides_to_dst_nested_sdfg( @@ -203,6 +211,7 @@ def gt_propagate_strides_from_access_node( edge=out_edge, outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -211,6 +220,7 @@ def gt_map_strides_to_dst_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` along the dataflow. @@ -223,6 +233,8 @@ def gt_map_strides_to_dst_nested_sdfg( state: The state where the data node is used. edge: The edge that writes to the data node, the nested SDFG is expected as the source. outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when you know what your are doing. """ @@ -233,6 +245,7 @@ def gt_map_strides_to_dst_nested_sdfg( outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, propagate_along_dataflow=True, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -241,6 +254,7 @@ def gt_map_strides_to_src_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, + ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` against the dataflow. @@ -253,6 +267,8 @@ def gt_map_strides_to_src_nested_sdfg( state: The state where the data node is used. edge: The edge that writes to the data node, the nested SDFG is expected as the source. outer_node: The data node whose strides should be propagated. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when you know what your are doing. """ @@ -263,6 +279,7 @@ def gt_map_strides_to_src_nested_sdfg( outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, propagate_along_dataflow=False, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -273,6 +290,7 @@ def _gt_map_strides_to_nested_sdfg_src_dst( outer_node: dace.nodes.AccessNode, processed_nsdfgs: Optional[set[PropagatedStrideRecord]], propagate_along_dataflow: bool, + ignore_symbol_mapping: bool = False, ) -> None: """Propagates the stride of `outer_node` along `edge`. @@ -300,6 +318,8 @@ def _gt_map_strides_to_nested_sdfg_src_dst( Only specify when you know what your are doing. propagate_along_dataflow: Determine the direction of propagation. If `True` the function follows the dataflow. + ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + of NestedSDFGs instead of manipulating the data descriptor. Note: A user should not use this function directly, instead `gt_propagate_strides_of()`, @@ -357,6 +377,7 @@ def next_edges_by_connector( outer_node=outer_node, processed_nsdfgs=processed_nsdfgs, propagate_along_dataflow=propagate_along_dataflow, + ignore_symbol_mapping=ignore_symbol_mapping, ) elif isinstance(get_node(edge), dace.nodes.NestedSDFG): @@ -378,6 +399,7 @@ def next_edges_by_connector( inner_data=inner_data, edge_data=edge.data, outer_strides=outer_node.desc(sdfg).strides, + ignore_symbol_mapping=ignore_symbol_mapping, ) # Because the function call above if not recursive we have now to scan the @@ -400,6 +422,7 @@ def next_edges_by_connector( state=nested_state, outer_node=nested_access, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -408,6 +431,7 @@ def _gt_map_strides_to_nested_sdfg( inner_data: str, edge_data: dace.Memlet, outer_strides: Iterable[int | dace.symbolic.SymExpr], + ignore_symbol_mapping: bool = False, ) -> None: # TODO(phimuell/edopao): Refactor this function. # We need to propagate the strides inside the nested SDFG on the global arrays @@ -428,7 +452,9 @@ def _gt_map_strides_to_nested_sdfg( return assert isinstance(inner_desc, dace.data.Array) - if all(isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides): + if (not ignore_symbol_mapping) and all( + isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides + ): for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride else: From 9c19d32438477a87e447e0ed510a58c6a85b5fb2 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 15:06:00 +0100 Subject: [PATCH 11/33] Added a todo to the replacement function. --- .../runners/dace_fieldview/transformations/strides.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index e8eb25bd59..ea14cf97fb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -433,7 +433,12 @@ def _gt_map_strides_to_nested_sdfg( outer_strides: Iterable[int | dace.symbolic.SymExpr], ignore_symbol_mapping: bool = False, ) -> None: - # TODO(phimuell/edopao): Refactor this function. + """ + Todo: + - Refactor this function. + - Handle the case the stride is used somewhere else. + - Handle the case where we have an explicit size 1 dimension in slicing. + """ # We need to propagate the strides inside the nested SDFG on the global arrays new_strides = tuple( stride From 9cad1f7179b7bc8124d9248569c1ee2ccaf904e8 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 18 Dec 2024 15:10:07 +0100 Subject: [PATCH 12/33] Added a first test to the propagation function. There are some functioanlity missing, but it is looking good. --- .../transformation_tests/test_strides.py | 221 ++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py new file mode 100644 index 0000000000..bb0af074c7 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -0,0 +1,221 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace import symbolic as dace_symbolic +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_strides_propagation_level3_sdfg() -> dace.SDFG: + """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" + sdfg = dace.SDFG(util.unique_name("level3")) + state = sdfg.add_state(is_start_block=True) + names = ["a3", "c3"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL3", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a3[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("c3[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + """Generates the level 2 SDFG (nested) SDFG for `test_strides_propagation()`. + + The function returns the level 2 SDFG and the NestedSDFG node that contains + the level 3 SDFG. + """ + sdfg = dace.SDFG(util.unique_name("level2")) + state = sdfg.add_state(is_start_block=True) + names = ["a2", "a2_alias", "b2", "c2"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + state.add_mapped_tasklet( + "compL2_1", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + + state.add_mapped_tasklet( + "compL2_2", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("c2[__i0]")}, + code="__out = __in1", + outputs={"__out": dace.Memlet("a2_alias[__i0]")}, + external_edges=True, + ) + + # This is the nested SDFG we have here. + sdfg_level3 = _make_strides_propagation_level3_sdfg() + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level3, + parent=sdfg, + inputs={"a3"}, + outputs={"c3"}, + symbol_mapping={s3: s3 for s3 in sdfg_level3.free_symbols}, + ) + + state.add_edge(state.add_access("a2"), None, nsdfg, "a3", dace.Memlet("a2[0:10]")) + state.add_edge(nsdfg, "c3", state.add_access("c2"), None, dace.Memlet("c2[0:10]")) + sdfg.validate() + + return sdfg, nsdfg + + +def _make_strides_propagation_level1_sdfg() -> ( + tuple[dace.SDFG, dace_nodes.NestedSDFG, dace_nodes.NestedSDFG] +): + """Generates the level 1 SDFG (top) SDFG for `test_strides_propagation()`. + + Note that the SDFG is valid, but will be indeterminate. The only point of + this SDFG is to have a lot of different situations that have to be handled + for renaming. + + Returns: + A tuple of length three, with the following members: + - The top level SDFG. + - The NestedSDFG node that contains the level 2 SDFG (member of the top level SDFG). + - The NestedSDFG node that contains the lebel 3 SDFG (member of the level 2 SDFG). + """ + + sdfg = dace.SDFG(util.unique_name("level1")) + state = sdfg.add_state(is_start_block=True) + names = ["a1", "b1", "c1"] + + for name in names: + stride_name = name + "_stride" + stride_sym = dace_symbolic.pystr_to_symbolic(stride_name) + sdfg.add_symbol(stride_name, dace.int64) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + strides=(stride_sym,), + ) + + sdfg_level2, nsdfg_level3 = _make_strides_propagation_level2_sdfg() + + nsdfg_level2: dace_nodes.NestedSDFG = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg, + inputs={"a2", "c2"}, + outputs={"a2_alias", "b2", "c2"}, + symbol_mapping={s: s for s in sdfg_level2.free_symbols}, + ) + + for inner_name in nsdfg_level2.in_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + state.add_access(outer_name), + None, + nsdfg_level2, + inner_name, + dace.Memlet(f"{outer_name}[0:10]"), + ) + for inner_name in nsdfg_level2.out_connectors: + outer_name = inner_name[0] + "1" + state.add_edge( + nsdfg_level2, + inner_name, + state.add_access(outer_name), + None, + dace.Memlet(f"{outer_name}[0:10]"), + ) + + sdfg.validate() + + return sdfg, nsdfg_level2, nsdfg_level3 + + +def test_strides_propagation(): + """ + Todo: + - Add a case where `ignore_symbol_mapping=False` can be tested. + - What happens if the stride symbol is used somewhere else? + """ + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + assert len(adesc.strides) == 1 + assert exp_stride == str( + adesc.strides[0] + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we propagate `a` and `b`, but not `c`. + # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + + # After the propagation `a` and `b` should use the same stride (the one that + # it has on level 1, but `c` should still be level depending. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + if aname.startswith("c"): + exp_stride = f"{aname}_stride" + else: + exp_stride = f"{aname[0]}1_stride" + assert len(adesc.strides) == 1 + assert exp_stride == str( + adesc.strides[0] + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname[0]}1_stride" + assert len(adesc.strides) == 1 + assert exp_stride == str( + adesc.strides[0] + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." From 2700f534142464daa38d1ce95edfea26ff3dafc1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 10:02:46 +0100 Subject: [PATCH 13/33] Modified the function that performs the actuall modification of the strides. However, it is not yet fully tested, tehy are on their wa. --- .../dace_fieldview/transformations/strides.py | 167 +++++++++++++----- 1 file changed, 127 insertions(+), 40 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index ea14cf97fb..17bdbceeec 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,8 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import functools -from typing import Iterable, Optional, TypeAlias +from typing import Optional, TypeAlias import dace from dace import data as dace_data @@ -346,6 +345,11 @@ def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_node def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: return edge.dst_conn + def get_subset( + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.src_subset + def next_edges_by_connector( state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: @@ -363,6 +367,11 @@ def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_node def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str: return edge.src_conn + def get_subset( + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], + ) -> dace.subsets.Subset: + return edge.data.dst_subset + def next_edges_by_connector( state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: @@ -394,11 +403,11 @@ def next_edges_by_connector( # Now set the stride of the data descriptor inside the nested SDFG to # the ones it has outside. - _gt_map_strides_to_nested_sdfg( + _gt_map_strides_into_nested_sdfg( nsdfg_node=nsdfg_node, inner_data=inner_data, - edge_data=edge.data, - outer_strides=outer_node.desc(sdfg).strides, + outer_subset=get_subset(edge), + outer_desc=outer_node.desc(sdfg), ignore_symbol_mapping=ignore_symbol_mapping, ) @@ -426,59 +435,137 @@ def next_edges_by_connector( ) -def _gt_map_strides_to_nested_sdfg( +def _gt_map_strides_into_nested_sdfg( nsdfg_node: dace.nodes.NestedSDFG, inner_data: str, - edge_data: dace.Memlet, - outer_strides: Iterable[int | dace.symbolic.SymExpr], + outer_subset: dace.subsets.Subset, + outer_desc: dace_data.Data, ignore_symbol_mapping: bool = False, ) -> None: - """ + """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. + + `inner_data` is the name of of a data descriptor inside the NestedSDFG. + The function will then modify the modify the strides of `inner_data` to + match the ones of `outer_desc`. + + Args: + nsdfg_node: The node in the parent SDFG that contains the NestedSDFG. + inner_data: The name of the data descriptor that should be processed + inside the NestedSDFG (by construction also a connector name). + outer_subset: The subset that describes what part of the outer data is + mapped into the NestedSDFG. + outer_desc: The data descriptor of the data on the outside. + ignore_symbol_mapping: If possible the function will perform the renaming + through the `symbol_mapping` of the nested SDFG. If `True` then + the function will always perform the renaming. + Todo: - - Refactor this function. - - Handle the case the stride is used somewhere else. - - Handle the case where we have an explicit size 1 dimension in slicing. + - Handle explicit dimensions of size 1. """ - # We need to propagate the strides inside the nested SDFG on the global arrays - new_strides = tuple( - stride - for stride, to_map_size in zip( - outer_strides, - edge_data.subset.size(), - strict=True, - ) - if to_map_size != 1 - ) - inner_desc = nsdfg_node.sdfg.arrays[inner_data] + # We need to compute the new strides. In the following we assume that the + # relative order of the dimension does not change, but some dimensions + # that are present on the outside are not present on the inside. For + # example this happens for the Memlet `a[__i0, 0:__a_size1]`. + # We detect this case by checking if that dimension has size 1. + # TODO(phimuell): Handle the case were some additional size 1 dimensions are added. + inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data] + inner_shape = inner_desc.shape + inner_strides_init = inner_desc.strides + + # TODO(phimuell): For now this is fine, but it should be possisble to allow it. assert not inner_desc.transient - if isinstance(inner_desc, dace.data.Scalar): - assert len(new_strides) == 0 + outer_strides = outer_desc.strides + outer_inflow = outer_subset.size() + + new_strides: list = [] + for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): + current_inner_dim = len(new_strides) + + if inner_shape[current_inner_dim] == 1 and dim_oinflow == 1: + # There is an explicit size 1 dimension. Because the only valid + # index for this dimension is `0` we can use any value here. + # To give the compiler more information we explicitly use `0`, + # instead of the outer value. + new_strides.append(0) + + elif dim_oinflow == 1: + # Only something flows in, thus there is no stride in this dimension. + pass + + else: + # There is inflow into the SDFG, so we need the stride. + assert dim_oinflow != 0 + new_strides.append(dim_ostride) + assert len(new_strides) <= len(inner_shape) + + if len(new_strides) != len(inner_shape): + raise ValueError("Failed to compute the inner strides.") + + # If we have a scalar on the inside, then there is nothing to adjust. + # We could have performed the test above, but doing it here, gives us + # the chance of validating it. + if isinstance(inner_desc, dace_data.Scalar): + if len(new_strides) != 0: + raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.") return - assert isinstance(inner_desc, dace.data.Array) + if not isinstance(inner_desc, dace_data.Array): + raise TypeError( + f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." + ) + + # Now we actually replace the strides, there are two ways of doing it. + # The first is to create an alias in the `symbol_mapping`, however, + # this is only possible if the current strides are singular symbols, + # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` + # or literal values. + # The second way would be to replace `strides` attributer of the + # inner data descriptor. In case the new stride consists of expressions + # such as `value1 - value2` we have to make them available inside the + # NestedSDFG. However, it could be that the strides is used somewhere else. + # We will do the following, if `ignore_symbol_mapping` is `False` and + # the strides of the inner descriptors are symbols, we will use the + # symbol mapping. Otherwise, we will replace the `strides` attribute + # of the inner descriptor, in addition we will install a remapping, + # for those values that were a symbol. if (not ignore_symbol_mapping) and all( - isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides + isinstance(inner_stride, dace.symbol) for inner_stride in inner_strides_init ): + # Use the symbol for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True): nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride else: - assert len(inner_desc.shape) == len(new_strides) + # We have to replace the `strides` attribute of the inner descriptor. inner_desc.set_shape(inner_desc.shape, new_strides) - new_strides_symbols: list[dace.symbol] = functools.reduce( - lambda acc, itm: (acc + list(itm.free_symbols)) # type: ignore[union-attr] - if dace.symbolic.issymbolic(itm) - else acc, - new_strides, - [], - ) - new_strides_free_symbols = { - sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols + # Now find the free symbols that the new strides need. + new_strides_symbols: list[str] = [] + for new_stride_dim in new_strides: + if dace.symbolic.issymbolic(new_stride_dim): + new_strides_symbols.append(str(new_stride_dim)) + else: + new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) + + # Now we determine the set of symbols that should be mapped inside the NestedSDFG. + # We will exclude all that are already inside the `symbol_mapping` (we do not + # check if they map to the same value, we just hope it). Furthermore, + # we will exclude all symbols that are listed in the `symbols` property + # of the SDFG that is nested, and hope that it has the same meaning. + missing_symbol_mappings: set[str] = { + sym + for sym in new_strides_symbols + if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping) } - for sym in new_strides_free_symbols: - nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) - nsdfg_node.symbol_mapping[sym.name] = sym + for sym in missing_symbol_mappings: + # We can not create symbols in the nested SDFG, because we do not have + # the type of the symbols. + nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym) + + # Now create aliases for the old symbols that were used as strides. + for old_sym, new_sym in zip(inner_strides_init, new_strides): + if dace.symbolic.issymbolic(old_sym): + nsdfg_node.symbol_mapping[str(old_sym)] = dace.symbolic.pystr_to_symbolic(new_sym) def _gt_find_toplevel_data_accesses( From a20d3c00a202aea530dd66ee842b00bb550e7045 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 10:07:24 +0100 Subject: [PATCH 14/33] Updated some tes, but more are missing. --- .../transformation_tests/test_strides.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index bb0af074c7..655e50fb23 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -187,11 +187,17 @@ def test_strides_propagation(): for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: for aname, adesc in sdfg.arrays.items(): exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] assert len(adesc.strides) == 1 - assert exp_stride == str( - adesc.strides[0] + assert ( + str(actual_stride) == exp_stride ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + # Now we propagate `a` and `b`, but not `c`. # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) @@ -201,6 +207,7 @@ def test_strides_propagation(): # it has on level 1, but `c` should still be level depending. for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: for aname, adesc in sdfg.arrays.items(): + original_stride = f"{aname}_stride" if aname.startswith("c"): exp_stride = f"{aname}_stride" else: @@ -210,12 +217,23 @@ def test_strides_propagation(): adesc.strides[0] ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride + # Now we also propagate `c` thus now all data descriptors have the same stride gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: for aname, adesc in sdfg.arrays.items(): exp_stride = f"{aname[0]}1_stride" + original_stride = f"{aname}_stride" assert len(adesc.strides) == 1 assert exp_stride == str( adesc.strides[0] ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride From b5ff46270733b4edd6fc7d43fcfe13c55558dd84 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:01:07 +0100 Subject: [PATCH 15/33] Subset caching strikes again. --- .../dace_fieldview/transformations/strides.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 17bdbceeec..8808248e40 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -346,12 +346,14 @@ def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str return edge.dst_conn def get_subset( + state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], ) -> dace.subsets.Subset: - return edge.data.src_subset + return edge.data.get_src_subset(edge, state) def next_edges_by_connector( - state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: if edge.dst_conn is None or not edge.dst_conn.startswith("IN_"): return [] @@ -368,12 +370,14 @@ def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str return edge.src_conn def get_subset( + state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], ) -> dace.subsets.Subset: - return edge.data.dst_subset + return edge.data.get_dst_subset(edge, state) def next_edges_by_connector( - state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet] + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet], ) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]: return list(state.in_edges_by_connector(edge.src, "IN_" + edge.src_conn[4:])) @@ -406,7 +410,7 @@ def next_edges_by_connector( _gt_map_strides_into_nested_sdfg( nsdfg_node=nsdfg_node, inner_data=inner_data, - outer_subset=get_subset(edge), + outer_subset=get_subset(state, edge), outer_desc=outer_node.desc(sdfg), ignore_symbol_mapping=ignore_symbol_mapping, ) From d326d3b7316f6561897f9b1e7117424b5911baf5 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:01:36 +0100 Subject: [PATCH 16/33] It seems that the explicit handling of one dimensions is not working. It also seems that it inferes with something. --- .../runners/dace_fieldview/transformations/strides.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 8808248e40..e0f21e4163 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -484,16 +484,7 @@ def _gt_map_strides_into_nested_sdfg( new_strides: list = [] for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): - current_inner_dim = len(new_strides) - - if inner_shape[current_inner_dim] == 1 and dim_oinflow == 1: - # There is an explicit size 1 dimension. Because the only valid - # index for this dimension is `0` we can use any value here. - # To give the compiler more information we explicitly use `0`, - # instead of the outer value. - new_strides.append(0) - - elif dim_oinflow == 1: + if dim_oinflow == 1: # Only something flows in, thus there is no stride in this dimension. pass From 252f348e104cff7f75fccc5629044bcdb5347b33 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:02:47 +0100 Subject: [PATCH 17/33] The test must be moved bellow. Because a scalar has a shape of `(1,)` but a stride of `()`. Thus we have first to handle this case. However, now we are back at the index stuff, let's fix it. --- .../runners/dace_fieldview/transformations/strides.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index e0f21e4163..1ee0260310 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -487,16 +487,12 @@ def _gt_map_strides_into_nested_sdfg( if dim_oinflow == 1: # Only something flows in, thus there is no stride in this dimension. pass - else: # There is inflow into the SDFG, so we need the stride. assert dim_oinflow != 0 new_strides.append(dim_ostride) assert len(new_strides) <= len(inner_shape) - if len(new_strides) != len(inner_shape): - raise ValueError("Failed to compute the inner strides.") - # If we have a scalar on the inside, then there is nothing to adjust. # We could have performed the test above, but doing it here, gives us # the chance of validating it. @@ -510,6 +506,9 @@ def _gt_map_strides_into_nested_sdfg( f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'." ) + if len(new_strides) != len(inner_shape): + raise ValueError("Failed to compute the inner strides.") + # Now we actually replace the strides, there are two ways of doing it. # The first is to create an alias in the `symbol_mapping`, however, # this is only possible if the current strides are singular symbols, From 49f81721b27440a22b2cf3f8fcc14401bb1fbaf1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:38:44 +0100 Subject: [PATCH 18/33] The symbol is also needed to be present in the nested SDFG. However, it still seems to fail in some cases. --- .../dace_fieldview/transformations/strides.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 1ee0260310..c03079037d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,6 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +import warnings from typing import Optional, TypeAlias import dace @@ -408,6 +409,7 @@ def next_edges_by_connector( # Now set the stride of the data descriptor inside the nested SDFG to # the ones it has outside. _gt_map_strides_into_nested_sdfg( + sdfg=sdfg, nsdfg_node=nsdfg_node, inner_data=inner_data, outer_subset=get_subset(state, edge), @@ -440,6 +442,7 @@ def next_edges_by_connector( def _gt_map_strides_into_nested_sdfg( + sdfg: dace.SDFG, nsdfg_node: dace.nodes.NestedSDFG, inner_data: str, outer_subset: dace.subsets.Subset, @@ -453,6 +456,7 @@ def _gt_map_strides_into_nested_sdfg( match the ones of `outer_desc`. Args: + sdfg: The SDFG containing the NestedSDFG. nsdfg_node: The node in the parent SDFG that contains the NestedSDFG. inner_data: The name of the data descriptor that should be processed inside the NestedSDFG (by construction also a connector name). @@ -539,7 +543,9 @@ def _gt_map_strides_into_nested_sdfg( if dace.symbolic.issymbolic(new_stride_dim): new_strides_symbols.append(str(new_stride_dim)) else: - new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) + # NOTE: In DaCe `free_symbols` is `set[str]` but in `sympy` it + # returns `set[symbol]`. We need `str` so we have to cast them. + new_strides_symbols.extend(str(sym) for sym in new_stride_dim.free_symbols) # Now we determine the set of symbols that should be mapped inside the NestedSDFG. # We will exclude all that are already inside the `symbol_mapping` (we do not @@ -551,9 +557,19 @@ def _gt_map_strides_into_nested_sdfg( for sym in new_strides_symbols if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping) } + # Now create the symbol we in the NestedSDFG. for sym in missing_symbol_mappings: - # We can not create symbols in the nested SDFG, because we do not have - # the type of the symbols. + if sym in sdfg.symbols: + # TODO(phimuell): Handle the case the symbol is already defined. + nsdfg_node.sdfg.add_symbol(sym, sdfg.symbols[sym]) + else: + # The symbol is not known in the parent SDFG, but we need a symbol + # for it. So we use the default. + nsdfg_node.sdfg.add_symbol(sym, dace.symbol("__INVALID_SYMBOL__").dtype) + warnings.warn( + f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides.", + stacklevel=1, + ) nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym) # Now create aliases for the old symbols that were used as strides. From 2d6dfc0e9e7497f5e31161c9baeec5f92c71921c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:59:11 +0100 Subject: [PATCH 19/33] Fixed a bug in determining the free symbols that we need. --- .../runners/dace_fieldview/transformations/strides.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index c03079037d..1b8ebdfd41 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -541,11 +541,11 @@ def _gt_map_strides_into_nested_sdfg( new_strides_symbols: list[str] = [] for new_stride_dim in new_strides: if dace.symbolic.issymbolic(new_stride_dim): - new_strides_symbols.append(str(new_stride_dim)) - else: # NOTE: In DaCe `free_symbols` is `set[str]` but in `sympy` it # returns `set[symbol]`. We need `str` so we have to cast them. new_strides_symbols.extend(str(sym) for sym in new_stride_dim.free_symbols) + else: + new_strides_symbols.append(str(new_stride_dim)) # Now we determine the set of symbols that should be mapped inside the NestedSDFG. # We will exclude all that are already inside the `symbol_mapping` (we do not From 6124c6d7d461799963a8913a1629d5b74a2bee34 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 11:59:47 +0100 Subject: [PATCH 20/33] Updated the propagation code for the symbols. The type is now a bit better estimated. --- .../dace_fieldview/transformations/strides.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 1b8ebdfd41..30864c4449 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -557,17 +557,20 @@ def _gt_map_strides_into_nested_sdfg( for sym in new_strides_symbols if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping) } - # Now create the symbol we in the NestedSDFG. + + # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: if sym in sdfg.symbols: # TODO(phimuell): Handle the case the symbol is already defined. nsdfg_node.sdfg.add_symbol(sym, sdfg.symbols[sym]) else: - # The symbol is not known in the parent SDFG, but we need a symbol - # for it. So we use the default. - nsdfg_node.sdfg.add_symbol(sym, dace.symbol("__INVALID_SYMBOL__").dtype) + # The symbol is not known in the parent SDFG, but we need to define a + # symbol and for that we need a `dtype`. Our solution (which is as + # wrong as any other) is to create a symbol with that name and then + # use the type that was deduced. + nsdfg_node.sdfg.add_symbol(sym, dace.symbol(sym).dtype) warnings.warn( - f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides.", + f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym]}' as dtype.", stacklevel=1, ) nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym) From 45bcf9795496eb857f80bc594cd1fd37406a7ff4 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 13:55:09 +0100 Subject: [PATCH 21/33] Addressed Edoardo's changes. --- .../transformations/simplify.py | 2 +- .../dace_fieldview/transformations/strides.py | 94 ++++++++++--------- 2 files changed, 51 insertions(+), 45 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index 1a132cacb2..4339a761fa 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -971,7 +971,7 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None - # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension + # Recursively visit the nested SDFGs for mapping of strides from inner to outer array gtx_transformations.gt_map_strides_to_src_nested_sdfg(sdfg, graph, map_to_tmp_edge, glob_ac) # We now remove the `tmp` node, and create a new connection between diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 30864c4449..e69d392770 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -22,10 +22,10 @@ """Record of a stride that has been propagated into a NestedSDFG. The type combines the NestedSDFG into which the strides were already propagated -and the data within that NestedSDFG to which we have propagated the data, +and the data within that NestedSDFG to which we have propagated the strides, which is the connector name on the NestedSDFG. We need the NestedSDFG because we have to know what was already processed, -however, we also need the name within because of aliasing, i.e. a data +however, we also need the inner array name because of aliasing, i.e. a data descriptor on the outside could be mapped to multiple data descriptors inside the NestedSDFG. """ @@ -96,13 +96,14 @@ def _gt_change_transient_strides_non_recursive_impl( for top_level_transient, accesses in top_level_transients_and_their_accesses.items(): desc: dace_data.Array = sdfg.arrays[top_level_transient] - # Setting the strides only make sense if we have more than two dimensions + # Setting the strides only make sense if we have more than one dimensions ndim = len(desc.shape) if ndim <= 1: continue # We assume that everything is in C order initially, to get FORTRAN order # we simply have to reverse the order. + # TODO(phimuell): Improve this. new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) @@ -110,11 +111,11 @@ def _gt_change_transient_strides_non_recursive_impl( # collected all the AccessNodes we are using the # `gt_propagate_strides_from_access_node()` function, but we have to # create `processed_nsdfg` set already outside here. - # Furthermore, the same comment as above apply, we do not have to + # Furthermore, the same comment as above applies here, we do not have to # propagate the non-transients, because they either come from outside, # or they were already handled in the levels above, where they were # defined and then propagated down. - # TODO(phimuell): Updated the functions such that only once scan is needed. + # TODO(phimuell): Updated the functions such that only one scan is needed. processed_nsdfgs: set[dace_nodes.NestedSDFG] = set() for state, access_node in accesses: gt_propagate_strides_from_access_node( @@ -166,7 +167,7 @@ def gt_propagate_strides_from_access_node( ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: - """Propagates the stride of `outer_node` along all adjacent edges of `outer_node`. + """Propagates the stride of `outer_node` to any adjacent reachable through its edges. The function will propagate the strides of the data descriptor `outer_node` refers to along all adjacent edges of `outer_node`. If one of these edges @@ -183,16 +184,13 @@ def gt_propagate_strides_from_access_node( state: The state where the data node is used. edge: The edge that reads from the data node, the nested SDFG is expected as the destination. outer_node: The data node whose strides should be propagated. - processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. - Only specify when you know what your are doing. ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. - propagate_along_dataflow: Determine the direction of propagation. If `True` the - function follows the dataflow. + processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. + Only specify when you know what your are doing. """ if processed_nsdfgs is None: # For preventing the case that nested SDFGs are handled multiple time. - # TODO: It certainly happens if a node is input and output, but are there other cases? processed_nsdfgs = set() for in_edge in state.in_edges(outer_node): @@ -225,8 +223,13 @@ def gt_map_strides_to_dst_nested_sdfg( ) -> None: """Propagates the strides of `outer_node` along `edge` along the dataflow. - For more information see the description of `_gt_map_strides_to_nested_sdfg_src_dst(). - However it is recommended to use `gt_propagate_strides_of()` directly. + In this context "along the dataflow" means that `edge` is an outgoing + edge of `outer_node` and the strides are into all NestedSDFGs that + are downstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. Args: sdfg: The SDFG to process. @@ -235,9 +238,10 @@ def gt_map_strides_to_dst_nested_sdfg( outer_node: The data node whose strides should be propagated. ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. - processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when you know what your are doing. """ + assert edge.src is outer_node _gt_map_strides_to_nested_sdfg_src_dst( sdfg=sdfg, state=state, @@ -259,8 +263,13 @@ def gt_map_strides_to_src_nested_sdfg( ) -> None: """Propagates the strides of `outer_node` along `edge` against the dataflow. - For more information see the description of `_gt_map_strides_to_nested_sdfg_src_dst(). - However it is recommended to use `gt_propagate_strides_of()` directly. + In this context "along the dataflow" means that `edge` is an incoming + edge of `outer_node` and the strides are into all NestedSDFGs that + are upstream of `outer_node`. + + Except in certain cases this function should not be used directly. It is + instead recommended to use `gt_propagate_strides_of()`, which propagates + all edges in the SDFG. Args: sdfg: The SDFG to process. @@ -269,7 +278,7 @@ def gt_map_strides_to_src_nested_sdfg( outer_node: The data node whose strides should be propagated. ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. - processed_nsdfgs: Set of Nested SDFG that were already processed. Only specify when + processed_nsdfgs: Set of NestedSDFGs that were already processed. Only specify when you know what your are doing. """ _gt_map_strides_to_nested_sdfg_src_dst( @@ -298,11 +307,11 @@ def _gt_map_strides_to_nested_sdfg_src_dst( `propagate_along_dataflow` and propagate the strides of `outer_node` into every NestedSDFG that is reachable by following `edge`. - When the function encounters a NestedSDFG it will determine the the data - descriptor `outer_node` refers on the inside of the NestedSDFG. + When the function encounters a NestedSDFG it will determine what data + the `outer_node` is mapped to on the inside of the NestedSDFG. It will then replace the stride of the inner descriptor with the ones - of the outside. Afterwards it will recursively propagates the - stride inside the NestedSDFG. + of the outside. Afterwards it will recursively propagate the strides + inside the NestedSDFG. During this propagation the function will follow any edges. If the function reaches a NestedSDFG that is listed inside `processed_nsdfgs` @@ -417,10 +426,9 @@ def next_edges_by_connector( ignore_symbol_mapping=ignore_symbol_mapping, ) - # Because the function call above if not recursive we have now to scan the - # propagate the change into the nested SDFG. Using - # `_gt_find_toplevel_data_accesses()` is a bit overkill, but allows for a - # more uniform processing. + # Since the function call above is not recursive we have now to propagate + # the change into the NestedSDFGs. Using `_gt_find_toplevel_data_accesses()` + # is a bit overkill, but allows for a more uniform processing. # TODO(phimuell): Instead of scanning every level for every data we modify # we should scan the whole SDFG once and then reuse this information. accesses_in_nested_sdfg = _gt_find_toplevel_data_accesses( @@ -429,8 +437,8 @@ def next_edges_by_connector( only_arrays=True, ) for nested_state, nested_access in accesses_in_nested_sdfg.get(inner_data, list()): - # We have to use `gt_propagate_strides_of()` here because we have to - # handle its entirety. We could wait until the other branch processes + # We have to use `gt_propagate_strides_from_access_node()` here because we + # have to handle its entirety. We could wait until the other branch processes # the nested SDFG, but this might not work, so let's do it fully now. gt_propagate_strides_from_access_node( sdfg=nsdfg_node.sdfg, @@ -451,9 +459,9 @@ def _gt_map_strides_into_nested_sdfg( ) -> None: """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. - `inner_data` is the name of of a data descriptor inside the NestedSDFG. - The function will then modify the modify the strides of `inner_data` to - match the ones of `outer_desc`. + `inner_data` is the name of a data descriptor inside the NestedSDFG. + The function will then modify the strides of `inner_data`, assuming this + is an array, to match the ones of `outer_desc`. Args: sdfg: The SDFG containing the NestedSDFG. @@ -471,25 +479,24 @@ def _gt_map_strides_into_nested_sdfg( - Handle explicit dimensions of size 1. """ # We need to compute the new strides. In the following we assume that the - # relative order of the dimension does not change, but some dimensions - # that are present on the outside are not present on the inside. For - # example this happens for the Memlet `a[__i0, 0:__a_size1]`. - # We detect this case by checking if that dimension has size 1. + # relative order of the dimensions does not change, but we support the case + # where some dimensions of the outer data descriptor are not present on the + # inside. For example this happens for the Memlet `a[__i0, 0:__a_size1]`. We + # detect this case by checking if the Memlet subset in that dimension has size 1. # TODO(phimuell): Handle the case were some additional size 1 dimensions are added. inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data] inner_shape = inner_desc.shape inner_strides_init = inner_desc.strides - # TODO(phimuell): For now this is fine, but it should be possisble to allow it. - assert not inner_desc.transient - outer_strides = outer_desc.strides outer_inflow = outer_subset.size() new_strides: list = [] for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): if dim_oinflow == 1: - # Only something flows in, thus there is no stride in this dimension. + # This is the case of implicit slicing along one dimension. The inner + # array descriptor has shape != 1 in `current_inner_dim`, which has + # to map to a subsequent dimension of `outer_inflow` pass else: # There is inflow into the SDFG, so we need the stride. @@ -518,7 +525,7 @@ def _gt_map_strides_into_nested_sdfg( # this is only possible if the current strides are singular symbols, # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` # or literal values. - # The second way would be to replace `strides` attributer of the + # The second way would be to replace `strides` attribute of the # inner data descriptor. In case the new stride consists of expressions # such as `value1 - value2` we have to make them available inside the # NestedSDFG. However, it could be that the strides is used somewhere else. @@ -552,6 +559,7 @@ def _gt_map_strides_into_nested_sdfg( # check if they map to the same value, we just hope it). Furthermore, # we will exclude all symbols that are listed in the `symbols` property # of the SDFG that is nested, and hope that it has the same meaning. + # TODO(phimuell): Add better checks here. missing_symbol_mappings: set[str] = { sym for sym in new_strides_symbols @@ -605,9 +613,8 @@ def _gt_find_toplevel_data_accesses( only_arrays: If `True`, defaults to `False`, only arrays are returned. Returns: - A `dict` that maps the name of a data container, that should be processed - to a list of tuples containing the state where the AccessNode was found - and the node. + A `dict` that maps the name of a data container, to a list of tuples + containing the state where the AccessNode was found and the AccessNode. """ # List of data that is accessed on the top level and all its access node. top_level_data: dict[str, list[tuple[dace.SDFGState, dace_nodes.AccessNode]]] = dict() @@ -637,8 +644,7 @@ def _gt_find_toplevel_data_accesses( continue elif gtx_transformations.util.is_view(dnode, sdfg): - # The AccessNode refers to a View so we ignore it anyway - # TODO(phimuell/edopao): Should the function return them? + # The AccessNode refers to a View so we ignore it anyway. continue # We have found a new data node that is on the top node and is unknown. From 23b0baa530077be44981e800a5b622bc2ae872bc Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 14:18:16 +0100 Subject: [PATCH 22/33] Updated how we get the type of symbols. The type are now extracted from the stuff we get from `free_symbols`. --- .../dace_fieldview/transformations/strides.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index e69d392770..5c501bca24 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -545,43 +545,47 @@ def _gt_map_strides_into_nested_sdfg( inner_desc.set_shape(inner_desc.shape, new_strides) # Now find the free symbols that the new strides need. - new_strides_symbols: list[str] = [] + # Note that usually `free_symbols` returns `set[str]`, but here, because + # we fall back on SymPy, we get back symbols. We will keep them, because + # then we can use them to extract the type form them, which we need later. + new_strides_symbols: list[dace.symbol] = [] for new_stride_dim in new_strides: if dace.symbolic.issymbolic(new_stride_dim): - # NOTE: In DaCe `free_symbols` is `set[str]` but in `sympy` it - # returns `set[symbol]`. We need `str` so we have to cast them. - new_strides_symbols.extend(str(sym) for sym in new_stride_dim.free_symbols) + new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) else: - new_strides_symbols.append(str(new_stride_dim)) + # It is not already a symbol, so we turn it into a symbol. + # However, we only add it, if it is also a symbol, for example `1`. + # should not be added. + new_stride_symbol = dace.symbolic.pystr_to_symbolic(new_stride_dim) + if new_stride_symbol.is_symbol: + new_strides_symbols.append(new_stride_symbol) # Now we determine the set of symbols that should be mapped inside the NestedSDFG. # We will exclude all that are already inside the `symbol_mapping` (we do not # check if they map to the same value, we just hope it). Furthermore, # we will exclude all symbols that are listed in the `symbols` property # of the SDFG that is nested, and hope that it has the same meaning. - # TODO(phimuell): Add better checks here. - missing_symbol_mappings: set[str] = { + # TODO(phimuell): Add better checks to avoid overwriting. + missing_symbol_mappings: set[dace.symbol] = { sym for sym in new_strides_symbols - if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping) + if not (sym.name in nsdfg_node.sdfg.symbols or sym.name in nsdfg_node.symbol_mapping) } # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: if sym in sdfg.symbols: - # TODO(phimuell): Handle the case the symbol is already defined. - nsdfg_node.sdfg.add_symbol(sym, sdfg.symbols[sym]) + # TODO(phimuell): Handle the case the symbol is already defined in + # the nested SDFG. + nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) else: - # The symbol is not known in the parent SDFG, but we need to define a - # symbol and for that we need a `dtype`. Our solution (which is as - # wrong as any other) is to create a symbol with that name and then - # use the type that was deduced. - nsdfg_node.sdfg.add_symbol(sym, dace.symbol(sym).dtype) + # The symbol is not known in the parent SDFG, so we add it + nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) warnings.warn( - f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym]}' as dtype.", + f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym.name]}' as dtype.", stacklevel=1, ) - nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym) + nsdfg_node.symbol_mapping[sym.name] = sym # Now create aliases for the old symbols that were used as strides. for old_sym, new_sym in zip(inner_strides_init, new_strides): From ff058802b5128ddde404e01c909e0ff36f85fefb Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 14:25:10 +0100 Subject: [PATCH 23/33] New restriction on the update of the symbol mapping. --- .../runners/dace_fieldview/transformations/strides.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 5c501bca24..f683737f23 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -575,8 +575,6 @@ def _gt_map_strides_into_nested_sdfg( # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: if sym in sdfg.symbols: - # TODO(phimuell): Handle the case the symbol is already defined in - # the nested SDFG. nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) else: # The symbol is not known in the parent SDFG, so we add it @@ -589,7 +587,7 @@ def _gt_map_strides_into_nested_sdfg( # Now create aliases for the old symbols that were used as strides. for old_sym, new_sym in zip(inner_strides_init, new_strides): - if dace.symbolic.issymbolic(old_sym): + if dace.symbolic.issymbolic(old_sym) and old_sym.is_symbol: nsdfg_node.symbol_mapping[str(old_sym)] = dace.symbolic.pystr_to_symbolic(new_sym) From 43ec33ccff098c7beacf4a9588120a047abd0e44 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 14:47:35 +0100 Subject: [PATCH 24/33] Updated the tests, now also made one that has tests for the symbol mapping branch. --- .../transformation_tests/test_strides.py | 81 ++++++++++++++++--- 1 file changed, 71 insertions(+), 10 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 655e50fb23..45c3ebc739 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -174,12 +174,70 @@ def _make_strides_propagation_level1_sdfg() -> ( return sdfg, nsdfg_level2, nsdfg_level3 -def test_strides_propagation(): - """ - Todo: - - Add a case where `ignore_symbol_mapping=False` can be tested. - - What happens if the stride symbol is used somewhere else? - """ +def test_strides_propagation_use_symbol_mapping(): + # Note that the SDFG we are building here is not really meaningful. + sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() + + # Tests if all strides are distinct in the beginning and match what we expect. + for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: + for aname, adesc in sdfg.arrays.items(): + exp_stride = f"{aname}_stride" + actual_stride = adesc.strides[0] + assert len(adesc.strides) == 1 + assert ( + str(actual_stride) == exp_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + nsdfg = sdfg.parent_nsdfg_node + if nsdfg is not None: + assert exp_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride + + # Now we propagate `a` and `b`, but not `c`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=False) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=False) + sdfg_level1.validate() + + # Because `ignore_symbol_mapping=False` the strides of the data descriptor should + # not have changed. But the `symbol_mapping` has been updated for `a` and `b`. + # However, the symbols will only point one level above. + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + + if aname.startswith("c"): + target_symbol = f"{aname}_stride" + else: + target_symbol = f"{aname[0]}{level - 1}_stride" + + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + # Now we also propagate `c` thus now all data descriptors have the same stride + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=False) + sdfg_level1.validate() + for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1): + for aname, adesc in sdfg.arrays.items(): + nsdfg = sdfg.parent_nsdfg_node + original_stride = f"{aname}_stride" + target_symbol = f"{aname[0]}{level-1}_stride" + if nsdfg is not None: + assert original_stride in nsdfg.symbol_mapping + assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol + assert len(adesc.strides) == 1 + assert ( + str(adesc.strides[0]) == original_stride + ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + + +def test_strides_propagation_ignore_symbol_mapping(): # Note that the SDFG we are building here is not really meaningful. sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg() @@ -201,7 +259,9 @@ def test_strides_propagation(): # Now we propagate `a` and `b`, but not `c`. # TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`. gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() # After the propagation `a` and `b` should use the same stride (the one that # it has on level 1, but `c` should still be level depending. @@ -213,8 +273,8 @@ def test_strides_propagation(): else: exp_stride = f"{aname[0]}1_stride" assert len(adesc.strides) == 1 - assert exp_stride == str( - adesc.strides[0] + assert ( + str(adesc.strides[0]) == exp_stride ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." nsdfg = sdfg.parent_nsdfg_node @@ -224,13 +284,14 @@ def test_strides_propagation(): # Now we also propagate `c` thus now all data descriptors have the same stride gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) + sdfg_level1.validate() for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]: for aname, adesc in sdfg.arrays.items(): exp_stride = f"{aname[0]}1_stride" original_stride = f"{aname}_stride" assert len(adesc.strides) == 1 - assert exp_stride == str( - adesc.strides[0] + assert ( + str(adesc.strides[0]) == exp_stride ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." nsdfg = sdfg.parent_nsdfg_node From d43153a4165878a1cf91e033bf3cbdb5360babea Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 15:15:14 +0100 Subject: [PATCH 25/33] Fixed two bug in the stride propagation function. --- .../runners/dace_fieldview/transformations/strides.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index f683737f23..2cc75e195d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -574,7 +574,7 @@ def _gt_map_strides_into_nested_sdfg( # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: - if sym in sdfg.symbols: + if str(sym) in sdfg.symbols: nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) else: # The symbol is not known in the parent SDFG, so we add it @@ -583,7 +583,7 @@ def _gt_map_strides_into_nested_sdfg( f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym.name]}' as dtype.", stacklevel=1, ) - nsdfg_node.symbol_mapping[sym.name] = sym + nsdfg_node.symbol_mapping[sym.name] = sym # Now create aliases for the old symbols that were used as strides. for old_sym, new_sym in zip(inner_strides_init, new_strides): From 2e82bd5a90e0bb2d1abda35e428b337bf91a7efa Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 15:19:42 +0100 Subject: [PATCH 26/33] Added a test that ensures that the dependent adding works. --- .../transformation_tests/test_strides.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 45c3ebc739..17874a3450 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -298,3 +298,108 @@ def test_strides_propagation_ignore_symbol_mapping(): if nsdfg is not None: assert original_stride in nsdfg.symbol_mapping assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride + + +def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("nested_sdfg")) + state = sdfg.add_state(is_start_block=True) + + array_names = ["a2", "b2"] + for name in array_names: + stride_sym = dace.symbol(f"{name}_stride", dtype=dace.uint64) + sdfg.add_symbol(stride_sym.name, stride_sym.dtype) + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("a2[__i0]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("nested_level")) + state = sdfg_level1.add_state(is_start_block=True) + + array_names = ["a1", "b1"] + for name in array_names: + stride_sym1 = dace.symbol(f"{name}_1stride", dtype=dace.uint64) + stride_sym2 = dace.symbol(f"{name}_2stride", dtype=dace.int64) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + sdfg_level1.add_symbol(stride_sym2.name, stride_sym2.dtype) + stride_sym = stride_sym1 * stride_sym2 + sdfg_level1.add_array( + name, + shape=(10,), + dtype=dace.float64, + strides=(stride_sym,), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_dependent_symbol_nsdfg() + + for sym, sym_dtype in sdfg_level2.symbols.items(): + sdfg_level1.add_symbol(sym, sym_dtype) + + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_dependent_symbol(): + sdfg_level1, nsdfg_level2 = _make_strides_propagation_dependent_symbol_sdfg() + sym1_dtype = dace.uint64 + sym2_dtype = dace.int64 + + # Ensure that the special symbols are not already present inside the nested SDFG. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in {fs.name for fs in adesc.strides[0].free_symbols} + assert sym not in nsdfg_level2.symbol_mapping + assert sym not in nsdfg_level2.sdfg.symbols + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + + # Now propagate `a1` and `b1`. + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True) + sdfg_level1.validate() + gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True) + sdfg_level1.validate() + + # Now we check if the update has worked. + for aname, adesc in sdfg_level1.arrays.items(): + sym1 = f"{aname}_1stride" + sym2 = f"{aname}_2stride" + adesc2 = nsdfg_level2.sdfg.arrays[aname.replace("1", "2")] + assert adesc2.strides == adesc.strides + + for sym, dtype in [(sym1, sym1_dtype), (sym2, sym2_dtype)]: + assert sym in nsdfg_level2.symbol_mapping + assert nsdfg_level2.symbol_mapping[sym].name == sym + assert sym in sdfg_level1.symbols + assert sdfg_level1.symbols[sym] == dtype + assert sym in nsdfg_level2.sdfg.symbols + assert nsdfg_level2.sdfg.symbols[sym] == dtype From 07e6a5cd61c17b4039c0a6d3ce0d3003fbed8f9c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 15:23:13 +0100 Subject: [PATCH 27/33] Changed the default of `ignore_symbol_mapping` to `True`. --- .../dace_fieldview/transformations/strides.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 2cc75e195d..bf298c0164 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -41,6 +41,11 @@ def gt_change_transient_strides( transients in the optimal way. The function should run after all maps have been created. + After the strides have been adjusted the function will also propagate + the strides into nested SDFG. This propagation will happen with + `ignore_symbol_mapping` set to `True`, see `gt_propagate_strides_of()` + for more. + Args: sdfg: The SDFG to process. gpu: If the SDFG is supposed to run on the GPU. @@ -123,13 +128,14 @@ def _gt_change_transient_strides_non_recursive_impl( state=state, outer_node=access_node, processed_nsdfgs=processed_nsdfgs, + ignore_symbol_mapping=True, ) def gt_propagate_strides_of( sdfg: dace.SDFG, data_name: str, - ignore_symbol_mapping: bool = False, + ignore_symbol_mapping: bool = True, ) -> None: """Propagates the strides of `data_name` within the whole SDFG. @@ -140,7 +146,7 @@ def gt_propagate_strides_of( Args: sdfg: The SDFG on which we operate. data_name: Name of the data descriptor that should be handled. - ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + ignore_symbol_mapping: If `False` (default is `True`) try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. """ @@ -164,7 +170,7 @@ def gt_propagate_strides_from_access_node( sdfg: dace.SDFG, state: dace.SDFGState, outer_node: dace_nodes.AccessNode, - ignore_symbol_mapping: bool = False, + ignore_symbol_mapping: bool = True, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the stride of `outer_node` to any adjacent reachable through its edges. @@ -184,7 +190,7 @@ def gt_propagate_strides_from_access_node( state: The state where the data node is used. edge: The edge that reads from the data node, the nested SDFG is expected as the destination. outer_node: The data node whose strides should be propagated. - ignore_symbol_mapping: If `False`, the default, try to modify the `symbol_mapping` + ignore_symbol_mapping: If `False` (default is `True`), try to modify the `symbol_mapping` of NestedSDFGs instead of manipulating the data descriptor. processed_nsdfgs: Set of NestedSDFG that were already processed and will be ignored. Only specify when you know what your are doing. From 4bf145b7d63ca6c98e94cb0d02f6bebed4246690 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 19 Dec 2024 15:40:19 +0100 Subject: [PATCH 28/33] Added Edoardo's comments. --- .../dace_fieldview/transformations/strides.py | 33 +++++++------------ 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index bf298c0164..7854cbea12 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import warnings from typing import Optional, TypeAlias import dace @@ -173,7 +172,7 @@ def gt_propagate_strides_from_access_node( ignore_symbol_mapping: bool = True, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: - """Propagates the stride of `outer_node` to any adjacent reachable through its edges. + """Propagates the stride of `outer_node` to any adjacent NestedSDFG. The function will propagate the strides of the data descriptor `outer_node` refers to along all adjacent edges of `outer_node`. If one of these edges @@ -227,10 +226,10 @@ def gt_map_strides_to_dst_nested_sdfg( ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: - """Propagates the strides of `outer_node` along `edge` along the dataflow. + """Propagates the strides of `outer_node` along `edge` in the dataflow direction. - In this context "along the dataflow" means that `edge` is an outgoing - edge of `outer_node` and the strides are into all NestedSDFGs that + In this context "along the dataflow direction" means that `edge` is an outgoing + edge of `outer_node` and the strides are propagated into all NestedSDFGs that are downstream of `outer_node`. Except in certain cases this function should not be used directly. It is @@ -267,11 +266,11 @@ def gt_map_strides_to_src_nested_sdfg( ignore_symbol_mapping: bool = False, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: - """Propagates the strides of `outer_node` along `edge` against the dataflow. + """Propagates the strides of `outer_node` along `edge` in the opposite direction of the dataflow - In this context "along the dataflow" means that `edge` is an incoming - edge of `outer_node` and the strides are into all NestedSDFGs that - are upstream of `outer_node`. + In this context "in the opposite direction of the dataflow" means that `edge` + is an incoming edge of `outer_node` and the strides are propagated into all + NestedSDFGs that are upstream of `outer_node`. Except in certain cases this function should not be used directly. It is instead recommended to use `gt_propagate_strides_of()`, which propagates @@ -500,13 +499,10 @@ def _gt_map_strides_into_nested_sdfg( new_strides: list = [] for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True): if dim_oinflow == 1: - # This is the case of implicit slicing along one dimension. The inner - # array descriptor has shape != 1 in `current_inner_dim`, which has - # to map to a subsequent dimension of `outer_inflow` + # This is the case of implicit slicing along one dimension. pass else: # There is inflow into the SDFG, so we need the stride. - assert dim_oinflow != 0 new_strides.append(dim_ostride) assert len(new_strides) <= len(inner_shape) @@ -580,15 +576,8 @@ def _gt_map_strides_into_nested_sdfg( # Now propagate the symbols from the parent SDFG to the NestedSDFG. for sym in missing_symbol_mappings: - if str(sym) in sdfg.symbols: - nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) - else: - # The symbol is not known in the parent SDFG, so we add it - nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype) - warnings.warn( - f"Could not find the symbol '{sym}' in the parent SDFG while modifying the strides, use '{nsdfg_node.sdfg.symbols[sym.name]}' as dtype.", - stacklevel=1, - ) + assert sym.name in sdfg.symbols, f"Expected that '{sym}' is defined in the parent SDFG." + nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) nsdfg_node.symbol_mapping[sym.name] = sym # Now create aliases for the old symbols that were used as strides. From 2b03bb4799ff092a68aecea820074813779cfc17 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 07:55:32 +0100 Subject: [PATCH 29/33] Removed the creation of aliasing if symbol tables are ignored. I realized that allowing this is not very safe. I also added a test to show that. --- .../dace_fieldview/transformations/strides.py | 16 ++++++++-------- .../transformation_tests/test_strides.py | 5 +++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 7854cbea12..06dfe6626c 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -223,7 +223,7 @@ def gt_map_strides_to_dst_nested_sdfg( state: dace.SDFGState, edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode, - ignore_symbol_mapping: bool = False, + ignore_symbol_mapping: bool = True, processed_nsdfgs: Optional[set[PropagatedStrideRecord]] = None, ) -> None: """Propagates the strides of `outer_node` along `edge` in the dataflow direction. @@ -460,7 +460,7 @@ def _gt_map_strides_into_nested_sdfg( inner_data: str, outer_subset: dace.subsets.Subset, outer_desc: dace_data.Data, - ignore_symbol_mapping: bool = False, + ignore_symbol_mapping: bool = True, ) -> None: """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. @@ -479,9 +479,12 @@ def _gt_map_strides_into_nested_sdfg( ignore_symbol_mapping: If possible the function will perform the renaming through the `symbol_mapping` of the nested SDFG. If `True` then the function will always perform the renaming. + Note that setting this value to `False` might have negative side effects. Todo: - Handle explicit dimensions of size 1. + - What should we do if the stride symbol is used somewhere else, creating an + alias is probably not the right thing? """ # We need to compute the new strides. In the following we assume that the # relative order of the dimensions does not change, but we support the case @@ -526,7 +529,9 @@ def _gt_map_strides_into_nested_sdfg( # The first is to create an alias in the `symbol_mapping`, however, # this is only possible if the current strides are singular symbols, # like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start` - # or literal values. + # or literal values. Furthermore, this would change the meaning of the + # old stride symbol in any context and not only in the one of the stride + # of a single and isolated data descriptor. # The second way would be to replace `strides` attribute of the # inner data descriptor. In case the new stride consists of expressions # such as `value1 - value2` we have to make them available inside the @@ -580,11 +585,6 @@ def _gt_map_strides_into_nested_sdfg( nsdfg_node.sdfg.add_symbol(sym.name, sdfg.symbols[sym.name]) nsdfg_node.symbol_mapping[sym.name] = sym - # Now create aliases for the old symbols that were used as strides. - for old_sym, new_sym in zip(inner_strides_init, new_strides): - if dace.symbolic.issymbolic(old_sym) and old_sym.is_symbol: - nsdfg_node.symbol_mapping[str(old_sym)] = dace.symbolic.pystr_to_symbolic(new_sym) - def _gt_find_toplevel_data_accesses( sdfg: dace.SDFG, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 17874a3450..22d1b16b39 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -280,7 +280,7 @@ def test_strides_propagation_ignore_symbol_mapping(): nsdfg = sdfg.parent_nsdfg_node if nsdfg is not None: assert original_stride in nsdfg.symbol_mapping - assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride # Now we also propagate `c` thus now all data descriptors have the same stride gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True) @@ -296,8 +296,9 @@ def test_strides_propagation_ignore_symbol_mapping(): nsdfg = sdfg.parent_nsdfg_node if nsdfg is not None: + # The symbol mapping must should not be updated. assert original_stride in nsdfg.symbol_mapping - assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride + assert str(nsdfg.symbol_mapping[original_stride]) == original_stride def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: From 40c225d6e601c7fee7c612da620bc0b37d15895b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 08:21:24 +0100 Subject: [PATCH 30/33] Added a test that shows that `ignore_symbol_mapping=False` does produces errors in certain cases. --- .../transformation_tests/test_strides.py | 129 +++++++++++++++++- 1 file changed, 127 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 22d1b16b39..6d6a36028a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -7,6 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause import pytest +import numpy as np +import copy dace = pytest.importorskip("dace") from dace import symbolic as dace_symbolic @@ -302,7 +304,7 @@ def test_strides_propagation_ignore_symbol_mapping(): def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("nested_sdfg")) + sdfg = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_nsdfg")) state = sdfg.add_state(is_start_block=True) array_names = ["a2", "b2"] @@ -330,7 +332,7 @@ def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: - sdfg_level1 = dace.SDFG(util.unique_name("nested_level")) + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_sdfg")) state = sdfg_level1.add_state(is_start_block=True) array_names = ["a1", "b1"] @@ -404,3 +406,126 @@ def test_strides_propagation_dependent_symbol(): assert sdfg_level1.symbols[sym] == dtype assert sym in nsdfg_level2.sdfg.symbols assert nsdfg_level2.sdfg.symbols[sym] == dtype + + +def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_nsdfg")) + state = sdfg.add_state(is_start_block=True) + + # NOTE: Both arrays have the same symbols used for strides. + array_names = ["a2", "b2"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=(stride_sym0, stride_sym1), + transient=False, + ) + + state.add_mapped_tasklet( + "nested_comp", + map_ranges={ + "__i0": "0:10", + "__i1": "0:10", + }, + inputs={"__in1": dace.Memlet("a2[__i0, __i1]")}, + code="__out = __in1 + 10.", + outputs={"__out": dace.Memlet("b2[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: + sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_sdfg")) + state = sdfg_level1.add_state(is_start_block=True) + + # NOTE: Both arrays use the same symbols as strides. + # Furthermore, they are the same as in the nested SDFG, i.e. they are shared. + array_names = ["a1", "b1"] + stride_sym0 = dace.symbol(f"__stride_0", dtype=dace.uint64) + stride_sym1 = dace.symbol(f"__stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_sym0.name, stride_sym0.dtype) + sdfg_level1.add_symbol(stride_sym1.name, stride_sym1.dtype) + for name in array_names: + sdfg_level1.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + strides=( + stride_sym0, + stride_sym1, + ), + transient=False, + ) + + sdfg_level2 = _make_strides_propagation_shared_symbols_nsdfg() + nsdfg = state.add_nested_sdfg( + sdfg=sdfg_level2, + parent=sdfg_level1, + inputs={"a2"}, + outputs={"b2"}, + symbol_mapping={s: s for s in sdfg_level2.symbols}, + ) + + state.add_edge(state.add_access("a1"), None, nsdfg, "a2", dace.Memlet("a1[0:10, 0:10]")) + state.add_edge(nsdfg, "b2", state.add_access("b1"), None, dace.Memlet("b1[0:10, 0:10]")) + sdfg_level1.validate() + + return sdfg_level1, nsdfg + + +def test_strides_propagation_shared_symbols_sdfg(): + """ + Note: + If `ignore_symbol_mapping` is `False` then this test will fail. + This is because the `symbol_mapping` of the NestedSDFG will act on the + whole SDFG. Thus it will not only change the strides of `b` but as an + unintended side effect also the strides of `a`. + """ + + def ref(a1, b1): + for i in range(10): + for j in range(10): + b1[i, j] = a1[i, j] + 10.0 + + sdfg_level1, nsdfg_level2 = _make_strides_propagation_shared_symbols_sdfg() + + res_args = { + "a1": np.array(np.random.rand(10, 10), order="C", dtype=np.float64, copy=True), + "b1": np.array(np.random.rand(10, 10), order="F", dtype=np.float64, copy=True), + } + ref_args = copy.deepcopy(res_args) + + # Now we change the strides of `b1`, and then we propagate the new strides + # into the nested SDFG. We want to keep (for whatever reasons) strides of `a1`. + stride_b1_sym0 = dace.symbol(f"__b1_stride_0", dtype=dace.uint64) + stride_b1_sym1 = dace.symbol(f"__b1_stride_1", dtype=dace.uint64) + sdfg_level1.add_symbol(stride_b1_sym0.name, stride_b1_sym0.dtype) + sdfg_level1.add_symbol(stride_b1_sym1.name, stride_b1_sym1.dtype) + + desc_b1 = sdfg_level1.arrays["b1"] + desc_b1.set_shape((10, 10), (stride_b1_sym0, stride_b1_sym1)) + + # Now we propagate the data into it. + gtx_transformations.gt_propagate_strides_of(sdfg=sdfg_level1, data_name="b1") + + # Now we have to prepare the call arguments, i.e. adding the strides + itemsize = res_args["b1"].itemsize + res_args.update( + { + "__b1_stride_0": res_args["b1"].strides[0] // itemsize, + "__b1_stride_1": res_args["b1"].strides[1] // itemsize, + "__stride_0": res_args["a1"].strides[0] // itemsize, + "__stride_1": res_args["a1"].strides[1] // itemsize, + } + ) + ref(**ref_args) + sdfg_level1(**res_args) + assert np.allclose(ref_args["b1"], res_args["b1"]) From 419a386722a685316ee5917e9c8d8e44905e153b Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 08:44:46 +0100 Subject: [PATCH 31/33] Updated the description. --- .../transformation_tests/test_strides.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index 6d6a36028a..5b16e41bc3 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -482,7 +482,14 @@ def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nod def test_strides_propagation_shared_symbols_sdfg(): - """ + """Tests what happens if symbols are (unintentionally) shred between descriptor. + + This test looks rather artificial, but it is actually quite likely. Because + transients will most likely have the same shape and if the strides are not + set explicitly, which is the case, the strides will also be related to their + shape. This test explores the situation, where we can, for whatever reason, + only propagate the strides of one such data descriptor. + Note: If `ignore_symbol_mapping` is `False` then this test will fail. This is because the `symbol_mapping` of the NestedSDFG will act on the @@ -514,7 +521,10 @@ def ref(a1, b1): desc_b1.set_shape((10, 10), (stride_b1_sym0, stride_b1_sym1)) # Now we propagate the data into it. - gtx_transformations.gt_propagate_strides_of(sdfg=sdfg_level1, data_name="b1") + gtx_transformations.gt_propagate_strides_of( + sdfg=sdfg_level1, + data_name="b1", + ) # Now we have to prepare the call arguments, i.e. adding the strides itemsize = res_args["b1"].itemsize From cc9801b7364e91d782772b7b4bbb949857c03ec3 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 08:46:38 +0100 Subject: [PATCH 32/33] Applied Edoardo's comment. --- .../runners/dace_fieldview/transformations/strides.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index 06dfe6626c..aa9d55b5f6 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -460,7 +460,7 @@ def _gt_map_strides_into_nested_sdfg( inner_data: str, outer_subset: dace.subsets.Subset, outer_desc: dace_data.Data, - ignore_symbol_mapping: bool = True, + ignore_symbol_mapping: bool, ) -> None: """Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`. From 360baae7f3b21521b8d55dec3f4f4e122501b5e1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 20 Dec 2024 09:16:33 +0100 Subject: [PATCH 33/33] Added a todo from Edoardo's suggestions. --- .../runners/dace_fieldview/transformations/strides.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py index aa9d55b5f6..980b2a8fdf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -485,6 +485,8 @@ def _gt_map_strides_into_nested_sdfg( - Handle explicit dimensions of size 1. - What should we do if the stride symbol is used somewhere else, creating an alias is probably not the right thing? + - Handle the case if the outer stride symbol is already used in another + context inside the Neste SDFG. """ # We need to compute the new strides. In the following we assume that the # relative order of the dimensions does not change, but we support the case