diff --git a/kernels/portable/cpu/op_nonzero.cpp b/kernels/portable/cpu/op_nonzero.cpp index 6c149ec4de..77f80126d9 100644 --- a/kernels/portable/cpu/op_nonzero.cpp +++ b/kernels/portable/cpu/op_nonzero.cpp @@ -88,10 +88,9 @@ Tensor& nonzero_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { ET_KERNEL_CHECK(ctx, check_nonzero_args(in, out), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND( - Bool, in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] { - nonzero(ctx, in, out); - }); + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "nonzero.out", CTYPE, [&] { + nonzero(ctx, in, out); + }); return out; } diff --git a/kernels/test/op_nonzero_test.cpp b/kernels/test/op_nonzero_test.cpp index 2eb828413e..c1948a439b 100644 --- a/kernels/test/op_nonzero_test.cpp +++ b/kernels/test/op_nonzero_test.cpp @@ -28,10 +28,8 @@ class OpNonzeroTest : public OperatorTest { void test_dtype() { TensorFactory tf_input; TensorFactory tf_long; - // clang-format off - Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 0, - 2, 4}); - // clang-format on + Tensor a = tf_input.make( + /*sizes=*/{2, 2}, /*data=*/{CTYPE(2), CTYPE(0), CTYPE(2), CTYPE(4)}); Tensor out = tf_long.zeros({3, 2}); op_nonzero_out(a, out); @@ -45,7 +43,7 @@ class OpNonzeroTest : public OperatorTest { TEST_F(OpNonzeroTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY }