Skip to content

Commit

Permalink
add cache interface for cf.tuple_push
Browse files Browse the repository at this point in the history
  • Loading branch information
gongshaotian committed Jan 8, 2025
1 parent 4128286 commit e2df3a0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
6 changes: 5 additions & 1 deletion paddle/pir/include/dialect/control_flow/ir/cf_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
///
/// \brief Push a value tuple to a container.
///
class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
class IR_API TuplePushOp : public Op<TuplePushOp,
SideEffectTrait,
CacheGradOpSymbolicShapeInterface> {
public:
using Op::Op;
static const char *name() { return "cf.tuple_push"; }
Expand Down Expand Up @@ -70,6 +72,8 @@ class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
return inlet().defining_op<ContainerOpInterface>();
}
TuplePopOp tuple_pop_op();

CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API TuplePopOp : public Op<TuplePopOp, SideEffectTrait> {
Expand Down
24 changes: 19 additions & 5 deletions paddle/pir/src/dialect/control_flow/ir/cf_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ TuplePopOp TuplePushOp::tuple_pop_op() {
return container_interface().tuple_pop_op();
}

void TuplePushOp::CacheGradOpSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape = GetInputShape(infer_context, this->operation(), 0);
pir::InferSymbolicShapeCacheKey op_shape_info("cf.tuple_pop", {x_shape}, );

std::vector<symbol::ShapeOrDataDimExprs> pop_value_shape_list;
for (size_t index = 1; index < num_operands(); ++index) {
const auto &pop_value_shape =
GetGradVarShapeFromInput(infer_context, this->operation(), index);
pop_value_shape_list.emplace_back(pop_value_shape);
}
infer_context->SetOpInferSymbolicShapeCache(op_shape_info,
pop_value_shape_list);
}

void TuplePopOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value outlet) {
Expand Down Expand Up @@ -202,11 +217,10 @@ void StackCreateOp::VerifySig() {

bool StackCreateOp::InferSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
const auto &null_shape_or_data =
symbol::ShapeOrDataDimExprs(symbol::NullShapeOrDataDimExpr());
infer_context->SetShapeOrDataForValue(result(0), null_shape_or_data);
infer_context->SetShapeOrDataForValue(result(1), null_shape_or_data);
infer_context->SetShapeOrDataForValue(result(2), null_shape_or_data);
symbol::DimExpr mark_symbol = infer_context->GetNextSymName();
infer_context->SetShapeOrDataForValue(result(0), mark_symbol);
infer_context->SetShapeOrDataForValue(result(1), mark_symbol);
infer_context->SetShapeOrDataForValue(result(2), mark_symbol);
return true;
}

Expand Down

0 comments on commit e2df3a0

Please sign in to comment.