// Copyright 2022 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. $assert BATCH_TILE % 8 == 0 $assert BATCH_TILE >= 8 $SIMD_TILE = BATCH_TILE // 8 #include #include #include #include #include #include $SHIFT_VARIANT = "_shift%s" % SHIFT if SHIFT else "" void xnn_s16_window${SHIFT_VARIANT}_ukernel__neon_x${BATCH_TILE}( size_t rows, size_t batch_size, const int16_t* input, const int16_t* weights, int16_t* output, uint32_t shift) { assert(rows != 0); assert(batch_size != 0); assert(input != NULL); assert(weights != NULL); assert(output != NULL); $if SHIFT != 0: assert(shift == ${SHIFT}); $else: assert(shift < 32); $if SHIFT == 0: const int32x4_t vshift = vdupq_n_s32(-(int32_t)shift); // negative to shift right. do { const int16_t* w = weights; size_t n = batch_size * sizeof(int16_t); $if BATCH_TILE > 8: for (; n >= ${BATCH_TILE} * sizeof(int16_t); n -= ${BATCH_TILE} * sizeof(int16_t)) { $for N in range(SIMD_TILE): const int16x8_t vi${N} = vld1q_s16(input); input += 8; $for N in range(SIMD_TILE): const int16x8_t vw${N} = vld1q_s16(w); w += 8; $if SHIFT == 15: $for N in range(SIMD_TILE): const int16x8_t vout${N} = vqdmulhq_s16(vi${N}, vw${N}); $else: $for N in range(SIMD_TILE): int32x4_t vacc${N}_lo = vmull_s16(vget_low_s16(vi${N}), vget_low_s16(vw${N})); int32x4_t vacc${N}_hi = vmull_s16(vget_high_s16(vi${N}), vget_high_s16(vw${N})); $if SHIFT != 0: $for N in range(SIMD_TILE): const int16x4_t vshift${N}_lo = vqshrn_n_s32(vacc${N}_lo, ${SHIFT}); const int16x4_t vshift${N}_hi = vqshrn_n_s32(vacc${N}_hi, ${SHIFT}); $for N in range(SIMD_TILE): const int16x8_t vout${N} = vcombine_s16(vshift${N}_lo, vshift${N}_hi); $else: $for N in range(SIMD_TILE): vacc${N}_lo = vshlq_s32(vacc${N}_lo, vshift); vacc${N}_hi = vshlq_s32(vacc${N}_hi, vshift); $for N in range(SIMD_TILE): const int16x8_t vout${N} = vcombine_s16(vqmovn_s32(vacc${N}_lo), vqmovn_s32(vacc${N}_hi)); $for N in range(SIMD_TILE): vst1q_s16(output, vout${N}); output += 8; } // Remainder of full vectors for (; n >= 8 * sizeof(int16_t); n -= 8 * sizeof(int16_t)) { const int16x8_t vi = vld1q_s16(input); input += 8; const int16x8_t vw = vld1q_s16(w); w += 8; $if SHIFT == 15: const int16x8_t vout = vqdmulhq_s16(vi, vw); $else: int32x4_t vacc_lo = vmull_s16(vget_low_s16(vi), vget_low_s16(vw)); int32x4_t vacc_hi = vmull_s16(vget_high_s16(vi), vget_high_s16(vw)); $if SHIFT != 0: const int16x4_t vshift_lo = vqshrn_n_s32(vacc_lo, ${SHIFT}); const int16x4_t vshift_hi = vqshrn_n_s32(vacc_hi, ${SHIFT}); const int16x8_t vout = vcombine_s16(vshift_lo, vshift_hi); $else: vacc_lo = vshlq_s32(vacc_lo, vshift); vacc_hi = vshlq_s32(vacc_hi, vshift); const int16x8_t vout = vcombine_s16(vqmovn_s32(vacc_lo), vqmovn_s32(vacc_hi)); vst1q_s16(output, vout); output += 8; } assert(n % 2 == 0); // Remainder of 1 to 7 batch_size if XNN_UNLIKELY(n != 0) { const int16x8_t vi = vld1q_s16(input); input = (const int16_t*) ((uintptr_t) input + n); const int16x8_t vw = vld1q_s16(w); $if SHIFT == 15: int16x4_t vout = vqdmulh_s16(vget_low_s16(vi), vget_low_s16(vw)); $else: int32x4_t vacc = vmull_s16(vget_low_s16(vi), vget_low_s16(vw)); $if SHIFT != 0: int16x4_t vout = vqshrn_n_s32(vacc, ${SHIFT}); $else: vacc = vshlq_s32(vacc, vshift); int16x4_t vout = vqmovn_s32(vacc); if (n & (4 * sizeof(int16_t))) { vst1_s16(output, vout); output += 4; $if SHIFT == 15: vout = vqdmulh_s16(vget_high_s16(vi), vget_high_s16(vw)); $else: vacc = vmull_s16(vget_high_s16(vi), vget_high_s16(vw)); $if SHIFT != 0: vout = vqshrn_n_s32(vacc, ${SHIFT}); $else: vacc = vshlq_s32(vacc, vshift); vout = vqmovn_s32(vacc); } if (n & (2 * sizeof(int16_t))) { vst1_lane_u32((void*) output, vreinterpret_u32_s16(vout), 0); output += 2; vout = vext_s16(vout, vout, 2); } if (n & (1 * sizeof(int16_t))) { vst1_lane_s16(output, vout, 0); output += 1; } } } while (--rows != 0); }