Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Support inferSymbolicShape for cf.tuple_pop and cf.tuple_push #70723

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions paddle/pir/include/dialect/control_flow/ir/cf_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
#include "paddle/pir/include/core/op_base.h"
#include "paddle/pir/include/core/op_trait.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_interface.h"
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/cache_grad_op_symbolic_shape.h"
#include "paddle/pir/include/dialect/shape/interface/infer_symbolic_shape/infer_symbolic_shape.h"

#include "paddle/pir/include/dialect/shape/utils/original_attributes_filter.h"
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"
namespace pir {
class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
public:
Expand All @@ -39,7 +41,9 @@ class IR_API YieldOp : public Op<YieldOp, SideEffectTrait> {
///
/// \brief Push a value tuple to a container.
///
class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
class IR_API TuplePushOp : public Op<TuplePushOp,
SideEffectTrait,
CacheGradOpSymbolicShapeInterface> {
public:
using Op::Op;
static const char *name() { return "cf.tuple_push"; }
Expand Down Expand Up @@ -70,6 +74,8 @@ class IR_API TuplePushOp : public Op<TuplePushOp, SideEffectTrait> {
return inlet().defining_op<ContainerOpInterface>();
}
TuplePopOp tuple_pop_op();

void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API TuplePopOp : public Op<TuplePopOp, SideEffectTrait> {
Expand Down
34 changes: 28 additions & 6 deletions paddle/pir/src/dialect/control_flow/ir/cf_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "paddle/pir/include/core/ir_printer.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"

namespace pir {

void YieldOp::Build(Builder &builder,
Expand Down Expand Up @@ -90,6 +89,26 @@ TuplePopOp TuplePushOp::tuple_pop_op() {
return container_interface().tuple_pop_op();
}

void TuplePushOp::CacheGradOpSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
const auto &x_shape =
infer_context->GetShapeOrDataForValue(this->operand_source(0));
pir::InferSymbolicShapeCacheKey op_shape_info(
"cf.tuple_pop",
gongshaotian marked this conversation as resolved.
Show resolved Hide resolved
{x_shape},
pir::GetOrderedOriginalAttributes("cf.tuple_pop",
this->operation()->attributes()));

std::vector<symbol::ShapeOrDataDimExprs> pop_value_shape_list;
for (size_t index = 1; index < num_operands(); ++index) {
const auto &pop_value_shape_or_data =
infer_context->GetShapeOrDataForValue(this->operand_source(index));
pop_value_shape_list.emplace_back(pop_value_shape_or_data);
}
infer_context->SetOpInferSymbolicShapeCache(op_shape_info,
pop_value_shape_list);
}

void TuplePopOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
Value outlet) {
Expand Down Expand Up @@ -202,11 +221,14 @@ void StackCreateOp::VerifySig() {

bool StackCreateOp::InferSymbolicShape(
pir::InferSymbolicShapeContext *infer_context) {
const auto &null_shape_or_data =
symbol::ShapeOrDataDimExprs(symbol::NullShapeOrDataDimExpr());
infer_context->SetShapeOrDataForValue(result(0), null_shape_or_data);
infer_context->SetShapeOrDataForValue(result(1), null_shape_or_data);
infer_context->SetShapeOrDataForValue(result(2), null_shape_or_data);
std::vector<symbol::DimExpr> shape;
shape.emplace_back(symbol::DimExpr(infer_context->GetNextSymName()));
const symbol::ShapeOrDataDimExprs &mark_shape_or_data =
symbol::ShapeOrDataDimExprs(symbol::TensorShapeOrDataDimExprs(shape));

infer_context->SetShapeOrDataForValue(result(0), mark_shape_or_data);
infer_context->SetShapeOrDataForValue(result(1), mark_shape_or_data);
infer_context->SetShapeOrDataForValue(result(2), mark_shape_or_data);
return true;
}

Expand Down
16 changes: 14 additions & 2 deletions paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,28 @@ void InferSymExprForOp(Operation* op,
op, infer_context->GetShapeOrDataForValue(op->result(0)));
}
} else {
bool is_grad_op = [&]() {
const bool is_grad_op = [&]() {
std::string suffix = "_grad";
const auto& op_name = op->name();
if (op_name.size() < suffix.size()) return false;
return op_name.compare(
op_name.size() - suffix.size(), suffix.size(), suffix) == 0;
}();

const bool is_special_cached_op = [&]() {
const auto& op_name = op->name();
std::vector<std::string> special_cached_ops = {
"cf.tuple_pop",
};
return (std::find(special_cached_ops.begin(),
special_cached_ops.end(),
op_name) != special_cached_ops.end());
}();

if (!is_grad_op)
LOG(WARNING) << op->name()
<< " DOES NOT have InferSymbolicShapeInterface!";

const bool all_outs_static_dims = [&] {
bool all_static_dims = true;
for (uint32_t i = 0; i < op->num_results(); ++i) {
Expand All @@ -288,7 +300,7 @@ void InferSymExprForOp(Operation* op,
return all_static_dims;
}();

if (all_outs_static_dims) {
if (all_outs_static_dims && !is_special_cached_op) {
for (uint32_t i = 0; i < op->num_results(); ++i) {
infer_context->SetSymbolForValueByStaticShape(op->result(i));
}
Expand Down
13 changes: 6 additions & 7 deletions python/paddle/jit/dy2static/pir_partial_program.py
gongshaotian marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,12 @@ class FullGraphPreProcessPass(ValuePreservePass):
def apply(self, program):
program = paddle.base.libpaddle.pir.apply_bn_add_act_pass(program)
if self.use_cinn_pass:
program = paddle.base.libpaddle.pir.reduce_as_sum_pass(program)
# NOTE(gongshaotian): execute infer_symbolic_shape_pass before reduce_as_sum_pass
pm = paddle.base.libpaddle.pir.PassManager()
pm.add_pass("delete_assert_op_pass", {})
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(pm, program)
pm.add_pass("reduce_as_sum_pass", {})
pm.run(program)
return program


Expand Down Expand Up @@ -711,12 +716,6 @@ def _create_program(self, is_infer_mode=False):
if is_infer_mode:

def pass_fn(forward_program, backward_program, program_name_attr):
# common pass
pm = paddle.base.libpaddle.pir.PassManager()
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(
pm, forward_program
)
pm.run(forward_program)

apply_general_passes(
forward_program,
Expand Down
Loading