Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[dace][next]: Fixing strides in optimization #1782

Merged
Merged
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d9218b6
This are the changes Edoardo implemented to fix some issues in the op…
edopao Dec 17, 2024
9d7e722
First rework.
philip-paul-mueller Dec 18, 2024
1ddd6fe
Updated some commenst.
philip-paul-mueller Dec 18, 2024
95e0007
I want to ignore register, not only consider them.
philip-paul-mueller Dec 18, 2024
f1b7a3f
There was a missing `not` in the check.
philip-paul-mueller Dec 18, 2024
50ad620
Had to update the propagation, to also handle aliasing.
philip-paul-mueller Dec 18, 2024
983022c
In the function for looking for top level accesses the `only_transien…
philip-paul-mueller Dec 18, 2024
e7b1afb
Small reminder of the future.
philip-paul-mueller Dec 18, 2024
df7bd0c
Forgot to export the new SDFG stuff.
philip-paul-mueller Dec 18, 2024
363ab59
Had to update function for actuall renaming of the strides.
philip-paul-mueller Dec 18, 2024
9c19d32
Added a todo to the replacement function.
philip-paul-mueller Dec 18, 2024
9cad1f7
Added a first test to the propagation function.
philip-paul-mueller Dec 18, 2024
2700f53
Modified the function that performs the actuall modification of the s…
philip-paul-mueller Dec 19, 2024
a20d3c0
Updated some tes, but more are missing.
philip-paul-mueller Dec 19, 2024
b5ff462
Subset caching strikes again.
philip-paul-mueller Dec 19, 2024
d326d3b
It seems that the explicit handling of one dimensions is not working.
philip-paul-mueller Dec 19, 2024
252f348
The test must be moved bellow.
philip-paul-mueller Dec 19, 2024
49f8172
The symbol is also needed to be present in the nested SDFG.
philip-paul-mueller Dec 19, 2024
2d6dfc0
Fixed a bug in determining the free symbols that we need.
philip-paul-mueller Dec 19, 2024
6124c6d
Updated the propagation code for the symbols.
philip-paul-mueller Dec 19, 2024
45bcf97
Addressed Edoardo's changes.
philip-paul-mueller Dec 19, 2024
23b0baa
Updated how we get the type of symbols.
philip-paul-mueller Dec 19, 2024
ff05880
New restriction on the update of the symbol mapping.
philip-paul-mueller Dec 19, 2024
43ec33c
Updated the tests, now also made one that has tests for the symbol ma…
philip-paul-mueller Dec 19, 2024
d43153a
Fixed two bug in the stride propagation function.
philip-paul-mueller Dec 19, 2024
2e82bd5
Added a test that ensures that the dependent adding works.
philip-paul-mueller Dec 19, 2024
07e6a5c
Changed the default of `ignore_symbol_mapping` to `True`.
philip-paul-mueller Dec 19, 2024
4bf145b
Added Edoardo's comments.
philip-paul-mueller Dec 19, 2024
2b03bb4
Removed the creation of aliasing if symbol tables are ignored.
philip-paul-mueller Dec 20, 2024
40c225d
Added a test that shows that `ignore_symbol_mapping=False` does produ…
philip-paul-mueller Dec 20, 2024
419a386
Updated the description.
philip-paul-mueller Dec 20, 2024
cc9801b
Applied Edoardo's comment.
philip-paul-mueller Dec 20, 2024
360baae
Added a todo from Edoardo's suggestions.
philip-paul-mueller Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Updated the tests, now also made one that has tests for the symbol ma…
…pping branch.
philip-paul-mueller committed Dec 19, 2024
commit 43ec33ccff098c7beacf4a9588120a047abd0e44
Original file line number Diff line number Diff line change
@@ -174,12 +174,70 @@ def _make_strides_propagation_level1_sdfg() -> (
return sdfg, nsdfg_level2, nsdfg_level3


def test_strides_propagation():
"""
Todo:
- Add a case where `ignore_symbol_mapping=False` can be tested.
- What happens if the stride symbol is used somewhere else?
"""
def test_strides_propagation_use_symbol_mapping():
# Note that the SDFG we are building here is not really meaningful.
sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg()

# Tests if all strides are distinct in the beginning and match what we expect.
for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]:
for aname, adesc in sdfg.arrays.items():
exp_stride = f"{aname}_stride"
actual_stride = adesc.strides[0]
assert len(adesc.strides) == 1
assert (
str(actual_stride) == exp_stride
), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'."

nsdfg = sdfg.parent_nsdfg_node
if nsdfg is not None:
assert exp_stride in nsdfg.symbol_mapping
assert str(nsdfg.symbol_mapping[exp_stride]) == exp_stride

# Now we propagate `a` and `b`, but not `c`.
gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=False)
sdfg_level1.validate()
gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=False)
sdfg_level1.validate()

# Because `ignore_symbol_mapping=False` the strides of the data descriptor should
# not have changed. But the `symbol_mapping` has been updated for `a` and `b`.
# However, the symbols will only point one level above.
for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1):
for aname, adesc in sdfg.arrays.items():
nsdfg = sdfg.parent_nsdfg_node
original_stride = f"{aname}_stride"

if aname.startswith("c"):
target_symbol = f"{aname}_stride"
else:
target_symbol = f"{aname[0]}{level - 1}_stride"

if nsdfg is not None:
assert original_stride in nsdfg.symbol_mapping
assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol
assert len(adesc.strides) == 1
assert (
str(adesc.strides[0]) == original_stride
), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'."

# Now we also propagate `c` thus now all data descriptors have the same stride
gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=False)
sdfg_level1.validate()
for level, sdfg in enumerate([sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg], start=1):
for aname, adesc in sdfg.arrays.items():
nsdfg = sdfg.parent_nsdfg_node
original_stride = f"{aname}_stride"
target_symbol = f"{aname[0]}{level-1}_stride"
if nsdfg is not None:
assert original_stride in nsdfg.symbol_mapping
assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol
assert len(adesc.strides) == 1
assert (
str(adesc.strides[0]) == original_stride
), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'."


def test_strides_propagation_ignore_symbol_mapping():
# Note that the SDFG we are building here is not really meaningful.
sdfg_level1, nsdfg_level2, nsdfg_level3 = _make_strides_propagation_level1_sdfg()

@@ -201,7 +259,9 @@ def test_strides_propagation():
# Now we propagate `a` and `b`, but not `c`.
# TODO(phimuell): Create a version where we can set `ignore_symbol_mapping=False`.
gtx_transformations.gt_propagate_strides_of(sdfg_level1, "a1", ignore_symbol_mapping=True)
sdfg_level1.validate()
gtx_transformations.gt_propagate_strides_of(sdfg_level1, "b1", ignore_symbol_mapping=True)
sdfg_level1.validate()

# After the propagation `a` and `b` should use the same stride (the one that
# it has on level 1, but `c` should still be level depending.
@@ -213,8 +273,8 @@ def test_strides_propagation():
else:
exp_stride = f"{aname[0]}1_stride"
assert len(adesc.strides) == 1
assert exp_stride == str(
adesc.strides[0]
assert (
str(adesc.strides[0]) == exp_stride
), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'."

nsdfg = sdfg.parent_nsdfg_node
@@ -224,13 +284,14 @@ def test_strides_propagation():

# Now we also propagate `c` thus now all data descriptors have the same stride
gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=True)
sdfg_level1.validate()
for sdfg in [sdfg_level1, nsdfg_level2.sdfg, nsdfg_level3.sdfg]:
for aname, adesc in sdfg.arrays.items():
exp_stride = f"{aname[0]}1_stride"
original_stride = f"{aname}_stride"
assert len(adesc.strides) == 1
assert exp_stride == str(
adesc.strides[0]
assert (
str(adesc.strides[0]) == exp_stride
), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'."

nsdfg = sdfg.parent_nsdfg_node