• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Quantized calculation utilities.
2 // TODO(vddang): Replace this with tensorflow/lite/kernels/internal/tensor_utils(common).h
3 // after TFLite module has been synced.
4 
5 #ifndef ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H
6 #define ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H
7 
8 #include <public/gemmlowp.h>
9 
10 #include <limits>
11 #include <memory>
12 
13 #include "LegacyUtils.h"
14 #include "OperationsUtils.h"
15 
16 namespace android {
17 namespace nn {
18 
MultiplyByQuantizedMultiplier(int32_t x,int32_t quantized_multiplier,int shift)19 inline int32_t MultiplyByQuantizedMultiplier(int32_t x, int32_t quantized_multiplier, int shift) {
20     using gemmlowp::RoundingDivideByPOT;
21     using gemmlowp::SaturatingRoundingDoublingHighMul;
22     int left_shift = shift > 0 ? shift : 0;
23     int right_shift = shift > 0 ? 0 : -shift;
24     return RoundingDivideByPOT(
25             SaturatingRoundingDoublingHighMul(x * (1 << left_shift), quantized_multiplier),
26             right_shift);
27 }
28 
29 template <typename T>
MatrixBatchVectorMultiplyAccumulate(const int8_t * input,const int32_t * bias,const int8_t * input_to_gate_weights,int32_t multiplier,int32_t shift,int32_t n_batch,int32_t n_input,int32_t n_output,int32_t output_zp,T * output)30 void MatrixBatchVectorMultiplyAccumulate(const int8_t* input, const int32_t* bias,
31                                          const int8_t* input_to_gate_weights, int32_t multiplier,
32                                          int32_t shift, int32_t n_batch, int32_t n_input,
33                                          int32_t n_output, int32_t output_zp, T* output) {
34     const int16_t output_max = std::numeric_limits<T>::max();
35     const int16_t output_min = std::numeric_limits<T>::min();
36     for (int batch = 0; batch < n_batch; ++batch) {
37         for (int row = 0; row < n_output; ++row) {
38             int32_t acc = bias[row];
39             for (int col = 0; col < n_input; ++col) {
40                 int8_t input_val = input[batch * n_input + col];
41                 int8_t weights_val = input_to_gate_weights[row * n_input + col];
42                 acc += input_val * weights_val;
43             }
44             acc = MultiplyByQuantizedMultiplier(acc, multiplier, shift);
45             acc += output_zp;
46             acc += output[batch * n_output + row];
47             if (acc > output_max) {
48                 acc = output_max;
49             }
50             if (acc < output_min) {
51                 acc = output_min;
52             }
53             output[batch * n_output + row] = static_cast<T>(acc);
54         }
55     }
56 }
57 
58 template <typename T>
CountLeadingZeros(T integer_input)59 int CountLeadingZeros(T integer_input) {
60     static_assert(std::is_unsigned<T>::value, "Only unsigned integer types handled.");
61 #if defined(__GNUC__)
62     return integer_input ? __builtin_clz(integer_input) : std::numeric_limits<T>::digits;
63 #else
64     if (integer_input == 0) {
65         return std::numeric_limits<T>::digits;
66     }
67 
68     const T one_in_leading_positive = static_cast<T>(1) << (std::numeric_limits<T>::digits - 1);
69     int leading_zeros = 0;
70     while (integer_input < one_in_leading_positive) {
71         integer_input <<= 1;
72         ++leading_zeros;
73     }
74     return leading_zeros;
75 #endif
76 }
77 
GetInvSqrtQuantizedMultiplierExp(int32_t input,int reverse_shift,int32_t * output_inv_sqrt,int * output_shift)78 inline bool GetInvSqrtQuantizedMultiplierExp(int32_t input, int reverse_shift,
79                                              int32_t* output_inv_sqrt, int* output_shift) {
80     NN_RET_CHECK_GE(input, 0);
81     if (input <= 1) {
82         // Handle the input value 1 separately to avoid overflow in that case
83         // in the general computation below. Also handle 0 as if it
84         // were a 1. 0 is an invalid input here (divide by zero) and 1 is a valid
85         // but rare/unrealistic input value. We can expect both to occur in some
86         // incompletely trained models, but probably not in fully trained models.
87         *output_inv_sqrt = std::numeric_limits<std::int32_t>::max();
88         *output_shift = 0;
89         return true;
90     }
91 
92     *output_shift = 11;
93     while (input >= (1 << 29)) {
94         input /= 4;
95         ++*output_shift;
96     }
97     const unsigned max_left_shift_bits = CountLeadingZeros(static_cast<uint32_t>(input)) - 1;
98     const unsigned max_left_shift_bit_pairs = max_left_shift_bits / 2;
99     const unsigned left_shift_bit_pairs = max_left_shift_bit_pairs - 1;
100     *output_shift -= left_shift_bit_pairs;
101     input <<= 2 * left_shift_bit_pairs;
102     NN_RET_CHECK_GE(input, (1 << 27));
103     NN_RET_CHECK_LT(input, (1 << 29));
104     using gemmlowp::FixedPoint;
105     using gemmlowp::Rescale;
106     using gemmlowp::SaturatingRoundingMultiplyByPOT;
107     // Using 3 integer bits gives us enough room for the internal arithmetic in
108     // this Newton-Raphson iteration.
109     using F3 = FixedPoint<int32_t, 3>;
110     using F0 = FixedPoint<int32_t, 0>;
111     const F3 fixedpoint_input = F3::FromRaw(input >> 1);
112     const F3 fixedpoint_half_input = SaturatingRoundingMultiplyByPOT<-1>(fixedpoint_input);
113     const F3 fixedpoint_half_three =
114             GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F3, (1 << 28) + (1 << 27), 1.5);
115     // Newton-Raphson iteration
116     // Naive unoptimized starting guess: x = 1
117     F3 x = F3::One();
118     // Naive unoptimized number of iterations: 5
119     for (int i = 0; i < 5; i++) {
120         const F3 x3 = Rescale<3>(x * x * x);
121         x = Rescale<3>(fixedpoint_half_three * x - fixedpoint_half_input * x3);
122     }
123     const F0 fixedpoint_half_sqrt_2 =
124             GEMMLOWP_CHECKED_FIXEDPOINT_CONSTANT(F0, 1518500250, std::sqrt(2.) / 2.);
125     x = x * fixedpoint_half_sqrt_2;
126     *output_inv_sqrt = x.raw();
127     if (*output_shift < 0) {
128         *output_inv_sqrt <<= -*output_shift;
129         *output_shift = 0;
130     }
131     // Convert right shift (right is positive) to left shift.
132     *output_shift *= reverse_shift;
133     return true;
134 }
135 
136 void ApplyLayerNorm(const int16_t* input, const int16_t* layer_norm_weights, const int32_t* bias,
137                     int32_t layer_norm_scale_a, int32_t layer_norm_scale_b, int32_t variance_limit,
138                     int n_batch, int n_input, int16_t* output);
139 
140 void MatrixScalarMultiplyAccumulate(const int8_t* matrix, int32_t scalar, int32_t n_row,
141                                     int32_t n_col, int32_t* output);
142 
143 bool PrecomputeZeroPointTimesWeightWithBias(int32_t zero_point, const int8_t* weight_tensor,
144                                             const Shape& weight_shape, const int32_t* bias_tensor,
145                                             std::unique_ptr<int32_t[]>* output);
146 
147 void ApplySigmoid(const int16_t* input, int32_t n_batch, int32_t n_input, int16_t* output);
148 
149 template <int IntegerBits>
ApplyTanh(const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)150 void ApplyTanh(const int16_t* input, int32_t n_batch, int32_t n_input, int16_t* output) {
151     using FX = gemmlowp::FixedPoint<std::int16_t, IntegerBits>;
152     using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
153     for (int batch = 0; batch < n_batch; ++batch) {
154         for (int i = 0; i < n_input; ++i) {
155             const int index = batch * n_input + i;
156             FX tanh_input = FX::FromRaw(input[index]);
157             F0 tanh_output = gemmlowp::tanh(tanh_input);
158             output[index] = tanh_output.raw();
159         }
160     }
161 }
162 
ApplyTanh(int32_t integer_bits,const int16_t * input,int32_t n_batch,int32_t n_input,int16_t * output)163 inline void ApplyTanh(int32_t integer_bits, const int16_t* input, int32_t n_batch, int32_t n_input,
164                       int16_t* output) {
165     assert(integer_bits <= 6);
166 #define DISPATCH_TANH(i)                               \
167     case i:                                            \
168         ApplyTanh<i>(input, n_batch, n_input, output); \
169         break;
170     switch (integer_bits) {
171         DISPATCH_TANH(0);
172         DISPATCH_TANH(1);
173         DISPATCH_TANH(2);
174         DISPATCH_TANH(3);
175         DISPATCH_TANH(4);
176         DISPATCH_TANH(5);
177         DISPATCH_TANH(6);
178         default:
179             return;
180     }
181 #undef DISPATCH_TANH
182 }
183 
184 void CwiseMul(const int16_t* input_1, const int16_t* input_2, int n_batch, int n_input, int shift,
185               int16_t* output);
186 void CwiseMul(const int16_t* input_1, const int16_t* input_2, int32_t multiplier, int32_t shift,
187               int32_t n_batch, int32_t n_input, int32_t output_zp, int8_t* output);
188 
189 bool CheckedLog2(const float x, int* log2_result);
190 
191 void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch, int n_input,
192               int16_t* output);
193 
Sub1Vector(const int16_t * vector,int v_size,int16_t * result)194 inline void Sub1Vector(const int16_t* vector, int v_size, int16_t* result) {
195     static const int16_t kOne = 32767;
196     for (int v = 0; v < v_size; v++) {
197         *result++ = kOne - *vector++;
198     }
199 }
200 
201 void CwiseClipping(int16_t* input, const int16_t clipping_value, int32_t n_batch, int32_t n_input);
202 
203 void CwiseClipping(int8_t* input, const int8_t clipping_value, int32_t n_batch, int32_t n_input);
204 
205 void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
206                                              const int16_t* batch_vector, int n_batch,
207                                              int32_t multiplier, int shift, int16_t* result);
208 
209 }  // namespace nn
210 }  // namespace android
211 
212 #endif  // ANDROID_FRAMEWORKS_ML_NN_COMMON_QUANTUTILS_H
213