• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert REQUANTIZATION == "FP32"
7$assert DATATYPE in ["QC8", "QS8", "QU8"]
8$assert VARIANT in ["LD64", "LD128", "EXTENDED"]
9$assert MR <= 4
10#include <assert.h>
11
12#include <wasm_simd128.h>
13
14#include <xnnpack/gemm.h>
15#include <xnnpack/math.h>
16
17
18
19$LOAD_SUFFIX = {"LD128": "_ld128", "LD64": "_ld64", "EXTENDED": ""}[VARIANT]
20$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
21$PARAMS_STRUCT = REQUANTIZATION.lower() + "_wasmsimd"
22$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower()
23$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
24$WASM_X16X8_LOAD8X8 = "wasm_u16x8_load8x8" if DATATYPE == "QU8" else "wasm_i16x8_load8x8"
25$WASM_X8X16_NARROW_I16X8 = "wasm_u8x16_narrow_i16x8" if DATATYPE == "QU8" else "wasm_i8x16_narrow_i16x8"
26$WASM_X8X16_MIN = "wasm_u8x16_min" if DATATYPE == "QU8" else "wasm_i8x16_min"
27void xnn_${DATATYPE.lower()}_gemm${GEMM_SUFFIX}_minmax_fp32_ukernel_${MR}x4c2s4__wasmsimd_dot16x2${LOAD_SUFFIX}(
28    size_t mr,
29    size_t nc,
30    size_t kc,
31    const ${XINT8_T}* restrict a,
32    size_t a_stride,
33    const void* restrict w,
34    ${XINT8_T}* restrict c,
35    size_t cm_stride,
36    size_t cn_stride,
37    const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
38{
39  assert(mr != 0);
40  assert(mr <= ${MR});
41  assert(nc != 0);
42  assert(kc != 0);
43  assert(kc % sizeof(${XINT8_T}) == 0);
44  assert(a != NULL);
45  assert(w != NULL);
46  assert(c != NULL);
47
48  const ${XINT8_T}* a0 = a;
49  ${XINT8_T}* c0 = c;
50  $for M in range(1, MR):
51    const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride);
52    ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride);
53    $if M % 2 == 0:
54      if XNN_UNPREDICTABLE(mr <= ${M}) {
55        a${M} = a${M-1};
56        c${M} = c${M-1};
57      }
58    $elif M + 1 == MR:
59      if XNN_UNPREDICTABLE(mr != ${M+1}) {
60        a${M} = a${M-1};
61        c${M} = c${M-1};
62      }
63    $else:
64      if XNN_UNPREDICTABLE(mr < ${M+1}) {
65        a${M} = a${M-1};
66        c${M} = c${M-1};
67      }
68
69  kc = round_up_po2(kc, 8 * sizeof(${XINT8_T}));
70  do {
71    v128_t vacc0x0123 = wasm_v128_load(w);
72    $for M in range(1, MR):
73      v128_t vacc${M}x0123 = vacc0x0123;
74    w = (const void*) ((const int32_t*) w + 4);
75
76    $if DATATYPE == "QU8":
77      const v128_t vb_zero_point = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.kernel_zero_point);
78    size_t k = kc;
79    do {
80      $for M in range(MR):
81        v128_t vxa${M} = ${WASM_X16X8_LOAD8X8}((const v128_t*) a${M});
82        a${M} += 8;
83
84      $if VARIANT == "LD128":
85        $for K in range(0, 4, 2):
86          $if K == 0:
87            const v128_t vb${K}${K+1} = wasm_v128_load(w);
88          $else:
89            const v128_t vb${K}${K+1} = wasm_v128_load((const ${XINT8_T}*) w + ${K * 8});
90          $if DATATYPE == "QU8":
91            const v128_t vxb${K} = wasm_i16x8_sub(wasm_u16x8_extend_low_u8x16(vb${K}${K+1}), vb_zero_point);
92            const v128_t vxb${K+1} = wasm_i16x8_sub(wasm_u16x8_extend_high_u8x16(vb${K}${K+1}), vb_zero_point);
93          $else:
94            const v128_t vxb${K} = wasm_i16x8_extend_low_i8x16(vb${K}${K+1});
95            const v128_t vxb${K+1} = wasm_i16x8_extend_high_i8x16(vb${K}${K+1});
96
97          $for M in range(MR):
98            vacc${M}x0123 = wasm_i32x4_add(vacc${M}x0123, wasm_i32x4_dot_i16x8(vxa${M}, vxb${K}));
99            vxa${M} = wasm_v32x4_shuffle(vxa${M}, vxa${M}, 1, 2, 3, 4);
100
101          $for M in range(MR):
102            vacc${M}x0123 = wasm_i32x4_add(vacc${M}x0123, wasm_i32x4_dot_i16x8(vxa${M}, vxb${K+1}));
103            $if K + 2 != 4:
104              vxa${M} = wasm_v32x4_shuffle(vxa${M}, vxa${M}, 1, 2, 3, 4);
105      $else:
106        $for K in range(4):
107          $if VARIANT == "LD64":
108            $if DATATYPE == "QU8":
109              $if K == 0:
110                const v128_t vxb${K} = wasm_i16x8_sub(wasm_u16x8_load8x8(w), vb_zero_point);
111              $else:
112                const v128_t vxb${K} = wasm_i16x8_sub(wasm_u16x8_load8x8((const ${XINT8_T}*) w + ${K * 8}), vb_zero_point);
113            $else:
114              $if K == 0:
115                const v128_t vxb${K} = wasm_i16x8_load8x8(w);
116              $else:
117                const v128_t vxb${K} = wasm_i16x8_load8x8((const ${XINT8_T}*) w + ${K * 8});
118          $elif VARIANT == "EXTENDED":
119            $if K == 0:
120              const v128_t vxb${K} = wasm_v128_load(w);
121            $else:
122              const v128_t vxb${K} = wasm_v128_load((const int16_t*) w + ${K * 8});
123
124          $for M in range(MR):
125            vacc${M}x0123 = wasm_i32x4_add(vacc${M}x0123, wasm_i32x4_dot_i16x8(vxa${M}, vxb${K}));
126            $if K + 1 != 4:
127              vxa${M} = wasm_v32x4_shuffle(vxa${M}, vxa${M}, 1, 2, 3, 4);
128
129      $if VARIANT == "EXTENDED":
130        w = (const int16_t*) w + 32;
131      $else:
132        w = (const ${XINT8_T}*) w + 32;
133      k -= 8 * sizeof(${XINT8_T});
134    } while (k != 0);
135
136    $for M in range(MR):
137      vacc${M}x0123 = wasm_f32x4_convert_i32x4(vacc${M}x0123);
138
139    $if DATATYPE == "QC8":
140      const v128_t vscale0123 = wasm_v128_load(w);
141      w = (const float*) w + 4;
142      $for M in range(MR):
143        vacc${M}x0123 = wasm_f32x4_mul(vacc${M}x0123, vscale0123);
144    $else:
145      const v128_t vscale = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.scale);
146      $for M in range(MR):
147        vacc${M}x0123 = wasm_f32x4_mul(vacc${M}x0123, vscale);
148
149    const v128_t vmagic_bias = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_bias);
150    $for M in range(MR):
151      vacc${M}x0123 = wasm_f32x4_add(vacc${M}x0123, vmagic_bias);
152
153    const v128_t vmagic_min = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_min);
154    $for M in range(MR):
155      vacc${M}x0123 = wasm_i32x4_max(vacc${M}x0123, vmagic_min);
156
157    const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.magic_bias_less_output_zero_point);
158    $for M in range(MR):
159      vacc${M}x0123 = wasm_i32x4_sub(vacc${M}x0123, vmagic_bias_less_output_zero_point);
160
161    $for M in range(0, MR, 2):
162      v128_t vacc${M}${min(M+1, MR-1)}x0123 = wasm_i16x8_narrow_i32x4(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123);
163
164    $if MR > 2:
165      v128_t vout = ${WASM_X8X16_NARROW_I16X8}(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
166    $else:
167      v128_t vout = ${WASM_X8X16_NARROW_I16X8}(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);
168
169    const v128_t voutput_max = wasm_v128_load64_splat(params->${PARAMS_STRUCT}.output_max);
170    vout = ${WASM_X8X16_MIN}(vout, voutput_max);
171
172    if (nc >= 4) {
173      $for M in range(MR):
174        *((float*) c${M}) = (float) wasm_f32x4_extract_lane(vout, ${M});
175
176      $for M in range(MR):
177        c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
178
179      $for M in range(MR):
180        a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
181
182      nc -= 4;
183    } else {
184      $for M in range(MR):
185        uint32_t vout${M} = wasm_i32x4_extract_lane(vout, ${M});
186      if (nc & 2) {
187        $for M in range(MR):
188          *((uint16_t*) c${M}) = (uint16_t) vout${M};
189          vout${M} >>= 16;
190          c${M} += 2;
191      }
192      if (nc & 1) {
193        $for M in range(MR):
194          *c${M} = (${XINT8_T}) vout${M};
195      }
196
197      nc = 0;
198    }
199  } while (nc != 0);
200}
201