1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #if defined(__cplusplus) && (__cplusplus >= 201103L)
12 #include <cstdint>
13 #include <cstddef>
14 #include <cassert>
15 #include <cmath>
16 #else
17 #include <stdint.h>
18 #include <stddef.h>
19 #include <assert.h>
20 #include <math.h>
21 #endif
22
23 #include <fp16.h>
24
25 #include <xnnpack/common.h>
26 #include <xnnpack/math.h>
27 #include <xnnpack/params.h>
28
29
xnn_init_scalar_qu8_gemm_params(uint8_t kernel_zero_point,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)30 static inline union xnn_qu8_gemm_params xnn_init_scalar_qu8_gemm_params(
31 uint8_t kernel_zero_point,
32 float scale,
33 uint8_t output_zero_point,
34 uint8_t output_min,
35 uint8_t output_max)
36 {
37 // Compute requantization parameters
38 const uint32_t scale_bits = fp32_to_bits(scale);
39
40 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
41 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
42 assert(multiplier >= INT32_C(0x40000000));
43 assert(multiplier <= INT32_C(0x7FFFFF80));
44
45 // Shift is in [0, 31] range.
46 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
47 assert(shift >= 0);
48 assert(shift < 32);
49
50 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
51 const uint32_t remainder_threshold = remainder_mask >> 1;
52
53 union xnn_qu8_gemm_params params;
54 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
55 params.scalar.multiplier = multiplier;
56 params.scalar.remainder_mask = (int32_t) remainder_mask;
57 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
58 params.scalar.shift = (uint32_t) shift;
59 params.scalar.output_min_less_zero_point =
60 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
61 params.scalar.output_max_less_zero_point =
62 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
63 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
64 return params;
65 }
66
xnn_init_qu8_gemm_params(uint8_t kernel_zero_point,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)67 static inline union xnn_qu8_gemm_params xnn_init_qu8_gemm_params(
68 uint8_t kernel_zero_point,
69 float scale,
70 uint8_t output_zero_point,
71 uint8_t output_min,
72 uint8_t output_max)
73 {
74 // Compute requantization parameters.
75 const uint32_t scale_bits = fp32_to_bits(scale);
76
77 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
78 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
79 assert(multiplier >= INT32_C(0x40000000));
80 assert(multiplier <= INT32_C(0x7FFFFF80));
81
82 // Shift is in [0, 31] range.
83 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
84 assert(shift >= 0);
85 assert(shift < 32);
86
87 union xnn_qu8_gemm_params params;
88 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
89 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
90 const uint32_t remainder_threshold = remainder_mask >> 1;
91 for (uint32_t i = 0; i < 8; i++) {
92 params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
93 }
94 params.sse2.multiplier[0] = multiplier;
95 params.sse2.multiplier[1] = multiplier;
96 params.sse2.multiplier[2] = multiplier;
97 params.sse2.multiplier[3] = multiplier;
98 params.sse2.rounding[0] = UINT64_C(0x40000000);
99 params.sse2.rounding[1] = UINT64_C(0x40000000);
100 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
101 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
102 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
103 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
104 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
105 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
106 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
107 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
108 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
109 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
110 for (uint32_t i = 0; i < 8; i++) {
111 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
112 }
113 for (uint32_t i = 0; i < 16; i++) {
114 params.sse2.output_min[i] = output_min;
115 params.sse2.output_max[i] = output_max;
116 }
117 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
118 params.neon.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
119 params.neon.multiplier = multiplier;
120 params.neon.right_shift = -shift;
121 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
122 params.neon.output_min = output_min;
123 params.neon.output_max = output_max;
124 #else
125 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
126 const uint32_t remainder_threshold = remainder_mask >> 1;
127 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
128 params.scalar.multiplier = multiplier;
129 params.scalar.remainder_mask = (int32_t) remainder_mask;
130 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
131 params.scalar.shift = (uint32_t) shift;
132 params.scalar.output_min_less_zero_point =
133 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
134 params.scalar.output_max_less_zero_point =
135 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
136 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
137 #endif
138 return params;
139 }
140
xnn_init_scalar_qs8_gemm_params(float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)141 static inline union xnn_qs8_gemm_params xnn_init_scalar_qs8_gemm_params(
142 float scale,
143 int8_t output_zero_point,
144 int8_t output_min,
145 int8_t output_max)
146 {
147 // Compute requantization parameters
148 const uint32_t scale_bits = fp32_to_bits(scale);
149
150 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
151 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
152 assert(multiplier >= INT32_C(0x40000000));
153 assert(multiplier <= INT32_C(0x7FFFFF80));
154
155 // Shift is in [0, 31] range.
156 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
157 assert(shift >= 0);
158 assert(shift < 32);
159
160 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
161 const uint32_t remainder_threshold = remainder_mask >> 1;
162
163 union xnn_qs8_gemm_params params;
164 params.scalar.multiplier = multiplier;
165 params.scalar.remainder_mask = (int32_t) remainder_mask;
166 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
167 params.scalar.shift = (uint32_t) shift;
168 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
169 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
170 params.scalar.output_zero_point = (int32_t) output_zero_point;
171 return params;
172 }
173
xnn_init_qs8_gemm_params(float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)174 static inline union xnn_qs8_gemm_params xnn_init_qs8_gemm_params(
175 float scale,
176 int8_t output_zero_point,
177 int8_t output_min,
178 int8_t output_max)
179 {
180 // Compute requantization parameters.
181 const uint32_t scale_bits = fp32_to_bits(scale);
182
183 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
184 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
185 assert(multiplier >= INT32_C(0x40000000));
186 assert(multiplier <= INT32_C(0x7FFFFF80));
187
188 // Shift is in [0, 31] range.
189 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
190 assert(shift >= 0);
191 assert(shift < 32);
192
193 union xnn_qs8_gemm_params params;
194 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
195 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
196 const uint32_t remainder_threshold = remainder_mask >> 1;
197 params.sse2.multiplier[0] = multiplier;
198 params.sse2.multiplier[1] = multiplier;
199 params.sse2.multiplier[2] = multiplier;
200 params.sse2.multiplier[3] = multiplier;
201 params.sse2.rounding[0] = UINT64_C(0x40000000);
202 params.sse2.rounding[1] = UINT64_C(0x40000000);
203 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
204 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
205 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
206 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
207 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
208 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
209 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
210 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
211 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
212 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
213 for (uint32_t i = 0; i < 8; i++) {
214 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
215 params.sse2.output_min[i] = (int16_t) output_min;
216 params.sse2.output_max[i] = (int16_t) output_max;
217 }
218 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
219 params.neon.multiplier = multiplier;
220 params.neon.right_shift = -shift;
221 params.neon.output_zero_point = (int16_t) output_zero_point;
222 params.neon.output_min = output_min;
223 params.neon.output_max = output_max;
224 #elif XNN_ARCH_WASMSIMD
225 const int64_t twice_multiplier = INT64_C(2) * (int64_t) multiplier;
226 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
227 const uint32_t remainder_threshold = remainder_mask >> 1;
228 params.wasmsimd.multiplier[0] = twice_multiplier;
229 params.wasmsimd.multiplier[1] = twice_multiplier;
230 params.wasmsimd.rounding[0] = INT64_C(0x80000000);
231 params.wasmsimd.rounding[1] = INT64_C(0x80000000);
232 params.wasmsimd.remainder_mask[0] = (int32_t) remainder_mask;
233 params.wasmsimd.remainder_mask[1] = (int32_t) remainder_mask;
234 params.wasmsimd.remainder_mask[2] = (int32_t) remainder_mask;
235 params.wasmsimd.remainder_mask[3] = (int32_t) remainder_mask;
236 params.wasmsimd.remainder_threshold[0] = (int32_t) remainder_threshold;
237 params.wasmsimd.remainder_threshold[1] = (int32_t) remainder_threshold;
238 params.wasmsimd.remainder_threshold[2] = (int32_t) remainder_threshold;
239 params.wasmsimd.remainder_threshold[3] = (int32_t) remainder_threshold;
240 params.wasmsimd.shift = shift;
241 for (uint32_t i = 0; i < 8; i++) {
242 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
243 }
244 for (uint32_t i = 0; i < 16; i++) {
245 params.wasmsimd.output_min[i] = output_min;
246 params.wasmsimd.output_max[i] = output_max;
247 }
248 #else
249 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
250 const uint32_t remainder_threshold = remainder_mask >> 1;
251 params.scalar.multiplier = multiplier;
252 params.scalar.remainder_mask = (int32_t) remainder_mask;
253 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
254 params.scalar.shift = (uint32_t) shift;
255 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
256 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
257 params.scalar.output_zero_point = (int32_t) output_zero_point;
258 #endif
259 return params;
260 }
261
xnn_init_scalar_qs8_gemm_xw_params(float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)262 static inline union xnn_qs8_gemm_xw_params xnn_init_scalar_qs8_gemm_xw_params(
263 float scale,
264 int8_t output_zero_point,
265 int8_t output_min,
266 int8_t output_max)
267 {
268 union {
269 union xnn_qs8_gemm_xw_params gemm_xw;
270 union xnn_qs8_gemm_params gemm;
271 } params;
272 params.gemm = xnn_init_scalar_qs8_gemm_params(scale, output_zero_point, output_min, output_max);
273 return params.gemm_xw;
274 }
275
xnn_init_qs8_gemm_xw_params(float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)276 static inline union xnn_qs8_gemm_xw_params xnn_init_qs8_gemm_xw_params(
277 float scale,
278 int8_t output_zero_point,
279 int8_t output_min,
280 int8_t output_max)
281 {
282 union {
283 union xnn_qs8_gemm_xw_params gemm_xw;
284 union xnn_qs8_gemm_params gemm;
285 } params;
286 params.gemm = xnn_init_qs8_gemm_params(scale, output_zero_point, output_min, output_max);
287 return params.gemm_xw;
288 }
289
xnn_init_qu8_avgpool_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)290 static inline union xnn_qu8_avgpool_params xnn_init_qu8_avgpool_params(
291 int32_t bias,
292 float scale,
293 uint8_t output_zero_point,
294 uint8_t output_min,
295 uint8_t output_max)
296 {
297 // Compute requantization parameters.
298 assert(scale >= 0x1.0p-32f);
299 assert(scale < 256.0f);
300 const uint32_t scale_bits = fp32_to_bits(scale);
301
302 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
303 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
304 assert(multiplier >= INT32_C(0x00800000));
305 assert(multiplier <= INT32_C(0x00FFFFFF));
306
307 // Shift is in [16, 55] range.
308 const int32_t shift = 127 + 23 - (scale_bits >> 23);
309 assert(shift >= 16);
310 assert(shift < 64);
311
312 union xnn_qu8_avgpool_params params;
313 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
314 const uint32_t right_shift = (uint32_t) shift;
315 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
316 params.sse2.bias[0] = bias;
317 params.sse2.bias[1] = bias;
318 params.sse2.bias[2] = bias;
319 params.sse2.bias[3] = bias;
320 params.sse2.multiplier[0] = (uint32_t) multiplier;
321 params.sse2.multiplier[1] = (uint32_t) multiplier;
322 params.sse2.multiplier[2] = (uint32_t) multiplier;
323 params.sse2.multiplier[3] = (uint32_t) multiplier;
324 params.sse2.rounding[0] = rounding;
325 params.sse2.rounding[1] = rounding;
326 params.sse2.right_shift[0] = (uint64_t) right_shift;
327 params.sse2.right_shift[1] = (uint64_t) right_shift;
328 for (uint32_t i = 0; i < 8; i++) {
329 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
330 }
331 for (uint32_t i = 0; i < 16; i++) {
332 params.sse2.output_min[i] = output_min;
333 params.sse2.output_max[i] = output_max;
334 }
335 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
336 params.neon.bias = bias;
337 params.neon.multiplier = multiplier;
338 params.neon.left_shift = (int64_t) -shift;
339 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
340 params.neon.output_min = output_min;
341 params.neon.output_max = output_max;
342 #else
343 const uint32_t right_shift = (uint32_t) shift;
344 const int64_t rounding = INT64_C(1) << (right_shift - 1);
345 params.scalar.bias = bias;
346 params.scalar.multiplier = multiplier;
347 params.scalar.rounding = rounding;
348 params.scalar.right_shift = right_shift;
349 params.scalar.output_min_less_zero_point =
350 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
351 params.scalar.output_max_less_zero_point =
352 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
353 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
354 #endif
355 return params;
356 }
357
xnn_init_scalar_qu8_avgpool_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)358 static inline union xnn_qu8_avgpool_params xnn_init_scalar_qu8_avgpool_params(
359 int32_t bias,
360 float scale,
361 uint8_t output_zero_point,
362 uint8_t output_min,
363 uint8_t output_max)
364 {
365 // Compute requantization parameters.
366 assert(scale >= 0x1.0p-32f);
367 assert(scale < 256.0f);
368 const uint32_t scale_bits = fp32_to_bits(scale);
369
370 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
371 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
372 assert(multiplier >= INT32_C(0x00800000));
373 assert(multiplier <= INT32_C(0x00FFFFFF));
374
375 // Shift is in [16, 55] range.
376 const int32_t shift = 127 + 23 - (scale_bits >> 23);
377 assert(shift >= 16);
378 assert(shift < 64);
379
380 union xnn_qu8_avgpool_params params;
381 const uint32_t right_shift = (uint32_t) shift;
382 const int64_t rounding = INT64_C(1) << (right_shift - 1);
383 params.scalar.bias = bias;
384 params.scalar.rounding = rounding;
385 params.scalar.multiplier = multiplier;
386 params.scalar.right_shift = right_shift;
387 params.scalar.output_min_less_zero_point =
388 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
389 params.scalar.output_max_less_zero_point =
390 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
391 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
392 return params;
393 }
394
xnn_update_qu8_avgpool_params(union xnn_qu8_avgpool_params * params,int32_t bias,float scale)395 static inline void xnn_update_qu8_avgpool_params(
396 union xnn_qu8_avgpool_params* params,
397 int32_t bias,
398 float scale)
399 {
400 // Compute requantization parameters.
401 assert(scale >= 0x1.0p-32f);
402 assert(scale < 256.0f);
403 const uint32_t scale_bits = fp32_to_bits(scale);
404
405 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
406 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
407 assert(multiplier >= INT32_C(0x00800000));
408 assert(multiplier <= INT32_C(0x00FFFFFF));
409
410 // Shift is in [16, 55] range.
411 const int32_t shift = 127 + 23 - (scale_bits >> 23);
412 assert(shift >= 16);
413 assert(shift < 64);
414
415 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
416 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
417 params->sse2.bias[0] = bias;
418 params->sse2.bias[1] = bias;
419 params->sse2.bias[2] = bias;
420 params->sse2.bias[3] = bias;
421 params->sse2.multiplier[0] = (uint32_t) multiplier;
422 params->sse2.multiplier[1] = (uint32_t) multiplier;
423 params->sse2.multiplier[2] = (uint32_t) multiplier;
424 params->sse2.multiplier[3] = (uint32_t) multiplier;
425 params->sse2.rounding[0] = rounding;
426 params->sse2.rounding[1] = rounding;
427 params->sse2.right_shift[0] = (uint64_t) (uint32_t) shift;
428 params->sse2.right_shift[1] = (uint64_t) (uint32_t) shift;
429 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
430 params->neon.bias = bias;
431 params->neon.multiplier = multiplier;
432 params->neon.left_shift = (int64_t) -shift;
433 #else
434 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
435 params->scalar.bias = bias;
436 params->scalar.multiplier = multiplier;
437 params->scalar.rounding = rounding;
438 params->scalar.right_shift = (uint32_t) shift;
439 #endif
440 }
441
xnn_init_qs8_avgpool_params(int32_t bias,float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)442 static inline union xnn_qs8_avgpool_params xnn_init_qs8_avgpool_params(
443 int32_t bias,
444 float scale,
445 int8_t output_zero_point,
446 int8_t output_min,
447 int8_t output_max)
448 {
449 // Compute requantization parameters.
450 assert(scale >= 0x1.0p-32f);
451 assert(scale < 256.0f);
452 const uint32_t scale_bits = fp32_to_bits(scale);
453
454 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
455 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
456 assert(multiplier >= INT32_C(0x00800000));
457 assert(multiplier <= INT32_C(0x00FFFFFF));
458
459 // Shift is in [16, 55] range.
460 const int32_t shift = 127 + 23 - (scale_bits >> 23);
461 assert(shift >= 16);
462 assert(shift < 64);
463
464 union xnn_qs8_avgpool_params params;
465 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
466 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
467 params.sse2.bias[0] = bias;
468 params.sse2.bias[1] = bias;
469 params.sse2.bias[2] = bias;
470 params.sse2.bias[3] = bias;
471 params.sse2.multiplier[0] = (uint32_t) multiplier;
472 params.sse2.multiplier[1] = (uint32_t) multiplier;
473 params.sse2.multiplier[2] = (uint32_t) multiplier;
474 params.sse2.multiplier[3] = (uint32_t) multiplier;
475 params.sse2.rounding[0] = rounding;
476 params.sse2.rounding[1] = rounding;
477 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
478 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
479 for (uint32_t i = 0; i < 8; i++) {
480 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
481 params.sse2.output_min[i] = (int16_t) output_min;
482 params.sse2.output_max[i] = (int16_t) output_max;
483 }
484 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
485 params.neon.bias = bias;
486 params.neon.multiplier = multiplier;
487 params.neon.left_shift = (int64_t) -shift;
488 params.neon.output_zero_point = (int16_t) output_zero_point;
489 params.neon.output_min = output_min;
490 params.neon.output_max = output_max;
491 #elif XNN_ARCH_WASMSIMD
492 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
493 params.wasmsimd.bias[0] = bias;
494 params.wasmsimd.bias[1] = bias;
495 params.wasmsimd.bias[2] = bias;
496 params.wasmsimd.bias[3] = bias;
497 params.wasmsimd.multiplier[0] = (int64_t) multiplier;
498 params.wasmsimd.multiplier[1] = (int64_t) multiplier;
499 params.wasmsimd.rounding[0] = rounding;
500 params.wasmsimd.rounding[1] = rounding;
501 params.wasmsimd.shift = shift;
502 for (uint32_t i = 0; i < 8; i++) {
503 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
504 }
505 for (uint32_t i = 0; i < 16; i++) {
506 params.wasmsimd.output_min[i] = output_min;
507 params.wasmsimd.output_max[i] = output_max;
508 }
509 #else
510 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
511 params.scalar.bias = bias;
512 params.scalar.multiplier = multiplier;
513 params.scalar.rounding = rounding;
514 params.scalar.shift = (uint32_t) shift;
515 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
516 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
517 params.scalar.output_zero_point = (int32_t) output_zero_point;
518 #endif
519 return params;
520 }
521
xnn_init_scalar_qs8_avgpool_params(int32_t bias,float scale,int8_t output_zero_point,int8_t output_min,int8_t output_max)522 static inline union xnn_qs8_avgpool_params xnn_init_scalar_qs8_avgpool_params(
523 int32_t bias,
524 float scale,
525 int8_t output_zero_point,
526 int8_t output_min,
527 int8_t output_max)
528 {
529 // Compute requantization parameters.
530 assert(scale >= 0x1.0p-32f);
531 assert(scale < 256.0f);
532 const uint32_t scale_bits = fp32_to_bits(scale);
533
534 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
535 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
536 assert(multiplier >= INT32_C(0x00800000));
537 assert(multiplier <= INT32_C(0x00FFFFFF));
538
539 // Shift is in [16, 55] range.
540 const int32_t shift = 127 + 23 - (scale_bits >> 23);
541 assert(shift >= 16);
542 assert(shift < 64);
543
544 union xnn_qs8_avgpool_params params;
545 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
546 params.scalar.bias = bias;
547 params.scalar.rounding = rounding;
548 params.scalar.multiplier = multiplier;
549 params.scalar.shift = shift;
550 params.scalar.output_min_less_zero_point = (int32_t) output_min - (int32_t) output_zero_point;
551 params.scalar.output_max_less_zero_point = (int32_t) output_max - (int32_t) output_zero_point;
552 params.scalar.output_zero_point = (int32_t) output_zero_point;
553 return params;
554 }
555
xnn_update_qs8_avgpool_params(union xnn_qs8_avgpool_params * params,int32_t bias,float scale)556 static inline void xnn_update_qs8_avgpool_params(
557 union xnn_qs8_avgpool_params* params,
558 int32_t bias,
559 float scale)
560 {
561 // Compute requantization parameters.
562 assert(scale >= 0x1.0p-32f);
563 assert(scale < 256.0f);
564 const uint32_t scale_bits = fp32_to_bits(scale);
565
566 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
567 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
568 assert(multiplier >= INT32_C(0x00800000));
569 assert(multiplier <= INT32_C(0x00FFFFFF));
570
571 // Shift is in [16, 55] range.
572 const int32_t shift = 127 + 23 - (scale_bits >> 23);
573 assert(shift >= 16);
574 assert(shift < 64);
575
576 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
577 const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
578 params->sse2.bias[0] = bias;
579 params->sse2.bias[1] = bias;
580 params->sse2.bias[2] = bias;
581 params->sse2.bias[3] = bias;
582 params->sse2.multiplier[0] = (uint32_t) multiplier;
583 params->sse2.multiplier[1] = (uint32_t) multiplier;
584 params->sse2.multiplier[2] = (uint32_t) multiplier;
585 params->sse2.multiplier[3] = (uint32_t) multiplier;
586 params->sse2.rounding[0] = rounding;
587 params->sse2.rounding[1] = rounding;
588 params->sse2.shift[0] = (uint64_t) (uint32_t) shift;
589 params->sse2.shift[1] = (uint64_t) (uint32_t) shift;
590 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
591 params->neon.bias = bias;
592 params->neon.multiplier = multiplier;
593 params->neon.left_shift = (int64_t) -shift;
594 #elif XNN_ARCH_WASMSIMD
595 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
596 params->wasmsimd.bias[0] = bias;
597 params->wasmsimd.bias[1] = bias;
598 params->wasmsimd.bias[2] = bias;
599 params->wasmsimd.bias[3] = bias;
600 params->wasmsimd.multiplier[0] = (int64_t) multiplier;
601 params->wasmsimd.multiplier[1] = (int64_t) multiplier;
602 params->wasmsimd.rounding[0] = rounding;
603 params->wasmsimd.rounding[1] = rounding;
604 params->wasmsimd.shift = shift;
605 #else
606 const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
607 params->scalar.bias = bias;
608 params->scalar.multiplier = multiplier;
609 params->scalar.rounding = rounding;
610 params->scalar.shift = (uint32_t) shift;
611 #endif
612 }
613
xnn_update_f16_scaleminmax_params(struct xnn_f16_scaleminmax_params * params,uint16_t scale)614 static inline void xnn_update_f16_scaleminmax_params(
615 struct xnn_f16_scaleminmax_params* params,
616 uint16_t scale)
617 {
618 params->scale = scale;
619 }
620
xnn_update_f32_scaleminmax_params(union xnn_f32_scaleminmax_params * params,float scale)621 static inline void xnn_update_f32_scaleminmax_params(
622 union xnn_f32_scaleminmax_params* params,
623 float scale)
624 {
625 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
626 for (uint32_t i = 0; i < 4; i++) {
627 params->sse2.scale[i] = scale;
628 }
629 #else
630 params->scalar.scale = scale;
631 #endif
632 }
633
xnn_init_f16_scaleminmax_params(uint16_t scale,uint16_t min,uint16_t max)634 static inline struct xnn_f16_scaleminmax_params xnn_init_f16_scaleminmax_params(
635 uint16_t scale,
636 uint16_t min,
637 uint16_t max)
638 {
639 struct xnn_f16_scaleminmax_params params;
640 params.scale = scale;
641 params.min = min;
642 params.max = max;
643 return params;
644 }
645
xnn_init_f32_scaleminmax_params(float scale,float min,float max)646 static inline union xnn_f32_scaleminmax_params xnn_init_f32_scaleminmax_params(
647 float scale,
648 float min,
649 float max)
650 {
651 union xnn_f32_scaleminmax_params params;
652 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
653 for (uint32_t i = 0; i < 4; i++) {
654 params.sse2.scale[i] = scale;
655 params.sse2.min[i] = min;
656 params.sse2.max[i] = max;
657 }
658 #else
659 params.scalar.scale = scale;
660 params.scalar.min = min;
661 params.scalar.max = max;
662 #endif
663 return params;
664 }
665
xnn_init_f32_gavgpool_params(float multiplier,float output_min,float output_max,uint32_t width)666 static inline union xnn_f32_gavgpool_params xnn_init_f32_gavgpool_params(
667 float multiplier,
668 float output_min,
669 float output_max,
670 uint32_t width)
671 {
672 union xnn_f32_gavgpool_params params;
673 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
674 for (uint32_t i = 0; i < 4; i++) {
675 params.sse.multiplier[i] = multiplier;
676 params.sse.output_min[i] = output_min;
677 params.sse.output_max[i] = output_max;
678 }
679
680 const uint32_t w = (width - 1) & 3;
681 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
682 params.sse.mask[1] = -(uint32_t) (w >= 1);
683 params.sse.mask[2] = -(uint32_t) (w >= 2);
684 params.sse.mask[3] = -(uint32_t) (w >= 3);
685 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
686 params.neon.multiplier = multiplier;
687 params.neon.output_min = output_min;
688 params.neon.output_max = output_max;
689
690 const uint32_t w = (width - 1) & 3;
691 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
692 params.neon.mask[1] = -(uint32_t) (w >= 1);
693 params.neon.mask[2] = -(uint32_t) (w >= 2);
694 params.neon.mask[3] = -(uint32_t) (w >= 3);
695 #else
696 params.scalar.multiplier = multiplier;
697 params.scalar.output_min = output_min;
698 params.scalar.output_max = output_max;
699
700 const uint32_t w = (width - 1) & 3;
701 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
702 params.scalar.mask[1] = -(int32_t) (w >= 1);
703 params.scalar.mask[2] = -(int32_t) (w >= 2);
704 params.scalar.mask[3] = -(int32_t) (w >= 3);
705 #endif
706 return params;
707 }
708
xnn_update_f32_gavgpool_params(union xnn_f32_gavgpool_params * params,float multiplier,uint32_t width)709 static inline void xnn_update_f32_gavgpool_params(
710 union xnn_f32_gavgpool_params* params,
711 float multiplier,
712 uint32_t width)
713 {
714 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
715 for (uint32_t i = 0; i < 4; i++) {
716 params->sse.multiplier[i] = multiplier;
717 }
718
719 const uint32_t w = (width - 1) & 3;
720 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
721 params->sse.mask[1] = -(uint32_t) (w >= 1);
722 params->sse.mask[2] = -(uint32_t) (w >= 2);
723 params->sse.mask[3] = -(uint32_t) (w >= 3);
724 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
725 params->neon.multiplier = multiplier;
726
727 const uint32_t w = (width - 1) & 3;
728 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
729 params->neon.mask[1] = -(uint32_t) (w >= 1);
730 params->neon.mask[2] = -(uint32_t) (w >= 2);
731 params->neon.mask[3] = -(uint32_t) (w >= 3);
732 #else
733 params->scalar.multiplier = multiplier;
734
735 const uint32_t w = (width - 1) & 3;
736 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
737 params->scalar.mask[1] = -(int32_t) (w >= 1);
738 params->scalar.mask[2] = -(int32_t) (w >= 2);
739 params->scalar.mask[3] = -(int32_t) (w >= 3);
740 #endif
741 }
742
xnn_init_scalar_f32_scaleminmax_params(float scale,float min,float max)743 static inline union xnn_f32_scaleminmax_params xnn_init_scalar_f32_scaleminmax_params(
744 float scale,
745 float min,
746 float max)
747 {
748 union xnn_f32_scaleminmax_params params;
749 params.scalar.scale = scale;
750 params.scalar.min = min;
751 params.scalar.max = max;
752 return params;
753 }
754
xnn_init_scalar_f32_gavgpool_params(float multiplier,float output_min,float output_max,uint32_t width)755 static inline union xnn_f32_gavgpool_params xnn_init_scalar_f32_gavgpool_params(
756 float multiplier,
757 float output_min,
758 float output_max,
759 uint32_t width)
760 {
761 union xnn_f32_gavgpool_params params;
762 params.scalar.multiplier = multiplier;
763 params.scalar.output_min = output_min;
764 params.scalar.output_max = output_max;
765
766 const uint32_t w = (width - 1) & 3;
767 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
768 params.scalar.mask[1] = -(int32_t) (w >= 1);
769 params.scalar.mask[2] = -(int32_t) (w >= 2);
770 params.scalar.mask[3] = -(int32_t) (w >= 3);
771 return params;
772 }
773
xnn_init_f16_minmax_params(uint16_t min,uint16_t max)774 static inline struct xnn_f16_minmax_params xnn_init_f16_minmax_params(
775 uint16_t min,
776 uint16_t max)
777 {
778 struct xnn_f16_minmax_params params;
779 params.min = min;
780 params.max = max;
781 return params;
782 }
783
xnn_init_f32_minmax_params(float output_min,float output_max)784 static inline union xnn_f32_minmax_params xnn_init_f32_minmax_params(
785 float output_min,
786 float output_max)
787 {
788 union xnn_f32_minmax_params params;
789 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
790 for (uint32_t i = 0; i < 4; i++) {
791 params.sse.min[i] = output_min;
792 params.sse.max[i] = output_max;
793 }
794 #else
795 params.scalar.min = output_min;
796 params.scalar.max = output_max;
797 #endif
798 return params;
799 }
800
xnn_init_scalar_f32_minmax_params(float output_min,float output_max)801 static inline union xnn_f32_minmax_params xnn_init_scalar_f32_minmax_params(
802 float output_min,
803 float output_max)
804 {
805 union xnn_f32_minmax_params params;
806 params.scalar.min = output_min;
807 params.scalar.max = output_max;
808 return params;
809 }
810
xnn_init_f16_hswish_params(void)811 static inline struct xnn_f16_hswish_params xnn_init_f16_hswish_params(void)
812 {
813 struct xnn_f16_hswish_params params;
814 params.sixth = UINT16_C(0x3155);
815 params.three = UINT16_C(0x4200);
816 params.six = UINT16_C(0x4600);
817 return params;
818 }
819
xnn_init_f32_hswish_params(void)820 static inline union xnn_f32_hswish_params xnn_init_f32_hswish_params(void)
821 {
822 union xnn_f32_hswish_params params;
823 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
824 for (uint32_t i = 0; i < 4; i++) {
825 params.sse.sixth[i] = 0x1.555556p-3f;
826 params.sse.half[i] = 0.5f;
827 params.sse.one[i] = 1.0f;
828 }
829 #else
830 params.scalar.sixth = 0x1.555556p-3f;
831 params.scalar.three = 3.0f;
832 params.scalar.six = 6.0f;
833 #endif
834 return params;
835 }
836
xnn_init_scalar_f32_hswish_params(void)837 static inline union xnn_f32_hswish_params xnn_init_scalar_f32_hswish_params(void)
838 {
839 union xnn_f32_hswish_params params;
840 params.scalar.sixth = 0x1.555556p-3f;
841 params.scalar.three = 3.0f;
842 params.scalar.six = 6.0f;
843 return params;
844 }
845
xnn_init_f32_abs_params(void)846 static inline union xnn_f32_abs_params xnn_init_f32_abs_params(void)
847 {
848 union xnn_f32_abs_params params = { 0 };
849 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
850 for (uint32_t i = 0; i < 4; i++) {
851 params.sse.nonsign_mask[i] = math_nonsign_mask_f32();
852 }
853 #elif XNN_ARCH_WASMSIMD
854 params.wasmsimd.nonsign_mask = math_nonsign_mask_f32();
855 #endif
856 return params;
857 }
858
xnn_init_scalar_f32_abs_params(void)859 static inline union xnn_f32_abs_params xnn_init_scalar_f32_abs_params(void)
860 {
861 union xnn_f32_abs_params params = { 0 };
862 return params;
863 }
864
xnn_init_f32_neg_params(void)865 static inline union xnn_f32_neg_params xnn_init_f32_neg_params(void)
866 {
867 union xnn_f32_neg_params params = { 0 };
868 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
869 for (uint32_t i = 0; i < 4; i++) {
870 params.sse.sign_mask[i] = -0.0f;
871 }
872 #elif XNN_ARCH_WASMSIMD
873 params.wasmsimd.sign_mask = -0.0f;
874 #endif
875 return params;
876 }
877
xnn_init_scalar_f32_neg_params(void)878 static inline union xnn_f32_neg_params xnn_init_scalar_f32_neg_params(void)
879 {
880 union xnn_f32_neg_params params = { 0 };
881 return params;
882 }
883
xnn_init_f32_rnd_params(void)884 static inline union xnn_f32_rnd_params xnn_init_f32_rnd_params(void)
885 {
886 union xnn_f32_rnd_params params = { 0 };
887 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
888 for (uint32_t i = 0; i < 4; i++) {
889 params.sse2.sign_mask[i] = -0.0f;
890 }
891 for (uint32_t i = 0; i < 4; i++) {
892 params.sse2.one[i] = 1.0f;
893 }
894 #endif
895 return params;
896 }
897
xnn_init_scalar_f32_rnd_params(void)898 static inline union xnn_f32_rnd_params xnn_init_scalar_f32_rnd_params(void)
899 {
900 union xnn_f32_rnd_params params = { 0 };
901 return params;
902 }
903
xnn_init_f32_elu_params(float prescale,float alpha,float beta)904 static inline union xnn_f32_elu_params xnn_init_f32_elu_params(float prescale, float alpha, float beta)
905 {
906 union xnn_f32_elu_params params;
907 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
908 for (uint32_t i = 0; i < 4; i++) {
909 params.sse.prescale[i] = prescale;
910 params.sse.alpha[i] = alpha;
911 params.sse.beta[i] = beta;
912 }
913 #else
914 params.scalar.prescale = prescale;
915 params.scalar.alpha = alpha;
916 params.scalar.beta = beta;
917 #endif
918 return params;
919 }
920
xnn_init_scalar_f32_elu_params(float prescale,float alpha,float beta)921 static inline union xnn_f32_elu_params xnn_init_scalar_f32_elu_params(float prescale, float alpha, float beta)
922 {
923 union xnn_f32_elu_params params;
924 params.scalar.prescale = prescale;
925 params.scalar.alpha = alpha;
926 params.scalar.beta = beta;
927 return params;
928 }
929
xnn_init_f32_lrelu_params(float slope)930 static inline union xnn_f32_lrelu_params xnn_init_f32_lrelu_params(float slope)
931 {
932 union xnn_f32_lrelu_params params;
933 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
934 for (uint32_t i = 0; i < 4; i++) {
935 params.sse.slope[i] = slope;
936 }
937 #else
938 params.scalar.slope = slope;
939 #endif
940 return params;
941 }
942
xnn_init_scalar_f32_lrelu_params(float slope)943 static inline union xnn_f32_lrelu_params xnn_init_scalar_f32_lrelu_params(float slope)
944 {
945 union xnn_f32_lrelu_params params;
946 params.scalar.slope = slope;
947 return params;
948 }
949
xnn_init_f32_sqrt_params(void)950 static inline union xnn_f32_sqrt_params xnn_init_f32_sqrt_params(void)
951 {
952 union xnn_f32_sqrt_params params = { 0 };
953 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
954 params.fma.half = 0.5f;
955 #endif
956 return params;
957 }
958
xnn_init_scalar_f32_sqrt_params(void)959 static inline union xnn_f32_sqrt_params xnn_init_scalar_f32_sqrt_params(void)
960 {
961 union xnn_f32_sqrt_params params = { 0 };
962 return params;
963 }
964
xnn_init_f32_chw_params(uint32_t width,float output_min,float output_max)965 static inline union xnn_f32_chw_params xnn_init_f32_chw_params(
966 uint32_t width,
967 float output_min,
968 float output_max)
969 {
970 union xnn_f32_chw_params params;
971 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
972 for (uint32_t i = 0; i < 4; i++) {
973 params.sse.min[i] = output_min;
974 params.sse.max[i] = output_max;
975 }
976
977 const uint32_t w4 = (width - 1) & 3;
978 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
979 params.sse.mask[1] = -(uint32_t) (w4 >= 1);
980 params.sse.mask[2] = -(uint32_t) (w4 >= 2);
981 params.sse.mask[3] = -(uint32_t) (w4 >= 3);
982
983 const uint32_t w8 = (width - 1) & 7;
984 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
985 params.sse.mask_even[1] = -(uint32_t) (w8 >= 2);
986 params.sse.mask_even[2] = -(uint32_t) (w8 >= 4);
987 params.sse.mask_even[3] = -(uint32_t) (w8 >= 6);
988 params.sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
989 params.sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
990 params.sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
991 params.sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
992 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
993 params.neon.min = output_min;
994 params.neon.max = output_max;
995
996 const uint32_t w4 = (width - 1) & 3;
997 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
998 params.neon.mask[1] = -(uint32_t) (w4 >= 1);
999 params.neon.mask[2] = -(uint32_t) (w4 >= 2);
1000 params.neon.mask[3] = -(uint32_t) (w4 >= 3);
1001
1002 const uint32_t w8 = (width - 1) & 7;
1003 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1004 params.neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1005 params.neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1006 params.neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1007 params.neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1008 params.neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1009 params.neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1010 params.neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1011 #else
1012 params.scalar.min = output_min;
1013 params.scalar.max = output_max;
1014
1015 const uint32_t w4 = (width - 1) & 3;
1016 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1017 params.scalar.mask[1] = -(uint32_t) (w4 >= 1);
1018 params.scalar.mask[2] = -(uint32_t) (w4 >= 2);
1019 params.scalar.mask[3] = -(uint32_t) (w4 >= 3);
1020
1021 const uint32_t w8 = (width - 1) & 7;
1022 params.scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1023 params.scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1024 params.scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1025 params.scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1026 params.scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1027 params.scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1028 params.scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1029 params.scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1030 #endif
1031 return params;
1032 }
1033
xnn_update_f32_chw_params(union xnn_f32_chw_params * params,uint32_t width)1034 static inline void xnn_update_f32_chw_params(
1035 union xnn_f32_chw_params* params,
1036 uint32_t width)
1037 {
1038 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1039 const uint32_t w4 = (width - 1) & 3;
1040 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
1041 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
1042 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
1043 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
1044
1045 const uint32_t w8 = (width - 1) & 7;
1046 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
1047 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
1048 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
1049 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
1050 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
1051 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
1052 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
1053 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
1054 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1055 const uint32_t w4 = (width - 1) & 3;
1056 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
1057 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
1058 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
1059 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
1060
1061 const uint32_t w8 = (width - 1) & 7;
1062 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
1063 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
1064 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
1065 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
1066 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
1067 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
1068 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
1069 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
1070 #else
1071 const uint32_t w4 = (width - 1) & 3;
1072 params->scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1073 params->scalar.mask[1] = -(uint32_t) (w4 >= 1);
1074 params->scalar.mask[2] = -(uint32_t) (w4 >= 2);
1075 params->scalar.mask[3] = -(uint32_t) (w4 >= 3);
1076
1077 const uint32_t w8 = (width - 1) & 7;
1078 params->scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1079 params->scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1080 params->scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1081 params->scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1082 params->scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1083 params->scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1084 params->scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1085 params->scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1086 #endif
1087 }
1088
xnn_init_scalar_f32_chw_params(uint32_t width,float output_min,float output_max)1089 static inline union xnn_f32_chw_params xnn_init_scalar_f32_chw_params(
1090 uint32_t width,
1091 float output_min,
1092 float output_max)
1093 {
1094 union xnn_f32_chw_params params;
1095 params.scalar.min = output_min;
1096 params.scalar.max = output_max;
1097
1098 const uint32_t w4 = (width - 1) & 3;
1099 params.scalar.mask[0] = UINT32_C(0xFFFFFFFF);
1100 params.scalar.mask[1] = -(uint32_t) (w4 >= 1);
1101 params.scalar.mask[2] = -(uint32_t) (w4 >= 2);
1102 params.scalar.mask[3] = -(uint32_t) (w4 >= 3);
1103
1104 const uint32_t w8 = (width - 1) & 7;
1105 params.scalar.mask_even[0] = UINT32_C(0xFFFFFFFF);
1106 params.scalar.mask_even[1] = -(uint32_t) (w8 >= 2);
1107 params.scalar.mask_even[2] = -(uint32_t) (w8 >= 4);
1108 params.scalar.mask_even[3] = -(uint32_t) (w8 >= 6);
1109 params.scalar.mask_odd[0] = -(uint32_t) (w8 >= 1);
1110 params.scalar.mask_odd[1] = -(uint32_t) (w8 >= 3);
1111 params.scalar.mask_odd[2] = -(uint32_t) (w8 >= 5);
1112 params.scalar.mask_odd[3] = -(uint32_t) (w8 >= 7);
1113
1114 return params;
1115 }
1116
xnn_init_u8_minmax_params(uint8_t output_min,uint8_t output_max)1117 static inline union xnn_u8_minmax_params xnn_init_u8_minmax_params(
1118 uint8_t output_min,
1119 uint8_t output_max)
1120 {
1121 assert(output_min < output_max);
1122
1123 union xnn_u8_minmax_params params;
1124 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1125 for (uint32_t i = 0; i < 16; i++) {
1126 params.sse2.min[i] = output_min;
1127 params.sse2.max[i] = output_max;
1128 }
1129 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1130 params.neon.min = output_min;
1131 params.neon.max = output_max;
1132 #else
1133 params.scalar.min = (int32_t) (uint32_t) output_min;
1134 params.scalar.max = (int32_t) (uint32_t) output_max;
1135 #endif
1136 return params;
1137 }
1138
xnn_init_scalar_u8_minmax_params(uint8_t output_min,uint8_t output_max)1139 static inline union xnn_u8_minmax_params xnn_init_scalar_u8_minmax_params(
1140 uint8_t output_min,
1141 uint8_t output_max)
1142 {
1143 assert(output_min < output_max);
1144
1145 union xnn_u8_minmax_params params;
1146 params.scalar.min = (int32_t) (uint32_t) output_min;
1147 params.scalar.max = (int32_t) (uint32_t) output_max;
1148 return params;
1149 }
1150
xnn_init_qu8_add_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)1151 static inline union xnn_qu8_add_params xnn_init_qu8_add_params(
1152 uint8_t a_zero_point,
1153 uint8_t b_zero_point,
1154 uint8_t output_zero_point,
1155 float a_output_scale,
1156 float b_output_scale,
1157 uint8_t output_min,
1158 uint8_t output_max)
1159 {
1160 assert(a_output_scale >= 0x1.0p-14f);
1161 assert(b_output_scale >= 0x1.0p-14f);
1162 assert(a_output_scale < 0x1.0p+8f);
1163 assert(b_output_scale < 0x1.0p+8f);
1164
1165 // Compute requantization parameters.
1166 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
1167 assert(max_output_scale >= 0x1.0p-14f);
1168 assert(max_output_scale < 0x1.0p+8f);
1169 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1170 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1171 // Shift is in [13, 31] range.
1172 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1173 assert(shift < 32);
1174 assert(shift >= 13);
1175
1176 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1177
1178 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1179 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(a_output_scale * scale_multiplier);
1180 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(b_output_scale * scale_multiplier);
1181 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1182 assert(a_multiplier < UINT32_C(0x00400000));
1183 assert(b_multiplier < UINT32_C(0x00400000));
1184
1185 union xnn_qu8_add_params params;
1186 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1187 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1188 const uint32_t remainder_threshold = remainder_mask >> 1;
1189 const int32_t zero_point_product =
1190 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1191 for (uint32_t i = 0; i < 4; i++) {
1192 params.sse2.zero_point_product[i] = zero_point_product;
1193 }
1194 for (uint32_t i = 0; i < 8; i++) {
1195 params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
1196 }
1197 for (uint32_t i = 0; i < 8; i++) {
1198 params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
1199 params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
1200 params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
1201 params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
1202 }
1203 params.sse2.a_multiplier = a_multiplier;
1204 params.sse2.b_multiplier = b_multiplier;
1205 for (uint32_t i = 0; i < 4; i++) {
1206 params.sse2.remainder_mask[i] = remainder_mask;
1207 params.sse2.remainder_threshold[i] = remainder_threshold;
1208 }
1209 params.sse2.shift = shift;
1210 for (uint32_t i = 0; i < 16; i++) {
1211 params.sse2.y_min[i] = output_min;
1212 params.sse2.y_max[i] = output_max;
1213 }
1214 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1215 params.neon.a_zero_point = a_zero_point;
1216 params.neon.b_zero_point = b_zero_point;
1217 params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
1218 params.neon.a_multiplier = (int32_t) a_multiplier;
1219 params.neon.b_multiplier = (int32_t) b_multiplier;
1220 params.neon.right_shift = (int32_t) -shift;
1221 params.neon.y_min = output_min;
1222 params.neon.y_max = output_max;
1223 #else
1224 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1225 const uint32_t remainder_threshold = remainder_mask >> 1;
1226 params.scalar.zero_point_product =
1227 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1228 params.scalar.a_multiplier = a_multiplier;
1229 params.scalar.b_multiplier = b_multiplier;
1230 params.scalar.remainder_mask = (int32_t) remainder_mask;
1231 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1232 params.scalar.shift = shift;
1233 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1234 params.scalar.y_min = (int32_t) (uint32_t) output_min;
1235 params.scalar.y_max = (int32_t) (uint32_t) output_max;
1236 #endif
1237 return params;
1238 }
1239
xnn_init_scalar_qu8_add_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)1240 static inline union xnn_qu8_add_params xnn_init_scalar_qu8_add_params(
1241 uint8_t a_zero_point,
1242 uint8_t b_zero_point,
1243 uint8_t output_zero_point,
1244 float a_output_scale,
1245 float b_output_scale,
1246 uint8_t output_min,
1247 uint8_t output_max)
1248 {
1249 assert(a_output_scale >= 0x1.0p-10f);
1250 assert(b_output_scale >= 0x1.0p-10f);
1251 assert(a_output_scale < 0x1.0p+8f);
1252 assert(b_output_scale < 0x1.0p+8f);
1253
1254 // Compute requantization parameters.
1255 const float max_output_scale = math_max_f32(a_output_scale, b_output_scale);
1256 assert(max_output_scale >= 0x1.0p-10f);
1257 assert(max_output_scale < 0x1.0p+8f);
1258 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1259 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1260 // Shift is in [13, 31] range.
1261 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1262 assert(shift < 32);
1263 assert(shift >= 13);
1264
1265 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1266 const uint32_t a_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
1267 const uint32_t b_multiplier = (uint32_t) (int32_t) lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
1268 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
1269 assert(a_multiplier < UINT32_C(0x00400000));
1270 assert(b_multiplier < UINT32_C(0x00400000));
1271
1272 union xnn_qu8_add_params params;
1273 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1274 const uint32_t remainder_threshold = remainder_mask >> 1;
1275 params.scalar.zero_point_product =
1276 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
1277 params.scalar.a_multiplier = a_multiplier;
1278 params.scalar.b_multiplier = b_multiplier;
1279 params.scalar.remainder_mask = (int32_t) remainder_mask;
1280 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1281 params.scalar.shift = shift;
1282 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
1283 params.scalar.y_min = (int32_t) (uint32_t) output_min;
1284 params.scalar.y_max = (int32_t) (uint32_t) output_max;
1285 return params;
1286 }
1287
xnn_init_qs8_add_params(int8_t x_zero_point,int8_t y_zero_point,int8_t output_zero_point,float x_output_scale,float y_output_scale,int8_t output_min,int8_t output_max)1288 static inline union xnn_qs8_add_params xnn_init_qs8_add_params(
1289 int8_t x_zero_point,
1290 int8_t y_zero_point,
1291 int8_t output_zero_point,
1292 float x_output_scale,
1293 float y_output_scale,
1294 int8_t output_min,
1295 int8_t output_max)
1296 {
1297 assert(x_output_scale >= 0x1.0p-14f);
1298 assert(y_output_scale >= 0x1.0p-14f);
1299 assert(x_output_scale < 0x1.0p+8f);
1300 assert(y_output_scale < 0x1.0p+8f);
1301
1302 // Compute requantization parameters.
1303 const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1304 assert(max_output_scale >= 0x1.0p-14f);
1305 assert(max_output_scale < 0x1.0p+8f);
1306 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1307 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1308 // Shift is in [13, 31] range.
1309 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1310 assert(shift < 32);
1311 assert(shift >= 13);
1312
1313 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
1314
1315 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1316 const int32_t x_multiplier = (int32_t) lrintf(x_output_scale * scale_multiplier);
1317 const int32_t y_multiplier = (int32_t) lrintf(y_output_scale * scale_multiplier);
1318 assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1319 assert(x_multiplier < INT32_C(0x00400000));
1320 assert(y_multiplier < INT32_C(0x00400000));
1321
1322 union xnn_qs8_add_params params;
1323 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
1324 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1325 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1326 const int32_t zero_point_product =
1327 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1328 for (uint32_t i = 0; i < 4; i++) {
1329 params.sse2.zero_point_product[i] = zero_point_product;
1330 }
1331 const uint16_t x_multiplier_lo = (uint16_t) x_multiplier;
1332 const uint16_t x_multiplier_hi = (uint16_t) ((uint32_t) x_multiplier >> 16);
1333 const uint16_t y_multiplier_lo = (uint16_t) y_multiplier;
1334 const uint16_t y_multiplier_hi = (uint16_t) ((uint32_t) y_multiplier >> 16);
1335 for (uint32_t i = 0; i < 8; i++) {
1336 params.sse2.x_multiplier_lo[i] = x_multiplier_lo;
1337 params.sse2.x_multiplier_hi[i] = x_multiplier_hi;
1338 params.sse2.y_multiplier_lo[i] = y_multiplier_lo;
1339 params.sse2.y_multiplier_hi[i] = y_multiplier_hi;
1340 }
1341 params.sse2.shift = shift;
1342 for (uint32_t i = 0; i < 4; i++) {
1343 params.sse2.x_multiplier[i] = x_multiplier;
1344 params.sse2.y_multiplier[i] = y_multiplier;
1345 params.sse2.remainder_mask[i] = remainder_mask;
1346 params.sse2.remainder_threshold[i] = remainder_threshold;
1347 }
1348 for (uint32_t i = 0; i < 8; i++) {
1349 params.sse2.output_zero_point[i] = (int16_t) output_zero_point;
1350 params.sse2.output_min[i] = (int16_t) output_min;
1351 params.sse2.output_max[i] = (int16_t) output_max;
1352 }
1353 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
1354 params.neon.x_zero_point = x_zero_point;
1355 params.neon.y_zero_point = y_zero_point;
1356 params.neon.x_multiplier = (int32_t) x_multiplier;
1357 params.neon.y_multiplier = (int32_t) y_multiplier;
1358 params.neon.right_shift = (int32_t) -shift;
1359 params.neon.output_zero_point = (int16_t) output_zero_point;
1360 params.neon.output_min = output_min;
1361 params.neon.output_max = output_max;
1362 #elif XNN_ARCH_WASMSIMD
1363 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1364 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1365 const int32_t zero_point_product =
1366 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1367 for (uint32_t i = 0; i < 4; i++) {
1368 params.wasmsimd.zero_point_product[i] = zero_point_product;
1369 params.wasmsimd.x_multiplier[i] = x_multiplier;
1370 params.wasmsimd.y_multiplier[i] = y_multiplier;
1371 params.wasmsimd.remainder_mask[i] = remainder_mask;
1372 params.wasmsimd.remainder_threshold[i] = remainder_threshold;
1373 }
1374 params.wasmsimd.shift = shift;
1375 for (uint32_t i = 0; i < 8; i++) {
1376 params.wasmsimd.output_zero_point[i] = (int16_t) output_zero_point;
1377 }
1378 for (uint32_t i = 0; i < 16; i++) {
1379 params.wasmsimd.output_min[i] = output_min;
1380 params.wasmsimd.output_max[i] = output_max;
1381 }
1382 #else
1383 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1384 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1385 params.scalar.zero_point_product =
1386 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1387 params.scalar.x_multiplier = x_multiplier;
1388 params.scalar.y_multiplier = y_multiplier;
1389 params.scalar.remainder_mask = (int32_t) remainder_mask;
1390 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1391 params.scalar.shift = (int32_t) shift;
1392 params.scalar.output_zero_point = (int32_t) output_zero_point;
1393 params.scalar.output_min = (int32_t) output_min;
1394 params.scalar.output_max = (int32_t) output_max;
1395 #endif
1396 return params;
1397 }
1398
xnn_init_scalar_qs8_add_params(int8_t x_zero_point,int8_t y_zero_point,int8_t output_zero_point,float x_output_scale,float y_output_scale,int8_t output_min,int8_t output_max)1399 static inline union xnn_qs8_add_params xnn_init_scalar_qs8_add_params(
1400 int8_t x_zero_point,
1401 int8_t y_zero_point,
1402 int8_t output_zero_point,
1403 float x_output_scale,
1404 float y_output_scale,
1405 int8_t output_min,
1406 int8_t output_max)
1407 {
1408 assert(x_output_scale >= 0x1.0p-10f);
1409 assert(y_output_scale >= 0x1.0p-10f);
1410 assert(x_output_scale < 0x1.0p+8f);
1411 assert(y_output_scale < 0x1.0p+8f);
1412
1413 // Compute requantization parameters.
1414 const float max_output_scale = math_max_f32(x_output_scale, y_output_scale);
1415 assert(max_output_scale >= 0x1.0p-10f);
1416 assert(max_output_scale < 0x1.0p+8f);
1417 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
1418 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
1419 // Shift is in [13, 31] range.
1420 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
1421 assert(shift < 32);
1422 assert(shift >= 13);
1423
1424 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
1425 const int32_t x_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(x_output_scale) + (shift << 23)));
1426 const int32_t y_multiplier = (int32_t) lrintf(fp32_from_bits(fp32_to_bits(y_output_scale) + (shift << 23)));
1427 assert((x_multiplier > y_multiplier ? x_multiplier : y_multiplier) >= INT32_C(0x00200000));
1428 assert(x_multiplier < INT32_C(0x00400000));
1429 assert(y_multiplier < INT32_C(0x00400000));
1430
1431 union xnn_qs8_add_params params;
1432 const int32_t remainder_mask = (INT32_C(1) << shift) - INT32_C(1);
1433 const int32_t remainder_threshold = (int32_t) ((uint32_t) remainder_mask >> 1);
1434 params.scalar.zero_point_product =
1435 (int32_t) -(x_multiplier * (int32_t) x_zero_point + y_multiplier * (int32_t) y_zero_point);
1436 params.scalar.x_multiplier = x_multiplier;
1437 params.scalar.y_multiplier = y_multiplier;
1438 params.scalar.remainder_mask = (int32_t) remainder_mask;
1439 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
1440 params.scalar.shift = shift;
1441 params.scalar.output_zero_point = (int32_t) output_zero_point;
1442 params.scalar.output_min = (int32_t) output_min;
1443 params.scalar.output_max = (int32_t) output_max;
1444 return params;
1445 }
1446
xnn_init_scalar_qu8_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)1447 static inline union xnn_qu8_requantization_params xnn_init_scalar_qu8_requantization_params(
1448 float scale,
1449 uint8_t zero_point,
1450 uint8_t min,
1451 uint8_t max)
1452 {
1453 // Compute requantization parameters.
1454 assert(scale < 1.0f);
1455 assert(scale >= 0x1.0p-32f);
1456 const uint32_t scale_bits = fp32_to_bits(scale);
1457
1458 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
1459 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1460 assert(multiplier >= INT32_C(0x40000000));
1461 assert(multiplier <= INT32_C(0x7FFFFF80));
1462
1463 // Shift is in [0, 31] range.
1464 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1465 assert(shift >= 0);
1466 assert(shift < 32);
1467
1468 union xnn_qu8_requantization_params params;
1469 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1470 const uint32_t remainder_threshold = remainder_mask >> 1;
1471 params.q31.multiplier = multiplier;
1472 params.q31.remainder_mask = (int32_t) remainder_mask;
1473 params.q31.remainder_threshold = (int32_t) remainder_threshold;
1474 params.q31.shift = (uint32_t) shift;
1475 params.q31.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
1476 params.q31.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
1477 params.q31.zero_point = (int32_t) (uint32_t) zero_point;
1478 return params;
1479 }
1480
xnn_init_scalar_qs8_requantization_params(float scale,int8_t zero_point,int8_t min,int8_t max)1481 static inline union xnn_qs8_requantization_params xnn_init_scalar_qs8_requantization_params(
1482 float scale,
1483 int8_t zero_point,
1484 int8_t min,
1485 int8_t max)
1486 {
1487 // Compute requantization parameters.
1488 assert(scale < 1.0f);
1489 assert(scale >= 0x1.0p-32f);
1490 const uint32_t scale_bits = fp32_to_bits(scale);
1491
1492 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
1493 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
1494 assert(multiplier >= INT32_C(0x40000000));
1495 assert(multiplier <= INT32_C(0x7FFFFF80));
1496
1497 // Shift is in [0, 31] range.
1498 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
1499 assert(shift >= 0);
1500 assert(shift < 32);
1501
1502 union xnn_qs8_requantization_params params;
1503 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
1504 const uint32_t remainder_threshold = remainder_mask >> 1;
1505 params.q31.multiplier = multiplier;
1506 params.q31.remainder_mask = (int32_t) remainder_mask;
1507 params.q31.remainder_threshold = (int32_t) remainder_threshold;
1508 params.q31.shift = (uint32_t) shift;
1509 params.q31.min_less_zero_point = (int32_t) min - (int32_t) zero_point;
1510 params.q31.max_less_zero_point = (int32_t) max - (int32_t) zero_point;
1511 params.q31.zero_point = (int32_t) zero_point;
1512 return params;
1513 }
1514
1515