-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[dace][next]: Fixing strides in optimization #1782
feat[dace][next]: Fixing strides in optimization #1782
Conversation
…timization pipeline when confronted with scans.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round.
for state in sdfg.states(): | ||
for data_node in state.data_nodes(): | ||
if data_node.data == top_level_transient: | ||
for in_edge in state.in_edges(data_node): | ||
gt_map_strides_to_src_nested_sdfg(sdfg, state, in_edge, data_node) | ||
for out_edge in state.out_edges(data_node): | ||
gt_map_strides_to_dst_nested_sdfg(sdfg, state, out_edge, data_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit inefficient as the sdfg is scanned multiple times.
The _find_toplevel_transients()
should be modified to also find the places where this is needed.
if isinstance(edge.dst, dace.nodes.MapEntry): | ||
# Find the destinaion of the edge entering the map entry node | ||
map_entry_out_conn = edge.dst_conn.replace("IN_", "OUT_") | ||
for edge_from_map_entry in state.out_edges_by_connector(edge.dst, map_entry_out_conn): | ||
gt_map_strides_to_dst_nested_sdfg(sdfg, state, edge_from_map_entry, outer_node) | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should use the MemletTree to do this a bit more efficient.
for inner_state in edge.dst.sdfg.states(): | ||
for inner_node in inner_state.data_nodes(): | ||
if inner_node.data == edge.dst: | ||
for inner_edge in inner_state.out_edges(inner_node): | ||
gt_map_strides_to_dst_nested_sdfg(sdfg, state, inner_edge, inner_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a bit inefficient.
gt_map_strides_to_dst_nested_sdfg(sdfg, state, inner_edge, inner_node) | ||
|
||
|
||
def gt_map_strides_to_src_nested_sdfg( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be merged with dst
version.
edge_data.subset.size(), | ||
strict=True, | ||
) | ||
if to_map_size != 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@edopao why do we exclude these dimensions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this code overlaps with the checkisinstance(inner_desc, dace.data.Scalar)
I introduced later at line 208. In our SDFGs, when the outer array is accessed point-wise by the memlet (edge_data.subset.size() == 1
in all dimensions), it is implicitly converted into a scalar and the strides are ignored. Of course, this implicit slicing (if you call it that way) only works when the memlet subset has size=1 in all dimensions.
We could rewrite it as:
if isinstance(inner_desc, dace.data.Scalar):
assert set(edge_data.subset.size()) == {1}
return
assert isinstance(inner_desc, dace.data.Array)
and then just use outer_strides
instead of new_strides
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, actually there was a reason. The reason is that currently the scan is writing into a 1D array along the vertical dimension, so the memlet from the nested SDFG implementing the scan is implicitly un-slicing the data into the outer field (which is usually multidimensional). So that code is there to handle array slicing through memlet subset. Question is whether we want to support it, or not since it will be deprecated in next dace release.
new_strides_free_symbols = { | ||
sym for sym in new_strides_symbols if sym.name not in nsdfg_node.sdfg.symbols | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Look at this more.
However the actuall modifier function is not modified yet.
Which is funny then if you look at the last commit, the number of `not`s in this function was correct.
It seems that we alsohave to handle alias. It makes thing a bit handler, instead of only looking at the NestedSDFG, we now look at the `(NameOfDataDescriptorInside, NestedSDFG)` pair. However, it still has some errors.
…ts` flag was not implemented properly.
Before the function had a special mode in which it performed the renaming through the `symbol_mapping`. However, this made testing a bit harder and so I decided that there should be a flag to disable this.
There are some functioanlity missing, but it is looking good.
…trides. However, it is not yet fully tested, tehy are on their wa.
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Show resolved
Hide resolved
new_strides_symbols.extend(sym for sym in new_stride_dim.free_symbols) | ||
|
||
# Now we determine the set of symbols that should be mapped inside the NestedSDFG. | ||
# We will exclude all that are already inside the `symbol_mapping` (we do not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to have a separate check to ensure that the symbols already inside symbol_mapping
map to equivalent expressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree.
However, you could apply this check only to the things that is defined in the symbol mapping and there it might be hard to do in general.
Furthermore, you could have interstate assignments.
I added a Todo, if you have concrete ideas let me know.
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
It also seems that it inferes with something.
Because a scalar has a shape of `(1,)` but a stride of `()`. Thus we have first to handle this case. However, now we are back at the index stuff, let's fix it.
However, it still seems to fail in some cases.
The type is now a bit better estimated.
The type are now extracted from the stuff we get from `free_symbols`.
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Show resolved
Hide resolved
@@ -24,6 +40,11 @@ def gt_change_transient_strides( | |||
transients in the optimal way. | |||
The function should run after all maps have been created. | |||
|
|||
After the strides have been adjusted the function will also propagate | |||
the strides into nested SDFG. This propagation will happen with | |||
`ignore_symbol_mapping` set to `True`, see `gt_propagate_strides_of()` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am of the opinion that ignore_symbol_mapping
should be False
. Actually, I would even prefer that this parameter was not available as configuration option. If the original SDFG was using symbol A
for the stride of the inner array, the transformed SDFG should still use A
, only the symbol mapping should change. Otherwise, if we replace symbols in nested SDFGs, we are doing two things in one transformation: stride propagation and symbol propagation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have now thought a bit more about it and came to the conclusion that in certain situation ignore_symbol_mapping=False
is actually wrong and will produce invalid results.
The reason is that using the symbol mapping actually affects the whole (nested) SDFG and not only the data descriptor.
The problem only occurs if different data descriptor use the same symbols as stride, which happens for example for temporaries, I mean for a long time they had the size (horizontal_end, vertical_end)
which which would result in the strides (horizontal_end, 1)
, which all transient shares.
This is the situation that tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py::test_strides_propagation_shared_symbols_sdfg
demonstrates.
For some reasons the strides of only one descriptors (a1
) are propagated, the other remains the same. If the symbol mapping would be used, then we would also change, as a side effect, the strides of b2
(the inner view).
Earlier I have removed the symbol aliasing code, 2b03bb4
, in the ignore_symbol_mapping=True
branch, exactly of that reason.
I also do not think that this is a big problem, as strides are not directly accessed, the shape is much more likely to be used, which does not change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. Based on your explanation, I now agree on using ignore_symbol_mapping=True
in its current meaning.
However, as a suggestion, we could change the meaning of ignore_symbol_mapping
:
ignore_symbol_mapping=True
would behave as of current implementation: takes the outer strides and just apply it on the inner arrays.ignore_symbol_mapping=False
would create new symbols in the nested SDFG (using the utility function to generate stride symbols based on the data name, and this should be enough to ensure that they are unique) and apply symbol mapping to the outer strides. The benefit of this is to avoid the risk of using a symbol that is already defined in the nested SDFG, with a different value than in parent SDFG.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an issue indeed, I added a todo.
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Show resolved
Hide resolved
src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py
Outdated
Show resolved
Hide resolved
I realized that allowing this is not very safe. I also added a test to show that.
…ces errors in certain cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks good, thank you for improving the fix. You can do as you prefer with my suggestion.
Added functionality to properly handle changes of strides.
During the implementation of the scan we found that the strides were not handled properly.
Most importantly a change on one level was not propagated into the next levels, i.e. they were still using the old strides.
This PR Solves most of the problems, but there are still some issues that are unsolved:
The initial functionality of this PR was done by Edoardo Paone (@edopao).
Co-authored-by: edopao [email protected]