Skip to content

Commit

Permalink
Fixed issue with cat in fx backend (openvinotoolkit#20744)
Browse files Browse the repository at this point in the history
* Added fix for cat in torchfx

* Added batch_norm_legit_no_training op

* Fixed coding style

* Fixed clang format

* Addressed PR comments
  • Loading branch information
suryasidd authored Nov 1, 2023
1 parent 26c9c41 commit bb0e4f8
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(self):
"torch.ops.aten.mul.Scalar": None,
"torch.ops.aten.mul.Tensor": None,
"torch.ops.aten.native_batch_norm.default": None,
"torch.ops.aten._native_batch_norm_legit.default": None,
"torch.ops.aten._native_batch_norm_legit_no_training.default": None,
"torch.ops.aten.native_group_norm.default": None,
"torch.ops.aten.native_layer_norm.default": None,
"torch.ops.aten.neg.default": None,
Expand Down
34 changes: 28 additions & 6 deletions src/frontends/pytorch/src/op/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ Output<Node> broadcast_const_to_channel_dim(const NodeContext& context,
}
} // namespace

OutputVector translate_batch_norm(const NodeContext& context) {
OutputVector translate_batch_norm_common(const NodeContext& context, bool training) {
// Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var,
// bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor
num_inputs_check(context, 8, 9);

// batch_norm_legit_no_training Schema: aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor?
// running_mean, Tensor? running_var, float momentum, float eps) -> Tensor

auto input = context.get_input(0);
Output<Node> weight;
Output<Node> bias;
Expand All @@ -63,7 +66,6 @@ OutputVector translate_batch_norm(const NodeContext& context) {
bias = broadcast_const_to_channel_dim(context, input, zero_f);
}
// index 3 running_mean and index 4 running_var can be none for training case only, check that not training before
auto training = context.const_input<bool>(5);
// if training for batch norm activated, but model in eval mode, it uses current statistics instead of running
if (training) {
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
Expand Down Expand Up @@ -92,14 +94,34 @@ OutputVector translate_batch_norm(const NodeContext& context) {
running_var = current_var;
}
// Input with index 6 is momentum, it is used only for updating running_mean accumulation during training
auto epsilon = context.const_input<float>(7);
// In batch_norm_legit_no_training, momentum is index 5 and epsilon is 6
float epsilon;
if (context.get_input_size() == 7) {
epsilon = context.const_input<float>(6);
} else {
epsilon = context.const_input<float>(7);
}
// Input with index 8 is flag "cudnn_enabled" we can ignore it
return {context.mark_node(
std::make_shared<v5::BatchNormInference>(input, weight, bias, running_mean, running_var, epsilon))};
};

OutputVector translate_batch_norm_fx(const NodeContext& context) {
auto output = translate_batch_norm(context);
OutputVector translate_batch_norm(const NodeContext& context) {
num_inputs_check(context, 7, 9);
auto training = context.const_input<bool>(5);
return translate_batch_norm_common(context, training);
}

OutputVector translate_batch_norm_legit_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
auto training = context.const_input<bool>(5);
auto output = translate_batch_norm_common(context, training);
return {context.mark_node(make_list_construct(output))};
}

OutputVector translate_batch_norm_legit_no_training_fx(const NodeContext& context) {
num_inputs_check(context, 7, 9);
auto output = translate_batch_norm_common(context, false);
return {context.mark_node(make_list_construct(output))};
}

Expand Down
13 changes: 7 additions & 6 deletions src/frontends/pytorch/src/op/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ using namespace ov::op;

OutputVector translate_cat_common(const NodeContext& context,
const std::deque<ov::Output<ov::Node>>& list_elems,
int64_t axis) {
int64_t axis,
bool is_fx) {
if (list_elems.empty()) {
// couldn't get list elements
auto fw_node = std::make_shared<PtFrameworkNode>(context.get_decoder(), OutputVector{context.get_input(0)}, 1);
Expand All @@ -39,8 +40,8 @@ OutputVector translate_cat_common(const NodeContext& context,
"<aten/quantized>::cat is located inside body while inputs are located outside of the body. "
"This case is not supported.");
if (list_elems.size() == 1 &&
!std::dynamic_pointer_cast<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr())) {
// Case when list was merged into tensor
!std::dynamic_pointer_cast<op::util::FrameworkNode>(context.get_input(0).get_node_shared_ptr()) && !is_fx) {
// Case when list was merged into tensor. // This case doesn't work with torchfx
auto tensor = list_elems[0];
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(tensor, element::i32));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
Expand All @@ -63,7 +64,7 @@ OutputVector translate_cat(const NodeContext& context) {
num_inputs_check(context, 2, 3);
const auto&& list_elems = get_list_as_outputs(context.get_input(0));
auto axis = context.const_input<int64_t>(1);
auto out = translate_cat_common(context, list_elems, axis);
auto out = translate_cat_common(context, list_elems, axis, false);
if (!context.input_is_none(2)) {
context.mutate_input(2, out[0]);
}
Expand All @@ -78,7 +79,7 @@ OutputVector translate_cat_fx(const NodeContext& context) {
list_elems.push_back(context.get_input(static_cast<int>(i)));
}
auto axis = context.const_input<int64_t>(context.get_input_size() - 1);
return translate_cat_common(context, list_elems, axis);
return translate_cat_common(context, list_elems, axis, true);
};

OutputVector translate_quantized_cat(const NodeContext& context) {
Expand All @@ -87,7 +88,7 @@ OutputVector translate_quantized_cat(const NodeContext& context) {
auto axis = context.const_input<int64_t>(1);
FRONT_END_OP_CONVERSION_CHECK(!list_elems.empty(), "Couldn't find quantized input for quantized::cat operation.");
return {quantize(context,
translate_cat_common(context, list_elems, axis)[0],
translate_cat_common(context, list_elems, axis, false)[0],
context.get_input(2),
context.get_input(3),
list_elems.front())};
Expand Down
7 changes: 5 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ OP_CONVERTER(translate_quantized_linear);
OP_CONVERTER(translate_xor);
// Torch FX Translations
OP_CONVERTER(translate_arange_fx);
OP_CONVERTER(translate_batch_norm_fx);
OP_CONVERTER(translate_batch_norm_legit_fx);
OP_CONVERTER(translate_batch_norm_legit_no_training_fx);
OP_CONVERTER(translate_cat_fx);
OP_CONVERTER(translate_chunk_fx);
OP_CONVERTER(translate_expand_fx);
Expand Down Expand Up @@ -612,7 +613,9 @@ const std::map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.mm.default", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten.mul.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten.mul.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten.native_batch_norm.default", op::translate_batch_norm_fx},
{"aten.native_batch_norm.default", op::translate_batch_norm_legit_fx},
{"aten._native_batch_norm_legit.default", op::translate_batch_norm_legit_fx},
{"aten._native_batch_norm_legit_no_training.default", op::translate_batch_norm_legit_no_training_fx},
{"aten.native_group_norm.default", op::translate_group_norm_fx},
{"aten.native_layer_norm.default", op::translate_layer_norm_fx},
{"aten.neg.default", op::translate_neg},
Expand Down

0 comments on commit bb0e4f8

Please sign in to comment.