• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE in [2, 3, 4]
7$assert not XOP or AVX
8$assert not AVX or SSE == 4
9$assert REQUANTIZATION == "FP32"
10$assert DATATYPE in ["QC8", "QS8", "QU8"]
11$assert VARIANT in ["LD64", "LD128"]
12$assert MR <= 4
13#include <assert.h>
14
15$if XOP:
16  #if defined(__GNUC__) || defined(__clang__)
17    #include <x86intrin.h>
18  #else
19    #include <immintrin.h>
20    #include <ammintrin.h>
21  #endif
22$else:
23  $SSE_HEADER = {2: "emmintrin.h", 3: "tmmintrin.h", 4: "smmintrin.h"}[SSE]
24  #include <${SSE_HEADER}>
25
26#include <xnnpack/igemm.h>
27#include <xnnpack/math.h>
28
29
30$PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_%s_conv_minmax_params" % DATATYPE.lower()
31$PARAMS_STRUCT = ("" if DATATYPE == "QC8" else "fp32_") + ("sse4" if SSE >= 4 and DATATYPE != "QU8" else "sse2")
32$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
33$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE]
34void xnn_${DATATYPE.lower()}_igemm_minmax_fp32_ukernel_${MR}x4c8__${ISA}_${VARIANT.lower()}(
35    size_t mr,
36    size_t nc,
37    size_t kc,
38    size_t ks,
39    const ${XINT8_T}** restrict a,
40    const void* restrict w,
41    ${XINT8_T}* restrict c,
42    size_t cm_stride,
43    size_t cn_stride,
44    size_t a_offset,
45    const ${XINT8_T}* zero,
46    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
47{
48  assert(mr != 0);
49  assert(mr <= ${MR});
50  assert(nc != 0);
51  assert(kc != 0);
52  assert(ks != 0);
53  assert(ks % (${MR} * sizeof(void*)) == 0);
54  assert(a_offset % sizeof(${XINT8_T}) == 0);
55  assert(a != NULL);
56  assert(w != NULL);
57  assert(c != NULL);
58
59  kc = round_up_po2(kc, 8);
60  ${XINT8_T}* c0 = c;
61  $for M in range(1, MR):
62    ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride);
63    $if M % 2 == 0:
64      if XNN_UNPREDICTABLE(mr <= ${M}) {
65        c${M} = c${M-1};
66      }
67    $elif M + 1 == MR:
68      if XNN_UNPREDICTABLE(mr != ${M+1}) {
69        c${M} = c${M-1};
70      }
71    $else:
72      if XNN_UNPREDICTABLE(mr < ${M+1}) {
73        c${M} = c${M-1};
74      }
75
76  do {
77    $for N in range(4):
78      __m128i vacc0x${N} = _mm_cvtsi32_si128((int) ((const int32_t*) w)[${N}]);
79    $for M in range(1, MR):
80      $for N in range(4):
81        __m128i vacc${M}x${N} = vacc0x${N};
82    w = (const void*) ((const int32_t*) w + 4);
83
84    size_t p = ks;
85    do {
86      $for M in range(MR):
87        const ${XINT8_T}* restrict a${M} = a[${M}];
88        if XNN_UNPREDICTABLE(a${M} != zero) {
89          a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + a_offset);
90        }
91      a += ${MR};
92
93      size_t k = 0;
94      $if DATATYPE == "QU8":
95        const __m128i vb_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.kernel_zero_point);
96        $if SSE < 4 or VARIANT == "LD128":
97          const __m128i vzero = _mm_setzero_si128();
98      while (k < kc) {
99        $for M in range(MR):
100          const __m128i va${M} = _mm_loadl_epi64((const __m128i*) a${M});
101          $if DATATYPE == "QU8":
102            $if SSE == 4:
103              const __m128i vxa${M} = _mm_cvtepu8_epi16(va${M});
104            $else:
105              const __m128i vxa${M} = _mm_unpacklo_epi8(va${M}, vzero);
106          $else:
107            $if SSE == 4:
108              const __m128i vxa${M} = _mm_cvtepi8_epi16(va${M});
109            $else:
110              const __m128i vxa${M} = _mm_srai_epi16(_mm_unpacklo_epi8(va${M}, va${M}), 8);
111          a${M} += 8;
112
113        $if VARIANT == "LD128":
114          $for N in range(0, 4, 2):
115            $if N == 0:
116              const __m128i vb${N}${N+1} = _mm_load_si128((const __m128i*) w);
117            $else:
118              const __m128i vb${N}${N+1} = _mm_load_si128((const __m128i*) ((const ${XINT8_T}*) w + ${N * 8}));
119            $if DATATYPE == "QU8":
120              const __m128i vxb${N} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${N}${N+1}, vzero), vb_zero_point);
121              const __m128i vxb${N+1} = _mm_sub_epi16(_mm_unpackhi_epi8(vb${N}${N+1}, vzero), vb_zero_point);
122            $elif SSE == 4:
123              const __m128i vxb${N} = _mm_cvtepi8_epi16(vb${N}${N+1});
124              const __m128i vxb${N+1} = _mm_srai_epi16(_mm_unpackhi_epi8(vb${N}${N+1}, vb${N}${N+1}), 8);
125            $else:
126              const __m128i vsb${N}${N+1} = _mm_cmpgt_epi8(_mm_setzero_si128(), vb${N}${N+1});
127              const __m128i vxb${N} = _mm_unpacklo_epi8(vb${N}${N+1}, vsb${N}${N+1});
128              const __m128i vxb${N+1} = _mm_unpackhi_epi8(vb${N}${N+1}, vsb${N}${N+1});
129
130            $for M in range(MR):
131              $if XOP:
132                vacc${M}x${N} = _mm_maddd_epi16(vxa${M}, vxb${N}, vacc${M}x${N});
133                vacc${M}x${N+1} = _mm_maddd_epi16(vxa${M}, vxb${N+1}, vacc${M}x${N+1});
134              $else:
135                vacc${M}x${N} = _mm_add_epi32(vacc${M}x${N}, _mm_madd_epi16(vxa${M}, vxb${N}));
136                vacc${M}x${N+1} = _mm_add_epi32(vacc${M}x${N+1}, _mm_madd_epi16(vxa${M}, vxb${N+1}));
137        $else:
138          $for N in range(4):
139            $if N == 0:
140              const __m128i vb${N} = _mm_loadl_epi64((const __m128i*) w);
141            $else:
142              const __m128i vb${N} = _mm_loadl_epi64((const __m128i*) ((const ${XINT8_T}*) w + ${N * 8}));
143            $if DATATYPE == "QU8":
144              $if SSE == 4:
145                const __m128i vxb${N} = _mm_sub_epi16(_mm_cvtepu8_epi16(vb${N}), vb_zero_point);
146              $else:
147                const __m128i vxb${N} = _mm_sub_epi16(_mm_unpacklo_epi8(vb${N}, vzero), vb_zero_point);
148            $else:
149              $if SSE == 4:
150                const __m128i vxb${N} = _mm_cvtepi8_epi16(vb${N});
151              $else:
152                const __m128i vxb${N} = _mm_srai_epi16(_mm_unpacklo_epi8(vb${N}, vb${N}), 8);
153
154            $for M in range(MR):
155              $if XOP:
156                vacc${M}x${N} = _mm_maddd_epi16(vxa${M}, vxb${N}, vacc${M}x${N});
157              $else:
158                vacc${M}x${N} = _mm_add_epi32(vacc${M}x${N}, _mm_madd_epi16(vxa${M}, vxb${N}));
159
160        w = (const void*) ((const ${XINT8_T}*) w + 32);
161        k += 8 * sizeof(${XINT8_T});
162      }
163      p -= ${MR} * sizeof(void*);
164    } while (p != 0);
165
166    $if SSE >= 3:
167      $for M in range(MR):
168        const __m128i vacc${M}x01 = _mm_hadd_epi32(vacc${M}x0, vacc${M}x1);
169        const __m128i vacc${M}x23 = _mm_hadd_epi32(vacc${M}x2, vacc${M}x3);
170
171      $for M in range(MR):
172        __m128i vacc${M}x0123 = _mm_hadd_epi32(vacc${M}x01, vacc${M}x23);
173    $else:
174      $for M in range(MR):
175        const __m128i vacc${M}x02 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x0, vacc${M}x2), _mm_unpackhi_epi32(vacc${M}x0, vacc${M}x2));
176        const __m128i vacc${M}x13 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x1, vacc${M}x3), _mm_unpackhi_epi32(vacc${M}x1, vacc${M}x3));
177
178      $for M in range(MR):
179        __m128i vacc${M}x0123 = _mm_add_epi32(_mm_unpacklo_epi32(vacc${M}x02, vacc${M}x13), _mm_unpackhi_epi32(vacc${M}x02, vacc${M}x13));
180
181    $for M in range(MR):
182      __m128 vscaled${M}x0123 = _mm_cvtepi32_ps(vacc${M}x0123);
183
184    $if DATATYPE == "QC8":
185      const __m128 vscale0123 = _mm_load_ps((const float*) w);
186      w = (const void*) ((const float*) w + 4);
187      $for M in range(MR):
188        vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale0123);
189    $else:
190      const __m128 vscale = _mm_load_ps(params->${PARAMS_STRUCT}.scale);
191      $for M in range(MR):
192        vscaled${M}x0123 = _mm_mul_ps(vscaled${M}x0123, vscale);
193
194    const __m128 voutput_max_less_zero_point = _mm_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
195    $for M in range(MR):
196      vscaled${M}x0123 = _mm_min_ps(vscaled${M}x0123, voutput_max_less_zero_point);
197
198    $for M in range(MR):
199      vacc${M}x0123 = _mm_cvtps_epi32(vscaled${M}x0123);
200
201    const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point);
202    $for M in range(0, MR, 2):
203      __m128i vacc${M}${min(M+1, MR-1)}x0123 = _mm_adds_epi16(_mm_packs_epi32(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123), voutput_zero_point);
204
205    $if DATATYPE == "QU8":
206      $if MR > 2:
207        __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
208      $else:
209        __m128i vout = _mm_packus_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
210
211      vout = _mm_max_epu8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
212    $else:
213      $if SSE < 4:
214        const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
215        $for M in range(0, MR, 2):
216          vacc${M}${min(M+1, MR-1)}x0123 = _mm_max_epi16(vacc${M}${min(M+1, MR-1)}x0123, voutput_min);
217
218      $if MR > 2:
219        __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
220      $else:
221        __m128i vout = _mm_packs_epi16(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
222
223      $if SSE == 4:
224        vout = _mm_max_epi8(vout, _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min));
225
226    if (nc >= 4) {
227      $for M in reversed(range(1, MR)):
228        $if SSE == 4:
229          *((uint32_t*) c${M}) = (uint32_t) _mm_extract_epi32(vout, ${M});
230        $else:
231          *((uint32_t*) c${M}) = (uint32_t) _mm_cvtsi128_si32(_mm_shuffle_epi32(vout, _MM_SHUFFLE(${M}, ${M}, ${M}, ${M})));
232        c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
233      *((uint32_t*) c0) = (uint32_t) _mm_cvtsi128_si32(vout);
234      c0 = (${XINT8_T}*) ((uintptr_t) c0 + cn_stride);
235
236      a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks);
237
238      nc -= 4;
239    } else {
240      if (nc & 2) {
241        $for M in reversed(range(MR)):
242          *((uint16_t*) c${M}) = (uint16_t) _mm_extract_epi16(vout, ${M * 2});
243          c${M} += 2;
244        vout = _mm_srli_epi32(vout, 16);
245      }
246      if (nc & 1) {
247        $if SSE == 4:
248          $for M in reversed(range(MR)):
249            *c${M} = (${XINT8_T}) _mm_extract_epi8(vout, ${M * 4});
250        $else:
251          $for M in reversed(range(1, MR)):
252            *c${M} = (${XINT8_T}) _mm_extract_epi16(vout, ${M * 2});
253          *c0 = (${XINT8_T}) _mm_cvtsi128_si32(vout);
254      }
255
256      nc = 0;
257    }
258  } while (nc != 0);
259}
260