diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index e771a6dfd9e5..25ba46dfef22 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -608,6 +608,7 @@ def TPU_AssumeMultipleOp : TPU_Op<"assume_multiple", [Pure, SameOperandsAndResul I32Attr:$multiple ); let results = (outs AnyTypeOf<[Index, AnyInteger]>:$result); + let hasVerifier = 1; } def TPU_MemRefSliceOp : TPU_Op<"memref_slice", [Pure, AttrSizedOperandSegments]> { diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index e2fe04579a31..69c29e51f3bc 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -1225,6 +1225,27 @@ LogicalResult DynamicGatherOp::verify() { return success(); } +LogicalResult AssumeMultipleOp::verify() { + auto operand_value = getValue(); + auto divisor = getMultiple(); + if (auto cst_op = operand_value.getDefiningOp()) { + auto int_attr = dyn_cast(cst_op.getValue()); + // Illegal usage of AssumeMultipleOp. + if (!int_attr) { + return emitOpError( + "Illegal user annotation, expected an integer, but got ") + << cst_op.getValue(); + } + if (int_attr.getInt() % divisor != 0) { + return emitOpError( + "Illegal user annotation, expected an integer that is " + "divisible by the multiple, but got ") + << int_attr.getInt() << " % " << divisor; + } + } + return success(); +} + } // namespace tpu } // namespace mlir