diff --git a/iree/turbine/kernel/wave/constraints.py b/iree/turbine/kernel/wave/constraints.py index b83ee825..6310d2a5 100644 --- a/iree/turbine/kernel/wave/constraints.py +++ b/iree/turbine/kernel/wave/constraints.py @@ -14,7 +14,6 @@ from .._support.dtype import DataType from ..lang.global_symbols import * - """ Formatting for different target intrinsics: __xx_[_] @@ -498,6 +497,27 @@ def apply(self) -> IndexSequence: raise ValueError("Index is being computed without setting wave id") return IndexSequence(self.tile_size * self.wave_id, 1) + def set_wave_id_from_hardware_and_workgroup_constraint( + self, + hardware_constraint: HardwareConstraint, + workgroup_constraint: WorkgroupConstraint, + ): + """ + The wave_id is the same as the thread_id, with the exception of + wave_id[0] = thread_id[0] / threads_per_wave + This is a convention that we adopt. + """ + old_wave_id = self.wave_id + assert self.dim == workgroup_constraint.dim, "Dimension mismatch" + self.wave_id = hardware_constraint.get_thread_id_from_workgroup_dim( + workgroup_constraint.workgroup_dim + ) + if workgroup_constraint.workgroup_dim == 0: + self.wave_id = floor(self.wave_id / hardware_constraint.threads_per_wave) + assert ( + old_wave_id is None or self.wave_id == old_wave_id + ), f"Conflicting preset wave_id old: {old_wave_id} new: {self.wave_id}" + def get_constrained_shape( shape: list[IndexExpr], constraints: list[WorkgroupConstraint | TilingConstraint] diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index c9bee670..066740ac 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -209,20 +209,10 @@ def initialize_wave_constraints(self, trace: CapturedTrace) -> None: hardware_constraint = self.hardware_constraints[0] for wave_constraint in self.wave_constraints: for workgroup_constraint in self.workgroup_constraints: - # The wave_id is the same as the thread_id, with the exception - # of wave_id[0] = thread_id[0] / threads_per_wave. This is - # a convention that we adopt. if wave_constraint.dim == workgroup_constraint.dim: - wave_constraint.wave_id = ( - hardware_constraint.get_thread_id_from_workgroup_dim( - workgroup_constraint.workgroup_dim - ) + wave_constraint.set_wave_id_from_hardware_and_workgroup_constraint( + hardware_constraint, workgroup_constraint ) - if workgroup_constraint.workgroup_dim == 0: - wave_constraint.wave_id = sympy.floor( - wave_constraint.wave_id - / hardware_constraint.threads_per_wave - ) def initialize_reductions(self, trace: CapturedTrace) -> None: """