Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[dace][next]: Fixing strides in optimization #1782

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d9218b6
This are the changes Edoardo implemented to fix some issues in the op…
edopao Dec 17, 2024
9d7e722
First rework.
philip-paul-mueller Dec 18, 2024
1ddd6fe
Updated some commenst.
philip-paul-mueller Dec 18, 2024
95e0007
I want to ignore register, not only consider them.
philip-paul-mueller Dec 18, 2024
f1b7a3f
There was a missing `not` in the check.
philip-paul-mueller Dec 18, 2024
50ad620
Had to update the propagation, to also handle aliasing.
philip-paul-mueller Dec 18, 2024
983022c
In the function for looking for top level accesses the `only_transien…
philip-paul-mueller Dec 18, 2024
e7b1afb
Small reminder of the future.
philip-paul-mueller Dec 18, 2024
df7bd0c
Forgot to export the new SDFG stuff.
philip-paul-mueller Dec 18, 2024
363ab59
Had to update function for actuall renaming of the strides.
philip-paul-mueller Dec 18, 2024
9c19d32
Added a todo to the replacement function.
philip-paul-mueller Dec 18, 2024
9cad1f7
Added a first test to the propagation function.
philip-paul-mueller Dec 18, 2024
2700f53
Modified the function that performs the actuall modification of the s…
philip-paul-mueller Dec 19, 2024
a20d3c0
Updated some tes, but more are missing.
philip-paul-mueller Dec 19, 2024
b5ff462
Subset caching strikes again.
philip-paul-mueller Dec 19, 2024
d326d3b
It seems that the explicit handling of one dimensions is not working.
philip-paul-mueller Dec 19, 2024
252f348
The test must be moved bellow.
philip-paul-mueller Dec 19, 2024
49f8172
The symbol is also needed to be present in the nested SDFG.
philip-paul-mueller Dec 19, 2024
2d6dfc0
Fixed a bug in determining the free symbols that we need.
philip-paul-mueller Dec 19, 2024
6124c6d
Updated the propagation code for the symbols.
philip-paul-mueller Dec 19, 2024
45bcf97
Addressed Edoardo's changes.
philip-paul-mueller Dec 19, 2024
23b0baa
Updated how we get the type of symbols.
philip-paul-mueller Dec 19, 2024
ff05880
New restriction on the update of the symbol mapping.
philip-paul-mueller Dec 19, 2024
43ec33c
Updated the tests, now also made one that has tests for the symbol ma…
philip-paul-mueller Dec 19, 2024
d43153a
Fixed two bug in the stride propagation function.
philip-paul-mueller Dec 19, 2024
2e82bd5
Added a test that ensures that the dependent adding works.
philip-paul-mueller Dec 19, 2024
07e6a5c
Changed the default of `ignore_symbol_mapping` to `True`.
philip-paul-mueller Dec 19, 2024
4bf145b
Added Edoardo's comments.
philip-paul-mueller Dec 19, 2024
2b03bb4
Removed the creation of aliasing if symbol tables are ignored.
philip-paul-mueller Dec 20, 2024
40c225d
Added a test that shows that `ignore_symbol_mapping=False` does produ…
philip-paul-mueller Dec 20, 2024
419a386
Updated the description.
philip-paul-mueller Dec 20, 2024
cc9801b
Applied Edoardo's comment.
philip-paul-mueller Dec 20, 2024
360baae
Added a todo from Edoardo's suggestions.
philip-paul-mueller Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import functools
from typing import Iterable, Optional, TypeAlias
from typing import Optional, TypeAlias

import dace
from dace import data as dace_data
Expand Down Expand Up @@ -346,6 +345,11 @@ def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_node
def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str:
return edge.dst_conn

def get_subset(
edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet],
) -> dace.subsets.Subset:
return edge.data.src_subset

def next_edges_by_connector(
state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]
) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]:
Expand All @@ -363,6 +367,11 @@ def get_node(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> dace_node
def get_inner_data(edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]) -> str:
return edge.src_conn

def get_subset(
edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet],
) -> dace.subsets.Subset:
return edge.data.dst_subset

def next_edges_by_connector(
state: dace.SDFGState, edge: dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]
) -> list[dace.sdfg.graph.MultiConnectorEdge[dace.Memlet]]:
Expand Down Expand Up @@ -394,11 +403,11 @@ def next_edges_by_connector(

# Now set the stride of the data descriptor inside the nested SDFG to
# the ones it has outside.
_gt_map_strides_to_nested_sdfg(
_gt_map_strides_into_nested_sdfg(
nsdfg_node=nsdfg_node,
inner_data=inner_data,
edge_data=edge.data,
outer_strides=outer_node.desc(sdfg).strides,
outer_subset=get_subset(edge),
outer_desc=outer_node.desc(sdfg),
ignore_symbol_mapping=ignore_symbol_mapping,
)

Expand Down Expand Up @@ -426,59 +435,137 @@ def next_edges_by_connector(
)


def _gt_map_strides_to_nested_sdfg(
def _gt_map_strides_into_nested_sdfg(
nsdfg_node: dace.nodes.NestedSDFG,
inner_data: str,
edge_data: dace.Memlet,
outer_strides: Iterable[int | dace.symbolic.SymExpr],
outer_subset: dace.subsets.Subset,
outer_desc: dace_data.Data,
ignore_symbol_mapping: bool = False,
edopao marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
"""Modify the strides of `inner_data` inside `nsdfg_node` to match `outer_desc`.

`inner_data` is the name of of a data descriptor inside the NestedSDFG.
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
The function will then modify the modify the strides of `inner_data` to
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
match the ones of `outer_desc`.

Args:
nsdfg_node: The node in the parent SDFG that contains the NestedSDFG.
inner_data: The name of the data descriptor that should be processed
inside the NestedSDFG (by construction also a connector name).
outer_subset: The subset that describes what part of the outer data is
mapped into the NestedSDFG.
outer_desc: The data descriptor of the data on the outside.
ignore_symbol_mapping: If possible the function will perform the renaming
through the `symbol_mapping` of the nested SDFG. If `True` then
the function will always perform the renaming.

Todo:
- Refactor this function.
- Handle the case the stride is used somewhere else.
- Handle the case where we have an explicit size 1 dimension in slicing.
- Handle explicit dimensions of size 1.
"""
# We need to propagate the strides inside the nested SDFG on the global arrays
new_strides = tuple(
stride
for stride, to_map_size in zip(
outer_strides,
edge_data.subset.size(),
strict=True,
)
if to_map_size != 1
)
inner_desc = nsdfg_node.sdfg.arrays[inner_data]
# We need to compute the new strides. In the following we assume that the
# relative order of the dimension does not change, but some dimensions
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
# that are present on the outside are not present on the inside. For
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
# example this happens for the Memlet `a[__i0, 0:__a_size1]`.
# We detect this case by checking if that dimension has size 1.
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
# TODO(phimuell): Handle the case were some additional size 1 dimensions are added.
inner_desc: dace_data.Data = nsdfg_node.sdfg.arrays[inner_data]
inner_shape = inner_desc.shape
inner_strides_init = inner_desc.strides

# TODO(phimuell): For now this is fine, but it should be possisble to allow it.
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
assert not inner_desc.transient

if isinstance(inner_desc, dace.data.Scalar):
assert len(new_strides) == 0
outer_strides = outer_desc.strides
outer_inflow = outer_subset.size()

new_strides: list = []
for dim_ostride, dim_oinflow in zip(outer_strides, outer_inflow, strict=True):
current_inner_dim = len(new_strides)

if inner_shape[current_inner_dim] == 1 and dim_oinflow == 1:
# There is an explicit size 1 dimension. Because the only valid
# index for this dimension is `0` we can use any value here.
# To give the compiler more information we explicitly use `0`,
# instead of the outer value.
new_strides.append(0)

elif dim_oinflow == 1:
# Only something flows in, thus there is no stride in this dimension.
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
pass

else:
# There is inflow into the SDFG, so we need the stride.
assert dim_oinflow != 0
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
new_strides.append(dim_ostride)
assert len(new_strides) <= len(inner_shape)

if len(new_strides) != len(inner_shape):
raise ValueError("Failed to compute the inner strides.")

# If we have a scalar on the inside, then there is nothing to adjust.
# We could have performed the test above, but doing it here, gives us
# the chance of validating it.
if isinstance(inner_desc, dace_data.Scalar):
if len(new_strides) != 0:
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Dimensional error for '{inner_data}' in '{nsdfg_node.label}'.")
return

assert isinstance(inner_desc, dace.data.Array)
if not isinstance(inner_desc, dace_data.Array):
raise TypeError(
f"Expected that '{inner_data}' is an 'Array' but it is '{type(inner_desc).__name__}'."
)

# Now we actually replace the strides, there are two ways of doing it.
# The first is to create an alias in the `symbol_mapping`, however,
# this is only possible if the current strides are singular symbols,
# like `__a_strides_1`, but not expressions such as `horizontal_end - horizontal_start`
# or literal values.
# The second way would be to replace `strides` attributer of the
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
# inner data descriptor. In case the new stride consists of expressions
# such as `value1 - value2` we have to make them available inside the
# NestedSDFG. However, it could be that the strides is used somewhere else.
# We will do the following, if `ignore_symbol_mapping` is `False` and
# the strides of the inner descriptors are symbols, we will use the
# symbol mapping. Otherwise, we will replace the `strides` attribute
# of the inner descriptor, in addition we will install a remapping,
# for those values that were a symbol.
if (not ignore_symbol_mapping) and all(
isinstance(inner_stride, dace.symbol) for inner_stride in inner_desc.strides
isinstance(inner_stride, dace.symbol) for inner_stride in inner_strides_init
):
# Use the symbol
for inner_stride, outer_stride in zip(inner_desc.strides, new_strides, strict=True):
nsdfg_node.symbol_mapping[inner_stride.name] = outer_stride
else:
assert len(inner_desc.shape) == len(new_strides)
# We have to replace the `strides` attribute of the inner descriptor.
inner_desc.set_shape(inner_desc.shape, new_strides)

new_strides_symbols: list[dace.symbol] = functools.reduce(
lambda acc, itm: (acc + list(itm.free_symbols)) # type: ignore[union-attr]
if dace.symbolic.issymbolic(itm)
else acc,
new_strides,
[],
)
new_strides_free_symbols = {
sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols
# Now find the free symbols that the new strides need.
new_strides_symbols: list[str] = []
for new_stride_dim in new_strides:
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
if dace.symbolic.issymbolic(new_stride_dim):
new_strides_symbols.append(str(new_stride_dim))
else:
edopao marked this conversation as resolved.
Show resolved Hide resolved
new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols)

# Now we determine the set of symbols that should be mapped inside the NestedSDFG.
# We will exclude all that are already inside the `symbol_mapping` (we do not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to have a separate check to ensure that the symbols already inside symbol_mapping map to equivalent expressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.
However, you could apply this check only to the things that is defined in the symbol mapping and there it might be hard to do in general.
Furthermore, you could have interstate assignments.

I added a Todo, if you have concrete ideas let me know.

# check if they map to the same value, we just hope it). Furthermore,
# we will exclude all symbols that are listed in the `symbols` property
# of the SDFG that is nested, and hope that it has the same meaning.
missing_symbol_mappings: set[str] = {
sym
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
for sym in new_strides_symbols
if not (sym in nsdfg_node.sdfg.symbols or sym in nsdfg_node.symbol_mapping)
}
for sym in new_strides_free_symbols:
nsdfg_node.sdfg.add_symbol(sym.name, sym.dtype)
nsdfg_node.symbol_mapping[sym.name] = sym
for sym in missing_symbol_mappings:
# We can not create symbols in the nested SDFG, because we do not have
# the type of the symbols.
nsdfg_node.symbol_mapping[sym] = dace.symbolic.pystr_to_symbolic(sym)
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved

# Now create aliases for the old symbols that were used as strides.
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved
for old_sym, new_sym in zip(inner_strides_init, new_strides):
if dace.symbolic.issymbolic(old_sym):
nsdfg_node.symbol_mapping[str(old_sym)] = dace.symbolic.pystr_to_symbolic(new_sym)


def _gt_find_toplevel_data_accesses(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,17 @@ def test_strides_propagation():
for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]:
for aname, adesc in sdfg.arrays.items():
exp_stride = f"{aname}_stride"
actual_stride = adesc.strides[0]
assert len(adesc.strides) == 1
assert exp_stride == str(
adesc.strides[0]
assert (
str(actual_stride) == exp_stride
), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'."

nsdfg = sdfg.parent_nsdfg_node
if nsdfg is not None:
assert exp_stride in nsdfg.symbol_mapping
assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride

# Now we propagate `a` and `b`, but not `c`.
# TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`.
gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True)
Expand All @@ -201,6 +207,7 @@ def test_strides_propagation():
# it has on level 1, but `c` should still be level depending.
for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]:
for aname, adesc in sdfg.arrays.items():
original_stride = f"{aname}_stride"
if aname.startswith("c"):
exp_stride = f"{aname}_stride"
else:
Expand All @@ -210,12 +217,23 @@ def test_strides_propagation():
adesc.strides[0]
), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'."

nsdfg = sdfg.parent_nsdfg_node
if nsdfg is not None:
assert original_stride in nsdfg.symbol_mapping
assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride

# Now we also propagate `c` thus now all data descriptors have the same stride
gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True)
for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]:
for aname, adesc in sdfg.arrays.items():
exp_stride = f"{aname[0]}1_stride"
original_stride = f"{aname}_stride"
assert len(adesc.strides) == 1
assert exp_stride == str(
adesc.strides[0]
), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'."

nsdfg = sdfg.parent_nsdfg_node
if nsdfg is not None:
assert original_stride in nsdfg.symbol_mapping
assert str(nsdfg.symbol_mapping[original_stride]) == exp_stride
Loading