• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stdint.h>
12 #include <stddef.h>
13 #include <assert.h>
14 #include <math.h>
15 
16 #include <xnnpack/common.h>
17 #include <xnnpack/math.h>
18 #include <xnnpack/microparams.h>
19 
20 
21 typedef int8_t (*xnn_qs8_requantize_fn)(
22   int32_t input,
23   float scale,
24   int8_t output_zero_point,
25   int8_t output_min,
26   int8_t output_max);
27 
28 typedef uint8_t (*xnn_qu8_requantize_fn)(
29   int32_t input,
30   float scale,
31   uint8_t output_zero_point,
32   uint8_t output_min,
33   uint8_t output_max);
34 
xnn_qs8_requantize_fp32(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)35 static inline int8_t xnn_qs8_requantize_fp32(
36   int32_t input,
37   float scale,
38   int8_t zero_point,
39   int8_t min,
40   int8_t max)
41 {
42   assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
43   assert(scale < 256.0f);
44 
45   const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
46   const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
47 
48   float scaled_input = (float) input * scale;
49   scaled_input = math_max_f32(scaled_input, min_less_zero_point);
50   scaled_input = math_min_f32(scaled_input, max_less_zero_point);
51 
52   const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
53   return (int8_t) output;
54 }
55 
xnn_qu8_requantize_fp32(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)56 static inline uint8_t xnn_qu8_requantize_fp32(
57   int32_t input,
58   float scale,
59   uint8_t zero_point,
60   uint8_t min,
61   uint8_t max)
62 {
63   assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
64   assert(scale < 256.0f);
65 
66   const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
67   const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
68 
69   float scaled_input = (float) input * scale;
70   scaled_input = math_max_f32(scaled_input, min_less_zero_point);
71   scaled_input = math_min_f32(scaled_input, max_less_zero_point);
72 
73   const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
74   return (uint8_t) output;
75 }
76 
xnn_qs8_requantize_rndna(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)77 static inline int8_t xnn_qs8_requantize_rndna(
78   int32_t input,
79   float scale,
80   int8_t zero_point,
81   int8_t min,
82   int8_t max)
83 {
84   assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
85   assert(scale < 256.0f);
86 
87   const uint32_t scale_bits = float_as_uint32(scale);
88   const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
89   const uint32_t shift = 127 + 23 - (scale_bits >> 23);
90   assert(shift >= 16);
91   assert(shift < 56);
92 
93   const uint64_t rounding = UINT64_C(1) << (shift - 1);
94   const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
95   const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
96 
97   uint32_t abs_input = (uint32_t) input;
98   if (input < 0) {
99     abs_input = -abs_input;
100   }
101 
102   const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
103   const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
104 
105   int32_t output = (int32_t) abs_scaled_input;
106   if (input < 0) {
107     output = -output;
108   }
109 
110   output = math_max_s32(output, min_less_zero_point);
111   output = math_min_s32(output, max_less_zero_point);
112   return (int8_t) (output + (int32_t) zero_point);
113 }
114 
xnn_qu8_requantize_rndna(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)115 static inline uint8_t xnn_qu8_requantize_rndna(
116   int32_t input,
117   float scale,
118   uint8_t zero_point,
119   uint8_t min,
120   uint8_t max)
121 {
122   assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
123   assert(scale < 256.0f);
124 
125   const uint32_t scale_bits = float_as_uint32(scale);
126   const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
127   const uint32_t shift = 127 + 23 - (scale_bits >> 23);
128   assert(shift >= 16);
129   assert(shift < 56);
130 
131   const uint64_t rounding = UINT64_C(1) << (shift - 1);
132   const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
133   const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
134 
135   uint32_t abs_input = (uint32_t) input;
136   if (input < 0) {
137     abs_input = -abs_input;
138   }
139 
140   const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
141   const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
142 
143   int32_t output = (int32_t) abs_scaled_input;
144   if (input < 0) {
145     output = -output;
146   }
147 
148   output = math_max_s32(output, min_less_zero_point);
149   output = math_min_s32(output, max_less_zero_point);
150   return (uint8_t) (output + (int32_t) zero_point);
151 }
152 
xnn_qs8_requantize_rndnu(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)153 static inline int8_t xnn_qs8_requantize_rndnu(
154   int32_t input,
155   float scale,
156   int8_t zero_point,
157   int8_t min,
158   int8_t max)
159 {
160   assert(scale < 256.0f);
161   assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
162 
163   const uint32_t scale_bits = float_as_uint32(scale);
164   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
165   const uint32_t shift = 127 + 23 - (scale_bits >> 23);
166   assert(shift >= 16);
167   assert(shift < 56);
168 
169   const int64_t rounding = INT64_C(1) << (shift - 1);
170   const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
171   const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
172 
173   const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
174   int32_t output = (int32_t) math_asr_s64(abs_prescaled_input + rounding, shift);
175   output = math_max_s32(output, min_less_zero_point);
176   output = math_min_s32(output, max_less_zero_point);
177   return (int8_t) (output + (int32_t) zero_point);
178 }
179 
xnn_qu8_requantize_rndnu(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)180 static inline uint8_t xnn_qu8_requantize_rndnu(
181   int32_t input,
182   float scale,
183   uint8_t zero_point,
184   uint8_t min,
185   uint8_t max)
186 {
187   assert(scale < 256.0f);
188   assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
189 
190   const uint32_t scale_bits = float_as_uint32(scale);
191   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
192   const uint32_t shift = 127 + 23 - (scale_bits >> 23);
193   assert(shift >= 16);
194   assert(shift < 56);
195 
196   const int64_t rounding = INT64_C(1) << (shift - 1);
197   const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
198   const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
199 
200   const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
201   int32_t output = (int32_t) math_asr_s64(abs_prescaled_input + rounding, shift);
202   output = math_max_s32(output, min_less_zero_point);
203   output = math_min_s32(output, max_less_zero_point);
204   return (uint8_t) (output + (int32_t) zero_point);
205 }
206 
xnn_qu8_quantize_add(uint8_t a,uint8_t b,union xnn_qu8_add_minmax_params params)207 static inline uint8_t xnn_qu8_quantize_add(
208   uint8_t a, uint8_t b,
209   union xnn_qu8_add_minmax_params params)
210 {
211   // Multiply by factors and accumulate products.
212   int32_t acc = params.scalar.bias + (int32_t) (uint32_t) a * params.scalar.a_multiplier + (int32_t) (uint32_t) b * params.scalar.b_multiplier;
213 
214   // Shift right with rounding away from zero.
215   acc = math_asr_s32(acc, params.scalar.shift);
216 
217   // Clamp and add output zero point.
218   acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
219   acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
220   return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
221 }
222 
xnn_qs8_quantize_add(int8_t a,int8_t b,union xnn_qs8_add_minmax_params params)223 static inline int8_t xnn_qs8_quantize_add(
224   int8_t a, int8_t b,
225   union xnn_qs8_add_minmax_params params)
226 {
227   // Multiply by factors and accumulate products.
228   int32_t acc = params.scalar.bias + (int32_t) a * params.scalar.a_multiplier + (int32_t) b * params.scalar.b_multiplier;
229 
230   // Shift right with rounding away from zero.
231   acc = math_asr_s32(acc, params.scalar.shift);
232 
233   // Clamp and add output zero point.
234   acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
235   acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
236   return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
237 }
238 
xnn_qs8_quantize(float val,float scale,int32_t zero_point)239 inline static int8_t xnn_qs8_quantize(float val, float scale, int32_t zero_point)
240 {
241   return (int8_t) lrintf(fminf(fmaxf(val / scale + (float) zero_point, -128.0f), 127.0f));
242 }
243 
xnn_qu8_quantize(float val,float scale,int32_t zero_point)244 inline static uint8_t xnn_qu8_quantize(float val, float scale, int32_t zero_point)
245 {
246   return (uint8_t) lrintf(fminf(fmaxf(val / scale + (float) zero_point, 0.0f), 255.0f));
247 }
248