1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert BATCH_TILE in [1, 2] 7#include <assert.h> 8#include <stddef.h> 9#include <stdint.h> 10 11#include <arm_neon.h> 12 13#include <xnnpack/math.h> 14#include <xnnpack/filterbank.h> 15 16 17void xnn_u32_filterbank_accumulate_ukernel__neon_x${BATCH_TILE}( 18 size_t rows, 19 const uint32_t* input, 20 const uint8_t* weight_widths, 21 const uint16_t* weights, 22 uint64_t* output) { 23 24 assert(rows != 0); 25 assert(input != NULL); 26 assert(weight_widths != NULL); 27 assert(weights != NULL); 28 assert(output != NULL); 29 30 // Compute unweight as initial weight 31 size_t n = (size_t) *weight_widths++; 32 assert(n != 0); 33 uint64x2_t weight_accumulator = vdupq_n_u64(0); 34 35 do { 36 const uint32x2_t vi = vld1_dup_u32(input); input += 1; 37 const uint16x4_t vw = vreinterpret_u16_u32(vld1_dup_u32((const void*) weights)); weights += 2; 38 const uint32x2_t vw32 = vget_low_u32(vmovl_u16(vw)); 39 40 weight_accumulator = vmlal_u32(weight_accumulator, vw32, vi); 41 } while (--n != 0); 42 43 do { 44 size_t n = (size_t) *weight_widths++; 45 assert(n != 0); 46 weight_accumulator = vcombine_u64(vget_high_u64(weight_accumulator), vdup_n_u64(0)); 47 48 $if BATCH_TILE == 2: 49 for (; n >= 2; n -= 2) { 50 const uint32x2_t vi = vld1_u32(input); input += 2; 51 const uint16x4_t vw = vld1_u16(weights); weights += 4; 52 const uint32x4_t vw32 = vmovl_u16(vw); 53 54 weight_accumulator = vmlal_lane_u32(weight_accumulator, vget_low_u32(vw32), vi, 0); 55 weight_accumulator = vmlal_lane_u32(weight_accumulator, vget_high_u32(vw32), vi, 1); 56 } 57 58 if XNN_UNPREDICTABLE(n != 0) { 59 const uint32x2_t vi = vld1_dup_u32(input); input += 1; 60 const uint16x4_t vw = vreinterpret_u16_u32(vld1_dup_u32((const void*) weights)); weights += 2; 61 const uint32x2_t vw32 = vget_low_u32(vmovl_u16(vw)); 62 63 weight_accumulator = vmlal_u32(weight_accumulator, vw32, vi); 64 } 65 $elif BATCH_TILE == 1: 66 do { 67 const uint32x2_t vi = vld1_dup_u32(input); input += 1; 68 const uint16x4_t vw = vreinterpret_u16_u32(vld1_dup_u32((const void*) weights)); weights += 2; 69 const uint32x2_t vw32 = vget_low_u32(vmovl_u16(vw)); 70 71 weight_accumulator = vmlal_u32(weight_accumulator, vw32, vi); 72 } while (--n != 0); 73 74 vst1_u64(output, vget_low_u64(weight_accumulator)); output += 1; 75 76 } while (--rows != 0); 77} 78