• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert NR % 8 == 0
8$assert 8 <= NR <= 16
9$assert REQUANTIZATION in ["FP32", "RNDNU"]
10$assert not CHANNELWISE or REQUANTIZATION == "FP32"
11#include <assert.h>
12
13#include <arm_neon.h>
14
15#include <xnnpack/gemm.h>
16$if REQUANTIZATION == "FP32" and ARMV8:
17  #include <xnnpack/intrinsics-polyfill.h>
18#include <xnnpack/math.h>
19
20
21$DATATYPE = "qc8" if CHANNELWISE else "qs8"
22$PARAMS_UNION = "xnn_qs8_minmax_params" if CHANNELWISE else "xnn_qs8_conv_minmax_params"
23$PARAMS_STRUCT = ("" if CHANNELWISE else REQUANTIZATION.lower() + "_") + ("neonv8" if ARMV8 and REQUANTIZATION == "FP32" else "neon")
24$ISA = "neonv8" if ARMV8 else "neon"
25void xnn_${DATATYPE}_gemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x${NR}c4s2__${ISA}_${"mlal" if MLA else "mull"}(
26    size_t mr,
27    size_t nc,
28    size_t kc,
29    const int8_t* restrict a,
30    size_t a_stride,
31    const void* restrict w,
32    int8_t* restrict c,
33    size_t cm_stride,
34    size_t cn_stride,
35    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
36{
37  assert(mr != 0);
38  assert(mr <= ${MR});
39  assert(nc != 0);
40  assert(kc != 0);
41  assert(kc % sizeof(int8_t) == 0);
42  assert(a != NULL);
43  assert(w != NULL);
44  assert(c != NULL);
45
46  const int8_t* a0 = a;
47  int8_t* c0 = c;
48  $for M in range(1, MR):
49    const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
50    int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
51    $if M % 2 == 0:
52      if XNN_UNPREDICTABLE(mr <= ${M}) {
53        a${M} = a${M-1};
54        c${M} = c${M-1};
55      }
56    $elif M + 1 == MR:
57      if XNN_UNPREDICTABLE(mr != ${M+1}) {
58        a${M} = a${M-1};
59        c${M} = c${M-1};
60      }
61    $else:
62      if XNN_UNPREDICTABLE(mr < ${M+1}) {
63        a${M} = a${M-1};
64        c${M} = c${M-1};
65      }
66
67  kc = round_up_po2(kc, 8 * sizeof(int8_t));
68  do {
69    $for N in range(0, NR, 2):
70      int32x4_t vacc0x${ABC[N:N+2]} = vreinterpretq_s32_u64(vmovl_u32(vld1_u32(w))); w = (const int32_t*) w + 2;
71    $for M in range(1, MR):
72      $for N in range(0, NR, 2):
73        int32x4_t vacc${M}x${ABC[N:N+2]} = vacc0x${ABC[N:N+2]};
74
75    size_t k = kc;
76    $if MLA:
77      while (k >= 16 * sizeof(int8_t)) {
78        $for M in range(MR):
79          int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
80          int8x8_t va${M}x1 = vld1_s8(a${M}); a${M} += 8;
81
82        $for K in range(2):
83          $for N in range(0, NR, 2):
84            const int8x8_t vb${ABC[N:N+2]}c${K}x0 = vld1_s8(w); w = (const int8_t*) w + 8;
85
86        $for K in range(2):
87          $for N in range(0, NR, 2):
88            $for M in range(MR):
89              int16x8_t vprod${M}x${ABC[N:N+2]}c${K} = vmull_s8(vb${ABC[N:N+2]}c${K}x0, va${M}x0);
90            const int8x8_t vb${ABC[N:N+2]}c${K}x1 = vld1_s8(w); w = (const int8_t*) w + 8;
91            $for M in range(MR):
92              vprod${M}x${ABC[N:N+2]}c${K} = vmlal_s8(vprod${M}x${ABC[N:N+2]}c${K}, vb${ABC[N:N+2]}c${K}x1, va${M}x1);
93            $for M in range(MR):
94              vacc${M}x${ABC[N:N+2]} = vpadalq_s16(vacc${M}x${ABC[N:N+2]}, vprod${M}x${ABC[N:N+2]}c${K});
95          $if K + 1 != 2:
96            $for M in range(MR):
97              va${M}x0 = vext_s8(va${M}x0, va${M}x0, 4);
98              va${M}x1 = vext_s8(va${M}x1, va${M}x1, 4);
99
100        k -= 16 * sizeof(int8_t);
101      }
102    ${"if (k != 0)" if MLA else "do"} {
103      $for M in range(MR):
104        int8x8_t va${M}x0 = vld1_s8(a${M}); a${M} += 8;
105
106      $for K in range(2):
107        $for N in range(0, NR, 2):
108          const int8x8_t vb${ABC[N:N+2]}c${K}x0 = vld1_s8(w); w = (const int8_t*) w + 8;
109
110      $for K in range(2):
111        $for N in range(0, NR, 2):
112          $for M in range(MR):
113            int16x8_t vprod${M}x${ABC[N:N+2]}c${K} = vmull_s8(vb${ABC[N:N+2]}c${K}x0, va${M}x0);
114          $for M in range(MR):
115            vacc${M}x${ABC[N:N+2]} = vpadalq_s16(vacc${M}x${ABC[N:N+2]}, vprod${M}x${ABC[N:N+2]}c${K});
116        $if K + 1 != 2:
117          $for M in range(MR):
118            va${M}x0 = vext_s8(va${M}x0, va${M}x0, 4);
119
120      $if not MLA:
121        k -= 8 * sizeof(int8_t);
122    }${"" if MLA else " while (k != 0);"}
123
124#if XNN_ARCH_ARM64
125    $for M in range(MR):
126      $for N in range(0, NR, 4):
127        int32x4_t vacc${M}x${ABC[N:N+4]} = vpaddq_s32(vacc${M}x${ABC[N:N+2]}, vacc${M}x${ABC[N+2:N+4]});
128#else
129    $for M in range(MR):
130      $for N in range(0, NR, 4):
131        const int32x2_t vsum${M}x${ABC[N:N+2]} = vpadd_s32(vget_low_s32(vacc${M}x${ABC[N:N+2]}), vget_high_s32(vacc${M}x${ABC[N:N+2]}));
132        const int32x2_t vsum${M}x${ABC[N+2:N+4]} = vpadd_s32(vget_low_s32(vacc${M}x${ABC[N+2:N+4]}), vget_high_s32(vacc${M}x${ABC[N+2:N+4]}));
133        int32x4_t vacc${M}x${ABC[N:N+4]} = vcombine_s32(vsum${M}x${ABC[N:N+2]}, vsum${M}x${ABC[N+2:N+4]});
134#endif
135
136    $if REQUANTIZATION == "RNDNU":
137      const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_pre_shift);
138      const int32x4_t vmultiplier = vld1q_dup_s32(&params->${PARAMS_STRUCT}.multiplier);
139      const int32x4_t vright_post_shift = vld1q_dup_s32(&params->${PARAMS_STRUCT}.right_post_shift);
140
141      $for M in range(MR):
142        $for N in range(0, NR, 4):
143          vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift);
144
145      $for M in range(MR):
146        $for N in range(0, NR, 4):
147          vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier);
148
149      $for M in range(MR):
150        $for N in range(0, NR, 4):
151          vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift);
152    $elif REQUANTIZATION == "FP32":
153      $for M in range(MR):
154        $for N in range(0, NR, 4):
155          float32x4_t vfpacc${M}x${ABC[N:N+4]} = vcvtq_f32_s32(vacc${M}x${ABC[N:N+4]});
156
157      $if CHANNELWISE:
158        $for N in range(0, NR, 4):
159          const float32x4_t vscale${ABC[N:N+4]} = vld1q_f32(w); w = (const float*) w + 4;
160          $for M in range(MR):
161            vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale${ABC[N:N+4]});
162      $else:
163        const float32x4_t vscale = vld1q_dup_f32(&params->${PARAMS_STRUCT}.scale);
164        $for M in range(MR):
165          $for N in range(0, NR, 4):
166            vfpacc${M}x${ABC[N:N+4]} = vmulq_f32(vfpacc${M}x${ABC[N:N+4]}, vscale);
167
168      $if ARMV8:
169        $for M in range(MR):
170          $for N in range(0, NR, 4):
171            vacc${M}x${ABC[N:N+4]} = vcvtnq_s32_f32(vfpacc${M}x${ABC[N:N+4]});
172      $else:
173        const float32x4_t vmagic_bias = vld1q_dup_f32(&params->${PARAMS_STRUCT}.magic_bias);
174        $for M in range(MR):
175          $for N in range(0, NR, 4):
176            vacc${M}x${ABC[N:N+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${M}x${ABC[N:N+4]}, vmagic_bias));
177
178        const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(&params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
179        $for M in range(MR):
180          $for N in range(0, NR, 4):
181            vacc${M}x${ABC[N:N+4]} = vqsubq_s32(vacc${M}x${ABC[N:N+4]}, vmagic_bias_less_output_zero_point);
182
183    $if REQUANTIZATION != "FP32" or ARMV8:
184      const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->${PARAMS_STRUCT}.output_zero_point);
185#if XNN_ARCH_ARM64
186    $for M in range(MR):
187      $for N in range(0, NR, 8):
188        int16x8_t vacc${M}x${ABC[N:N+8]} = vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]});
189
190    $if REQUANTIZATION != "FP32" or ARMV8:
191      $for M in range(MR):
192        $for N in range(0, NR, 8):
193          vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point);
194
195    $for M in range(MR):
196      $for N in range(0, NR, 16):
197        $if N + 8 < NR:
198          int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]});
199        $elif M % 2 == 1:
200          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]});
201        $elif M + 1 == MR:
202          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
203#else
204    $for M in range(MR):
205      $for N in range(0, NR, 8):
206        int16x8_t vacc${M}x${ABC[N:N+8]} = vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]}));
207
208    $if REQUANTIZATION != "FP32" or ARMV8:
209      $for M in range(MR):
210        $for N in range(0, NR, 8):
211          vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vacc${M}x${ABC[N:N+8]}, voutput_zero_point);
212
213    $for M in range(MR):
214      $for N in range(0, NR, 16):
215        $if N + 8 < NR:
216          int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]}));
217        $elif M % 2 == 1:
218          int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]}));
219        $elif M + 1 == MR:
220          int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
221#endif
222
223    $if NR == 8 and MR == 1:
224      const int8x8_t voutput_min = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_min);
225    $else:
226      const int8x16_t voutput_min = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_min);
227    $for M in range(MR):
228      $for N in range(0, NR, 16):
229        $if N + 8 < NR:
230          vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min);
231        $elif M % 2 == 1:
232          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min);
233        $elif M + 1 == MR:
234          $if NR == 8 and MR == 1:
235            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min);
236          $else:
237            vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min));
238
239    $if NR == 8 and MR == 1:
240      const int8x8_t voutput_max = vld1_dup_s8(&params->${PARAMS_STRUCT}.output_max);
241    $else:
242      const int8x16_t voutput_max = vld1q_dup_s8(&params->${PARAMS_STRUCT}.output_max);
243    $for M in range(MR):
244      $for N in range(0, NR, 16):
245        $if N + 8 < NR:
246          vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max);
247        $elif M % 2 == 1:
248          vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max);
249        $elif M + 1 == MR:
250          $if NR == 8 and MR == 1:
251            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max);
252          $else:
253            vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max));
254
255    if (nc >= ${NR}) {
256      $for M in range(MR):
257        $for N in range(0, NR, 16):
258          $if N + 8 < NR:
259            vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]});
260          $elif M % 2 == 1:
261            vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
262            vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
263          $elif M + 1 == MR:
264            vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]});
265
266      $for M in range(MR):
267        c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
268
269      $for M in range(MR):
270        a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
271
272      nc -= ${NR};
273    } else {
274      // Final case where not all of the ${NR} columns fit in the destination.
275      $if NR == 16:
276        $for M in range(MR):
277          $if M % 2 == 1:
278            int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF));
279          $elif M + 1 == MR:
280            int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF);
281        if (nc & 8) {
282          $for M in range(MR):
283            $if M % 2 == 1:
284              vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8;
285              vst1_s8(c${M}, vget_high_s8(vout${M-1}x01234567_${M}x01234567)); c${M} += 8;
286            $elif M + 1 == MR:
287              vst1_s8(c${M}, vout${M}x01234567); c${M} += 8;
288          $for M in range(MR):
289            $if M % 2 == 1:
290              vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF));
291            $elif M + 1 == MR:
292              vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF);
293        }
294      if (nc & 4) {
295        $for M in range(MR):
296          $if M % 2 == 1:
297            vst1q_lane_u32((void*) c${M-1}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4;
298            vst1q_lane_u32((void*) c${M}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4;
299          $elif M + 1 == MR:
300            vst1_lane_u32((void*) c${M}, vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4;
301        $for M in range(MR):
302          $if M % 2 == 1:
303            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4);
304          $elif M + 1 == MR:
305            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4);
306      }
307      if (nc & 2) {
308        $for M in range(MR):
309          $if M % 2 == 1:
310            vst1q_lane_u16((void*) c${M-1}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2;
311            vst1q_lane_u16((void*) c${M}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2;
312          $elif M + 1 == MR:
313            vst1_lane_u16((void*) c${M}, vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2;
314        $for M in range(MR):
315          $if M % 2 == 1:
316            vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2);
317          $elif M + 1 == MR:
318            vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2);
319      }
320      if (nc & 1) {
321        $for M in range(MR):
322          $if M % 2 == 1:
323            vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0);
324            vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8);
325          $elif M + 1 == MR:
326            vst1_lane_s8(c${M}, vout${M}x01234567, 0);
327      }
328
329      nc = 0;
330    }
331  } while (nc != 0);
332}
333