1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8
9 #pragma once
10
11 #include <stdint.h>
12 #include <stddef.h>
13 #include <assert.h>
14 #include <math.h>
15
16 #include <fp16.h>
17
18 #include <xnnpack/common.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/params.h>
21
22
23 typedef int8_t (*xnn_qs8_requantize_fn)(
24 int32_t input,
25 float scale,
26 int8_t output_zero_point,
27 int8_t output_min,
28 int8_t output_max);
29
30 typedef uint8_t (*xnn_qu8_requantize_fn)(
31 int32_t input,
32 float scale,
33 uint8_t output_zero_point,
34 uint8_t output_min,
35 uint8_t output_max);
36
xnn_qs8_requantize_fp32(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)37 static inline int8_t xnn_qs8_requantize_fp32(
38 int32_t input,
39 float scale,
40 int8_t zero_point,
41 int8_t min,
42 int8_t max)
43 {
44 assert(scale >= 0x1.0p-32f);
45 assert(scale < 256.0f);
46
47 const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
48 const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
49
50 float scaled_input = (float) input * scale;
51 scaled_input = math_max_f32(scaled_input, min_less_zero_point);
52 scaled_input = math_min_f32(scaled_input, max_less_zero_point);
53
54 const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
55 return (int8_t) output;
56 }
57
xnn_qu8_requantize_fp32(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)58 static inline uint8_t xnn_qu8_requantize_fp32(
59 int32_t input,
60 float scale,
61 uint8_t zero_point,
62 uint8_t min,
63 uint8_t max)
64 {
65 assert(scale >= 0x1.0p-32f);
66 assert(scale < 256.0f);
67
68 const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
69 const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
70
71 float scaled_input = (float) input * scale;
72 scaled_input = math_max_f32(scaled_input, min_less_zero_point);
73 scaled_input = math_min_f32(scaled_input, max_less_zero_point);
74
75 const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
76 return (uint8_t) output;
77 }
78
xnn_qs8_requantize_rndna(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)79 static inline int8_t xnn_qs8_requantize_rndna(
80 int32_t input,
81 float scale,
82 int8_t zero_point,
83 int8_t min,
84 int8_t max)
85 {
86 assert(scale >= 0x1.0p-32f);
87 assert(scale < 256.0f);
88
89 const uint32_t scale_bits = fp32_to_bits(scale);
90 const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
91 const uint32_t shift = 127 + 23 - (scale_bits >> 23);
92 assert(shift >= 16);
93 assert(shift < 56);
94
95 const uint64_t rounding = UINT64_C(1) << (shift - 1);
96 const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
97 const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
98
99 uint32_t abs_input = (uint32_t) input;
100 if (input < 0) {
101 abs_input = -abs_input;
102 }
103
104 const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
105 const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
106
107 int32_t output = (int32_t) abs_scaled_input;
108 if (input < 0) {
109 output = -output;
110 }
111
112 output = math_max_s32(output, min_less_zero_point);
113 output = math_min_s32(output, max_less_zero_point);
114 return (int8_t) (output + (int32_t) zero_point);
115 }
116
xnn_qu8_requantize_rndna(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)117 static inline uint8_t xnn_qu8_requantize_rndna(
118 int32_t input,
119 float scale,
120 uint8_t zero_point,
121 uint8_t min,
122 uint8_t max)
123 {
124 assert(scale >= 0x1.0p-32f);
125 assert(scale < 256.0f);
126
127 const uint32_t scale_bits = fp32_to_bits(scale);
128 const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
129 const uint32_t shift = 127 + 23 - (scale_bits >> 23);
130 assert(shift >= 16);
131 assert(shift < 56);
132
133 const uint64_t rounding = UINT64_C(1) << (shift - 1);
134 const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
135 const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
136
137 uint32_t abs_input = (uint32_t) input;
138 if (input < 0) {
139 abs_input = -abs_input;
140 }
141
142 const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
143 const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
144
145 int32_t output = (int32_t) abs_scaled_input;
146 if (input < 0) {
147 output = -output;
148 }
149
150 output = math_max_s32(output, min_less_zero_point);
151 output = math_min_s32(output, max_less_zero_point);
152 return (uint8_t) (output + (int32_t) zero_point);
153 }
154
xnn_qs8_requantize_rndnu(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)155 static inline int8_t xnn_qs8_requantize_rndnu(
156 int32_t input,
157 float scale,
158 int8_t zero_point,
159 int8_t min,
160 int8_t max)
161 {
162 assert(scale < 256.0f);
163 assert(scale >= 0x1.0p-32f);
164
165 const uint32_t scale_bits = fp32_to_bits(scale);
166 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
167 const uint32_t shift = 127 + 23 - (scale_bits >> 23);
168 assert(shift >= 16);
169 assert(shift < 56);
170
171 const int64_t rounding = INT64_C(1) << (shift - 1);
172 const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
173 const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
174
175 const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
176 int32_t output = (int32_t) asr_s64(abs_prescaled_input + rounding, shift);
177 output = math_max_s32(output, min_less_zero_point);
178 output = math_min_s32(output, max_less_zero_point);
179 return (int8_t) (output + (int32_t) zero_point);
180 }
181
xnn_qu8_requantize_rndnu(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)182 static inline uint8_t xnn_qu8_requantize_rndnu(
183 int32_t input,
184 float scale,
185 uint8_t zero_point,
186 uint8_t min,
187 uint8_t max)
188 {
189 assert(scale < 256.0f);
190 assert(scale >= 0x1.0p-32f);
191
192 const uint32_t scale_bits = fp32_to_bits(scale);
193 const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
194 const uint32_t shift = 127 + 23 - (scale_bits >> 23);
195 assert(shift >= 16);
196 assert(shift < 56);
197
198 const int64_t rounding = INT64_C(1) << (shift - 1);
199 const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
200 const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
201
202 const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
203 int32_t output = (int32_t) asr_s64(abs_prescaled_input + rounding, shift);
204 output = math_max_s32(output, min_less_zero_point);
205 output = math_min_s32(output, max_less_zero_point);
206 return (uint8_t) (output + (int32_t) zero_point);
207 }
208
xnn_qu8_quantize_add(uint8_t a,uint8_t b,union xnn_qu8_addsub_minmax_params params)209 static inline uint8_t xnn_qu8_quantize_add(
210 uint8_t a, uint8_t b,
211 union xnn_qu8_addsub_minmax_params params)
212 {
213 // Multiply by factors and accumulate products.
214 int32_t acc = params.scalar.bias + (int32_t) (uint32_t) a * params.scalar.a_multiplier + (int32_t) (uint32_t) b * params.scalar.b_multiplier;
215
216 // Shift right with rounding away from zero.
217 acc = asr_s32(acc, params.scalar.shift);
218
219 // Clamp and add output zero point.
220 acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
221 acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
222 return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
223 }
224
xnn_qs8_quantize_add(int8_t a,int8_t b,union xnn_qs8_addsub_minmax_params params)225 static inline int8_t xnn_qs8_quantize_add(
226 int8_t a, int8_t b,
227 union xnn_qs8_addsub_minmax_params params)
228 {
229 // Multiply by factors and accumulate products.
230 int32_t acc = params.scalar.bias + (int32_t) a * params.scalar.a_multiplier + (int32_t) b * params.scalar.b_multiplier;
231
232 // Shift right with rounding away from zero.
233 acc = asr_s32(acc, params.scalar.shift);
234
235 // Clamp and add output zero point.
236 acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
237 acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
238 return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
239 }
240