Skip to content

Commit

Permalink
test2
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed May 23, 2024
1 parent 86c3a14 commit 6096bb3
Show file tree
Hide file tree
Showing 19 changed files with 100 additions and 225 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,11 @@ void ReduceCPULayerTest::SetUp() {

function = makeNgraphFunction(netPrecision, params, reduce, "Reduce");

if (netPrecision == ov::element::f32 && configuration.count(ov::hint::inference_precision.name()) &&
(configuration.at(ov::hint::inference_precision.name()) == ov::element::f16 ||
configuration.at(ov::hint::inference_precision.name()) == ov::element::bf16)) {
abs_threshold = 5e-3;
// if (ov::with_cpu_x86_avx512_core_amx()) {
// abs_threshold = 5e-3;
// } else {
// abs_threshold = 5e-2;
// }
if (ov::with_cpu_x86_avx512_core_amx()) {
if (netPrecision == ov::element::f32 && configuration.count(ov::hint::inference_precision.name()) &&
configuration.at(ov::hint::inference_precision.name()) == ov::element::f16) {
abs_threshold = 5e-3;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ class GroupConvolutionLayerCPUTest : public testing::WithParamInterface<groupCon
selectedType = makeSelectedTypeStr(selectedType, netType);
}

if (fusedOps.size() == 3 && fusedOps[1] == std::string("Elu") && fusedOps[2] == std::string("FakeQuantize")) {
abs_threshold = 5e-3f;
}

ov::op::PadType padType;
std::vector<size_t> kernel, stride, dilation;
std::vector<ptrdiff_t> padBegin, padEnd;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ class LoopWhileLayerCPUTest : public LoopLayerCPUTest {
bool exec_cond;
std::vector<InputShape> shapes;
std::vector<LOOP_IN_TYPE> types;
// ov::element::Type in_type;
std::tie(trip_count_type, trip_count, exec_cond, shapes, types, inType) = this->GetParam();

targetDevice = ov::test::utils::DEVICE_CPU;
Expand Down Expand Up @@ -249,7 +248,6 @@ class LoopForDiffShapesLayerCPUTest : public LoopLayerCPUTest {
bool exec_cond;
std::vector<InputShape> shapes;
std::vector<LOOP_IN_TYPE> types;
ov::element::Type inType;
std::tie(trip_count_type, trip_count, exec_cond, shapes, types, inType) = this->GetParam();

targetDevice = ov::test::utils::DEVICE_CPU;
Expand Down Expand Up @@ -328,7 +326,6 @@ class LoopForConcatLayerCPUTest : public LoopLayerCPUTest {
bool exec_cond;
std::vector<InputShape> shapes;
std::vector<LOOP_IN_TYPE> types;
// ov::element::Type in_type;
std::tie(trip_count_type, trip_count, exec_cond, shapes, types, inType) = this->GetParam();

targetDevice = ov::test::utils::DEVICE_CPU;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,28 +313,6 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*smoke_LoopForCommon/LoopLayerCPUTest.CompareWithRefs/.*trip_count=5_exec_cond=1_netType=i8.*)",
R"(.*smoke_LoopForCommon/LoopLayerCPUTest.CompareWithRefs/Input0_IS=\[\?.1.\?\]_TS=\(10.1.10\)_\(1.1.1\)_\(1.1.1\)_\(5.1.3\)_Input1_IS=\[\?.\?.\?\]_TS=.*_Input2_IS=\[\?.1.\?\]_.*_types=0_0_1_trip_count_type=.*_trip_count=(1|5)_exec_cond=1_netType=i8.*)",
R"(.*smoke_LoopForCommon/LoopLayerCPUTest.CompareWithRefs/Input0_IS=\[1..10.1.1..10\]_.*_Input1_IS=\[1..8.1.1..8\]_.*_Input2_IS=\[1..10.\?.1..10\]_TS=.*_types=0_0_1_trip_count_type=.*_trip_count=(1|5)_exec_cond=1_netType=i8.*)",
// R"(.*smoke_Reduce_MultiAxis_4D_fusing_CPU/ReduceCPULayerTest.CompareWithRefs/.*=VECTOR_type=(Max|ReduceL2|Mean)_.*_INFERENCE_PRECISION_HINT=bf16_.*_Fused=Multiply\(PerChannel\).Add\(PerChannel\).*)",
// R"(.*smoke_Reduce_MultiAxis_4D_fusing_CPU/ReduceCPULayerTest.CompareWithRefs/IS=.*_axes=\((0.2|0.3)\)_opType=VECTOR_type=Mean_.*_INFERENCE_PRECISION_HINT=bf16_.*_Fused=Multiply\(PerChannel\).Add\(PerChannel\).*)",
// R"(.*smoke_Reduce_OneAxis_fusing_CPU/ReduceCPULayerTest.CompareWithRefs/.*_INFERENCE_PRECISION_HINT=bf16_Fused=Multiply\(PerChannel\).Add\(PerChannel\).*)",
// R"(.*smoke_Reduce_OneAxis_CPU/ReduceCPULayerTest.CompareWithRefs/IS=.*_axes=\((1|3)\)_opType=.*_type=(ReduceL1|Sum)_.*_INFERENCE_PRECISION_HINT=bf16.*)",
// R"(.*smoke_Reduce_MultiAxis_4D_CPU/ReduceCPULayerTest.CompareWithRefs/IS=.*_axes=\((0.1|0.3|1.2|1.3|2.3|0.1.2|0.1.3|0.2.3|1.2.3)\)_opType=.*_type=(ReduceL1|Sum)_.*_INFERENCE_PRECISION_HINT=bf16.*)",
// R"(.*smoke_Reduce_MultiAxis_5D_fusing_CPU/ReduceCPULayerTest.CompareWithRefs/IS=.*_axes=\((2.4|0.2.4|0.1.2.3.4)\)_opType=.*_type=(ReduceL1|Mean|ReduceL2|Max)_KeepDims=true_netPRC=f32_.*_INFERENCE_PRECISION_HINT=bf16_.*_Fused=Multiply\(PerChannel\).Add\(PerChannel\).*)",
// R"(.*smoke_Reduce_MultiAxis_5D_fusing_CPU/ReduceCPULayerTest.CompareWithRefs/IS=.*_axes=\(1.2.4\)_opType=.*_type=(ReduceL2|Max)_KeepDims=true_netPRC=f32_.*_INFERENCE_PRECISION_HINT=bf16_.*_Fused=Multiply\(PerChannel\).Add\(PerChannel\).*)",
// R"(.*smoke_Reduce_SingleBatch_CPU/ReduceCPULayerTest.CompareWithRefs/.*_axes=\((1|3)\)_opType=.*_type=(ReduceL1|Sum)_KeepDims=true_netPRC=f32_.*_INFERENCE_PRECISION_HINT=bf16_.*)",
// R"(.*smoke_Reduce_MultiAxis_5D_CPU/ReduceCPULayerTest.CompareWithRefs/IS=.*_axes=\((2.4|0.2.4)\)_opType=.*_type=(ReduceL1|Sum)_KeepDims=true_netPRC=f32_.*_INFERENCE_PRECISION_HINT=bf16.*)",
// R"(.*smoke_Reduce_MultiAxis_4D_dynamic_CPU/ReduceCPULayerTest.CompareWithRefs/IS=.*_axes=\(0.1\)_opType=.*_type=(ReduceL1|Sum)_KeepDims=true_netPRC=f32_.*_INFERENCE_PRECISION_HINT=bf16.*)",
// R"(.*smoke_Reduce_NHWC_SmallChannel_CPU/ReduceCPULayerTest.CompareWithRefs/IS=.*_axes=\(2.3\)_opType=.*_type=(ReduceL1|Sum)_KeepDims=true_netPRC=f32_.*_INFERENCE_PRECISION_HINT=bf16.*)",
// R"(.*smoke_GatherTree/GatherTreeLayerTest.Inference/IS=\(20.(1|20).10\)_secondary_input_type=CONSTANT_netPRC=f32.*)",
// R"(.*smoke_FakeQuantizeLayerCPUTest_5D_jit/FakeQuantizeLayerCPUTest.CompareWithRefs/IS=\[\?.\?.\?.\?.\?\]_TS=\(\(4.16.6.7.8\)\)_\(\(1.16.1.1.1\)\)_.*_inPrec=f32_LOW_BOUNDS=-10_HIGH_BOUNDS=10_IL=\(-10\)_IH=\(-5\)_OL=\(5\)_OH=\(25\)_LEVELS=256_.*)",
// R"(.*smoke_basic/PermConvPermConcat.CompareWithRefs/IS=\(1.1.8.16\)_KS=\(1.3\)_OC=(32|64)_ET=f32_targetDevice=CPU.*)",
// R"(.*smoke_GatherTreeCPUStatic/GatherTreeLayerCPUTest.CompareWithRefs/IS=.*_TS=\(20.(1|20).10\)_secondaryInputType=CONSTANT_netPRC=f32_inPRC=undefined_outPRC=undefined_trgDev=CPU.*)",
// R"(.*smoke_GatherTreeCPUDynamicConstant/GatherTreeLayerCPUTest.CompareWithRefs/IS=.*_TS=\((7.1.10|2.1.7|20.1.10|20.20.15)\)_secondaryInputType=CONSTANT_netPRC=f32_inPRC=undefined_outPRC=undefined_trgDev=CPU.*)",
// R"(.*smoke_GroupConv_3D_Gemm_FP32/GroupConvolutionLayerCPUTest.CompareWithRefs/IS=\[\].*S\(1.1.1\).*O=6_G=3_AP=explicit_netPRC=f32_.*inFmts=(ndhwc|ncdhw)_outFmts=(ndhwc|ncdhw)_.*_Fused=FakeQuantize\(PerTensor\).Relu.*)",
// R"(.*smoke_GroupConv_3D_Gemm_FP32/GroupConvolutionLayerCPUTest.CompareWithRefs/IS=\[1..200.12.\?.1..200.\?\]_.*_S\(1.1.1\)_.*_O=6_G=3_AP=explicit_netPRC=f32_inPRC=undefined_outPRC=undefined_trgDev=CPU_inFmts=(ndhwc|ncdhw)_outFmts=(ndhwc|ncdhw)_primitive=jit_gemm_Fused=FakeQuantize\(PerTensor\).Relu.*)",
// R"(.*smoke_JIT_SSE42_DW_GroupConv/GroupConvolutionLayerCPUTest.CompareWithRefs/IS=.*_TS=\(\(2.8.129.129\)_\)_K\(3.3\)_S\(2.2\)_PB\(1.1\)_PE\(1.1\)_D=\(1.1\)_O=8_G=8_AP=explicit_netPRC=f32_.*_inFmts=(nChw8c|nhwc)_outFmts=(nChw8c|nhwc)_primitive=jit_sse42_dw_Fused=Add\(Parameters\).Elu.FakeQuantize\(PerTensor\).*)",
// R"(.*smoke_FC_3D_BF16/MatMulLayerCPUTest.CompareWithRefs/FullyConnected_IS=\[\?.\?\]_\[1.129.1\]_TS=\(\(1.129\)_\(2.129\)_\(1.129\)_\(2.129\)\)_\(\(1.129.1\)_\(1.129.1\)_\(1.129.1\)_\(1.129.1\)\)_transpose_a=1_transpose_b=1_secondaryInputType=CONSTANT_netPRC=(bf16|f32).*config=\(INFERENCE_PRECISION_HINT=bf16_\)_Fused=Multiply\(PerChannel\).*)",
// R"(.*smoke_MatMulCompressedWeights_non_default_dyn_quant_group_sizes/MatmulWeightsDecompression.CompareWithRefs/data_shape=\[\]_\(\[1,1,1728\]\)_weights_shape=\[1728,128\]_group_size=64_weights_precision=(u8|u4)_decompression_precision=f32_transpose_weights=
// 1_decompression_subtract=full_reshape_on_decompression=0_config=\(DYNAMIC_QUANTIZATION_GROUP_SIZE, 128:\).*)",
};

#if defined(OPENVINO_ARCH_X86)
Expand Down Expand Up @@ -367,6 +345,9 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(.*smoke_LogSoftmax4D/LogSoftmaxLayerTest.Inference/.*TS=\{\(2.3.4.5\)\}_modelType=f32_axis=(-4|0).*)");
retVector.emplace_back(R"(.*smoke_Interpolate_Basic/InterpolateLayerTest.Inference/.*InterpolateMode=cubic_ShapeCalcMode=sizes_CoordinateTransformMode=tf_half_pixel.*PB=\(0.0.0.0\)_PE=\(0.0.1.1\)_.*netType=f32.*)");
retVector.emplace_back(R"(.*smoke_CompareWithRefs_4D_Bitwise.*/EltwiseLayerCPUTest.*_eltwise_op_type=Bitwise.*_model_type=i32_.*)");
// range propagation
retVector.emplace_back(R"(.*smoke_CompareWithRefs_static/EltwiseLayerTest.Inference/IS=.*_TS=\(\(16.16.16.(16|1).(16|1)\)_\)_eltwise_op_type=Sum_secondary_input_type=PARAMETER_opType=SCALAR_model_type=i32.*)");
retVector.emplace_back(R"(.*smoke_CompareWithRefs_static_check_collapsing/EltwiseLayerTest.Inference/IS=.*_TS=\(\(16.16.16.16\)_\(16.16.(16.1|1.16)\)_\)_eltwise_op_type=Sum_secondary_input_type=PARAMETER_opType=VECTOR_model_type=i32_.*)");
}
// invalid test: checks u8 precision for runtime graph, while it should be f32
retVector.emplace_back(R"(smoke_NegativeQuantizedMatMulMultiplyFusion.*)");
Expand All @@ -388,26 +369,6 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(smoke_VariableState/OVInferRequestVariableStateTest.*)");
// Issue: 141705
retVector.emplace_back(R"(.*smoke_arm_Deconv_2D_Planar_FP16/DeconvolutionLayerCPUTest.*INFERENCE_PRECISION_HINT=f16.*)");
// fill_data_random fix
// retVector.emplace_back(R"(.*smoke_arm_Deconv_2D_Planar_FP16/DeconvolutionLayerCPUTest.*INFERENCE_PRECISION_HINT=f16.*)");
// retVector.emplace_back(R"(.*smoke_CompareWithRefs_dynamic/EltwiseLayerTest.Inference/IS=.*1..10.200.*1..10.200.*_TS=.*2.200.*1.200.*2.200.*5.200.*_eltwise_op_type=Sum_secondary_input_type=PARAMETER_opType=VECTOR_model_type=i32_.*)");
// retVector.emplace_back(R"(.*smoke_CompareWithRefs_static_check_collapsing/EltwiseLayerTest.Inference/IS=.*_TS=\(\(16.16.16.16\)_\(16.16.(16.1|1.16)\)_\)_eltwise_op_type=Sum_secondary_input_type=PARAMETER_opType=VECTOR_model_type=i32_.*)");
// retVector.emplace_back(R"(.*smoke_CompareWithRefs_static_check_collapsing/EltwiseLayerTest.Inference/IS=.*_TS=\(\(16.16.16.(16|1)\)_\(16.16.16.1\)_\)_eltwise_op_type=SqDiff_secondary_input_type=.*_opType=VECTOR_model_type=i32.*)");
// retVector.emplace_back(R"(.*smoke_CompareWithRefs_static_check_collapsing/EltwiseLayerTest.Inference/IS=.*_TS=\(\(16.16.(16|1).16\)_\(16.16.1.16\)_\)_eltwise_op_type=SqDiff_secondary_input_type=.*_opType=VECTOR_model_type=i32.*)");
// retVector.emplace_back(R"(.*smoke_CompareWithRefs_static/EltwiseLayerTest.Inference/IS=.*_TS=\(\(16.16.16.(16|1).(16|1)\)_\)_eltwise_op_type=Sum_secondary_input_type=PARAMETER_opType=SCALAR_model_type=i32.*)");
// retVector.emplace_back(R"(.*smoke_LSTMCellCommon/LSTMCellTest.Inference/decomposition0_batch=5_hidden_size=10_input_size=1_IS=\(5.1\)\(5.10\)\(5.10\)\(40.1\)\(40.10\)\(40\)_activations=\(relu.(sigmoid.tanh|relu.relu)\)_clip=0_WType=.*_RType=CONSTANT_BType=PARAMETER_modelType=f16.*)");
// retVector.emplace_back(R"(.*smoke_LSTMCellCommon/LSTMCellTest.Inference/decomposition0_batch=5_hidden_size=10_input_size=30_IS=\(5.30\)\(5.10\)\(5.10\)\(40.30\)\(40.10\)\(40\)_activations=\(relu.(sigmoid.tanh|relu.relu)\)_clip=0_WType=.*_RType=CONSTANT_BType=PARAMETER_modelType=f16.*)");
// retVector.emplace_back(R"(.*smoke_GRUCellCommon/GRUCellTest.Inference/decomposition1_batch=5_hidden_size=1_input_size=1_IS=\(5.1\)\(5.1\)\(3.1\)\(3.1\)\(4\)_activations=\(tanh.relu\)_clip=0_linear_before_reset=1_WType=.*_RType=.*_BType=CONSTANT_netPRC=f32_.*)");
// retVector.emplace_back(R"(.*smoke_GRUCellCommon/GRUCellTest.Inference/decomposition1_batch=5_hidden_size=1_input_size=30_IS=\(5.30\)\(5.1\)\(3.30\)\(3.1\)\(4\)_activations=\(tanh.relu\)_clip=0_linear_before_reset=1_WType=.*_RType=.*_BType=CONSTANT_netPRC=f32_.*)");
// retVector.emplace_back(R"(.*smoke_GRUCellCommon/GRUCellTest.Inference/decomposition1_batch=5_hidden_size=10_input_size=1_IS=\(5.1\)\(5.10\)\(30.1\)\(30.10\)\((40|30)\)_activations=\(tanh.relu\)_clip=0_linear_before_reset=(0|1)_WType=.*_RType=.*_BType=CONSTANT_netPRC=f32_.*)");
// retVector.emplace_back(R"(.*smoke_GRUCellCommon/GRUCellTest.Inference/decomposition1_batch=5_hidden_size=10_input_size=30_IS=\(5.30\)\(5.10\)\(30.30\)\(30.10\)\(30\)_activations=\(tanh.relu\)_clip=0_linear_before_reset=0_WType=.*_RType=.*_BType=CONSTANT_netPRC=f32.*)");
// retVector.emplace_back(R"(.*moke_Activation5D_dynamicMath_CPU/ActivationLayerCPUTest.CompareWithRefs/Log_IS=\(\[?.?\]\)_TS=\(1.50\)_\(5.128\)_\(3.64\)_AS=\(\)_ConstantsValue=\(\)_netPRC=f32_inPRC=f32_outPRC=f32_.*)");
// retVector.emplace_back(R"(.*moke_Activation5D_dynamicMath_CPU/ActivationLayerCPUTest.CompareWithRefs/Log_IS=\(\[1..5.128\]\)_TS=\(1.128\)_\(3.128\)_\(5.128\)_AS=\(\)_ConstantsValue=\(\)_netPRC=f32_inPRC=f32_outPRC=f32_.*)");
// retVector.emplace_back(R"(.*smoke_EltwiseChain_MergeConvert_int8/EltwiseChainTest.CompareWithRefs/IS=.*_TS=\(\(1.1.2.3\)_\(1.1.2.3\)_\(1.1.2.3\)_InPRC0=f16_InPRC1=f32_InPRC2=f32_Op0=Div_secondaryInputType=CONSTANT_WithQuant=0_Conversion=(i8|u8).*)");
// retVector.emplace_back(R"(.*smoke_EltwiseChain_MergeConvert_int8/EltwiseChainTest.CompareWithRefs/IS=.*_TS=\(\(1.1.2.3\)_\(1.1.2.3\)_\(1.1.2.3\)_InPRC0=f32_InPRC1=f32_InPRC2=f32_Op0=Prod_secondaryInputType=CONSTANT_WithQuant=0_Conversion=(i8|u8).*)");
// // to long
// retVector.emplace_back(R"(.*smoke_TensorIteratorCommonClip/TensorIteratorTest.Inference/.*_TensorIteratorBody=LSTM_.*_modelType=(f16|f32).*)");
// retVector.emplace_back(R"(.*smoke_EltwiseChain_MergeConvert_int8/EltwiseChainTest.CompareWithRefs/IS=.*_TS=\(\(1.1.2.3\)_\(1.1.2.3\)_\(1.1.2.3\)_InPRC0=f16_InPRC1=f32_InPRC2=f32_Op0=(Prod|Sum)_secondaryInputType=CONSTANT_WithQuant=0_Conversion=(i8|u8)_.*)");
#endif

#if defined(OPENVINO_ARCH_ARM)
Expand Down Expand Up @@ -563,14 +524,6 @@ std::vector<std::string> disabledTestPatterns() {
// Issue: 141705
retVector.emplace_back(R"(.*smoke_Deconv_(2|3)D_NSPC_INT8_AMX/DeconvolutionLayerCPUTest.*)");
retVector.emplace_back(R"(.*smoke_Deconv_(2|3)D_NSPC_INT8_AMX/DeconvolutionLayerCPUTest.*)");
// range
// retVector.emplace_back(R"(.*smoke_FC_3D_BF16/MatMulLayerCPUTest.CompareWithRefs/FullyConnected_IS=\[\?.\?\]_\[1.129.1\]_.*_netPRC=(f32|bf16)_.*config=\(INFERENCE_PRECISION_HINT=bf16_\)_Fused=Multiply\(PerChannel\)_primitive=jit_gemm.*)");
// retVector.emplace_back(R"(.*smoke_MatMulCompressedWeights_non_default_dyn_quant_group_sizes/MatmulWeightsDecompression.CompareWithRefs/.*_\(\[1,1,1728\]\)_.*_precision=(u8|u4)_decompression_precision=f32_.*_subtract=full_reshape_on_decompression=0_config=\(DYNAMIC_QUANTIZATION_GROUP_SIZE.*128.*Fused=fusingBias.*)");
// retVector.emplace_back(R"(.*smoke_LoopForDiffShapesConcat/LoopForDiffShapesLayerCPUTest.CompareWithRefs/Input0_IS=.*_TS=\(10.1.10\)_\(1.10.1\)_\(1.10.1\)_\(2.2.2\)_types=trip_count_type=PARAMETER_trip_count=(1|5)_exec_cond=1_netType=bf16.*)");
// retVector.emplace_back(R"(.*smoke_LoopForDiffShapesConcat/LoopForDiffShapesLayerCPUTest.CompareWithRefs/Input0_IS=.*_TS=\(10.5.10\)_\(1.10.1\)_\(1.10.1\)_\(2.1.2\)_types=trip_count_type=PARAMETER_trip_count=(1|5)_exec_cond=1_netType=bf16.*)");
// retVector.emplace_back(R"(.*smoke_LoopForConcat/LoopForConcatLayerCPUTest.CompareWithRefs/Input0_IS=.*_TS=\(10.5.10\)_\(1.10.1\)_\(1.10.1\)_\(2.1.2\)_types=trip_count_type=PARAMETER_trip_count=(1|5)_exec_cond=1_netType=bf16.*)");
// retVector.emplace_back(R"(.*smoke_LoopForConcat/LoopForConcatLayerCPUTest.CompareWithRefs/Input0_IS=.*_TS=\(10.10.10\)_\(5.10.10\)_\(5.10.10\)_\(8.10.10\)_Input1_IS=\[\?.10.10\]_.*_types=trip_count_type=PARAMETER_trip_count=1_exec_cond=1_netType=bf16.*)");
// retVector.emplace_back(R"(.*smoke_Deconv_(2|3)D_NSPC_INT8_AMX/DeconvolutionLayerCPUTest.*)");
}

if (ov::with_cpu_x86_avx512_core_fp16()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ TEST(VariablesTest, smoke_set_get_state_with_convert) {
auto variables = request.query_state();
ASSERT_EQ(variables.size(), 1);
auto variable = variables.front();
ASSERT_EQ(variable.get_name(), "v0");
ASSERT_EQ(variable.get_name(), "v0");
auto state_tensor = variable.get_state();
ASSERT_EQ(state_tensor.get_shape(), virable_shape);
ASSERT_EQ(state_tensor.get_element_type(), et);
Expand Down
Loading

0 comments on commit 6096bb3

Please sign in to comment.