• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) Facebook, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // Copyright 2019 Google LLC
5 //
6 // This source code is licensed under the BSD-style license found in the
7 // LICENSE file in the root directory of this source tree.
8 
9 #pragma once
10 
11 #include <stdint.h>
12 #include <stddef.h>
13 #include <assert.h>
14 #include <math.h>
15 
16 #include <fp16.h>
17 
18 #include <xnnpack/common.h>
19 #include <xnnpack/math.h>
20 #include <xnnpack/params.h>
21 
22 
23 typedef int8_t (*xnn_qs8_requantize_fn)(
24   int32_t input,
25   float scale,
26   int8_t output_zero_point,
27   int8_t output_min,
28   int8_t output_max);
29 
30 typedef uint8_t (*xnn_qu8_requantize_fn)(
31   int32_t input,
32   float scale,
33   uint8_t output_zero_point,
34   uint8_t output_min,
35   uint8_t output_max);
36 
xnn_qs8_requantize_fp32(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)37 static inline int8_t xnn_qs8_requantize_fp32(
38   int32_t input,
39   float scale,
40   int8_t zero_point,
41   int8_t min,
42   int8_t max)
43 {
44   assert(scale >= 0x1.0p-32f);
45   assert(scale < 256.0f);
46 
47   const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
48   const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
49 
50   float scaled_input = (float) input * scale;
51   scaled_input = math_max_f32(scaled_input, min_less_zero_point);
52   scaled_input = math_min_f32(scaled_input, max_less_zero_point);
53 
54   const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
55   return (int8_t) output;
56 }
57 
xnn_qu8_requantize_fp32(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)58 static inline uint8_t xnn_qu8_requantize_fp32(
59   int32_t input,
60   float scale,
61   uint8_t zero_point,
62   uint8_t min,
63   uint8_t max)
64 {
65   assert(scale >= 0x1.0p-32f);
66   assert(scale < 256.0f);
67 
68   const float min_less_zero_point = (float) ((int32_t) min - (int32_t) zero_point);
69   const float max_less_zero_point = (float) ((int32_t) max - (int32_t) zero_point);
70 
71   float scaled_input = (float) input * scale;
72   scaled_input = math_max_f32(scaled_input, min_less_zero_point);
73   scaled_input = math_min_f32(scaled_input, max_less_zero_point);
74 
75   const int32_t output = (int32_t) lrintf(scaled_input) + (int32_t) zero_point;
76   return (uint8_t) output;
77 }
78 
xnn_qs8_requantize_rndna(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)79 static inline int8_t xnn_qs8_requantize_rndna(
80   int32_t input,
81   float scale,
82   int8_t zero_point,
83   int8_t min,
84   int8_t max)
85 {
86   assert(scale >= 0x1.0p-32f);
87   assert(scale < 256.0f);
88 
89   const uint32_t scale_bits = fp32_to_bits(scale);
90   const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
91   const uint32_t shift = 127 + 23 - (scale_bits >> 23);
92   assert(shift >= 16);
93   assert(shift < 56);
94 
95   const uint64_t rounding = UINT64_C(1) << (shift - 1);
96   const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
97   const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
98 
99   uint32_t abs_input = (uint32_t) input;
100   if (input < 0) {
101     abs_input = -abs_input;
102   }
103 
104   const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
105   const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
106 
107   int32_t output = (int32_t) abs_scaled_input;
108   if (input < 0) {
109     output = -output;
110   }
111 
112   output = math_max_s32(output, min_less_zero_point);
113   output = math_min_s32(output, max_less_zero_point);
114   return (int8_t) (output + (int32_t) zero_point);
115 }
116 
xnn_qu8_requantize_rndna(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)117 static inline uint8_t xnn_qu8_requantize_rndna(
118   int32_t input,
119   float scale,
120   uint8_t zero_point,
121   uint8_t min,
122   uint8_t max)
123 {
124   assert(scale >= 0x1.0p-32f);
125   assert(scale < 256.0f);
126 
127   const uint32_t scale_bits = fp32_to_bits(scale);
128   const uint32_t multiplier = (scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000);
129   const uint32_t shift = 127 + 23 - (scale_bits >> 23);
130   assert(shift >= 16);
131   assert(shift < 56);
132 
133   const uint64_t rounding = UINT64_C(1) << (shift - 1);
134   const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
135   const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
136 
137   uint32_t abs_input = (uint32_t) input;
138   if (input < 0) {
139     abs_input = -abs_input;
140   }
141 
142   const uint64_t abs_prescaled_input = (uint64_t) abs_input * (uint64_t) multiplier;
143   const uint32_t abs_scaled_input = (uint32_t) ((abs_prescaled_input + rounding) >> shift);
144 
145   int32_t output = (int32_t) abs_scaled_input;
146   if (input < 0) {
147     output = -output;
148   }
149 
150   output = math_max_s32(output, min_less_zero_point);
151   output = math_min_s32(output, max_less_zero_point);
152   return (uint8_t) (output + (int32_t) zero_point);
153 }
154 
xnn_qs8_requantize_rndnu(int32_t input,float scale,int8_t zero_point,int8_t min,int8_t max)155 static inline int8_t xnn_qs8_requantize_rndnu(
156   int32_t input,
157   float scale,
158   int8_t zero_point,
159   int8_t min,
160   int8_t max)
161 {
162   assert(scale < 256.0f);
163   assert(scale >= 0x1.0p-32f);
164 
165   const uint32_t scale_bits = fp32_to_bits(scale);
166   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
167   const uint32_t shift = 127 + 23 - (scale_bits >> 23);
168   assert(shift >= 16);
169   assert(shift < 56);
170 
171   const int64_t rounding = INT64_C(1) << (shift - 1);
172   const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
173   const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
174 
175   const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
176   int32_t output = (int32_t) asr_s64(abs_prescaled_input + rounding, shift);
177   output = math_max_s32(output, min_less_zero_point);
178   output = math_min_s32(output, max_less_zero_point);
179   return (int8_t) (output + (int32_t) zero_point);
180 }
181 
xnn_qu8_requantize_rndnu(int32_t input,float scale,uint8_t zero_point,uint8_t min,uint8_t max)182 static inline uint8_t xnn_qu8_requantize_rndnu(
183   int32_t input,
184   float scale,
185   uint8_t zero_point,
186   uint8_t min,
187   uint8_t max)
188 {
189   assert(scale < 256.0f);
190   assert(scale >= 0x1.0p-32f);
191 
192   const uint32_t scale_bits = fp32_to_bits(scale);
193   const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
194   const uint32_t shift = 127 + 23 - (scale_bits >> 23);
195   assert(shift >= 16);
196   assert(shift < 56);
197 
198   const int64_t rounding = INT64_C(1) << (shift - 1);
199   const int32_t min_less_zero_point = (int32_t) min - (int32_t) zero_point;
200   const int32_t max_less_zero_point = (int32_t) max - (int32_t) zero_point;
201 
202   const int64_t abs_prescaled_input = (int64_t) input * (int64_t) multiplier;
203   int32_t output = (int32_t) asr_s64(abs_prescaled_input + rounding, shift);
204   output = math_max_s32(output, min_less_zero_point);
205   output = math_min_s32(output, max_less_zero_point);
206   return (uint8_t) (output + (int32_t) zero_point);
207 }
208 
xnn_qu8_quantize_add(uint8_t a,uint8_t b,union xnn_qu8_addsub_minmax_params params)209 static inline uint8_t xnn_qu8_quantize_add(
210   uint8_t a, uint8_t b,
211   union xnn_qu8_addsub_minmax_params params)
212 {
213   // Multiply by factors and accumulate products.
214   int32_t acc = params.scalar.bias + (int32_t) (uint32_t) a * params.scalar.a_multiplier + (int32_t) (uint32_t) b * params.scalar.b_multiplier;
215 
216   // Shift right with rounding away from zero.
217   acc = asr_s32(acc, params.scalar.shift);
218 
219   // Clamp and add output zero point.
220   acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
221   acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
222   return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
223 }
224 
xnn_qs8_quantize_add(int8_t a,int8_t b,union xnn_qs8_addsub_minmax_params params)225 static inline int8_t xnn_qs8_quantize_add(
226   int8_t a, int8_t b,
227   union xnn_qs8_addsub_minmax_params params)
228 {
229   // Multiply by factors and accumulate products.
230   int32_t acc = params.scalar.bias + (int32_t) a * params.scalar.a_multiplier + (int32_t) b * params.scalar.b_multiplier;
231 
232   // Shift right with rounding away from zero.
233   acc = asr_s32(acc, params.scalar.shift);
234 
235   // Clamp and add output zero point.
236   acc = math_max_s32(acc, params.scalar.output_min_less_zero_point);
237   acc = math_min_s32(acc, params.scalar.output_max_less_zero_point);
238   return (int8_t) ((int32_t) acc + params.scalar.output_zero_point);
239 }
240