Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ARM Neon and scalar implementations of SIMD functions #359

Merged
merged 4 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ fastmap.o: bwa.h bntseq.h bwt.h bwamem.h kvec.h malloc_wrap.h utils.h kseq.h
is.o: malloc_wrap.h
kopen.o: malloc_wrap.h
kstring.o: kstring.h malloc_wrap.h
ksw.o: ksw.h malloc_wrap.h
ksw.o: ksw.h neon_sse.h scalar_sse.h malloc_wrap.h
main.o: kstring.h malloc_wrap.h utils.h
malloc_wrap.o: malloc_wrap.h
maxk.o: bwa.h bntseq.h bwt.h bwamem.h kseq.h malloc_wrap.h
Expand Down
46 changes: 41 additions & 5 deletions ksw.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@
#include <stdlib.h>
#include <stdint.h>
#include <assert.h>
#if defined __SSE2__
#include <emmintrin.h>
#elif defined __ARM_NEON
#include "neon_sse.h"
#else
#include "scalar_sse.h"
#endif
#include "ksw.h"

#ifdef USE_MALLOC_WRAPPERS
Expand Down Expand Up @@ -108,13 +114,19 @@ kswq_t *ksw_qinit(int size, int qlen, const uint8_t *query, int m, const int8_t
return q;
}

#if defined __ARM_NEON
// This macro implicitly uses each function's `zero` local variable
#define _mm_slli_si128(a, n) (vextq_u8(zero, (a), 16 - (n)))
#endif

kswr_t ksw_u8(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_del, int _o_ins, int _e_ins, int xtra) // the first gap costs -(_o+_e)
{
int slen, i, m_b, n_b, te = -1, gmax = 0, minsc, endsc;
uint64_t *b;
__m128i zero, oe_del, e_del, oe_ins, e_ins, shift, *H0, *H1, *E, *Hmax;
kswr_t r;

#if defined __SSE2__
#define __max_16(ret, xx) do { \
(xx) = _mm_max_epu8((xx), _mm_srli_si128((xx), 8)); \
(xx) = _mm_max_epu8((xx), _mm_srli_si128((xx), 4)); \
Expand All @@ -123,6 +135,18 @@ kswr_t ksw_u8(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_del
(ret) = _mm_extract_epi16((xx), 0) & 0x00ff; \
} while (0)

// Given entries with arbitrary values, return whether they are all 0x00
#define allzero_16(xx) (_mm_movemask_epi8(_mm_cmpeq_epi8((xx), zero)) == 0xffff)

#elif defined __ARM_NEON
#define __max_16(ret, xx) (ret) = vmaxvq_u8((xx))
#define allzero_16(xx) (vmaxvq_u8((xx)) == 0)

#else
#define __max_16(ret, xx) (ret) = m128i_max_u8((xx))
#define allzero_16(xx) (m128i_allzero((xx)))
#endif

// initialization
r = g_defr;
minsc = (xtra&KSW_XSUBO)? xtra&0xffff : 0x10000;
Expand All @@ -143,7 +167,7 @@ kswr_t ksw_u8(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_del
}
// the core loop
for (i = 0; i < tlen; ++i) {
int j, k, cmp, imax;
int j, k, imax;
__m128i e, h, t, f = zero, max = zero, *S = q->qp + target[i] * slen; // s is the 1st score vector
h = _mm_load_si128(H0 + slen - 1); // h={2,5,8,11,14,17,-1,-1} in the above example
h = _mm_slli_si128(h, 1); // h=H(i-1,-1); << instead of >> because x64 is little-endian
Expand Down Expand Up @@ -182,8 +206,7 @@ kswr_t ksw_u8(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_del
_mm_store_si128(H1 + j, h);
h = _mm_subs_epu8(h, oe_ins);
f = _mm_subs_epu8(f, e_ins);
cmp = _mm_movemask_epi8(_mm_cmpeq_epi8(_mm_subs_epu8(f, h), zero));
if (UNLIKELY(cmp == 0xffff)) goto end_loop16;
if (UNLIKELY(allzero_16(_mm_subs_epu8(f, h)))) goto end_loop16;
}
}
end_loop16:
Expand Down Expand Up @@ -236,13 +259,26 @@ kswr_t ksw_i16(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_de
__m128i zero, oe_del, e_del, oe_ins, e_ins, *H0, *H1, *E, *Hmax;
kswr_t r;

#if defined __SSE2__
#define __max_8(ret, xx) do { \
(xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 8)); \
(xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 4)); \
(xx) = _mm_max_epi16((xx), _mm_srli_si128((xx), 2)); \
(ret) = _mm_extract_epi16((xx), 0); \
} while (0)

// Given entries all either 0x0000 or 0xffff, return whether they are all 0x0000
#define allzero_0f_8(xx) (!_mm_movemask_epi8((xx)))

#elif defined __ARM_NEON
#define __max_8(ret, xx) (ret) = vmaxvq_s16(vreinterpretq_s16_u8((xx)))
#define allzero_0f_8(xx) (vmaxvq_u16(vreinterpretq_u16_u8((xx))) == 0)

#else
#define __max_8(ret, xx) (ret) = m128i_max_s16((xx))
#define allzero_0f_8(xx) (m128i_allzero((xx)))
#endif

// initialization
r = g_defr;
minsc = (xtra&KSW_XSUBO)? xtra&0xffff : 0x10000;
Expand All @@ -267,7 +303,7 @@ kswr_t ksw_i16(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_de
h = _mm_load_si128(H0 + slen - 1); // h={2,5,8,11,14,17,-1,-1} in the above example
h = _mm_slli_si128(h, 2);
for (j = 0; LIKELY(j < slen); ++j) {
h = _mm_adds_epi16(h, *S++);
h = _mm_adds_epi16(h, _mm_load_si128(S++));
e = _mm_load_si128(E + j);
h = _mm_max_epi16(h, e);
h = _mm_max_epi16(h, f);
Expand All @@ -290,7 +326,7 @@ kswr_t ksw_i16(kswq_t *q, int tlen, const uint8_t *target, int _o_del, int _e_de
_mm_store_si128(H1 + j, h);
h = _mm_subs_epu16(h, oe_ins);
f = _mm_subs_epu16(f, e_ins);
if(UNLIKELY(!_mm_movemask_epi8(_mm_cmpgt_epi16(f, h)))) goto end_loop8;
if(UNLIKELY(allzero_0f_8(_mm_cmpgt_epi16(f, h)))) goto end_loop8;
}
}
end_loop8:
Expand Down
33 changes: 33 additions & 0 deletions neon_sse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#ifndef NEON_SSE_H
#define NEON_SSE_H

#include <arm_neon.h>

typedef uint8x16_t __m128i;

static inline __m128i _mm_load_si128(const __m128i *ptr) { return vld1q_u8((const uint8_t *) ptr); }
static inline __m128i _mm_set1_epi32(int n) { return vreinterpretq_u8_s32(vdupq_n_s32(n)); }
static inline void _mm_store_si128(__m128i *ptr, __m128i a) { vst1q_u8((uint8_t *) ptr, a); }

static inline __m128i _mm_adds_epu8(__m128i a, __m128i b) { return vqaddq_u8(a, b); }
static inline __m128i _mm_max_epu8(__m128i a, __m128i b) { return vmaxq_u8(a, b); }
static inline __m128i _mm_set1_epi8(int8_t n) { return vreinterpretq_u8_s8(vdupq_n_s8(n)); }
static inline __m128i _mm_subs_epu8(__m128i a, __m128i b) { return vqsubq_u8(a, b); }

#define M128I(a) vreinterpretq_u8_s16((a))
#define UM128I(a) vreinterpretq_u8_u16((a))
#define S16(a) vreinterpretq_s16_u8((a))
#define U16(a) vreinterpretq_u16_u8((a))

static inline __m128i _mm_adds_epi16(__m128i a, __m128i b) { return M128I(vqaddq_s16(S16(a), S16(b))); }
static inline __m128i _mm_cmpgt_epi16(__m128i a, __m128i b) { return UM128I(vcgtq_s16(S16(a), S16(b))); }
static inline __m128i _mm_max_epi16(__m128i a, __m128i b) { return M128I(vmaxq_s16(S16(a), S16(b))); }
static inline __m128i _mm_set1_epi16(int16_t n) { return vreinterpretq_u8_s16(vdupq_n_s16(n)); }
static inline __m128i _mm_subs_epu16(__m128i a, __m128i b) { return UM128I(vqsubq_u16(U16(a), U16(b))); }

#undef M128I
#undef UM128I
#undef S16
#undef U16

#endif
119 changes: 119 additions & 0 deletions scalar_sse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#ifndef SCALAR_SSE_H
#define SCALAR_SSE_H

#include <assert.h>
#include <stdint.h>
#include <string.h>

typedef union m128i {
uint8_t u8[16];
int16_t i16[8];
} __m128i;

static inline __m128i _mm_set1_epi32(int32_t n) {
assert(n >= 0 && n <= 255);
__m128i r; memset(&r, n, sizeof r); return r;
}

static inline __m128i _mm_load_si128(const __m128i *ptr) { __m128i r; memcpy(&r, ptr, sizeof r); return r; }
static inline void _mm_store_si128(__m128i *ptr, __m128i a) { memcpy(ptr, &a, sizeof a); }

static inline int m128i_allzero(__m128i a) {
static const char zero[] = "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0";
return memcmp(&a, zero, sizeof a) == 0;
}

static inline __m128i _mm_slli_si128(__m128i a, int n) {
int i;
memmove(&a.u8[n], &a.u8[0], 16 - n);
for (i = 0; i < n; i++) a.u8[i] = 0;
return a;
}

static inline __m128i _mm_adds_epu8(__m128i a, __m128i b) {
int i;
for (i = 0; i < 16; i++) {
uint16_t aa = a.u8[i];
aa += b.u8[i];
a.u8[i] = (aa < 256)? aa : 255;
}
return a;
}

static inline __m128i _mm_max_epu8(__m128i a, __m128i b) {
int i;
for (i = 0; i < 16; i++)
if (a.u8[i] < b.u8[i]) a.u8[i] = b.u8[i];
return a;
}

static inline uint8_t m128i_max_u8(__m128i a) {
uint8_t max = 0;
int i;
for (i = 0; i < 16; i++)
if (max < a.u8[i]) max = a.u8[i];
return max;
}

static inline __m128i _mm_set1_epi8(int8_t n) { __m128i r; memset(&r, n, sizeof r); return r; }

static inline __m128i _mm_subs_epu8(__m128i a, __m128i b) {
int i;
for (i = 0; i < 16; i++) {
int16_t aa = a.u8[i];
aa -= b.u8[i];
a.u8[i] = (aa >= 0)? aa : 0;
}
return a;
}

static inline __m128i _mm_adds_epi16(__m128i a, __m128i b) {
int i;
for (i = 0; i < 8; i++) {
int32_t aa = a.i16[i];
aa += b.i16[i];
a.i16[i] = (aa < 32768)? aa : 32767;
}
return a;
}

static inline __m128i _mm_cmpgt_epi16(__m128i a, __m128i b) {
int i;
for (i = 0; i < 8; i++)
a.i16[i] = (a.i16[i] > b.i16[i])? 0xffff : 0x0000;
return a;
}

static inline __m128i _mm_max_epi16(__m128i a, __m128i b) {
int i;
for (i = 0; i < 8; i++)
if (a.i16[i] < b.i16[i]) a.i16[i] = b.i16[i];
return a;
}

static inline __m128i _mm_set1_epi16(int16_t n) {
__m128i r;
r.i16[0] = r.i16[1] = r.i16[2] = r.i16[3] =
r.i16[4] = r.i16[5] = r.i16[6] = r.i16[7] = n;
return r;
}

static inline int16_t m128i_max_s16(__m128i a) {
int16_t max = -32768;
int i;
for (i = 0; i < 8; i++)
if (max < a.i16[i]) max = a.i16[i];
return max;
}

static inline __m128i _mm_subs_epu16(__m128i a, __m128i b) {
int i;
for (i = 0; i < 8; i++) {
int32_t aa = a.i16[i];
aa -= b.i16[i];
a.i16[i] = (aa >= 0)? aa : 0;
}
return a;
}

#endif