1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #if defined(__cplusplus) && (__cplusplus >= 201103L)
12 #include <cstdint>
13 #include <cstddef>
14 #include <cassert>
15 #include <cmath>
16 #else
17 #include <stdint.h>
18 #include <stddef.h>
19 #include <assert.h>
20 #include <math.h>
21 #endif
22
23 #include <fp16.h>
24
25 #include <xnnpack/common.h>
26 #include <xnnpack/params.h>
27
28
xnn_init_scalar_q8_gemm_params(uint8_t input_zero_point,uint8_t kernel_zero_point,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)29 static inline union xnn_q8_gemm_params xnn_init_scalar_q8_gemm_params(
30 uint8_t input_zero_point,
31 uint8_t kernel_zero_point,
32 float scale,
33 uint8_t output_zero_point,
34 uint8_t output_min,
35 uint8_t output_max)
36 {
37 // Compute requantization parameters
38 const uint32_t scale_bits = fp32_to_bits(scale);
39
40 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
41 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
42 assert(multiplier >= INT32_C(0x40000000));
43 assert(multiplier <= INT32_C(0x7FFFFF80));
44
45 // Shift is in [0, 31] range.
46 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
47 assert(shift >= 0);
48 assert(shift < 32);
49
50 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
51 const uint32_t remainder_threshold = remainder_mask >> 1;
52
53 union xnn_q8_gemm_params params;
54 params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
55 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
56 params.scalar.multiplier = multiplier;
57 params.scalar.remainder_mask = (int32_t) remainder_mask;
58 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
59 params.scalar.shift = (uint32_t) shift;
60 params.scalar.output_min_less_zero_point =
61 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
62 params.scalar.output_max_less_zero_point =
63 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
64 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
65 return params;
66 }
67
xnn_init_q8_gemm_params(uint8_t input_zero_point,uint8_t kernel_zero_point,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)68 static inline union xnn_q8_gemm_params xnn_init_q8_gemm_params(
69 uint8_t input_zero_point,
70 uint8_t kernel_zero_point,
71 float scale,
72 uint8_t output_zero_point,
73 uint8_t output_min,
74 uint8_t output_max)
75 {
76 // Compute requantization parameters.
77 const uint32_t scale_bits = fp32_to_bits(scale);
78
79 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
80 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
81 assert(multiplier >= INT32_C(0x40000000));
82 assert(multiplier <= INT32_C(0x7FFFFF80));
83
84 // Shift is in [0, 31] range.
85 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
86 assert(shift >= 0);
87 assert(shift < 32);
88
89 union xnn_q8_gemm_params params;
90 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
91 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
92 const uint32_t remainder_threshold = remainder_mask >> 1;
93 for (uint32_t i = 0; i < 8; i++) {
94 params.sse2.input_zero_point[i] = (int16_t) (uint16_t) input_zero_point;
95 params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
96 }
97 params.sse2.multiplier[0] = multiplier;
98 params.sse2.multiplier[1] = multiplier;
99 params.sse2.multiplier[2] = multiplier;
100 params.sse2.multiplier[3] = multiplier;
101 params.sse2.rounding[0] = UINT64_C(0x40000000);
102 params.sse2.rounding[1] = UINT64_C(0x40000000);
103 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
104 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
105 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
106 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
107 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
108 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
109 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
110 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
111 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
112 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
113 for (uint32_t i = 0; i < 8; i++) {
114 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
115 }
116 for (uint32_t i = 0; i < 16; i++) {
117 params.sse2.output_max[i] = output_max;
118 params.sse2.output_min[i] = output_min;
119 }
120 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
121 params.neon.input_zero_point = (int16_t) (uint16_t) input_zero_point;
122 params.neon.kernel_zero_point = (int16_t) (uint16_t) kernel_zero_point;
123 params.neon.multiplier = multiplier;
124 params.neon.right_shift = -shift;
125 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
126 params.neon.output_max = output_max;
127 params.neon.output_min = output_min;
128 #else
129 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
130 const uint32_t remainder_threshold = remainder_mask >> 1;
131 params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
132 params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
133 params.scalar.multiplier = multiplier;
134 params.scalar.remainder_mask = (int32_t) remainder_mask;
135 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
136 params.scalar.shift = (uint32_t) shift;
137 params.scalar.output_min_less_zero_point =
138 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
139 params.scalar.output_max_less_zero_point =
140 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
141 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
142 #endif
143 return params;
144 }
145
xnn_init_q8_avgpool_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)146 static inline union xnn_q8_avgpool_params xnn_init_q8_avgpool_params(
147 int32_t bias,
148 float scale,
149 uint8_t output_zero_point,
150 uint8_t output_min,
151 uint8_t output_max)
152 {
153 // Compute requantization parameters.
154 assert(scale >= 0x1.0p-32f);
155 assert(scale < 256.0f);
156 const uint32_t scale_bits = fp32_to_bits(scale);
157
158 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
159 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
160 assert(multiplier >= INT32_C(0x00800000));
161 assert(multiplier <= INT32_C(0x00FFFFFF));
162
163 // Shift is in [16, 55] range.
164 const int32_t shift = 127 + 23 - (scale_bits >> 23);
165 assert(shift >= 16);
166 assert(shift < 64);
167
168 union xnn_q8_avgpool_params params;
169 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
170 const uint32_t right_shift = (uint32_t) shift;
171 const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
172 params.sse2.bias[0] = bias;
173 params.sse2.bias[1] = bias;
174 params.sse2.bias[2] = bias;
175 params.sse2.bias[3] = bias;
176 params.sse2.multiplier[0] = (uint32_t) multiplier;
177 params.sse2.multiplier[1] = (uint32_t) multiplier;
178 params.sse2.multiplier[2] = (uint32_t) multiplier;
179 params.sse2.multiplier[3] = (uint32_t) multiplier;
180 params.sse2.rounding[0] = rounding;
181 params.sse2.rounding[1] = rounding;
182 params.sse2.right_shift[0] = (uint64_t) right_shift;
183 params.sse2.right_shift[1] = (uint64_t) right_shift;
184 for (uint32_t i = 0; i < 8; i++) {
185 params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
186 }
187 for (uint32_t i = 0; i < 16; i++) {
188 params.sse2.output_max[i] = output_max;
189 params.sse2.output_min[i] = output_min;
190 }
191 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
192 params.neon.bias = bias;
193 params.neon.multiplier = multiplier;
194 params.neon.left_shift = (int64_t) -shift;
195 params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
196 params.neon.output_max = output_max;
197 params.neon.output_min = output_min;
198 #else
199 const uint32_t right_shift = (uint32_t) shift;
200 const int64_t rounding = INT64_C(1) << (right_shift - 1);
201 params.scalar.bias = bias;
202 params.scalar.multiplier = multiplier;
203 params.scalar.rounding = rounding;
204 params.scalar.right_shift = right_shift;
205 params.scalar.output_min_less_zero_point =
206 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
207 params.scalar.output_max_less_zero_point =
208 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
209 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
210 #endif
211 return params;
212 }
213
xnn_init_scalar_q8_avgpool_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)214 static inline union xnn_q8_avgpool_params xnn_init_scalar_q8_avgpool_params(
215 int32_t bias,
216 float scale,
217 uint8_t output_zero_point,
218 uint8_t output_min,
219 uint8_t output_max)
220 {
221 // Compute requantization parameters.
222 assert(scale >= 0x1.0p-32f);
223 assert(scale < 256.0f);
224 const uint32_t scale_bits = fp32_to_bits(scale);
225
226 // Multiplier is in [0x00800000, 0x00FFFFFF] range.
227 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
228 assert(multiplier >= INT32_C(0x00800000));
229 assert(multiplier <= INT32_C(0x00FFFFFF));
230
231 // Shift is in [16, 55] range.
232 const int32_t shift = 127 + 23 - (scale_bits >> 23);
233 assert(shift >= 16);
234 assert(shift < 64);
235
236 union xnn_q8_avgpool_params params;
237 const uint32_t right_shift = (uint32_t) shift;
238 const int64_t rounding = INT64_C(1) << (right_shift - 1);
239 params.scalar.bias = bias;
240 params.scalar.rounding = rounding;
241 params.scalar.multiplier = multiplier;
242 params.scalar.right_shift = right_shift;
243 params.scalar.output_min_less_zero_point =
244 (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
245 params.scalar.output_max_less_zero_point =
246 (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
247 params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
248 return params;
249 }
250
xnn_update_f32_avgpool_params(union xnn_f32_avgpool_params * params,float multiplier)251 static inline void xnn_update_f32_avgpool_params(
252 union xnn_f32_avgpool_params* params,
253 float multiplier)
254 {
255 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
256 for (uint32_t i = 0; i < 4; i++) {
257 params->sse2.multiplier[i] = multiplier;
258 }
259 #else
260 params->scalar.multiplier = multiplier;
261 #endif
262 }
263
xnn_init_f32_avgpool_params(float multiplier,float output_min,float output_max)264 static inline union xnn_f32_avgpool_params xnn_init_f32_avgpool_params(
265 float multiplier,
266 float output_min,
267 float output_max)
268 {
269 union xnn_f32_avgpool_params params;
270 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
271 for (uint32_t i = 0; i < 4; i++) {
272 params.sse2.multiplier[i] = multiplier;
273 params.sse2.output_min[i] = output_min;
274 params.sse2.output_max[i] = output_max;
275 }
276 #else
277 params.scalar.multiplier = multiplier;
278 params.scalar.output_min = output_min;
279 params.scalar.output_max = output_max;
280 #endif
281 return params;
282 }
283
xnn_init_f32_gavgpool_params(float multiplier,float output_min,float output_max,uint32_t width)284 static inline union xnn_f32_gavgpool_params xnn_init_f32_gavgpool_params(
285 float multiplier,
286 float output_min,
287 float output_max,
288 uint32_t width)
289 {
290 union xnn_f32_gavgpool_params params;
291 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
292 for (uint32_t i = 0; i < 4; i++) {
293 params.sse.multiplier[i] = multiplier;
294 params.sse.output_min[i] = output_min;
295 params.sse.output_max[i] = output_max;
296 }
297
298 const uint32_t w = (width - 1) & 3;
299 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
300 params.sse.mask[1] = -(uint32_t) (w >= 1);
301 params.sse.mask[2] = -(uint32_t) (w >= 2);
302 params.sse.mask[3] = -(uint32_t) (w >= 3);
303 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
304 params.neon.multiplier = multiplier;
305 params.neon.output_min = output_min;
306 params.neon.output_max = output_max;
307
308 const uint32_t w = (width - 1) & 3;
309 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
310 params.neon.mask[1] = -(uint32_t) (w >= 1);
311 params.neon.mask[2] = -(uint32_t) (w >= 2);
312 params.neon.mask[3] = -(uint32_t) (w >= 3);
313 #else
314 params.scalar.multiplier = multiplier;
315 params.scalar.output_min = output_min;
316 params.scalar.output_max = output_max;
317 #endif
318 return params;
319 }
320
xnn_update_f32_gavgpool_params(union xnn_f32_gavgpool_params * params,float multiplier,uint32_t width)321 static inline void xnn_update_f32_gavgpool_params(
322 union xnn_f32_gavgpool_params* params,
323 float multiplier,
324 uint32_t width)
325 {
326 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
327 for (uint32_t i = 0; i < 4; i++) {
328 params->sse.multiplier[i] = multiplier;
329 }
330
331 const uint32_t w = (width - 1) & 3;
332 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
333 params->sse.mask[1] = -(uint32_t) (w >= 1);
334 params->sse.mask[2] = -(uint32_t) (w >= 2);
335 params->sse.mask[3] = -(uint32_t) (w >= 3);
336 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
337 params->neon.multiplier = multiplier;
338
339 const uint32_t w = (width - 1) & 3;
340 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
341 params->neon.mask[1] = -(uint32_t) (w >= 1);
342 params->neon.mask[2] = -(uint32_t) (w >= 2);
343 params->neon.mask[3] = -(uint32_t) (w >= 3);
344 #else
345 params->scalar.multiplier = multiplier;
346 #endif
347 }
348
xnn_init_scalar_f32_avgpool_params(float multiplier,float output_min,float output_max)349 static inline union xnn_f32_avgpool_params xnn_init_scalar_f32_avgpool_params(
350 float multiplier,
351 float output_min,
352 float output_max)
353 {
354 union xnn_f32_avgpool_params params;
355 params.scalar.multiplier = multiplier;
356 params.scalar.output_min = output_min;
357 params.scalar.output_max = output_max;
358 return params;
359 }
360
xnn_init_scalar_f32_gavgpool_params(float multiplier,float output_min,float output_max,uint32_t width)361 static inline union xnn_f32_gavgpool_params xnn_init_scalar_f32_gavgpool_params(
362 float multiplier,
363 float output_min,
364 float output_max,
365 uint32_t width)
366 {
367 union xnn_f32_gavgpool_params params;
368 params.scalar.multiplier = multiplier;
369 params.scalar.output_min = output_min;
370 params.scalar.output_max = output_max;
371 return params;
372 }
373
xnn_init_f32_output_params(float output_min,float output_max)374 static inline union xnn_f32_output_params xnn_init_f32_output_params(
375 float output_min,
376 float output_max)
377 {
378 union xnn_f32_output_params params;
379 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
380 for (uint32_t i = 0; i < 4; i++) {
381 params.sse.min[i] = output_min;
382 params.sse.max[i] = output_max;
383 }
384 #else
385 params.scalar.min = output_min;
386 params.scalar.max = output_max;
387 #endif
388 return params;
389 }
390
xnn_init_scalar_f32_output_params(float output_min,float output_max)391 static inline union xnn_f32_output_params xnn_init_scalar_f32_output_params(
392 float output_min,
393 float output_max)
394 {
395 union xnn_f32_output_params params;
396 params.scalar.min = output_min;
397 params.scalar.max = output_max;
398 return params;
399 }
400
xnn_init_f32_hswish_params(void)401 static inline union xnn_f32_hswish_params xnn_init_f32_hswish_params(void)
402 {
403 union xnn_f32_hswish_params params;
404 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
405 for (uint32_t i = 0; i < 4; i++) {
406 params.sse.sixth[i] = 0x1.555556p-3f;
407 params.sse.half[i] = 0.5f;
408 params.sse.one[i] = 1.0f;
409 }
410 #else
411 params.scalar.sixth = 0x1.555556p-3f;
412 params.scalar.half = 0.5f;
413 params.scalar.one = 1.0f;
414 #endif
415 return params;
416 }
417
xnn_init_scalar_f32_hswish_params(void)418 static inline union xnn_f32_hswish_params xnn_init_scalar_f32_hswish_params(void)
419 {
420 union xnn_f32_hswish_params params;
421 params.scalar.sixth = 0x1.555556p-3f;
422 params.scalar.half = 0.5f;
423 params.scalar.one = 1.0f;
424 return params;
425 }
426
xnn_init_f32_spchw_params(uint32_t width,float output_min,float output_max)427 static inline union xnn_f32_spchw_params xnn_init_f32_spchw_params(
428 uint32_t width,
429 float output_min,
430 float output_max)
431 {
432 union xnn_f32_spchw_params params;
433 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
434 for (uint32_t i = 0; i < 4; i++) {
435 params.sse.max[i] = output_max;
436 params.sse.min[i] = output_min;
437 }
438
439 const uint32_t w4 = (width - 1) & 3;
440 params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
441 params.sse.mask[1] = -(uint32_t) (w4 >= 1);
442 params.sse.mask[2] = -(uint32_t) (w4 >= 2);
443 params.sse.mask[3] = -(uint32_t) (w4 >= 3);
444
445 const uint32_t w8 = (width - 1) & 7;
446 params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
447 params.sse.mask_even[1] = -(uint32_t) (w8 >= 2);
448 params.sse.mask_even[2] = -(uint32_t) (w8 >= 4);
449 params.sse.mask_even[3] = -(uint32_t) (w8 >= 6);
450 params.sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
451 params.sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
452 params.sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
453 params.sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
454 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
455 params.neon.max = output_max;
456 params.neon.min = output_min;
457
458 const uint32_t w4 = (width - 1) & 3;
459 params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
460 params.neon.mask[1] = -(uint32_t) (w4 >= 1);
461 params.neon.mask[2] = -(uint32_t) (w4 >= 2);
462 params.neon.mask[3] = -(uint32_t) (w4 >= 3);
463
464 const uint32_t w8 = (width - 1) & 7;
465 params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
466 params.neon.mask_even[1] = -(uint32_t) (w8 >= 2);
467 params.neon.mask_even[2] = -(uint32_t) (w8 >= 4);
468 params.neon.mask_even[3] = -(uint32_t) (w8 >= 6);
469 params.neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
470 params.neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
471 params.neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
472 params.neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
473 #else
474 params.scalar.max = output_max;
475 params.scalar.min = output_min;
476 #endif
477 return params;
478 }
479
xnn_update_f32_spchw_params(union xnn_f32_spchw_params * params,uint32_t width)480 static inline void xnn_update_f32_spchw_params(
481 union xnn_f32_spchw_params* params,
482 uint32_t width)
483 {
484 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
485 const uint32_t w4 = (width - 1) & 3;
486 params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
487 params->sse.mask[1] = -(uint32_t) (w4 >= 1);
488 params->sse.mask[2] = -(uint32_t) (w4 >= 2);
489 params->sse.mask[3] = -(uint32_t) (w4 >= 3);
490
491 const uint32_t w8 = (width - 1) & 7;
492 params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
493 params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
494 params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
495 params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
496 params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
497 params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
498 params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
499 params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
500 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
501 const uint32_t w4 = (width - 1) & 3;
502 params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
503 params->neon.mask[1] = -(uint32_t) (w4 >= 1);
504 params->neon.mask[2] = -(uint32_t) (w4 >= 2);
505 params->neon.mask[3] = -(uint32_t) (w4 >= 3);
506
507 const uint32_t w8 = (width - 1) & 7;
508 params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
509 params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
510 params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
511 params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
512 params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
513 params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
514 params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
515 params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
516 #endif
517 }
518
xnn_init_scalar_f32_spchw_params(uint32_t width,float output_min,float output_max)519 static inline union xnn_f32_spchw_params xnn_init_scalar_f32_spchw_params(
520 uint32_t width,
521 float output_min,
522 float output_max)
523 {
524 union xnn_f32_spchw_params params;
525 params.scalar.max = output_max;
526 params.scalar.min = output_min;
527 return params;
528 }
529
xnn_init_u8_output_params(uint8_t output_min,uint8_t output_max)530 static inline union xnn_u8_output_params xnn_init_u8_output_params(
531 uint8_t output_min,
532 uint8_t output_max)
533 {
534 assert(output_min < output_max);
535
536 union xnn_u8_output_params params;
537 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
538 for (uint32_t i = 0; i < 16; i++) {
539 params.sse2.max[i] = output_max;
540 params.sse2.min[i] = output_min;
541 }
542 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
543 params.neon.max = output_max;
544 params.neon.min = output_min;
545 #else
546 params.scalar.min = (int32_t) (uint32_t) output_min;
547 params.scalar.max = (int32_t) (uint32_t) output_max;
548 #endif
549 return params;
550 }
551
xnn_init_scalar_u8_output_params(uint8_t output_min,uint8_t output_max)552 static inline union xnn_u8_output_params xnn_init_scalar_u8_output_params(
553 uint8_t output_min,
554 uint8_t output_max)
555 {
556 assert(output_min < output_max);
557
558 union xnn_u8_output_params params;
559 params.scalar.min = (int32_t) (uint32_t) output_min;
560 params.scalar.max = (int32_t) (uint32_t) output_max;
561 return params;
562 }
563
xnn_init_q8_add_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)564 static inline union xnn_q8_add_params xnn_init_q8_add_params(
565 uint8_t a_zero_point,
566 uint8_t b_zero_point,
567 uint8_t output_zero_point,
568 float a_output_scale,
569 float b_output_scale,
570 uint8_t output_min,
571 uint8_t output_max)
572 {
573 assert(a_output_scale >= 0x1.0p-14f);
574 assert(b_output_scale >= 0x1.0p-14f);
575 assert(a_output_scale < 0x1.0p+8f);
576 assert(b_output_scale < 0x1.0p+8f);
577
578 // Compute requantization parameters.
579 const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
580 assert(max_output_scale >= 0x1.0p-14f);
581 assert(max_output_scale < 0x1.0p+8f);
582 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
583 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
584 // Shift is in [13, 31] range.
585 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
586 assert(shift < 32);
587 assert(shift >= 13);
588
589 const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
590
591 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
592 const uint32_t a_multiplier = (uint32_t) (int32_t) __builtin_lrintf(a_output_scale * scale_multiplier);
593 const uint32_t b_multiplier = (uint32_t) (int32_t) __builtin_lrintf(b_output_scale * scale_multiplier);
594 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
595 assert(a_multiplier < UINT32_C(0x00400000));
596 assert(b_multiplier < UINT32_C(0x00400000));
597
598 union xnn_q8_add_params params;
599 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
600 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
601 const uint32_t remainder_threshold = remainder_mask >> 1;
602 const int32_t zero_point_product =
603 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
604 for (uint32_t i = 0; i < 4; i++) {
605 params.sse2.zero_point_product[i] = zero_point_product;
606 }
607 for (uint32_t i = 0; i < 8; i++) {
608 params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
609 }
610 for (uint32_t i = 0; i < 8; i++) {
611 params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
612 params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
613 params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
614 params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
615 }
616 params.sse2.a_multiplier = a_multiplier;
617 params.sse2.b_multiplier = b_multiplier;
618 for (uint32_t i = 0; i < 4; i++) {
619 params.sse2.remainder_mask[i] = remainder_mask;
620 params.sse2.remainder_threshold[i] = remainder_threshold;
621 }
622 params.sse2.shift = shift;
623 for (uint32_t i = 0; i < 16; i++) {
624 params.sse2.y_max[i] = output_max;
625 params.sse2.y_min[i] = output_min;
626 }
627 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
628 params.neon.a_zero_point = a_zero_point;
629 params.neon.b_zero_point = b_zero_point;
630 params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
631 params.neon.a_multiplier = (int32_t) a_multiplier;
632 params.neon.b_multiplier = (int32_t) b_multiplier;
633 params.neon.right_shift = (int32_t) -shift;
634 params.neon.y_max = output_max;
635 params.neon.y_min = output_min;
636 #else
637 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
638 const uint32_t remainder_threshold = remainder_mask >> 1;
639 params.scalar.zero_point_product =
640 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
641 params.scalar.a_multiplier = a_multiplier;
642 params.scalar.b_multiplier = b_multiplier;
643 params.scalar.remainder_mask = (int32_t) remainder_mask;
644 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
645 params.scalar.shift = shift;
646 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
647 params.scalar.y_max = (int32_t) (uint32_t) output_max;
648 params.scalar.y_min = (int32_t) (uint32_t) output_min;
649 #endif
650 return params;
651 }
652
xnn_init_scalar_q8_add_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)653 static inline union xnn_q8_add_params xnn_init_scalar_q8_add_params(
654 uint8_t a_zero_point,
655 uint8_t b_zero_point,
656 uint8_t output_zero_point,
657 float a_output_scale,
658 float b_output_scale,
659 uint8_t output_min,
660 uint8_t output_max)
661 {
662 assert(a_output_scale >= 0x1.0p-10f);
663 assert(b_output_scale >= 0x1.0p-10f);
664 assert(a_output_scale < 0x1.0p+8f);
665 assert(b_output_scale < 0x1.0p+8f);
666
667 // Compute requantization parameters.
668 const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
669 assert(max_output_scale >= 0x1.0p-10f);
670 assert(max_output_scale < 0x1.0p+8f);
671 const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
672 const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
673 // Shift is in [13, 31] range.
674 const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
675 assert(shift < 32);
676 assert(shift >= 13);
677
678 // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
679 const uint32_t a_multiplier = (uint32_t) (int32_t) __builtin_lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
680 const uint32_t b_multiplier = (uint32_t) (int32_t) __builtin_lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
681 assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
682 assert(a_multiplier < UINT32_C(0x00400000));
683 assert(b_multiplier < UINT32_C(0x00400000));
684
685 union xnn_q8_add_params params;
686 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
687 const uint32_t remainder_threshold = remainder_mask >> 1;
688 params.scalar.zero_point_product =
689 (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
690 params.scalar.a_multiplier = a_multiplier;
691 params.scalar.b_multiplier = b_multiplier;
692 params.scalar.remainder_mask = (int32_t) remainder_mask;
693 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
694 params.scalar.shift = shift;
695 params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
696 params.scalar.y_max = (int32_t) (uint32_t) output_max;
697 params.scalar.y_min = (int32_t) (uint32_t) output_min;
698 return params;
699 }
700
xnn_init_scalar_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)701 static inline union xnn_q31_requantization_params xnn_init_scalar_requantization_params(
702 float scale,
703 uint8_t zero_point,
704 uint8_t min,
705 uint8_t max)
706 {
707 // Compute requantization parameters.
708 assert(scale < 1.0f);
709 assert(scale >= 0x1.0p-32f);
710 const uint32_t scale_bits = fp32_to_bits(scale);
711
712 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
713 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
714 assert(multiplier >= INT32_C(0x40000000));
715 assert(multiplier <= INT32_C(0x7FFFFF80));
716
717 // Shift is in [0, 31] range.
718 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
719 assert(shift >= 0);
720 assert(shift < 32);
721
722 union xnn_q31_requantization_params params;
723 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
724 const uint32_t remainder_threshold = remainder_mask >> 1;
725 params.scalar.multiplier = multiplier;
726 params.scalar.remainder_mask = (int32_t) remainder_mask;
727 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
728 params.scalar.shift = (uint32_t) shift;
729 params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
730 params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
731 params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
732 return params;
733 }
734
xnn_init_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)735 static inline union xnn_q31_requantization_params xnn_init_requantization_params(
736 float scale,
737 uint8_t zero_point,
738 uint8_t min,
739 uint8_t max)
740 {
741 // Compute requantization parameters.
742 const uint32_t scale_bits = fp32_to_bits(scale);
743
744 // Multiplier is in [0x40000000, 0x7FFFFF80] range.
745 const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
746 assert(multiplier >= INT32_C(0x40000000));
747 assert(multiplier <= INT32_C(0x7FFFFF80));
748
749 // Shift is in [0, 31] range.
750 const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
751 assert(shift >= 0);
752 assert(shift < 32);
753
754 union xnn_q31_requantization_params params;
755 #if XNN_ARCH_X86 || XNN_ARCH_X86_64
756 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
757 const uint32_t remainder_threshold = remainder_mask >> 1;
758 params.sse2.multiplier[0] = multiplier;
759 params.sse2.multiplier[1] = multiplier;
760 params.sse2.multiplier[2] = multiplier;
761 params.sse2.multiplier[3] = multiplier;
762 params.sse2.rounding[0] = UINT64_C(0x40000000);
763 params.sse2.rounding[1] = UINT64_C(0x40000000);
764 params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
765 params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
766 params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
767 params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
768 params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
769 params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
770 params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
771 params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
772 params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
773 params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
774 for (uint32_t i = 0; i < 8; i++) {
775 params.sse2.zero_point[i] = (int16_t) (uint16_t) zero_point;
776 }
777 for (uint32_t i = 0; i < 16; i++) {
778 params.sse2.max[i] = max;
779 params.sse2.min[i] = min;
780 }
781 #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
782 params.neon.multiplier = multiplier;
783 params.neon.right_shift = -shift;
784 params.neon.zero_point = (int16_t) (uint16_t) zero_point;
785 params.neon.max = max;
786 params.neon.min = min;
787 #else
788 const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
789 const uint32_t remainder_threshold = remainder_mask >> 1;
790 params.scalar.multiplier = multiplier;
791 params.scalar.remainder_mask = (int32_t) remainder_mask;
792 params.scalar.remainder_threshold = (int32_t) remainder_threshold;
793 params.scalar.shift = (uint32_t) shift;
794 params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
795 params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
796 params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
797 #endif
798 return params;
799 }
800