1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #include <stdint.h>
12 #include <stddef.h>
13 #include <assert.h>
14 #include <math.h>
15
16 #include <xnnpack/common.h>
17 #include <xnnpack/math.h>
18 #include <xnnpack/microparams.h>
19
20
21 typedef int8_t (*xnn_qs8_requantize_fn)(
22 int32_t input,
23 float scale,
24 int8_t output_zero_point,
25 int8_t output_min,
26 int8_t output_max);
27
28 typedef uint8_t (*xnn_qu8_requantize_fn)(
29 int32_t input,
30 float scale,
31 uint8_t output_zero_point,
32 uint8_t output_min,
33 uint8_t output_max);
34
xnn_qs8_requantize_fp32(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)35 static inline int8_t xnn_qs8_requantize_fp32(
36 int32_t input,
37 float scale,
38 int8_t zero_point,
39 int8_t min,
40 int8_t max)
41 {
42 assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
43 assert(scale < 256.0f);
44
45 const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
46 const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
47
48 float scaled_input = (float) input * scale;
49 scaled_input = math_max_f32(scaled_input, min_less_zero_point);
50 scaled_input = math_min_f32(scaled_input, max_less_zero_point);
51
52 const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
53 return (int8_t) output;
54 }
55
xnn_qu8_requantize_fp32(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)56 static inline uint8_t xnn_qu8_requantize_fp32(
57 int32_t input,
58 float scale,
59 uint8_t zero_point,
60 uint8_t min,
61 uint8_t max)
62 {
63 assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
64 assert(scale < 256.0f);
65
66 const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
67 const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
68
69 float scaled_input = (float) input * scale;
70 scaled_input = math_max_f32(scaled_input, min_less_zero_point);
71 scaled_input = math_min_f32(scaled_input, max_less_zero_point);
72
73 const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
74 return (uint8_t) output;
75 }
76
xnn_qs8_requantize_rndna(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)77 static inline int8_t xnn_qs8_requantize_rndna(
78 int32_t input,
79 float scale,
80 int8_t zero_point,
81 int8_t min,
82 int8_t max)
83 {
84 assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
85 assert(scale < 256.0f);
86
87 const uint32_t scale_bits = float_as_uint32(scale);
88 const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
89 const uint32_t shift = 127 + 23 - (scale_bits >> 23);
90 assert(shift >= 16);
91 assert(shift < 56);
92
93 const uint64_t rounding = UINT64_C(1) << (shift - 1);
94 const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
95 const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
96
97 uint32_t abs_input = (uint32_t) input;
98 if (input < 0) {
99 abs_input = -abs_input;
100 }
101
102 const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
103 const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
104
105 int32_t output = (int32_t) abs_scaled_input;
106 if (input < 0) {
107 output = -output;
108 }
109
110 output = math_max_s32(output, min_less_zero_point);
111 output = math_min_s32(output, max_less_zero_point);
112 return (int8_t) (output + (int32_t) zero_point);
113 }
114
xnn_qu8_requantize_rndna(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)115 static inline uint8_t xnn_qu8_requantize_rndna(
116 int32_t input,
117 float scale,
118 uint8_t zero_point,
119 uint8_t min,
120 uint8_t max)
121 {
122 assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
123 assert(scale < 256.0f);
124
125 const uint32_t scale_bits = float_as_uint32(scale);
126 const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
127 const uint32_t shift = 127 + 23 - (scale_bits >> 23);
128 assert(shift >= 16);
129 assert(shift < 56);
130
131 const uint64_t rounding = UINT64_C(1) << (shift - 1);
132 const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
133 const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
134
135 uint32_t abs_input = (uint32_t) input;
136 if (input < 0) {
137 abs_input = -abs_input;
138 }
139
140 const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
141 const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
142
143 int32_t output = (int32_t) abs_scaled_input;
144 if (input < 0) {
145 output = -output;
146 }
147
148 output = math_max_s32(output, min_less_zero_point);
149 output = math_min_s32(output, max_less_zero_point);
150 return (uint8_t) (output + (int32_t) zero_point);
151 }
152
xnn_qs8_requantize_rndnu(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)153 static inline int8_t xnn_qs8_requantize_rndnu(
154 int32_t input,
155 float scale,
156 int8_t zero_point,
157 int8_t min,
158 int8_t max)
159 {
160 assert(scale < 256.0f);
161 assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
162
163 const uint32_t scale_bits = float_as_uint32(scale);
164 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
165 const uint32_t shift = 127 + 23 - (scale_bits >> 23);
166 assert(shift >= 16);
167 assert(shift < 56);
168
169 const int64_t rounding = INT64_C(1) << (shift - 1);
170 const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
171 const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
172
173 const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
174 int32_t output = (int32_t) math_asr_s64(abs_prescaled_input + rounding, shift);
175 output = math_max_s32(output, min_less_zero_point);
176 output = math_min_s32(output, max_less_zero_point);
177 return (int8_t) (output + (int32_t) zero_point);
178 }
179
xnn_qu8_requantize_rndnu(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)180 static inline uint8_t xnn_qu8_requantize_rndnu(
181 int32_t input,
182 float scale,
183 uint8_t zero_point,
184 uint8_t min,
185 uint8_t max)
186 {
187 assert(scale < 256.0f);
188 assert(scale >= 1.0f / 4294967296.0f /* 0x1.0p-32f */);
189
190 const uint32_t scale_bits = float_as_uint32(scale);
191 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
192 const uint32_t shift = 127 + 23 - (scale_bits >> 23);
193 assert(shift >= 16);
194 assert(shift < 56);
195
196 const int64_t rounding = INT64_C(1) << (shift - 1);
197 const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
198 const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
199
200 const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
201 int32_t output = (int32_t) math_asr_s64(abs_prescaled_input + rounding, shift);
202 output = math_max_s32(output, min_less_zero_point);
203 output = math_min_s32(output, max_less_zero_point);
204 return (uint8_t) (output + (int32_t) zero_point);
205 }
206
xnn_qu8_quantize_add(uint8_t a,uint8_t b,union xnn_qu8_add_minmax_params params)207 static inline uint8_t xnn_qu8_quantize_add(
208 uint8_t a, uint8_t b,
209 union xnn_qu8_add_minmax_params params)
210 {
211 // Multiply by factors and accumulate products.
212 int32_t acc = params.scalar.bias + (int32_t) (uint32_t) a * params.scalar.a_multiplier + (int32_t) (uint32_t) b * params.scalar.b_multiplier;
213
214 // Shift right with rounding away from zero.
215 acc = math_asr_s32(acc, params.scalar.shift);
216
217 // Clamp and add output zero point.
218 acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
219 acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
220 return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
221 }
222
xnn_qs8_quantize_add(int8_t a,int8_t b,union xnn_qs8_add_minmax_params params)223 static inline int8_t xnn_qs8_quantize_add(
224 int8_t a, int8_t b,
225 union xnn_qs8_add_minmax_params params)
226 {
227 // Multiply by factors and accumulate products.
228 int32_t acc = params.scalar.bias + (int32_t) a * params.scalar.a_multiplier + (int32_t) b * params.scalar.b_multiplier;
229
230 // Shift right with rounding away from zero.
231 acc = math_asr_s32(acc, params.scalar.shift);
232
233 // Clamp and add output zero point.
234 acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
235 acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
236 return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
237 }
238
xnn_qs8_quantize(float val,float scale,int32_t zero_point)239 inline static int8_t xnn_qs8_quantize(float val, float scale, int32_t zero_point)
240 {
241 return (int8_t) lrintf(fminf(fmaxf(val / scale + (float) zero_point, -128.0f), 127.0f));
242 }
243
xnn_qu8_quantize(float val,float scale,int32_t zero_point)244 inline static uint8_t xnn_qu8_quantize(float val, float scale, int32_t zero_point)
245 {
246 return (uint8_t) lrintf(fminf(fmaxf(val / scale + (float) zero_point, 0.0f), 255.0f));
247 }
248