• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE in [1, 2]
7#include <assert.h>
8#include <stddef.h>
9#include <stdint.h>
10
11#include <arm_neon.h>
12
13#include <xnnpack/math.h>
14#include <xnnpack/filterbank.h>
15
16
17void xnn_u32_filterbank_accumulate_ukernel__neon_x${BATCH_TILE}(
18    size_t rows,
19    const uint32_t* input,
20    const uint8_t* weight_widths,
21    const uint16_t* weights,
22    uint64_t* output) {
23
24  assert(rows != 0);
25  assert(input != NULL);
26  assert(weight_widths != NULL);
27  assert(weights != NULL);
28  assert(output != NULL);
29
30  // Compute unweight as initial weight
31  size_t n = (size_t) *weight_widths++;
32  assert(n != 0);
33  uint64x2_t weight_accumulator = vdupq_n_u64(0);
34
35  do {
36    const uint32x2_t vi = vld1_dup_u32(input); input += 1;
37    const uint16x4_t vw = vreinterpret_u16_u32(vld1_dup_u32((const void*) weights)); weights += 2;
38    const uint32x2_t vw32 = vget_low_u32(vmovl_u16(vw));
39
40    weight_accumulator = vmlal_u32(weight_accumulator, vw32, vi);
41  } while (--n != 0);
42
43  do {
44    size_t n = (size_t) *weight_widths++;
45    assert(n != 0);
46    weight_accumulator = vcombine_u64(vget_high_u64(weight_accumulator), vdup_n_u64(0));
47
48    $if BATCH_TILE == 2:
49      for (; n >= 2; n -= 2) {
50        const uint32x2_t vi = vld1_u32(input); input += 2;
51        const uint16x4_t vw = vld1_u16(weights); weights += 4;
52        const uint32x4_t vw32 = vmovl_u16(vw);
53
54        weight_accumulator = vmlal_lane_u32(weight_accumulator, vget_low_u32(vw32), vi, 0);
55        weight_accumulator = vmlal_lane_u32(weight_accumulator, vget_high_u32(vw32), vi, 1);
56      }
57
58      if XNN_UNPREDICTABLE(n != 0) {
59        const uint32x2_t vi = vld1_dup_u32(input); input += 1;
60        const uint16x4_t vw = vreinterpret_u16_u32(vld1_dup_u32((const void*) weights)); weights += 2;
61        const uint32x2_t vw32 = vget_low_u32(vmovl_u16(vw));
62
63        weight_accumulator = vmlal_u32(weight_accumulator, vw32, vi);
64      }
65    $elif BATCH_TILE == 1:
66      do {
67        const uint32x2_t vi = vld1_dup_u32(input); input += 1;
68        const uint16x4_t vw = vreinterpret_u16_u32(vld1_dup_u32((const void*) weights)); weights += 2;
69        const uint32x2_t vw32 = vget_low_u32(vmovl_u16(vw));
70
71        weight_accumulator = vmlal_u32(weight_accumulator, vw32, vi);
72      } while (--n != 0);
73
74    vst1_u64(output, vget_low_u64(weight_accumulator));  output += 1;
75
76  } while (--rows != 0);
77}
78