diff --git a/paddle/pir/include/dialect/control_flow/ir/cf_op.h b/paddle/pir/include/dialect/control_flow/ir/cf_op.h index b3cca969d44ca..24af0942ca982 100644 --- a/paddle/pir/include/dialect/control_flow/ir/cf_op.h +++ b/paddle/pir/include/dialect/control_flow/ir/cf_op.h @@ -39,7 +39,9 @@ class IR_API YieldOp : public Op { /// /// \brief Push a value tuple to a container. /// -class IR_API TuplePushOp : public Op { +class IR_API TuplePushOp : public Op { public: using Op::Op; static const char *name() { return "cf.tuple_push"; } @@ -70,6 +72,8 @@ class IR_API TuplePushOp : public Op { return inlet().defining_op(); } TuplePopOp tuple_pop_op(); + + CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context); }; class IR_API TuplePopOp : public Op { diff --git a/paddle/pir/src/dialect/control_flow/ir/cf_op.cc b/paddle/pir/src/dialect/control_flow/ir/cf_op.cc index 6b66ba21478ec..bfd5e14a404db 100644 --- a/paddle/pir/src/dialect/control_flow/ir/cf_op.cc +++ b/paddle/pir/src/dialect/control_flow/ir/cf_op.cc @@ -90,6 +90,21 @@ TuplePopOp TuplePushOp::tuple_pop_op() { return container_interface().tuple_pop_op(); } +void TuplePushOp::CacheGradOpSymbolicShape( + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape = GetInputShape(infer_context, this->operation(), 0); + pir::InferSymbolicShapeCacheKey op_shape_info("cf.tuple_pop", {x_shape}, ); + + std::vector pop_value_shape_list; + for (size_t index = 1; index < num_operands(); ++index) { + const auto &pop_value_shape = + GetGradVarShapeFromInput(infer_context, this->operation(), index); + pop_value_shape_list.emplace_back(pop_value_shape); + } + infer_context->SetOpInferSymbolicShapeCache(op_shape_info, + pop_value_shape_list); +} + void TuplePopOp::Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT Value outlet) { @@ -202,11 +217,10 @@ void StackCreateOp::VerifySig() { bool StackCreateOp::InferSymbolicShape( pir::InferSymbolicShapeContext *infer_context) { - const auto &null_shape_or_data = - symbol::ShapeOrDataDimExprs(symbol::NullShapeOrDataDimExpr()); - infer_context->SetShapeOrDataForValue(result(0), null_shape_or_data); - infer_context->SetShapeOrDataForValue(result(1), null_shape_or_data); - infer_context->SetShapeOrDataForValue(result(2), null_shape_or_data); + symbol::DimExpr mark_symbol = infer_context->GetNextSymName(); + infer_context->SetShapeOrDataForValue(result(0), mark_symbol); + infer_context->SetShapeOrDataForValue(result(1), mark_symbol); + infer_context->SetShapeOrDataForValue(result(2), mark_symbol); return true; }