From 17dd4a64504acec447f9ec64bfe48bcc965fb8ab Mon Sep 17 00:00:00 2001 From: Bilyana Indzheva Date: Sun, 29 Dec 2024 04:35:07 +0200 Subject: [PATCH 1/6] Add mul --- docs/OperatorKernels.md | 2 +- .../providers/cpu/cpu_execution_provider.cc | 12 +++++ .../providers/cpu/math/element_wise_ops.cc | 6 +++ .../cpu/math/element_wise_ops_test.cc | 48 +++++++++++++++++++ 4 files changed, 67 insertions(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index eeb8ebb3ccefe..02d2e1a1f12bd 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -226,7 +226,7 @@ Do not modify directly.* |||[6, 7]|**T** = tensor(float)| |Mod|*in* A:**T**
*in* B:**T**
*out* C:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[10, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Mul|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|Mul|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Multinomial|*in* input:**T1**
*out* output:**T2**|7+|**T1** = tensor(float)
**T2** = tensor(int32), tensor(int64)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 0499a15e1df0a..acbcb066213f8 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -832,8 +832,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Sub); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, Mul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, Mul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int8_t, Mul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int16_t, Mul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int32_t, Mul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Mul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint8_t, Mul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint16_t, Mul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint32_t, Mul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint64_t, Mul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int32_t, Div); @@ -2375,8 +2381,14 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index a78ff69e5c894..bd6121e3f2762 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -189,8 +189,14 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 13, 13, int32_t, Mul); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 13, 13, int64_t, Mul); REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, float, Mul); REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, double, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, int8_t, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, int16_t, Mul); REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, int32_t, Mul); REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, int64_t, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, uint8_t, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, uint16_t, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, uint32_t, Mul); +REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, uint64_t, Mul); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 7, 12, float, Div); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 7, 12, double, Div); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index a74517840097c..d30a5a6db87dc 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -621,6 +621,22 @@ TEST(MathOpTest, Sub_Broadcast_Scalar) { run(true); } +TEST(MathOpTest, Mul_int8) { + OpTester test("Mul", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, -3, 6}); + test.AddOutput("C", {3}, {4, -6, 18}); + test.Run(); +} + +TEST(MathOpTest, Mul_int16) { + OpTester test("Mul", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, -3, 6}); + test.AddOutput("C", {3}, {4, -6, 18}); + test.Run(); +} + TEST(MathOpTest, Mul_int32) { OpTester test("Mul"); test.AddInput("A", {3}, {1, 2, 3}); @@ -637,6 +653,38 @@ TEST(MathOpTest, Mul_int64) { test.Run(); } +TEST(MathOpTest, Mul_uint8) { + OpTester test("Mul", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 3, 6}); + test.AddOutput("C", {3}, {4, 6, 18}); + test.Run(); +} + +TEST(MathOpTest, Mul_uint16) { + OpTester test("Mul", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 3, 6}); + test.AddOutput("C", {3}, {4, 6, 18}); + test.Run(); +} + +TEST(MathOpTest, Mul_uint32) { + OpTester test("Mul", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 3, 6}); + test.AddOutput("C", {3}, {4, 6, 18}); + test.Run(); +} + +TEST(MathOpTest, Mul_uint64) { + OpTester test("Mul", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 3, 6}); + test.AddOutput("C", {3}, {4, 6, 18}); + test.Run(); +} + TEST(MathOpTest, Mul) { OpTester test("Mul"); std::vector dims{3, 3}; From fa89f134766c31d1779ab31fa604a6b4d9c4dd6f Mon Sep 17 00:00:00 2001 From: Bilyana Indzheva Date: Sun, 29 Dec 2024 05:18:33 +0200 Subject: [PATCH 2/6] Add sub, add, div --- .../providers/cpu/cpu_execution_provider.cc | 134 +++++++++++++++- .../providers/cpu/math/element_wise_ops.cc | 48 ++++++ .../cpu/math/element_wise_ops_test.cc | 144 ++++++++++++++++++ 3 files changed, 318 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index acbcb066213f8..eceb96bd2ab95 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -284,8 +284,14 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, Less); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int8_t, Less); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int16_t, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, Less); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, uint8_t, Less); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, uint16_t, Less); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Less); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Less); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, EyeLike); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, IsNaN); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, IsNaN); @@ -397,8 +403,14 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 13, int32_t, CumSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 13, int64_t, CumSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, bool, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int8_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int16_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint8_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint16_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, uint64_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, Round); @@ -654,15 +666,27 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Concat); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Less); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, Less); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int16_t, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Less); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, Less); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint16_t, Less); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint32_t, Less); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint64_t, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, bool, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int16_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int64_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, uint16_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, uint32_t, Equal); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, uint64_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, float, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, double, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, float, Add); @@ -824,12 +848,24 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Trilu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, Add); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int8_t, Add); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int16_t, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int32_t, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Add); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint8_t, Add); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint16_t, Add); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint32_t, Add); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint64_t, Add); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, Sub); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, Sub); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int8_t, Sub); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int16_t, Sub); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int32_t, Sub); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Sub); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint8_t, Sub); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint16_t, Sub); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint32_t, Sub); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint64_t, Sub); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, Mul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, Mul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int8_t, Mul); @@ -842,8 +878,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint64_t, Mul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, float, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, double, Div); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int8_t, Div); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int16_t, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int32_t, Div); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint8_t, Div); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint16_t, Div); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint32_t, Div); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, uint64_t, Div); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 18, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 15, Identity); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, @@ -1019,12 +1061,14 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2FNUZ, DequantizeLinear); #endif -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, bool, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, int8_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, int16_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, int32_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, int64_t, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, float, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, double, Equal); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, string, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint8_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint16_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint32_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint64_t, Equal); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Identity); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Loop); @@ -1569,10 +1613,22 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Less)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1732,10 +1788,22 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int64_t, CumSum)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2391,8 +2495,14 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index bd6121e3f2762..fb0c7b59ae2e5 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -163,8 +163,14 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 13, 13, int32_t, Add); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 13, 13, int64_t, Add); REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, float, Add); REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, double, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, int8_t, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, int16_t, Add); REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, int32_t, Add); REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, int64_t, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, uint8_t, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, uint16_t, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, uint32_t, Add); +REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, uint64_t, Add); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 7, 12, float, Sub); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 7, 12, double, Sub); @@ -176,8 +182,14 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 13, 13, int32_t, Sub); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 13, 13, int64_t, Sub); REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, float, Sub); REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, double, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, int8_t, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, int16_t, Sub); REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, int32_t, Sub); REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, int64_t, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, uint8_t, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, uint16_t, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, uint32_t, Sub); +REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, uint64_t, Sub); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 7, 12, float, Mul); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 7, 12, double, Mul); @@ -208,8 +220,14 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 13, 13, int32_t, Div); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 13, 13, int64_t, Div); REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, float, Div); REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, double, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, int8_t, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, int16_t, Div); REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, int32_t, Div); REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, int64_t, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, uint8_t, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, uint16_t, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, uint32_t, Div); +REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, uint64_t, Div); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Abs, 6, 12, float, Abs); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Abs, 6, 12, double, Abs); @@ -314,12 +332,24 @@ REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, float, Less); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 7, 8, double, Less); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, float, Less); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, double, Less); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, int8_t, Less); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, int16_t, Less); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, int32_t, Less); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, int64_t, Less); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, uint8_t, Less); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, uint16_t, Less); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, uint32_t, Less); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Less, 9, 12, uint64_t, Less); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, float, Less); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, double, Less); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, int8_t, Less); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, int16_t, Less); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, int32_t, Less); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, int64_t, Less); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, uint8_t, Less); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, uint16_t, Less); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, uint32_t, Less); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Less, 13, uint64_t, Less); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 8, float, Greater); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 8, double, Greater); @@ -338,18 +368,36 @@ REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, int64_t, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, float, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, double, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, bool, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, int8_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, int16_t, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, int32_t, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, int64_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, uint8_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, uint16_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, uint32_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, uint64_t, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, float, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 11, 12, double, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, bool, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, int8_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, int16_t, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, int32_t, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, int64_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, uint8_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, uint16_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, uint32_t, Equal); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, uint64_t, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, float, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 13, 18, double, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, bool, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, int8_t, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, int16_t, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, int32_t, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, int64_t, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, uint8_t, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, uint16_t, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, uint32_t, Equal); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, uint64_t, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, float, Equal); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, double, Equal); using string = std::string; diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index d30a5a6db87dc..b692c47e77b45 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -183,6 +183,22 @@ TEST(MathOpTest, DimWithZeroHandling) { run(test5); } +TEST(MathOpTest, Add_int8) { + OpTester test("Add", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 5, 6}); + test.AddOutput("C", {3}, {5, 7, 9}); + test.Run(); +} + +TEST(MathOpTest, Add_int16) { + OpTester test("Add", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 5, 6}); + test.AddOutput("C", {3}, {5, 7, 9}); + test.Run(); +} + TEST(MathOpTest, Add_int32) { OpTester test("Add"); test.AddInput("A", {3}, {1, 2, 3}); @@ -199,6 +215,38 @@ TEST(MathOpTest, Add_int64) { test.Run(); } +TEST(MathOpTest, Add_uint8) { + OpTester test("Add", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 5, 6}); + test.AddOutput("C", {3}, {5, 7, 9}); + test.Run(); +} + +TEST(MathOpTest, Add_uint16) { + OpTester test("Add", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 5, 6}); + test.AddOutput("C", {3}, {5, 7, 9}); + test.Run(); +} + +TEST(MathOpTest, Add_uint32) { + OpTester test("Add", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 5, 6}); + test.AddOutput("C", {3}, {5, 7, 9}); + test.Run(); +} + +TEST(MathOpTest, Add_uint64) { + OpTester test("Add", 14); + test.AddInput("A", {3}, {1, 2, 3}); + test.AddInput("B", {3}, {4, 5, 6}); + test.AddOutput("C", {3}, {5, 7, 9}); + test.Run(); +} + TEST(MathOpTest, Add_float) { OpTester test("Add"); std::vector dims{3, 3}; @@ -567,6 +615,22 @@ TEST(MathOpTest, Add_Invalid_Broadcast) { // test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); // } +TEST(MathOpTest, Sub_int8) { + OpTester test("Sub", 14); + test.AddInput("A", {3}, {1, 5, 6}); + test.AddInput("B", {3}, {4, 5, 3}); + test.AddOutput("C", {3}, {-3, 0, 3}); + test.Run(); +} + +TEST(MathOpTest, Sub_int16) { + OpTester test("Sub", 14); + test.AddInput("A", {3}, {1, 5, 6}); + test.AddInput("B", {3}, {4, 5, 3}); + test.AddOutput("C", {3}, {-3, 0, 3}); + test.Run(); +} + TEST(MathOpTest, Sub_int32) { OpTester test("Sub"); test.AddInput("A", {3}, {1, 4, 3}); @@ -583,6 +647,38 @@ TEST(MathOpTest, Sub_int64) { test.Run(); } +TEST(MathOpTest, Sub_uint8) { + OpTester test("Sub", 14); + test.AddInput("A", {3}, {4, 5, 6}); + test.AddInput("B", {3}, {1, 5, 3}); + test.AddOutput("C", {3}, {3, 0, 3}); + test.Run(); +} + +TEST(MathOpTest, Sub_uint16) { + OpTester test("Sub", 14); + test.AddInput("A", {3}, {4, 5, 6}); + test.AddInput("B", {3}, {1, 5, 3}); + test.AddOutput("C", {3}, {3, 0, 3}); + test.Run(); +} + +TEST(MathOpTest, Sub_uint32) { + OpTester test("Sub", 14); + test.AddInput("A", {3}, {4, 5, 6}); + test.AddInput("B", {3}, {1, 5, 3}); + test.AddOutput("C", {3}, {3, 0, 3}); + test.Run(); +} + +TEST(MathOpTest, Sub_uint64) { + OpTester test("Sub", 14); + test.AddInput("A", {3}, {4, 5, 6}); + test.AddInput("B", {3}, {1, 5, 3}); + test.AddOutput("C", {3}, {3, 0, 3}); + test.Run(); +} + TEST(MathOpTest, Sub) { OpTester test("Sub"); std::vector dims{3, 3}; @@ -704,6 +800,22 @@ TEST(MathOpTest, Mul) { #endif } +TEST(MathOpTest, Div_int8) { + OpTester test("Div", 14); + test.AddInput("A", {3}, {4, 8, 8}); + test.AddInput("B", {3}, {1, 3, 2}); + test.AddOutput("C", {3}, {4, 2, 4}); + test.Run(); +} + +TEST(MathOpTest, Div_int16) { + OpTester test("Div", 14); + test.AddInput("A", {3}, {4, 8, 8}); + test.AddInput("B", {3}, {1, 3, 2}); + test.AddOutput("C", {3}, {4, 2, 4}); + test.Run(); +} + TEST(MathOpTest, Div_int32) { OpTester test("Div"); test.AddInput("A", {3}, {4, 8, 8}); @@ -722,6 +834,38 @@ TEST(MathOpTest, Div_int64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT parser:elementwise inputs must not be Int32 } +TEST(MathOpTest, Div_uint8) { + OpTester test("Div", 14); + test.AddInput("A", {3}, {4, 8, 8}); + test.AddInput("B", {3}, {1, 3, 2}); + test.AddOutput("C", {3}, {4, 2, 4}); + test.Run(); +} + +TEST(MathOpTest, Div_uint16) { + OpTester test("Div", 14); + test.AddInput("A", {3}, {4, 8, 8}); + test.AddInput("B", {3}, {1, 3, 2}); + test.AddOutput("C", {3}, {4, 2, 4}); + test.Run(); +} + +TEST(MathOpTest, Div_uint32) { + OpTester test("Div", 14); + test.AddInput("A", {3}, {4, 8, 8}); + test.AddInput("B", {3}, {1, 3, 2}); + test.AddOutput("C", {3}, {4, 2, 4}); + test.Run(); +} + +TEST(MathOpTest, Div_uint64) { + OpTester test("Div", 14); + test.AddInput("A", {3}, {4, 8, 8}); + test.AddInput("B", {3}, {2, 3, 4}); + test.AddOutput("C", {3}, {2, 2, 2}); + test.Run(); +} + TEST(MathOpTest, Div) { OpTester test("Div"); std::vector dims{2, 3}; From c37ce598c41bf55d2030c5d04816885f1cc8e22a Mon Sep 17 00:00:00 2001 From: Bilyana Indzheva Date: Sun, 29 Dec 2024 05:26:31 +0200 Subject: [PATCH 3/6] Add equal, less --- .../cpu/math/element_wise_ops_test.cc | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index b692c47e77b45..e0a6b72a1a640 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -2651,6 +2651,21 @@ TEST(MathOpTest, Less_Scalar1) { test.Run(); } +TEST(MathOpTest, Less_int8_Scalar1) { + OpTester test("Less", 9); + test.AddInput("A", {4}, {1, 0, 2, -1}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {false, true, false, true}); + test.Run(); +} +TEST(MathOpTest, Less_int16_Scalar1) { + OpTester test("Less", 9); + test.AddInput("A", {4}, {1, 0, 2, -1}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {false, true, false, true}); + test.Run(); +} + TEST(MathOpTest, Less_int64_Scalar1) { OpTester test("Less", 9); test.AddInput("A", {4}, {1, 0, 2, -1}); @@ -2658,6 +2673,38 @@ TEST(MathOpTest, Less_int64_Scalar1) { test.AddOutput("C", {4}, {false, true, false, true}); test.Run(); } + +TEST(MathOpTest, Less_uint8_Scalar1) { + OpTester test("Less", 9); + test.AddInput("A", {4}, {1, 0, 2, 3}); + test.AddInput("B", {1}, {2}); + test.AddOutput("C", {4}, {true, true, false, false}); + test.Run(); +} +TEST(MathOpTest, Less_uint16_Scalar1) { + OpTester test("Less", 9); + test.AddInput("A", {4}, {1, 0, 2, 3}); + test.AddInput("B", {1}, {2}); + test.AddOutput("C", {4}, {true, true, false, false}); + test.Run(); +} + +TEST(MathOpTest, Less_uint32_Scalar1) { + OpTester test("Less", 9); + test.AddInput("A", {4}, {1, 0, 2, 3}); + test.AddInput("B", {1}, {2}); + test.AddOutput("C", {4}, {true, true, false, false}); + test.Run(); +} + +TEST(MathOpTest, Less_uint64_Scalar1) { + OpTester test("Less", 9); + test.AddInput("A", {4}, {1, 0, 2, 3}); + test.AddInput("B", {1}, {2}); + test.AddOutput("C", {4}, {true, true, false, false}); + test.Run(); +} + TEST(MathOpTest, Less_broadcastAB) { OpTester test("Less", 9); test.AddInput("A", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17}); @@ -3171,6 +3218,24 @@ TEST(MathOpTest, Equal_bool_scalar1) { test.Run(); } +TEST(MathOpTest, Equal_int8) { + OpTester test("Equal", 11); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, -1, -1}); + test.AddInput("B", dims, {1, 1, 2, -1}); + test.AddOutput("C", dims, {true, false, false, true}); + test.Run(); +} + +TEST(MathOpTest, Equal_int16) { + OpTester test("Equal", 11); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, -1, -1}); + test.AddInput("B", dims, {1, 1, 2, -1}); + test.AddOutput("C", dims, {true, false, false, true}); + test.Run(); +} + TEST(MathOpTest, Equal_int32) { OpTester test("Equal"); std::vector dims{4}; @@ -3189,6 +3254,42 @@ TEST(MathOpTest, Equal_int64) { test.Run(); } +TEST(MathOpTest, Equal_uint8) { + OpTester test("Equal", 11); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, 1, 1}); + test.AddInput("B", dims, {1, 1, 2, 1}); + test.AddOutput("C", dims, {true, false, false, true}); + test.Run(); +} + +TEST(MathOpTest, Equal_uint16) { + OpTester test("Equal", 11); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, 1, 1}); + test.AddInput("B", dims, {1, 1, 2, 1}); + test.AddOutput("C", dims, {true, false, false, true}); + test.Run(); +} + +TEST(MathOpTest, Equal_uint32) { + OpTester test("Equal", 11); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, 1, 1}); + test.AddInput("B", dims, {1, 1, 2, 1}); + test.AddOutput("C", dims, {true, false, false, true}); + test.Run(); +} + +TEST(MathOpTest, Equal_uint64) { + OpTester test("Equal", 11); + std::vector dims{4}; + test.AddInput("A", dims, {1, 0, 1, 1}); + test.AddInput("B", dims, {1, 1, 2, 1}); + test.AddOutput("C", dims, {true, false, false, true}); + test.Run(); +} + TEST(MathOpTest, Equal_float) { OpTester test("Equal", 11); std::vector dims{4}; From ba120dbecfc4f820430f4872f61da2c699b3821d Mon Sep 17 00:00:00 2001 From: Bilyana Indzheva Date: Sun, 29 Dec 2024 05:39:55 +0200 Subject: [PATCH 4/6] Add LessOrEqual, GreaterOrEqual --- .../providers/cpu/cpu_execution_provider.cc | 128 ++++++++++++++++-- .../providers/cpu/math/element_wise_ops.cc | 36 +++++ .../cpu/math/element_wise_ops_test.cc | 103 ++++++++++++++ 3 files changed, 259 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index eceb96bd2ab95..74c0525d69e95 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -280,8 +280,14 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, MeanVarianceNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, Greater); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int8_t, Greater); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int16_t, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, Greater); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, uint8_t, Greater); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, uint16_t, Greater); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, uint32_t, Greater); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, uint64_t, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, Less); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int8_t, Less); @@ -592,14 +598,32 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn GreaterOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int8_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int16_t, + GreaterOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, uint8_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, uint16_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, uint32_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, uint64_t, + GreaterOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int8_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int16_t, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int64_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, uint8_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, uint16_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, uint32_t, LessOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, uint64_t, LessOrEqual); // opset 13 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Erf); @@ -676,8 +700,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint64_t, Less); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Greater); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int8_t, Greater); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int16_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, Greater); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int64_t, Greater); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint8_t, Greater); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint16_t, Greater); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint32_t, Greater); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, uint64_t, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, bool, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, Equal); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 18, int16_t, Equal); @@ -931,12 +961,24 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, PR class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, 18, Scan); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, GreaterOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, GreaterOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int8_t, GreaterOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int16_t, GreaterOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int32_t, GreaterOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int64_t, GreaterOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, uint8_t, GreaterOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, uint16_t, GreaterOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, uint32_t, GreaterOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, uint64_t, GreaterOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int8_t, LessOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int16_t, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int32_t, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int64_t, LessOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, uint8_t, LessOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, uint16_t, LessOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, uint32_t, LessOrEqual); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, uint64_t, LessOrEqual); // Opset 17 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, 19, DFT); @@ -1605,10 +1647,22 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Greater)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 13 BuildKernelCreateInfo, @@ -2211,14 +2289,24 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index fb0c7b59ae2e5..ed576fc7699f3 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -355,12 +355,24 @@ REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 8, float, Greater); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 7, 8, double, Greater); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, float, Greater); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, double, Greater); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, int8_t, Greater); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, int16_t, Greater); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, int32_t, Greater); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, int64_t, Greater); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, uint8_t, Greater); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, uint16_t, Greater); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, uint32_t, Greater); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Greater, 9, 12, uint64_t, Greater); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, float, Greater); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, double, Greater); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, int8_t, Greater); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, int16_t, Greater); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, int32_t, Greater); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, int64_t, Greater); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, uint8_t, Greater); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, uint16_t, Greater); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, uint32_t, Greater); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Greater, 13, uint64_t, Greater); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, bool, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(Equal, 7, 10, int32_t, Equal); @@ -405,25 +417,49 @@ REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(Equal, 19, string, Equal); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, float, LessOrEqual); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, double, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, int8_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, int16_t, LessOrEqual); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, int32_t, LessOrEqual); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, int64_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, uint8_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, uint16_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, uint32_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(LessOrEqual, 12, 15, uint64_t, LessOrEqual); // Opset-16 adds BFloat16 to allowed types for the LessOrEqual operator REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, float, LessOrEqual); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, double, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, int8_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, int16_t, LessOrEqual); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, int32_t, LessOrEqual); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, int64_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, uint8_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, uint16_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, uint32_t, LessOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(LessOrEqual, 16, uint64_t, LessOrEqual); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, float, GreaterOrEqual); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, double, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, int8_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, int16_t, GreaterOrEqual); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, int32_t, GreaterOrEqual); REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, int64_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, uint8_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, uint16_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, uint32_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_VERSIONED_TYPED_KERNEL(GreaterOrEqual, 12, 15, uint64_t, GreaterOrEqual); // Opset-16 adds BFloat16 to allowed types for the GreaterOrEqual operator REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, float, GreaterOrEqual); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, double, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, int8_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, int16_t, GreaterOrEqual); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, int32_t, GreaterOrEqual); REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, int64_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, uint8_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, uint16_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, uint32_t, GreaterOrEqual); +REG_ELEMENTWISE_LOGICALOP_TYPED_KERNEL(GreaterOrEqual, 16, uint64_t, GreaterOrEqual); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 6, 7, float, Mean_6); REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mean, 8, 12, float, Mean_8); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index e0a6b72a1a640..e5aa6688fc5e0 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -2765,6 +2765,22 @@ TEST(MathOpTest, LessOrEqual_Scalar1) { {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } +TEST(MathOpTest, LessOrEqual_int8_Scalar1) { + OpTester test("LessOrEqual", 12); + test.AddInput("A", {4}, {1, 0, 2, -1}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {true, true, false, true}); + test.Run(); +} + +TEST(MathOpTest, LessOrEqual_int16_Scalar1) { + OpTester test("LessOrEqual", 12); + test.AddInput("A", {4}, {1, 0, 2, -1}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {true, true, false, true}); + test.Run(); +} + TEST(MathOpTest, LessOrEqual_int64_Scalar1) { OpTester test("LessOrEqual", 12); test.AddInput("A", {4}, {1, 0, 2, -1}); @@ -2773,6 +2789,39 @@ TEST(MathOpTest, LessOrEqual_int64_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kNnapiExecutionProvider, kOpenVINOExecutionProvider}); } + +TEST(MathOpTest, LessOrEqual_uint8_Scalar1) { + OpTester test("LessOrEqual", 12); + test.AddInput("A", {4}, {1, 0, 2, 3}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {true, true, false, false}); + test.Run(); +} + +TEST(MathOpTest, LessOrEqual_uint16_Scalar1) { + OpTester test("LessOrEqual", 12); + test.AddInput("A", {4}, {1, 0, 2, 3}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {true, true, false, false}); + test.Run(); +} + +TEST(MathOpTest, LessOrEqual_uint32_Scalar1) { + OpTester test("LessOrEqual", 12); + test.AddInput("A", {4}, {1, 0, 2, 3}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {true, true, false, false}); + test.Run(); +} + +TEST(MathOpTest, LessOrEqual_uint64_Scalar1) { + OpTester test("LessOrEqual", 12); + test.AddInput("A", {4}, {1, 0, 2, 3}); + test.AddInput("B", {1}, {1}); + test.AddOutput("C", {4}, {true, true, false, false}); + test.Run(); +} + TEST(MathOpTest, LessOrEqual_broadcastAB) { OpTester test("LessOrEqual", 12); test.AddInput("A", {4, 2}, {10, 11, 12, 13, 14, 15, 16, 17}); @@ -3006,6 +3055,24 @@ TEST(MathOpTest, Greater_9_double) { test.Run(); } +TEST(MathOpTest, Greater_9_int8) { + OpTester test("Greater", 9); + std::vector dims{4}; + test.AddInput("A", dims, {10, 11, 12, 13}); + test.AddInput("B", dims, {15, 7, 12, 9}); + test.AddOutput("C", dims, {false, true, false, true}); + test.Run(); +} + +TEST(MathOpTest, Greater_9_int16) { + OpTester test("Greater", 9); + std::vector dims{4}; + test.AddInput("A", dims, {10, 11, 12, 13}); + test.AddInput("B", dims, {15, 7, 12, 9}); + test.AddOutput("C", dims, {false, true, false, true}); + test.Run(); +} + TEST(MathOpTest, Greater_9_int32) { OpTester test("Greater", 9); std::vector dims{4}; @@ -3023,6 +3090,42 @@ TEST(MathOpTest, Greater_9_int64) { test.AddOutput("C", dims, {false, true, false, true}); test.Run(); } + +TEST(MathOpTest, Greater_9_uint8) { + OpTester test("Greater", 9); + std::vector dims{4}; + test.AddInput("A", dims, {10, 11, 12, 13}); + test.AddInput("B", dims, {15, 7, 12, 9}); + test.AddOutput("C", dims, {false, true, false, true}); + test.Run(); +} + +TEST(MathOpTest, Greater_9_uint16) { + OpTester test("Greater", 9); + std::vector dims{4}; + test.AddInput("A", dims, {10, 11, 12, 13}); + test.AddInput("B", dims, {15, 7, 12, 9}); + test.AddOutput("C", dims, {false, true, false, true}); + test.Run(); +} + +TEST(MathOpTest, Greater_9_uint32) { + OpTester test("Greater", 9); + std::vector dims{4}; + test.AddInput("A", dims, {10, 11, 12, 13}); + test.AddInput("B", dims, {15, 7, 12, 9}); + test.AddOutput("C", dims, {false, true, false, true}); + test.Run(); +} + +TEST(MathOpTest, Greater_9_uint64) { + OpTester test("Greater", 9); + std::vector dims{4}; + test.AddInput("A", dims, {10, 11, 12, 13}); + test.AddInput("B", dims, {15, 7, 12, 9}); + test.AddOutput("C", dims, {false, true, false, true}); + test.Run(); +} #if defined(USE_DNNL) TEST(MathOpTest, Greater_13_bfloat16) { #ifdef USE_DNNL From 9677e2db57112cad5e5d0eba3efbd7da421de450 Mon Sep 17 00:00:00 2001 From: Bilyana Indzheva Date: Sun, 29 Dec 2024 05:46:06 +0200 Subject: [PATCH 5/6] Add docs --- docs/OperatorKernels.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 02d2e1a1f12bd..2de2daa724ae3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -21,7 +21,7 @@ Do not modify directly.* |||[6, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Acos|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float)| |Acosh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float)| -|Add|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|Add|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| @@ -92,7 +92,7 @@ Do not modify directly.* |||[13, 18]|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||[10, 12]|**T** = tensor(int32), tensor(int8), tensor(uint8)| |Det|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(float)| -|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Dropout|*in* data:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T**

or

*in* data:**T**
*out* output:**T**
*out* mask:**T1**|13+|**T** = tensor(double), tensor(float)
**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| @@ -103,9 +103,9 @@ Do not modify directly.* |DynamicSlice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)
**T1** = tensor(bool)| -|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 10]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[9, 12]|**T** = tensor(float)| @@ -139,11 +139,11 @@ Do not modify directly.* |GlobalAveragePool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GlobalLpPool|*in* X:**T**
*out* Y:**T**|2+|**T** = tensor(float)| |GlobalMaxPool|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| -|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Greater|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| -|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(double), tensor(float)| |||[16, 19]|**T1** = tensor(float)
**T2** = tensor(float)| |HammingWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -179,11 +179,11 @@ Do not modify directly.* |||[1, 16]|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float), tensor(float16)
**V** = tensor(double), tensor(float), tensor(float16)| |LeakyRelu|*in* X:**T**
*out* Y:**T**|16+|**T** = tensor(float)| |||[6, 15]|**T** = tensor(float)| -|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|Less|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[9, 12]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 8]|**T** = tensor(double), tensor(float)
**T1** = tensor(bool)| -|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| -|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| +|LessOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||[12, 15]|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |Log|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |LogSoftmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| @@ -402,7 +402,7 @@ Do not modify directly.* |StringConcat|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|20+|**T** = tensor(string)| |StringNormalizer|*in* X:**tensor(string)**
*out* Y:**tensor(string)**|10+|**X** = tensor(string)| |StringSplit|*in* X:**T1**
*out* Y:**T2**
*out* Z:**T3**|20+|**T1** = tensor(string)
**T2** = tensor(string)
**T3** = tensor(int64)| -|Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| +|Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Sum|*in* data_0:**T**
*out* sum:**T**|13+|**T** = tensor(double), tensor(float)| From b35146d0418341aecc3d00fa062aaa0f5760a979 Mon Sep 17 00:00:00 2001 From: Bilyana Indzheva Date: Sun, 29 Dec 2024 06:03:42 +0200 Subject: [PATCH 6/6] Small fix --- docs/OperatorKernels.md | 2 +- .../providers/cpu/cpu_execution_provider.cc | 46 +++++++++++-------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2de2daa724ae3..a18a84083d4de 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -103,7 +103,7 @@ Do not modify directly.* |DynamicSlice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[13, 18]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||[7, 10]|**T** = tensor(bool), tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 74c0525d69e95..df1ffaed3ae12 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1103,6 +1103,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Float8E5M2FNUZ, DequantizeLinear); #endif +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, bool, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, int8_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, int16_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, int32_t, Equal); @@ -1111,6 +1112,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint16_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint32_t, Equal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, uint64_t, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, float, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, double, Equal); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, string, Equal); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Identity); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, 20, Loop); @@ -2289,24 +2293,26 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, #endif + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo,