• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert NR % 8 == 0
8$assert 8 <= NR <= 16
9#include <assert.h>
10
11#include <arm_neon.h>
12
13#include <xnnpack/gemm.h>
14#include <xnnpack/math.h>
15
16
17void xnn_qs8_gemm_minmax_ukernel_${MR}x${NR}c2__neon_${"mlal" if MLA else "mull"}_padal_dup(
18    size_t mr,
19    size_t nc,
20    size_t kc,
21    const int8_t* restrict a,
22    size_t a_stride,
23    const void* restrict w,
24    int8_t* restrict c,
25    size_t cm_stride,
26    size_t cn_stride,
27    const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
28{
29  assert(mr != 0);
30  assert(mr <= ${MR});
31  assert(nc != 0);
32  assert(kc != 0);
33  assert(kc % sizeof(int8_t) == 0);
34  assert(a != NULL);
35  assert(w != NULL);
36  assert(c != NULL);
37
38  kc = round_up_po2(kc, 2);
39  const int8_t* a0 = a;
40  int8_t* c0 = c;
41  $for M in range(1, MR):
42    const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
43    int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
44    $if M % 2 == 0:
45      if XNN_UNPREDICTABLE(mr <= ${M}) {
46        a${M} = a${M-1};
47        c${M} = c${M-1};
48      }
49    $elif M + 1 == MR:
50      if XNN_UNPREDICTABLE(mr != ${M+1}) {
51        a${M} = a${M-1};
52        c${M} = c${M-1};
53      }
54    $else:
55      if XNN_UNPREDICTABLE(mr < ${M+1}) {
56        a${M} = a${M-1};
57        c${M} = c${M-1};
58      }
59
60  do {
61    $for N in range(0, NR, 4):
62      int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 4 * sizeof(int32_t));
63    $for M in range(1, MR):
64      $for N in range(0, NR, 4):
65        int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
66
67    size_t k = kc;
68
69    $if MLA:
70      while (k >= 16 * sizeof(int8_t)) {
71        $for M in range(MR):
72          const int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
73          const int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;
74
75        $for K in range(4):
76          $for N in range(0, NR, 4):
77            const int8x8_t vb${ABC[N:N+4]}c${K}x0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
78
79        $for K in range(4):
80          $for N in range(0, NR, 4):
81            $for M in range(MR):
82              int16x8_t vprod${M}x${ABC[N:N+4]}c${K} = vmull_s8(vb${ABC[N:N+4]}c${K}x0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}x0), ${K})));
83            const int8x8_t vb${ABC[N:N+4]}c${K}x1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
84            $for M in range(MR):
85              vprod${M}x${ABC[N:N+4]}c${K} = vmlal_s8(vprod${M}x${ABC[N:N+4]}c${K}, vb${ABC[N:N+4]}c${K}x1, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}x1), ${K})));
86            $for M in range(MR):
87              vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c${K});
88
89        k -= 16 * sizeof(int8_t);
90      }
91
92    ${"if" if MLA else "while"} (k >= 8 * sizeof(int8_t)) {
93      $for M in range(MR):
94        const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
95
96      $for K in range(4):
97        $for N in range(0, NR, 4):
98          const int8x8_t vb${ABC[N:N+4]}c${K} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
99
100      $for M in range(MR):
101        $for N in range(0, NR, 4):
102          $for K in range(4):
103            const int16x8_t vprod${M}x${ABC[N:N+4]}c${K} = vmull_s8(vb${ABC[N:N+4]}c${K}, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), ${K})));
104          $for K in range(4):
105            vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c${K});
106
107      k -= 8 * sizeof(int8_t);
108    }
109
110    if XNN_UNLIKELY(k != 0) {
111      $for M in range(MR):
112        const int8x8_t va${M} = vld1_s8(a${M}); a${M} = (const int8_t*) ((uintptr_t) a${M} + k);
113
114      $for N in range(0, NR, 4):
115        const int8x8_t vb${ABC[N:N+4]}c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
116
117      $for M in range(MR):
118        $for N in range(0, NR, 4):
119          const int16x8_t vprod${M}x${ABC[N:N+4]}c0 = vmull_s8(vb${ABC[N:N+4]}c0, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 0)));
120          vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c0);
121
122      if (k > 2 * sizeof(int8_t)) {
123        $for N in range(0, NR, 4):
124          const int8x8_t vb${ABC[N:N+4]}c1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
125
126        $for M in range(MR):
127          $for N in range(0, NR, 4):
128            const int16x8_t vprod${M}x${ABC[N:N+4]}c1 = vmull_s8(vb${ABC[N:N+4]}c1, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 1)));
129            vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c1);
130
131        if (k > 4 * sizeof(int8_t)) {
132          $for N in range(0, NR, 4):
133            const int8x8_t vb${ABC[N:N+4]}c2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
134
135          $for M in range(MR):
136            $for N in range(0, NR, 4):
137              const int16x8_t vprod${M}x${ABC[N:N+4]}c2 = vmull_s8(vb${ABC[N:N+4]}c2, vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(va${M}), 2)));
138              vacc${M}x${ABC[N:N+4]} = vpadalq_s16(vacc${M}x${ABC[N:N+4]}, vprod${M}x${ABC[N:N+4]}c2);
139        }
140      }
141    }
142    const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
143    $for M in range(MR):
144      $for N in range(0, NR, 4):
145        vacc${M}x${ABC[N:N+4]} = vqrdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier);
146
147    const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
148    const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
149    $for M in range(MR):
150      $for N in range(0, NR, 4):
151        vacc${M}x${ABC[N:N+4]} = vsraq_n_s32(vacc${M}x${ABC[N:N+4]}, vbicq_s32(vacc${M}x${ABC[N:N+4]}, vzero_shift_mask), 31);
152
153    $for M in range(MR):
154      $for N in range(0, NR, 4):
155        vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_shift);
156
157    const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
158#if XNN_ARCH_ARM64
159    $for M in range(MR):
160      $for N in range(0, NR, 8):
161        const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}), voutput_zero_point);
162
163    $for M in range(MR):
164      $for N in range(0, NR, 16):
165        $if N + 8 < NR:
166          int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]});
167        $elif M % 2 == 1:
168          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]});
169        $elif M + 1 == MR:
170          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
171#else
172    $for M in range(MR):
173      $for N in range(0, NR, 8):
174        const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})), voutput_zero_point);
175
176    $for M in range(MR):
177      $for N in range(0, NR, 16):
178        $if N + 8 < NR:
179          int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]}));
180        $elif M % 2 == 1:
181          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]}));
182        $elif M + 1 == MR:
183          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
184#endif
185    $if NR == 8 and MR == 1:
186      const int8x8_t voutput_min = vld1_dup_s8(&params->neon.output_min);
187      const int8x8_t voutput_max = vld1_dup_s8(&params->neon.output_max);
188    $else:
189      const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
190      const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
191
192    $for M in range(MR):
193      $for N in range(0, NR, 16):
194        $if N + 8 < NR:
195          vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min);
196        $elif M % 2 == 1:
197          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min);
198        $elif M + 1 == MR:
199          $if NR == 8 and MR == 1:
200            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min);
201          $else:
202            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min));
203
204    $for M in range(MR):
205      $for N in range(0, NR, 16):
206        $if N + 8 < NR:
207          vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max);
208        $elif M % 2 == 1:
209          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max);
210        $elif M + 1 == MR:
211          $if NR == 8 and MR == 1:
212            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max);
213          $else:
214            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max));
215
216    if (nc >= ${NR}) {
217      $for M in range(MR):
218        $for N in range(0, NR, 16):
219          $if N + 8 < NR:
220            vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]});
221          $elif M % 2 == 1:
222            vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
223            vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
224          $elif M + 1 == MR:
225            vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]});
226
227      $for M in range(MR):
228        c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
229
230      $for M in range(MR):
231        a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
232
233      nc -= ${NR};
234    } else {
235      $if NR == 16:
236        $for M in range(MR):
237          $if M % 2 == 1:
238            int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF));
239          $elif M + 1 == MR:
240            int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF);
241        if (nc & 8) {
242          $for M in range(MR):
243            $if M % 2 == 1:
244              vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); c${M-1} += 8;
245              vst1_s8(c${M}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); c${M} += 8;
246            $elif M + 1 == MR:
247              vst1_s8(c${M}, vout${M}x${ABC[N:N+8]}); c${M} += 8;
248          $for M in range(MR):
249            $if M % 2 == 1:
250              vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF));
251            $elif M + 1 == MR:
252              vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF);
253        }
254      if (nc & 4) {
255        $for M in range(MR):
256          $if M % 2 == 1:
257            vst1q_lane_u32(__builtin_assume_aligned(c${M-1}, 1), vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4;
258            vst1q_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4;
259          $elif M + 1 == MR:
260            vst1_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4;
261        $for M in range(MR):
262          $if M % 2 == 1:
263            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4);
264          $elif M + 1 == MR:
265            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4);
266      }
267      if (nc & 2) {
268        $for M in range(MR):
269          $if M % 2 == 1:
270            vst1q_lane_u16(__builtin_assume_aligned(c${M-1}, 1), vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2;
271            vst1q_lane_u16(__builtin_assume_aligned(c${M}, 1), vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2;
272          $elif M + 1 == MR:
273            vst1_lane_u16(__builtin_assume_aligned(c${M}, 1), vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2;
274        $for M in range(MR):
275          $if M % 2 == 1:
276            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2);
277          $elif M + 1 == MR:
278            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2);
279      }
280      if (nc & 1) {
281        $for M in range(MR):
282          $if M % 2 == 1:
283            vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0);
284            vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8);
285          $elif M + 1 == MR:
286            vst1_lane_s8(c${M}, vout${M}x01234567, 0);
287      }
288
289      nc = 0;
290    }
291  } while (nc != 0);
292}
293