diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 7eef09e55101d0..f718cbf65480ab 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -7095,8 +7095,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load auto *MLoad = dyn_cast(N0); ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true); - if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat && - N1.hasOneUse()) { + if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && Splat) { EVT LoadVT = MLoad->getMemoryVT(); EVT ExtVT = VT; if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) { diff --git a/llvm/test/CodeGen/AArch64/sve-hadd.ll b/llvm/test/CodeGen/AArch64/sve-hadd.ll index ce440d3095d3f3..857a883d80ea3d 100644 --- a/llvm/test/CodeGen/AArch64/sve-hadd.ll +++ b/llvm/test/CodeGen/AArch64/sve-hadd.ll @@ -1341,3 +1341,66 @@ entry: %avg = ashr %add, splat (i64 1) ret %avg } + +define void @zext_mload_avgflooru(ptr %p1, ptr %p2, %mask) { +; SVE-LABEL: zext_mload_avgflooru: +; SVE: // %bb.0: +; SVE-NEXT: ld1b { z0.h }, p0/z, [x0] +; SVE-NEXT: ld1b { z1.h }, p0/z, [x1] +; SVE-NEXT: eor z2.d, z0.d, z1.d +; SVE-NEXT: and z0.d, z0.d, z1.d +; SVE-NEXT: lsr z1.h, z2.h, #1 +; SVE-NEXT: add z0.h, z0.h, z1.h +; SVE-NEXT: st1h { z0.h }, p0, [x0] +; SVE-NEXT: ret +; +; SVE2-LABEL: zext_mload_avgflooru: +; SVE2: // %bb.0: +; SVE2-NEXT: ld1b { z0.h }, p0/z, [x0] +; SVE2-NEXT: ld1b { z1.h }, p0/z, [x1] +; SVE2-NEXT: ptrue p1.h +; SVE2-NEXT: uhadd z0.h, p1/m, z0.h, z1.h +; SVE2-NEXT: st1h { z0.h }, p0, [x0] +; SVE2-NEXT: ret + %ld1 = call @llvm.masked.load(ptr %p1, i32 16, %mask, zeroinitializer) + %ld2 = call @llvm.masked.load(ptr %p2, i32 16, %mask, zeroinitializer) + %and = and %ld1, %ld2 + %xor = xor %ld1, %ld2 + %shift = lshr %xor, splat(i8 1) + %avg = add %and, %shift + %avgext = zext %avg to + call void @llvm.masked.store.nxv8i16( %avgext, ptr %p1, i32 16, %mask) + ret void +} + +define void @zext_mload_avgceilu(ptr %p1, ptr %p2, %mask) { +; SVE-LABEL: zext_mload_avgceilu: +; SVE: // %bb.0: +; SVE-NEXT: ld1b { z0.h }, p0/z, [x0] +; SVE-NEXT: ld1b { z1.h }, p0/z, [x1] +; SVE-NEXT: eor z2.d, z0.d, z1.d +; SVE-NEXT: orr z0.d, z0.d, z1.d +; SVE-NEXT: lsr z1.h, z2.h, #1 +; SVE-NEXT: sub z0.h, z0.h, z1.h +; SVE-NEXT: st1b { z0.h }, p0, [x0] +; SVE-NEXT: ret +; +; SVE2-LABEL: zext_mload_avgceilu: +; SVE2: // %bb.0: +; SVE2-NEXT: ld1b { z0.h }, p0/z, [x0] +; SVE2-NEXT: ld1b { z1.h }, p0/z, [x1] +; SVE2-NEXT: ptrue p1.h +; SVE2-NEXT: urhadd z0.h, p1/m, z0.h, z1.h +; SVE2-NEXT: st1b { z0.h }, p0, [x0] +; SVE2-NEXT: ret + %ld1 = call @llvm.masked.load(ptr %p1, i32 16, %mask, zeroinitializer) + %ld2 = call @llvm.masked.load(ptr %p2, i32 16, %mask, zeroinitializer) + %zext1 = zext %ld1 to + %zext2 = zext %ld2 to + %add1 = add nuw nsw %zext1, splat(i16 1) + %add2 = add nuw nsw %add1, %zext2 + %shift = lshr %add2, splat(i16 1) + %trunc = trunc %shift to + call void @llvm.masked.store.nxv8i8( %trunc, ptr %p1, i32 16, %mask) + ret void +}