• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #if defined(__cplusplus) && (__cplusplus >= 201103L)
12   #include <cstdint>
13   #include <cstddef>
14   #include <cassert>
15   #include <cmath>
16 #else
17   #include <stdint.h>
18   #include <stddef.h>
19   #include <assert.h>
20   #include <math.h>
21 #endif
22 
23 #include <fp16.h>
24 
25 #include <xnnpack/common.h>
26 #include <xnnpack/params.h>
27 
28 
xnn_init_scalar_q8_gemm_params(uint8_t input_zero_point,uint8_t kernel_zero_point,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)29 static inline union xnn_q8_gemm_params xnn_init_scalar_q8_gemm_params(
30   uint8_t input_zero_point,
31   uint8_t kernel_zero_point,
32   float scale,
33   uint8_t output_zero_point,
34   uint8_t output_min,
35   uint8_t output_max)
36 {
37   // Compute requantization parameters
38   const uint32_t scale_bits = fp32_to_bits(scale);
39 
40   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
41   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
42   assert(multiplier >= INT32_C(0x40000000));
43   assert(multiplier <= INT32_C(0x7FFFFF80));
44 
45   // Shift is in [0, 31] range.
46   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
47   assert(shift >= 0);
48   assert(shift < 32);
49 
50   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
51   const uint32_t remainder_threshold = remainder_mask >> 1;
52 
53   union xnn_q8_gemm_params params;
54   params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
55   params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
56   params.scalar.multiplier = multiplier;
57   params.scalar.remainder_mask = (int32_t) remainder_mask;
58   params.scalar.remainder_threshold = (int32_t) remainder_threshold;
59   params.scalar.shift = (uint32_t) shift;
60   params.scalar.output_min_less_zero_point =
61     (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
62   params.scalar.output_max_less_zero_point =
63     (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
64   params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
65   return params;
66 }
67 
xnn_init_q8_gemm_params(uint8_t input_zero_point,uint8_t kernel_zero_point,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)68 static inline union xnn_q8_gemm_params xnn_init_q8_gemm_params(
69   uint8_t input_zero_point,
70   uint8_t kernel_zero_point,
71   float scale,
72   uint8_t output_zero_point,
73   uint8_t output_min,
74   uint8_t output_max)
75 {
76   // Compute requantization parameters.
77   const uint32_t scale_bits = fp32_to_bits(scale);
78 
79   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
80   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
81   assert(multiplier >= INT32_C(0x40000000));
82   assert(multiplier <= INT32_C(0x7FFFFF80));
83 
84   // Shift is in [0, 31] range.
85   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
86   assert(shift >= 0);
87   assert(shift < 32);
88 
89   union xnn_q8_gemm_params params;
90   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
91     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
92     const uint32_t remainder_threshold = remainder_mask >> 1;
93     for (uint32_t i = 0; i < 8; i++) {
94       params.sse2.input_zero_point[i] = (int16_t) (uint16_t) input_zero_point;
95       params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
96     }
97     params.sse2.multiplier[0] = multiplier;
98     params.sse2.multiplier[1] = multiplier;
99     params.sse2.multiplier[2] = multiplier;
100     params.sse2.multiplier[3] = multiplier;
101     params.sse2.rounding[0] = UINT64_C(0x40000000);
102     params.sse2.rounding[1] = UINT64_C(0x40000000);
103     params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
104     params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
105     params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
106     params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
107     params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
108     params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
109     params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
110     params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
111     params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
112     params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
113     for (uint32_t i = 0; i < 8; i++) {
114       params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
115     }
116     for (uint32_t i = 0; i < 16; i++) {
117       params.sse2.output_max[i] = output_max;
118       params.sse2.output_min[i] = output_min;
119     }
120   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
121     params.neon.input_zero_point = (int16_t) (uint16_t) input_zero_point;
122     params.neon.kernel_zero_point = (int16_t) (uint16_t) kernel_zero_point;
123     params.neon.multiplier = multiplier;
124     params.neon.right_shift = -shift;
125     params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
126     params.neon.output_max = output_max;
127     params.neon.output_min = output_min;
128   #else
129     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
130     const uint32_t remainder_threshold = remainder_mask >> 1;
131     params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point;
132     params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point;
133     params.scalar.multiplier = multiplier;
134     params.scalar.remainder_mask = (int32_t) remainder_mask;
135     params.scalar.remainder_threshold = (int32_t) remainder_threshold;
136     params.scalar.shift = (uint32_t) shift;
137     params.scalar.output_min_less_zero_point =
138       (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
139     params.scalar.output_max_less_zero_point =
140       (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
141     params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
142   #endif
143   return params;
144 }
145 
xnn_init_q8_avgpool_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)146 static inline union xnn_q8_avgpool_params xnn_init_q8_avgpool_params(
147   int32_t bias,
148   float scale,
149   uint8_t output_zero_point,
150   uint8_t output_min,
151   uint8_t output_max)
152 {
153   // Compute requantization parameters.
154   assert(scale >= 0x1.0p-32f);
155   assert(scale < 256.0f);
156   const uint32_t scale_bits = fp32_to_bits(scale);
157 
158   // Multiplier is in [0x00800000, 0x00FFFFFF] range.
159   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
160   assert(multiplier >= INT32_C(0x00800000));
161   assert(multiplier <= INT32_C(0x00FFFFFF));
162 
163   // Shift is in [16, 55] range.
164   const int32_t shift = 127 + 23 - (scale_bits >> 23);
165   assert(shift >= 16);
166   assert(shift < 64);
167 
168   union xnn_q8_avgpool_params params;
169   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
170     const uint32_t right_shift = (uint32_t) shift;
171     const uint64_t rounding = UINT64_C(1) << (right_shift - 1);
172     params.sse2.bias[0] = bias;
173     params.sse2.bias[1] = bias;
174     params.sse2.bias[2] = bias;
175     params.sse2.bias[3] = bias;
176     params.sse2.multiplier[0] = (uint32_t) multiplier;
177     params.sse2.multiplier[1] = (uint32_t) multiplier;
178     params.sse2.multiplier[2] = (uint32_t) multiplier;
179     params.sse2.multiplier[3] = (uint32_t) multiplier;
180     params.sse2.rounding[0] = rounding;
181     params.sse2.rounding[1] = rounding;
182     params.sse2.right_shift[0] = (uint64_t) right_shift;
183     params.sse2.right_shift[1] = (uint64_t) right_shift;
184     for (uint32_t i = 0; i < 8; i++) {
185       params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
186     }
187     for (uint32_t i = 0; i < 16; i++) {
188       params.sse2.output_max[i] = output_max;
189       params.sse2.output_min[i] = output_min;
190     }
191   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
192     params.neon.bias = bias;
193     params.neon.multiplier = multiplier;
194     params.neon.left_shift = (int64_t) -shift;
195     params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point;
196     params.neon.output_max = output_max;
197     params.neon.output_min = output_min;
198   #else
199     const uint32_t right_shift = (uint32_t) shift;
200     const int64_t rounding = INT64_C(1) << (right_shift - 1);
201     params.scalar.bias = bias;
202     params.scalar.multiplier = multiplier;
203     params.scalar.rounding = rounding;
204     params.scalar.right_shift = right_shift;
205     params.scalar.output_min_less_zero_point =
206       (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
207     params.scalar.output_max_less_zero_point =
208       (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
209     params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
210   #endif
211   return params;
212 }
213 
xnn_init_scalar_q8_avgpool_params(int32_t bias,float scale,uint8_t output_zero_point,uint8_t output_min,uint8_t output_max)214 static inline union xnn_q8_avgpool_params xnn_init_scalar_q8_avgpool_params(
215   int32_t bias,
216   float scale,
217   uint8_t output_zero_point,
218   uint8_t output_min,
219   uint8_t output_max)
220 {
221   // Compute requantization parameters.
222   assert(scale >= 0x1.0p-32f);
223   assert(scale < 256.0f);
224   const uint32_t scale_bits = fp32_to_bits(scale);
225 
226   // Multiplier is in [0x00800000, 0x00FFFFFF] range.
227   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
228   assert(multiplier >= INT32_C(0x00800000));
229   assert(multiplier <= INT32_C(0x00FFFFFF));
230 
231   // Shift is in [16, 55] range.
232   const int32_t shift = 127 + 23 - (scale_bits >> 23);
233   assert(shift >= 16);
234   assert(shift < 64);
235 
236   union xnn_q8_avgpool_params params;
237   const uint32_t right_shift = (uint32_t) shift;
238   const int64_t rounding = INT64_C(1) << (right_shift - 1);
239   params.scalar.bias = bias;
240   params.scalar.rounding = rounding;
241   params.scalar.multiplier = multiplier;
242   params.scalar.right_shift = right_shift;
243   params.scalar.output_min_less_zero_point =
244     (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point;
245   params.scalar.output_max_less_zero_point =
246     (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point;
247   params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point;
248   return params;
249 }
250 
xnn_update_f32_avgpool_params(union xnn_f32_avgpool_params * params,float multiplier)251 static inline void xnn_update_f32_avgpool_params(
252   union xnn_f32_avgpool_params* params,
253   float multiplier)
254 {
255   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
256     for (uint32_t i = 0; i < 4; i++) {
257       params->sse2.multiplier[i] = multiplier;
258     }
259   #else
260     params->scalar.multiplier = multiplier;
261   #endif
262 }
263 
xnn_init_f32_avgpool_params(float multiplier,float output_min,float output_max)264 static inline union xnn_f32_avgpool_params xnn_init_f32_avgpool_params(
265   float multiplier,
266   float output_min,
267   float output_max)
268 {
269   union xnn_f32_avgpool_params params;
270   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
271     for (uint32_t i = 0; i < 4; i++) {
272       params.sse2.multiplier[i] = multiplier;
273       params.sse2.output_min[i] = output_min;
274       params.sse2.output_max[i] = output_max;
275     }
276   #else
277     params.scalar.multiplier = multiplier;
278     params.scalar.output_min = output_min;
279     params.scalar.output_max = output_max;
280   #endif
281   return params;
282 }
283 
xnn_init_f32_gavgpool_params(float multiplier,float output_min,float output_max,uint32_t width)284 static inline union xnn_f32_gavgpool_params xnn_init_f32_gavgpool_params(
285   float multiplier,
286   float output_min,
287   float output_max,
288   uint32_t width)
289 {
290   union xnn_f32_gavgpool_params params;
291   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
292     for (uint32_t i = 0; i < 4; i++) {
293       params.sse.multiplier[i] = multiplier;
294       params.sse.output_min[i] = output_min;
295       params.sse.output_max[i] = output_max;
296     }
297 
298     const uint32_t w = (width - 1) & 3;
299     params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
300     params.sse.mask[1] = -(uint32_t) (w >= 1);
301     params.sse.mask[2] = -(uint32_t) (w >= 2);
302     params.sse.mask[3] = -(uint32_t) (w >= 3);
303   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
304     params.neon.multiplier = multiplier;
305     params.neon.output_min = output_min;
306     params.neon.output_max = output_max;
307 
308     const uint32_t w = (width - 1) & 3;
309     params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
310     params.neon.mask[1] = -(uint32_t) (w >= 1);
311     params.neon.mask[2] = -(uint32_t) (w >= 2);
312     params.neon.mask[3] = -(uint32_t) (w >= 3);
313   #else
314     params.scalar.multiplier = multiplier;
315     params.scalar.output_min = output_min;
316     params.scalar.output_max = output_max;
317   #endif
318   return params;
319 }
320 
xnn_update_f32_gavgpool_params(union xnn_f32_gavgpool_params * params,float multiplier,uint32_t width)321 static inline void xnn_update_f32_gavgpool_params(
322   union xnn_f32_gavgpool_params* params,
323   float multiplier,
324   uint32_t width)
325 {
326   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
327     for (uint32_t i = 0; i < 4; i++) {
328       params->sse.multiplier[i] = multiplier;
329     }
330 
331     const uint32_t w = (width - 1) & 3;
332     params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
333     params->sse.mask[1] = -(uint32_t) (w >= 1);
334     params->sse.mask[2] = -(uint32_t) (w >= 2);
335     params->sse.mask[3] = -(uint32_t) (w >= 3);
336   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
337     params->neon.multiplier = multiplier;
338 
339     const uint32_t w = (width - 1) & 3;
340     params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
341     params->neon.mask[1] = -(uint32_t) (w >= 1);
342     params->neon.mask[2] = -(uint32_t) (w >= 2);
343     params->neon.mask[3] = -(uint32_t) (w >= 3);
344   #else
345     params->scalar.multiplier = multiplier;
346   #endif
347 }
348 
xnn_init_scalar_f32_avgpool_params(float multiplier,float output_min,float output_max)349 static inline union xnn_f32_avgpool_params xnn_init_scalar_f32_avgpool_params(
350   float multiplier,
351   float output_min,
352   float output_max)
353 {
354   union xnn_f32_avgpool_params params;
355   params.scalar.multiplier = multiplier;
356   params.scalar.output_min = output_min;
357   params.scalar.output_max = output_max;
358   return params;
359 }
360 
xnn_init_scalar_f32_gavgpool_params(float multiplier,float output_min,float output_max,uint32_t width)361 static inline union xnn_f32_gavgpool_params xnn_init_scalar_f32_gavgpool_params(
362   float multiplier,
363   float output_min,
364   float output_max,
365   uint32_t width)
366 {
367   union xnn_f32_gavgpool_params params;
368   params.scalar.multiplier = multiplier;
369   params.scalar.output_min = output_min;
370   params.scalar.output_max = output_max;
371   return params;
372 }
373 
xnn_init_f32_output_params(float output_min,float output_max)374 static inline union xnn_f32_output_params xnn_init_f32_output_params(
375   float output_min,
376   float output_max)
377 {
378   union xnn_f32_output_params params;
379   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
380     for (uint32_t i = 0; i < 4; i++) {
381       params.sse.min[i] = output_min;
382       params.sse.max[i] = output_max;
383     }
384   #else
385     params.scalar.min = output_min;
386     params.scalar.max = output_max;
387   #endif
388   return params;
389 }
390 
xnn_init_scalar_f32_output_params(float output_min,float output_max)391 static inline union xnn_f32_output_params xnn_init_scalar_f32_output_params(
392   float output_min,
393   float output_max)
394 {
395   union xnn_f32_output_params params;
396   params.scalar.min = output_min;
397   params.scalar.max = output_max;
398   return params;
399 }
400 
xnn_init_f32_hswish_params(void)401 static inline union xnn_f32_hswish_params xnn_init_f32_hswish_params(void)
402 {
403   union xnn_f32_hswish_params params;
404   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
405     for (uint32_t i = 0; i < 4; i++) {
406       params.sse.sixth[i] = 0x1.555556p-3f;
407       params.sse.half[i] = 0.5f;
408       params.sse.one[i] = 1.0f;
409     }
410   #else
411     params.scalar.sixth = 0x1.555556p-3f;
412     params.scalar.half = 0.5f;
413     params.scalar.one = 1.0f;
414   #endif
415   return params;
416 }
417 
xnn_init_scalar_f32_hswish_params(void)418 static inline union xnn_f32_hswish_params xnn_init_scalar_f32_hswish_params(void)
419 {
420   union xnn_f32_hswish_params params;
421   params.scalar.sixth = 0x1.555556p-3f;
422   params.scalar.half = 0.5f;
423   params.scalar.one = 1.0f;
424   return params;
425 }
426 
xnn_init_f32_spchw_params(uint32_t width,float output_min,float output_max)427 static inline union xnn_f32_spchw_params xnn_init_f32_spchw_params(
428   uint32_t width,
429   float output_min,
430   float output_max)
431 {
432   union xnn_f32_spchw_params params;
433   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
434     for (uint32_t i = 0; i < 4; i++) {
435       params.sse.max[i] = output_max;
436       params.sse.min[i] = output_min;
437     }
438 
439     const uint32_t w4 = (width - 1) & 3;
440     params.sse.mask[0] = UINT32_C(0xFFFFFFFF);
441     params.sse.mask[1] = -(uint32_t) (w4 >= 1);
442     params.sse.mask[2] = -(uint32_t) (w4 >= 2);
443     params.sse.mask[3] = -(uint32_t) (w4 >= 3);
444 
445     const uint32_t w8 = (width - 1) & 7;
446     params.sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
447     params.sse.mask_even[1] = -(uint32_t) (w8 >= 2);
448     params.sse.mask_even[2] = -(uint32_t) (w8 >= 4);
449     params.sse.mask_even[3] = -(uint32_t) (w8 >= 6);
450     params.sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
451     params.sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
452     params.sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
453     params.sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
454   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
455     params.neon.max = output_max;
456     params.neon.min = output_min;
457 
458     const uint32_t w4 = (width - 1) & 3;
459     params.neon.mask[0] = UINT32_C(0xFFFFFFFF);
460     params.neon.mask[1] = -(uint32_t) (w4 >= 1);
461     params.neon.mask[2] = -(uint32_t) (w4 >= 2);
462     params.neon.mask[3] = -(uint32_t) (w4 >= 3);
463 
464     const uint32_t w8 = (width - 1) & 7;
465     params.neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
466     params.neon.mask_even[1] = -(uint32_t) (w8 >= 2);
467     params.neon.mask_even[2] = -(uint32_t) (w8 >= 4);
468     params.neon.mask_even[3] = -(uint32_t) (w8 >= 6);
469     params.neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
470     params.neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
471     params.neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
472     params.neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
473   #else
474     params.scalar.max = output_max;
475     params.scalar.min = output_min;
476   #endif
477   return params;
478 }
479 
xnn_update_f32_spchw_params(union xnn_f32_spchw_params * params,uint32_t width)480 static inline void xnn_update_f32_spchw_params(
481   union xnn_f32_spchw_params* params,
482   uint32_t width)
483 {
484   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
485     const uint32_t w4 = (width - 1) & 3;
486     params->sse.mask[0] = UINT32_C(0xFFFFFFFF);
487     params->sse.mask[1] = -(uint32_t) (w4 >= 1);
488     params->sse.mask[2] = -(uint32_t) (w4 >= 2);
489     params->sse.mask[3] = -(uint32_t) (w4 >= 3);
490 
491     const uint32_t w8 = (width - 1) & 7;
492     params->sse.mask_even[0] = UINT32_C(0xFFFFFFFF);
493     params->sse.mask_even[1] = -(uint32_t) (w8 >= 2);
494     params->sse.mask_even[2] = -(uint32_t) (w8 >= 4);
495     params->sse.mask_even[3] = -(uint32_t) (w8 >= 6);
496     params->sse.mask_odd[0] = -(uint32_t) (w8 >= 1);
497     params->sse.mask_odd[1] = -(uint32_t) (w8 >= 3);
498     params->sse.mask_odd[2] = -(uint32_t) (w8 >= 5);
499     params->sse.mask_odd[3] = -(uint32_t) (w8 >= 7);
500   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
501     const uint32_t w4 = (width - 1) & 3;
502     params->neon.mask[0] = UINT32_C(0xFFFFFFFF);
503     params->neon.mask[1] = -(uint32_t) (w4 >= 1);
504     params->neon.mask[2] = -(uint32_t) (w4 >= 2);
505     params->neon.mask[3] = -(uint32_t) (w4 >= 3);
506 
507     const uint32_t w8 = (width - 1) & 7;
508     params->neon.mask_even[0] = UINT32_C(0xFFFFFFFF);
509     params->neon.mask_even[1] = -(uint32_t) (w8 >= 2);
510     params->neon.mask_even[2] = -(uint32_t) (w8 >= 4);
511     params->neon.mask_even[3] = -(uint32_t) (w8 >= 6);
512     params->neon.mask_odd[0] = -(uint32_t) (w8 >= 1);
513     params->neon.mask_odd[1] = -(uint32_t) (w8 >= 3);
514     params->neon.mask_odd[2] = -(uint32_t) (w8 >= 5);
515     params->neon.mask_odd[3] = -(uint32_t) (w8 >= 7);
516   #endif
517 }
518 
xnn_init_scalar_f32_spchw_params(uint32_t width,float output_min,float output_max)519 static inline union xnn_f32_spchw_params xnn_init_scalar_f32_spchw_params(
520   uint32_t width,
521   float output_min,
522   float output_max)
523 {
524   union xnn_f32_spchw_params params;
525   params.scalar.max = output_max;
526   params.scalar.min = output_min;
527   return params;
528 }
529 
xnn_init_u8_output_params(uint8_t output_min,uint8_t output_max)530 static inline union xnn_u8_output_params xnn_init_u8_output_params(
531   uint8_t output_min,
532   uint8_t output_max)
533 {
534   assert(output_min < output_max);
535 
536   union xnn_u8_output_params params;
537   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
538     for (uint32_t i = 0; i < 16; i++) {
539       params.sse2.max[i] = output_max;
540       params.sse2.min[i] = output_min;
541     }
542   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
543     params.neon.max = output_max;
544     params.neon.min = output_min;
545   #else
546     params.scalar.min = (int32_t) (uint32_t) output_min;
547     params.scalar.max = (int32_t) (uint32_t) output_max;
548   #endif
549   return params;
550 }
551 
xnn_init_scalar_u8_output_params(uint8_t output_min,uint8_t output_max)552 static inline union xnn_u8_output_params xnn_init_scalar_u8_output_params(
553   uint8_t output_min,
554   uint8_t output_max)
555 {
556   assert(output_min < output_max);
557 
558   union xnn_u8_output_params params;
559   params.scalar.min = (int32_t) (uint32_t) output_min;
560   params.scalar.max = (int32_t) (uint32_t) output_max;
561   return params;
562 }
563 
xnn_init_q8_add_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)564 static inline union xnn_q8_add_params xnn_init_q8_add_params(
565   uint8_t a_zero_point,
566   uint8_t b_zero_point,
567   uint8_t output_zero_point,
568   float a_output_scale,
569   float b_output_scale,
570   uint8_t output_min,
571   uint8_t output_max)
572 {
573   assert(a_output_scale >= 0x1.0p-14f);
574   assert(b_output_scale >= 0x1.0p-14f);
575   assert(a_output_scale < 0x1.0p+8f);
576   assert(b_output_scale < 0x1.0p+8f);
577 
578   // Compute requantization parameters.
579   const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
580   assert(max_output_scale >= 0x1.0p-14f);
581   assert(max_output_scale < 0x1.0p+8f);
582   const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
583   const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
584   // Shift is in [13, 31] range.
585   const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
586   assert(shift < 32);
587   assert(shift >= 13);
588 
589   const float scale_multiplier = fp32_from_bits((uint32_t) (21 - max_scale_exponent + 127) << 23);
590 
591   // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
592   const uint32_t a_multiplier = (uint32_t) (int32_t) __builtin_lrintf(a_output_scale * scale_multiplier);
593   const uint32_t b_multiplier = (uint32_t) (int32_t) __builtin_lrintf(b_output_scale * scale_multiplier);
594   assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
595   assert(a_multiplier < UINT32_C(0x00400000));
596   assert(b_multiplier < UINT32_C(0x00400000));
597 
598   union xnn_q8_add_params params;
599   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
600     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
601     const uint32_t remainder_threshold = remainder_mask >> 1;
602     const int32_t zero_point_product =
603       (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
604     for (uint32_t i = 0; i < 4; i++) {
605       params.sse2.zero_point_product[i] = zero_point_product;
606     }
607     for (uint32_t i = 0; i < 8; i++) {
608       params.sse2.y_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
609     }
610     for (uint32_t i = 0; i < 8; i++) {
611       params.sse2.a_multiplier_lo[i] = (uint16_t) (uint32_t) a_multiplier;
612       params.sse2.a_multiplier_hi[i] = (uint16_t) ((uint32_t) a_multiplier >> 16);
613       params.sse2.b_multiplier_lo[i] = (uint16_t) (uint32_t) b_multiplier;
614       params.sse2.b_multiplier_hi[i] = (uint16_t) ((uint32_t) b_multiplier >> 16);
615     }
616     params.sse2.a_multiplier = a_multiplier;
617     params.sse2.b_multiplier = b_multiplier;
618     for (uint32_t i = 0; i < 4; i++) {
619       params.sse2.remainder_mask[i] = remainder_mask;
620       params.sse2.remainder_threshold[i] = remainder_threshold;
621     }
622     params.sse2.shift = shift;
623     for (uint32_t i = 0; i < 16; i++) {
624       params.sse2.y_max[i] = output_max;
625       params.sse2.y_min[i] = output_min;
626     }
627   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
628     params.neon.a_zero_point = a_zero_point;
629     params.neon.b_zero_point = b_zero_point;
630     params.neon.y_zero_point = (int16_t) (uint16_t) output_zero_point;
631     params.neon.a_multiplier = (int32_t) a_multiplier;
632     params.neon.b_multiplier = (int32_t) b_multiplier;
633     params.neon.right_shift = (int32_t) -shift;
634     params.neon.y_max = output_max;
635     params.neon.y_min = output_min;
636   #else
637     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
638     const uint32_t remainder_threshold = remainder_mask >> 1;
639     params.scalar.zero_point_product =
640       (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
641     params.scalar.a_multiplier = a_multiplier;
642     params.scalar.b_multiplier = b_multiplier;
643     params.scalar.remainder_mask = (int32_t) remainder_mask;
644     params.scalar.remainder_threshold = (int32_t) remainder_threshold;
645     params.scalar.shift = shift;
646     params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
647     params.scalar.y_max = (int32_t) (uint32_t) output_max;
648     params.scalar.y_min = (int32_t) (uint32_t) output_min;
649   #endif
650   return params;
651 }
652 
xnn_init_scalar_q8_add_params(uint8_t a_zero_point,uint8_t b_zero_point,uint8_t output_zero_point,float a_output_scale,float b_output_scale,uint8_t output_min,uint8_t output_max)653 static inline union xnn_q8_add_params xnn_init_scalar_q8_add_params(
654   uint8_t a_zero_point,
655   uint8_t b_zero_point,
656   uint8_t output_zero_point,
657   float a_output_scale,
658   float b_output_scale,
659   uint8_t output_min,
660   uint8_t output_max)
661 {
662   assert(a_output_scale >= 0x1.0p-10f);
663   assert(b_output_scale >= 0x1.0p-10f);
664   assert(a_output_scale < 0x1.0p+8f);
665   assert(b_output_scale < 0x1.0p+8f);
666 
667   // Compute requantization parameters.
668   const float max_output_scale = a_output_scale > b_output_scale ? a_output_scale : b_output_scale;
669   assert(max_output_scale >= 0x1.0p-10f);
670   assert(max_output_scale < 0x1.0p+8f);
671   const uint32_t max_scale_bits = fp32_to_bits(max_output_scale);
672   const int32_t max_scale_exponent = (int32_t) (max_scale_bits >> 23) - 127;
673   // Shift is in [13, 31] range.
674   const uint32_t shift = (uint32_t) (21 - max_scale_exponent);
675   assert(shift < 32);
676   assert(shift >= 13);
677 
678   // Multipliers are in [0, 2**22) range, largest multiplier is in [2**21, 2**22) range.
679   const uint32_t a_multiplier = (uint32_t) (int32_t) __builtin_lrintf(fp32_from_bits(fp32_to_bits(a_output_scale) + (shift << 23)));
680   const uint32_t b_multiplier = (uint32_t) (int32_t) __builtin_lrintf(fp32_from_bits(fp32_to_bits(b_output_scale) + (shift << 23)));
681   assert((a_multiplier > b_multiplier ? a_multiplier : b_multiplier) >= UINT32_C(0x00200000));
682   assert(a_multiplier < UINT32_C(0x00400000));
683   assert(b_multiplier < UINT32_C(0x00400000));
684 
685   union xnn_q8_add_params params;
686   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
687   const uint32_t remainder_threshold = remainder_mask >> 1;
688   params.scalar.zero_point_product =
689     (int32_t) -(a_multiplier * (uint32_t) a_zero_point + b_multiplier * (uint32_t) b_zero_point);
690   params.scalar.a_multiplier = a_multiplier;
691   params.scalar.b_multiplier = b_multiplier;
692   params.scalar.remainder_mask = (int32_t) remainder_mask;
693   params.scalar.remainder_threshold = (int32_t) remainder_threshold;
694   params.scalar.shift = shift;
695   params.scalar.y_zero_point = (int32_t) (uint32_t) output_zero_point;
696   params.scalar.y_max = (int32_t) (uint32_t) output_max;
697   params.scalar.y_min = (int32_t) (uint32_t) output_min;
698   return params;
699 }
700 
xnn_init_scalar_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)701 static inline union xnn_q31_requantization_params xnn_init_scalar_requantization_params(
702   float scale,
703   uint8_t zero_point,
704   uint8_t min,
705   uint8_t max)
706 {
707   // Compute requantization parameters.
708   assert(scale < 1.0f);
709   assert(scale >= 0x1.0p-32f);
710   const uint32_t scale_bits = fp32_to_bits(scale);
711 
712   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
713   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
714   assert(multiplier >= INT32_C(0x40000000));
715   assert(multiplier <= INT32_C(0x7FFFFF80));
716 
717   // Shift is in [0, 31] range.
718   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
719   assert(shift >= 0);
720   assert(shift < 32);
721 
722   union xnn_q31_requantization_params params;
723   const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
724   const uint32_t remainder_threshold = remainder_mask >> 1;
725   params.scalar.multiplier = multiplier;
726   params.scalar.remainder_mask = (int32_t) remainder_mask;
727   params.scalar.remainder_threshold = (int32_t) remainder_threshold;
728   params.scalar.shift = (uint32_t) shift;
729   params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
730   params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
731   params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
732   return params;
733 }
734 
xnn_init_requantization_params(float scale,uint8_t zero_point,uint8_t min,uint8_t max)735 static inline union xnn_q31_requantization_params xnn_init_requantization_params(
736   float scale,
737   uint8_t zero_point,
738   uint8_t min,
739   uint8_t max)
740 {
741   // Compute requantization parameters.
742   const uint32_t scale_bits = fp32_to_bits(scale);
743 
744   // Multiplier is in [0x40000000, 0x7FFFFF80] range.
745   const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7);
746   assert(multiplier >= INT32_C(0x40000000));
747   assert(multiplier <= INT32_C(0x7FFFFF80));
748 
749   // Shift is in [0, 31] range.
750   const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23);
751   assert(shift >= 0);
752   assert(shift < 32);
753 
754   union xnn_q31_requantization_params params;
755   #if XNN_ARCH_X86 || XNN_ARCH_X86_64
756     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
757     const uint32_t remainder_threshold = remainder_mask >> 1;
758     params.sse2.multiplier[0] = multiplier;
759     params.sse2.multiplier[1] = multiplier;
760     params.sse2.multiplier[2] = multiplier;
761     params.sse2.multiplier[3] = multiplier;
762     params.sse2.rounding[0] = UINT64_C(0x40000000);
763     params.sse2.rounding[1] = UINT64_C(0x40000000);
764     params.sse2.remainder_mask[0] = (int32_t) remainder_mask;
765     params.sse2.remainder_mask[1] = (int32_t) remainder_mask;
766     params.sse2.remainder_mask[2] = (int32_t) remainder_mask;
767     params.sse2.remainder_mask[3] = (int32_t) remainder_mask;
768     params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold;
769     params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold;
770     params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold;
771     params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold;
772     params.sse2.shift[0] = (uint64_t) (uint32_t) shift;
773     params.sse2.shift[1] = (uint64_t) (uint32_t) shift;
774     for (uint32_t i = 0; i < 8; i++) {
775       params.sse2.zero_point[i] = (int16_t) (uint16_t) zero_point;
776     }
777     for (uint32_t i = 0; i < 16; i++) {
778       params.sse2.max[i] = max;
779       params.sse2.min[i] = min;
780     }
781   #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
782     params.neon.multiplier = multiplier;
783     params.neon.right_shift = -shift;
784     params.neon.zero_point = (int16_t) (uint16_t) zero_point;
785     params.neon.max = max;
786     params.neon.min = min;
787   #else
788     const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1);
789     const uint32_t remainder_threshold = remainder_mask >> 1;
790     params.scalar.multiplier = multiplier;
791     params.scalar.remainder_mask = (int32_t) remainder_mask;
792     params.scalar.remainder_threshold = (int32_t) remainder_threshold;
793     params.scalar.shift = (uint32_t) shift;
794     params.scalar.min_less_zero_point = (int32_t) (uint32_t) min - (int32_t) (uint32_t) zero_point;
795     params.scalar.max_less_zero_point = (int32_t) (uint32_t) max - (int32_t) (uint32_t) zero_point;
796     params.scalar.zero_point = (int32_t) (uint32_t) zero_point;
797   #endif
798   return params;
799 }
800