1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
17 #define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
18
19 #include <cmath>
20 #define EIGEN_USE_THREADS
21
22 // This is a set of functions that standardizes how quantized values are
23 // interpreted as float numbers.
24 // All of the current implementations are for reference and have not been
25 // optimized. They should be implementable using fixed point representations
26 // to avoid a dependency on floating-point hardware.
27
28 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
29 #define QUANTIZATION_UTILS_USE_NEON
30 #include <arm_neon.h>
31 #endif
32
33 #include <array>
34
35 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
36 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
37 #include "public/gemmlowp.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/lib/core/threadpool.h"
40
41 namespace tensorflow {
42
43 // We have to be able to detect and handle overflows in int32, so this function
44 // uses doubles and int64's to make sure we have enough room.
45 template <class T>
FloatToQuantizedUnclamped(float input,float range_min,float range_max)46 inline int64 FloatToQuantizedUnclamped(float input, float range_min,
47 float range_max) {
48 const int64 lowest_quantized =
49 static_cast<double>(Eigen::NumTraits<T>::lowest());
50 if (range_min == range_max) {
51 return lowest_quantized;
52 }
53 const int number_of_bits = sizeof(T) * 8;
54 const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
55 const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
56 const double range = ((range_max - range_min) * range_adjust);
57 const double range_scale = (number_of_steps / range);
58 int64 quantized =
59 (round(input * range_scale) - round(range_min * range_scale));
60 quantized += lowest_quantized;
61 return quantized;
62 }
63
64 template <>
65 inline int64 FloatToQuantizedUnclamped<float>(float input, float range_min,
66 float range_max) {
67 return -1;
68 }
69
70 // This converts the float into the final quantized type, clamping/saturating
71 // any over or underflows.
72 template <class T>
FloatToQuantized(float input,float range_min,float range_max)73 T FloatToQuantized(float input, float range_min, float range_max) {
74 if (std::is_same<T, float>::value) {
75 // Specialization for float. This is used in reference implementation
76 // for float which is useful to compare performance between float
77 // and quantized type.
78 return input;
79 }
80 int64 quantized = FloatToQuantizedUnclamped<T>(input, range_min, range_max);
81 const int64 lowest_quantized =
82 static_cast<int64>(Eigen::NumTraits<T>::lowest());
83 const int64 highest_quantized =
84 static_cast<int64>(Eigen::NumTraits<T>::highest());
85 quantized = std::max(quantized, lowest_quantized);
86 quantized = std::min(quantized, highest_quantized);
87 return static_cast<T>(static_cast<int32>(quantized));
88 }
89
90 template <class T>
QuantizedToFloat(T input,float range_min,float range_max)91 float QuantizedToFloat(T input, float range_min, float range_max) {
92 if (std::is_same<T, float>::value) {
93 // Specialization for float. This is used in reference implementation
94 // for float which is useful to compare performance between float
95 // and quantized type.
96 return input;
97 }
98 if (range_min == range_max) {
99 return range_min;
100 }
101 const int number_of_bits = sizeof(T) * 8;
102 const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
103 const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
104 const double range = ((range_max - range_min) * range_adjust);
105 const double range_scale = (range / number_of_steps);
106 const int64 lowest_quantized =
107 static_cast<int64>(Eigen::NumTraits<T>::lowest());
108 const double offset_input = static_cast<double>(input) - lowest_quantized;
109 // For compatibility with DEQUANTIZE_WITH_EIGEN, we should convert
110 // range_scale to a float, otherwise range_min_rounded might be slightly
111 // different.
112 const double range_min_rounded =
113 std::round(range_min / static_cast<float>(range_scale)) *
114 static_cast<float>(range_scale);
115 const double result = range_min_rounded + (offset_input * range_scale);
116 return static_cast<float>(result);
117 }
118
119 template <class T>
FloatForOneQuantizedLevel(float range_min,float range_max)120 float FloatForOneQuantizedLevel(float range_min, float range_max) {
121 const int64 highest = static_cast<int64>(Eigen::NumTraits<T>::highest());
122 const int64 lowest = static_cast<int64>(Eigen::NumTraits<T>::lowest());
123 const float float_for_one_quantized_level =
124 (range_max - range_min) / (highest - lowest);
125 return float_for_one_quantized_level;
126 }
127
128 template <class T1, class T2, class T3>
QuantizationRangeForMultiplication(float min_a,float max_a,float min_b,float max_b,float * min_c,float * max_c)129 void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b,
130 float max_b, float* min_c,
131 float* max_c) {
132 const float a_float_for_one_quant_level =
133 FloatForOneQuantizedLevel<T1>(min_a, max_a);
134 const float b_float_for_one_quant_level =
135 FloatForOneQuantizedLevel<T2>(min_b, max_b);
136
137 const int64 c_highest = static_cast<int64>(Eigen::NumTraits<T3>::highest());
138 const int64 c_lowest = static_cast<int64>(Eigen::NumTraits<T3>::lowest());
139 const float c_float_for_one_quant_level =
140 a_float_for_one_quant_level * b_float_for_one_quant_level;
141
142 *min_c = c_float_for_one_quant_level * c_lowest;
143 *max_c = c_float_for_one_quant_level * c_highest;
144 }
145
146 // input_array is an eigen Tensor. q2f is a QuantizedToFloatStruct.
147 // This evaluates to an eigen tensor expression, to be used like:
148 // auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f);
149 #define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \
150 ((q2f.range_min_rounded - q2f.lowest_quantized() * q2f.range_scale) + \
151 input_array.template cast<float>() * q2f.range_scale)
152
153 // input_array is an eigen Tensor. f2q is a FloatToQuantizedStruct.
154 // OutputType is the type of output (e.g. quint8).
155 // This evaluates to an eigen tensor expression, to be used like:
156 // auto tensor = QUANTIZE_WITH_EIGEN(input_tensor, f2q, T);
157 #define QUANTIZE_WITH_EIGEN(input_array, f2q, OutputType) \
158 ((input_array * f2q.range_scale).round() - \
159 (f2q.range_min_scaled - f2q.lowest_quantized())) \
160 .cwiseMax(f2q.lower_bound_float()) \
161 .cwiseMin(f2q.upper_bound_float()) \
162 .template cast<int32>() \
163 .template cast<OutputType>()
164
165 // For use with DEQUANTIZE_WITH_EIGEN.
166 template <typename T>
167 struct QuantizedToFloatStruct {
168 static constexpr int number_of_bits = sizeof(T) * 8;
169 static constexpr int64 number_of_steps = static_cast<int64>(1)
170 << number_of_bits;
171
lowest_quantizedQuantizedToFloatStruct172 static float lowest_quantized() {
173 return static_cast<float>(Eigen::NumTraits<T>::lowest());
174 }
175
QuantizedToFloatStructQuantizedToFloatStruct176 QuantizedToFloatStruct(float range_min, float range_max)
177 : range_min(range_min),
178 range_scale((range_max - range_min) / (number_of_steps - 1.0)),
179 range_min_rounded(range_max == range_min
180 ? range_min
181 : std::round(range_min / range_scale) *
182 range_scale) {}
183
184 const float range_min;
185 const float range_scale;
186 const float range_min_rounded;
187 };
188
189 // For use with QUANTIZE_WITH_EIGEN.
190 template <typename T>
191 struct FloatToQuantizedStruct {
192 static constexpr int number_of_bits = sizeof(T) * 8;
193 static constexpr int64 number_of_steps = static_cast<int64>(1)
194 << number_of_bits;
195 static constexpr double range_adjust =
196 (number_of_steps / (number_of_steps - 1.0));
197
198 // Casting QInt32's lowest or highest to a float gives a float that can't be
199 // cast back to int32 or QInt32. Instead, use bounds that can be converted
200 // back to int32 without going outside the range of an int32.
lower_bound_floatFloatToQuantizedStruct201 static float lower_bound_float() {
202 return Eigen::numext::maxi(
203 static_cast<float>(Eigen::NumTraits<T>::lowest()), -2.147483648e+09f);
204 }
upper_bound_floatFloatToQuantizedStruct205 static float upper_bound_float() {
206 return Eigen::numext::mini(
207 static_cast<float>(Eigen::NumTraits<T>::highest()), +2.147483520e+09f);
208 }
209
lowest_quantizedFloatToQuantizedStruct210 static float lowest_quantized() {
211 return static_cast<float>(Eigen::NumTraits<T>::lowest());
212 }
213
FloatToQuantizedStructFloatToQuantizedStruct214 FloatToQuantizedStruct(float range_min, float range_max)
215 : range_min(range_min),
216 range_scale(range_max == range_min
217 ? 0.0
218 : (number_of_steps - 1.0) / (range_max - range_min)),
219 range_min_scaled(std::round(range_min * range_scale)) {}
220
221 const float range_min;
222 const float range_scale;
223 const float range_min_scaled;
224 };
225
226 template <class T1, class T2>
RequantizeInNewRange(T1 input,float min_input,float max_input,float min_new,float max_new)227 inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input,
228 float min_new, float max_new) {
229 const float input_float = QuantizedToFloat<T1>(input, min_input, max_input);
230 return FloatToQuantized<T2>(input_float, min_new, max_new);
231 }
232
233 template <class T1, class T2>
RequantizeManyInNewRange(const T1 * input,int64 count,float min_input,float max_input,float min_output,float max_output,T2 * output)234 inline void RequantizeManyInNewRange(const T1* input, int64 count,
235 float min_input, float max_input,
236 float min_output, float max_output,
237 T2* output) {
238 for (size_t index = 0; index < count; ++index) {
239 const float input_float =
240 QuantizedToFloat<T1>(input[index], min_input, max_input);
241 output[index] = FloatToQuantized<T2>(input_float, min_output, max_output);
242 }
243 }
244
245 // Because converting 32-bit accumulated results down to eight bit is a common
246 // case, we have a specialized code path to handle it as efficiently as
247 // possible using only fixed-point math for the inner loop.
RequantizeManyInNewRangeReference(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)248 inline void RequantizeManyInNewRangeReference(const qint32* input, int64 count,
249 float min_input, float max_input,
250 float min_output,
251 float max_output,
252 quint8* output) {
253 // Initially we calculate all the constants we need once, before we go into
254 // the inner loop. If this is updated, also update the Eigen version.
255 const int fp_shift = 16;
256 const float input_range = max_input - min_input;
257 const float output_range = max_output - min_output;
258 const float recip_output_range =
259 output_range == 0.0 ? 0.0 : (255.0 / output_range);
260 const float input_rezero = (min_input + max_input) / 2.0;
261 const int64 range_scale_fp =
262 output_range == 0.0 ? 0.0
263 : static_cast<int64>(255.0 * (1 << fp_shift) *
264 input_range / output_range);
265 const int64 input_offset_fp =
266 static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
267 const int64 output_offset_fp =
268 output_range == 0.0
269 ? 0
270 : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
271 output_range);
272 const int64 rounding_delta = 1 << (fp_shift - 1);
273
274 // Inside this loop we just do minimal adds, multiplies, and shifts, in a way
275 // that could be easily adapted for a SIMD implementation. It should also be
276 // possible to perform all the calculations in 32-bit rather than 64, but
277 // that's not been implemented yet.
278 for (tensorflow::int64 index = 0; index < count; ++index) {
279 const int64 input_value = static_cast<int64>(input[index]);
280 const int64 fp_value =
281 ((input_value * range_scale_fp) >> 32) + input_offset_fp;
282 const int64 offset_intermediate = fp_value - output_offset_fp;
283 const int64 round_intermediate = offset_intermediate + rounding_delta;
284 int64 quantized_int64 = round_intermediate >> fp_shift;
285 quantized_int64 = std::max(quantized_int64, int64{0});
286 quantized_int64 = std::min(quantized_int64, int64{255});
287 output[index] = static_cast<quint8>(static_cast<int32>(quantized_int64));
288 }
289 }
290
291 // Another common case is converting eight bit inputs up to thirty two bits, so
292 // we have specialized fixed-point code to accelerate that. There is also a NEON
293 // version for ARM devices below.
RequantizeManyInNewRange8To32BitReference(const quint8 * input,int64 count,float min_input,float max_input,float min_output,float max_output,qint32 * output)294 inline void RequantizeManyInNewRange8To32BitReference(
295 const quint8* input, int64 count, float min_input, float max_input,
296 float min_output, float max_output, qint32* output) {
297 const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
298 const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
299 const int64 code_0_int64 =
300 FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
301 const int64 code_1_int64 =
302 FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
303 const int32 mult_int32 = code_1_int64 - code_0_int64;
304 const int64 lowest_quantized =
305 static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
306 const int64 highest_quantized =
307 static_cast<int64>(Eigen::NumTraits<qint32>::highest());
308 for (int64 i = 0; i < count; ++i) {
309 const int64 input_value = static_cast<int64>(input[i]);
310 int64 output_value = code_0_int64 + (input_value * mult_int32);
311 output_value = std::max(output_value, lowest_quantized);
312 output_value = std::min(output_value, highest_quantized);
313 output[i] = static_cast<int32>(output_value);
314 }
315 }
316
317 #ifdef QUANTIZATION_UTILS_USE_NEON
318 // Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD
319 // intrinsics for ARM platforms.
RequantizeManyInNewRangeNeon(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)320 inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count,
321 float min_input, float max_input,
322 float min_output, float max_output,
323 quint8* output) {
324 // Initially we calculate all the constants we need once, before we go into
325 // the inner loop. If this is updated, also update the Eigen version.
326 const int fp_shift = 16;
327
328 // Calculate range variables in advance.
329 // Input range.
330 const float input_range = max_input - min_input;
331 // Output range.
332 const float output_range = max_output - min_output;
333 // Ratio of output range.
334 const float recip_output_range =
335 output_range == 0.0 ? 0.0 : (255.0 / output_range);
336 // Average of input range as zero position of input.
337 const float input_rezero = (min_input + max_input) / 2.0;
338 // In-out range scale.
339 const int32 range_scale_fp =
340 output_range == 0.0 ? 0.0
341 : static_cast<int32>(255.0 * (1 << (fp_shift - 16)) *
342 input_range / output_range);
343 // Input zero position offset to output.
344 const int32 input_offset_fp =
345 static_cast<int32>(input_rezero * recip_output_range * (1 << fp_shift));
346 // Output min offset.
347 const int32 output_offset_fp =
348 output_range == 0.0
349 ? 0
350 : static_cast<int32>((1 << fp_shift) * (min_output * 255.0) /
351 output_range);
352 const int32 rounding_delta = 1 << (fp_shift - 1);
353
354 // broadcast range to each lane
355 const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp);
356 const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp);
357 const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp);
358 const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta);
359
360 int64 index = 0;
361 // Use SIMD to requantize.
362 for (; index < (count - 7); index += 8) {
363 const int32* input_ptr = &(input->value) + index;
364 const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr);
365 const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4);
366 const int32x4_t fp_value_low_32x4 = vaddq_s32(
367 input_offset_fp_32x4,
368 vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4));
369 const int32x4_t fp_value_high_32x4 = vaddq_s32(
370 input_offset_fp_32x4,
371 vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4));
372 const int32x4_t offset_intermediate_low_32x4 =
373 vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4);
374 const int32x4_t offset_intermediate_high_32x4 =
375 vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4);
376 const int32x4_t round_intermediate_low_32x4 =
377 vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4);
378 const int32x4_t round_intermediate_high_32x4 =
379 vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4);
380 const int16x4_t quantized_low_16x4 =
381 vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift));
382 const int16x4_t quantized_high_16x4 =
383 vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift));
384 const uint8x8_t quantized_8x8 =
385 vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4));
386 uint8* output_ptr = &(output->value) + index;
387 vst1_u8(output_ptr, quantized_8x8);
388 }
389
390 // Requantize remaining elements in array without SIMD.
391 for (; index < count; ++index) {
392 const int32 input_value = static_cast<int32>(input[index]);
393 const int32 fp_value =
394 static_cast<int32>(
395 (static_cast<int32>(input_value >> 16) * (range_scale_fp))) +
396 input_offset_fp;
397 const int32 offset_intermediate = fp_value - output_offset_fp;
398 const int32 round_intermediate = offset_intermediate + rounding_delta;
399 int32 quantized_int32 = round_intermediate >> fp_shift;
400 quantized_int32 = std::max(quantized_int32, 0);
401 quantized_int32 = std::min(quantized_int32, 255);
402 output[index] = static_cast<quint8>(static_cast<int32>(quantized_int32));
403 }
404 }
405
406 template <>
407 inline void RequantizeManyInNewRange<qint32, quint8>(
408 const qint32* input, int64 count, float min_input, float max_input,
409 float min_output, float max_output, quint8* output) {
410 const float input_range = max_input - min_input;
411 const float output_range = max_output - min_output;
412 if ((input_range / output_range) > 16384.0f) {
413 // Our NEON implementation uses 32-bit math and can't handle very
414 // large ranges, so fall back to the reference implementation. We don't
415 // expect these to be common in models, so this shouldn't be a performance
416 // problem in practice.
417 RequantizeManyInNewRangeReference(input, count, min_input, max_input,
418 min_output, max_output, output);
419 } else {
420 RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output,
421 max_output, output);
422 }
423 }
424
425 // NEON accelerated 16bit rounded division by 2^n.
426 template <int POW>
Divide16x8PowRound(const int16x8_t val)427 inline int16x8_t Divide16x8PowRound(const int16x8_t val) {
428 const int16x8_t val_sign = vshrq_n_s16(val, 15);
429 const int16x8_t val_xor = veorq_s16(val, val_sign);
430 const int16x8_t val_pos = vsubq_s16(val_xor, val_sign);
431 const int16x8_t shifted_val_pos = vrshrq_n_s16(val_pos, POW);
432 const int16x8_t shifted_val_pos_xor = veorq_s16(shifted_val_pos, val_sign);
433 const int16x8_t shifted_val = vsubq_s16(shifted_val_pos_xor, val_sign);
434 return shifted_val;
435 }
436
437 // NEON accelerated 64bit rounded division by 2^n.
438 template <int POW>
Divide64x2PowRound(const int64x2_t val)439 inline int64x2_t Divide64x2PowRound(const int64x2_t val) {
440 const int64x2_t val_sign = vshrq_n_s64(val, 63);
441 const int64x2_t val_xor = veorq_s64(val, val_sign);
442 const int64x2_t val_pos = vsubq_s64(val_xor, val_sign);
443 const int64x2_t shifted_val_pos = vrshrq_n_s64(val_pos, POW);
444 const int64x2_t shifted_val_pos_xor = veorq_s64(shifted_val_pos, val_sign);
445 const int64x2_t shifted_val = vsubq_s64(shifted_val_pos_xor, val_sign);
446 return shifted_val;
447 }
448
449 // NEON accelerated 16bit division by 2^n.
450 // CAVEAT: The input must be greater than min-int16 to avoid underflow.
451 template <int POW>
Divide16x8Pow(const int16x8_t val)452 inline int16x8_t Divide16x8Pow(const int16x8_t val) {
453 static constexpr int16 FIRST_BIT_VAL = 0x0000000000000001;
454 static const int16x8_t FIRST_BIT = vmovq_n_s16(FIRST_BIT_VAL);
455 const int16x8_t val_sign = vshrq_n_s16(val, 15);
456 const int16x8_t neg_offset = vandq_s16(val_sign, FIRST_BIT);
457 const int16x8_t val_with_offset = vsubq_s16(val, neg_offset);
458 const int16x8_t shifted_wo_offset =
459 vsraq_n_s16(neg_offset, val_with_offset, POW);
460 return shifted_wo_offset;
461 }
462
463 // NEON accelerated 64bit division by 2^n.
464 // CAVEAT: The input must be greater than min-int64 to avoid underflow.
465 template <int POW>
Divide64x2Pow(const int64x2_t val)466 inline int64x2_t Divide64x2Pow(const int64x2_t val) {
467 static constexpr int64 FIRST_BIT_VAL = 0x0000000000000001;
468 static const int64x2_t FIRST_BIT = vmovq_n_s64(FIRST_BIT_VAL);
469 const int64x2_t val_sign = vshrq_n_s64(val, 63);
470 const int64x2_t neg_offset = vandq_s64(val_sign, FIRST_BIT);
471 const int64x2_t val_with_offset = vsubq_s64(val, neg_offset);
472 const int64x2_t shifted_wo_offset =
473 vsraq_n_s64(neg_offset, val_with_offset, POW);
474 return shifted_wo_offset;
475 }
476
477 // 32bit x 2 NEON accelerated lerp computation.
478 template <int RESOLUTION>
ComputeLerp32x2(const int32x2_t top_left,const int32x2_t top_right,const int32x2_t bottom_left,const int32x2_t bottom_right,const int32x2_t x_lerp,const int32x2_t y_lerp)479 inline int32x2_t ComputeLerp32x2(const int32x2_t top_left,
480 const int32x2_t top_right,
481 const int32x2_t bottom_left,
482 const int32x2_t bottom_right,
483 const int32x2_t x_lerp,
484 const int32x2_t y_lerp) {
485 static_assert(RESOLUTION < 31, "RESOLUTION must be less than 31");
486 constexpr int32 RESOLUTION_MULT32 = (1 << RESOLUTION);
487 static const int32x2_t RESOLUTION_MULT32x2 = vmov_n_s32(RESOLUTION_MULT32);
488
489 const int64x2_t top_left_x_res = vmull_s32(top_left, RESOLUTION_MULT32x2);
490 const int64x2_t bottom_left_x_res =
491 vmull_s32(bottom_left, RESOLUTION_MULT32x2);
492
493 const int32x2_t top_right_sub_top_left = vsub_s32(top_right, top_left);
494 const int64x2_t top_x_res =
495 vmlal_s32(top_left_x_res, top_right_sub_top_left, x_lerp);
496 const int32x2_t bottom_right_sub_bottom_left =
497 vsub_s32(bottom_right, bottom_left);
498 const int64x2_t bottom_x_res =
499 vmlal_s32(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
500
501 const int64x2_t bottom_sub_top_x_res = vsubq_s64(bottom_x_res, top_x_res);
502 const int64x2_t bottom_sub_top =
503 Divide64x2Pow<RESOLUTION>(bottom_sub_top_x_res);
504 const int32x2_t bottom_sub_top_32 = vqmovn_s64(bottom_sub_top);
505 const int64x2_t top_add_bottom_sub_top_mul_ylerp_x_res =
506 vmlal_s32(top_x_res, bottom_sub_top_32, y_lerp);
507 const int64x2_t retval =
508 Divide64x2PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
509 const int32x2_t retval32 = vqmovn_s64(retval);
510 return retval32;
511 }
512
513 // 8bit x 8 NEON accelerated lerp computation.
514 template <int RESOLUTION>
ComputeLerp8x8(const uint8x8_t top_left8x8,const uint8x8_t top_right8x8,const uint8x8_t bottom_left8x8,const uint8x8_t bottom_right8x8,const int16x8_t x_lerp,const int16x8_t y_lerp)515 inline uint8x8_t ComputeLerp8x8(const uint8x8_t top_left8x8,
516 const uint8x8_t top_right8x8,
517 const uint8x8_t bottom_left8x8,
518 const uint8x8_t bottom_right8x8,
519 const int16x8_t x_lerp,
520 const int16x8_t y_lerp) {
521 static_assert(RESOLUTION < 8, "RESOLUTION must be less than 8");
522 constexpr uint8 RESOLUTION_MULT_VAL = (1 << RESOLUTION);
523 static const uint8x8_t RESOLUTION_MULT = vdup_n_u8(RESOLUTION_MULT_VAL);
524
525 const int16x8_t top_left_x_res =
526 vreinterpretq_s16_u16(vmull_u8(top_left8x8, RESOLUTION_MULT));
527 const int16x8_t bottom_left_x_res =
528 vreinterpretq_s16_u16(vmull_u8(bottom_left8x8, RESOLUTION_MULT));
529
530 const int16x8_t top_right_sub_top_left =
531 vreinterpretq_s16_u16(vsubl_u8(top_right8x8, top_left8x8));
532 const int16x8_t top_x_res =
533 vmlaq_s16(top_left_x_res, top_right_sub_top_left, x_lerp);
534
535 const int16x8_t bottom_right_sub_bottom_left =
536 vreinterpretq_s16_u16(vsubl_u8(bottom_right8x8, bottom_left8x8));
537 const int16x8_t bottom_x_res =
538 vmlaq_s16(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
539
540 const int16x8_t bottom_sub_top_x_res = vsubq_s16(bottom_x_res, top_x_res);
541 const int16x8_t bottom_sub_top =
542 Divide16x8Pow<RESOLUTION>(bottom_sub_top_x_res);
543 const int16x8_t top_add_bottom_sub_top_mul_ylerp_x_res =
544 vmlaq_s16(top_x_res, bottom_sub_top, y_lerp);
545 const int16x8_t retval16 =
546 Divide16x8PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
547 const uint8x8_t retval = vmovn_u16(vreinterpretq_u16_s16(retval16));
548 return retval;
549 }
550
551 // Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon
552 // Return std::array instead of pointer to leverage return value optimization
Requantize8x8To32Neon(const uint8 * input_ptr,const int64x2_t input_0_64x2,const int32x2_t input_mult_32x2)553 inline std::array<int32x4_t, 2> Requantize8x8To32Neon(
554 const uint8* input_ptr, const int64x2_t input_0_64x2,
555 const int32x2_t input_mult_32x2) {
556 const uint8x8_t input_value_8x8 = vld1_u8(input_ptr);
557 const int16x8_t input_value_16x8 =
558 vreinterpretq_s16_u16(vmovl_u8(input_value_8x8));
559 const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8);
560 const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8);
561 const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4);
562 const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4);
563 const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4);
564 const int32x2_t input_value_low_high_32x2 =
565 vget_high_s32(input_value_low_32x4);
566 const int32x2_t input_value_high_low_32x2 =
567 vget_low_s32(input_value_high_32x4);
568 const int32x2_t input_value_high_high_32x2 =
569 vget_high_s32(input_value_high_32x4);
570 const int64x2_t mult_result_low_low_64x2 =
571 vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2);
572 const int64x2_t mult_result_low_high_64x2 =
573 vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2);
574 const int64x2_t mult_result_high_low_64x2 =
575 vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2);
576 const int64x2_t mult_result_high_high_64x2 =
577 vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2);
578 const int32x2_t output_value_low_low_32x2 =
579 vqmovn_s64(mult_result_low_low_64x2);
580 const int32x2_t output_value_low_high_32x2 =
581 vqmovn_s64(mult_result_low_high_64x2);
582 const int32x2_t output_value_high_low_32x2 =
583 vqmovn_s64(mult_result_high_low_64x2);
584 const int32x2_t output_value_high_high_32x2 =
585 vqmovn_s64(mult_result_high_high_64x2);
586 const int32x4_t output_value_low_32x4 =
587 vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2);
588 const int32x4_t output_value_high_32x4 =
589 vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2);
590 return std::array<int32x4_t, 2>{
591 {output_value_low_32x4, output_value_high_32x4}};
592 }
593
594 // Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD
595 // intrinsics for ARM platforms.
596 template <>
597 inline void RequantizeManyInNewRange<quint8, qint32>(
598 const quint8* input, int64 count, float min_input, float max_input,
599 float min_output, float max_output, qint32* output) {
600 // Pre-calculate zero position and multiplier.
601 // Calculate 0 and 1 value in float.
602 const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
603 const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
604
605 // Cast 0 and 1 value in int64.
606 const int64 code_0_int64 =
607 FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
608 const int64 code_1_int64 =
609 FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
610
611 // Calculate multiplier.
612 const int32 mult_int32 = static_cast<int32>(code_1_int64 - code_0_int64);
613
614 // Broadcast 0 position and multiplier to lanes
615 const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64);
616 const int32x2_t mult_32x2 = vmov_n_s32(mult_int32);
617
618 int64 i = 0;
619
620 // Use SIMD to requantize array.
621 for (; i < (count - 7); i += 8) {
622 const uint8* input_ptr = &(input->value) + i;
623 int32* output_ptr = &(output->value) + i;
624 const std::array<int32x4_t, 2> output_value =
625 Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2);
626 vst1q_s32(output_ptr + 0, output_value[0]);
627 vst1q_s32(output_ptr + 4, output_value[1]);
628 }
629
630 // Requantize remaining elements in array without SIMD.
631 const int64 lowest_quantized =
632 static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
633 const int64 highest_quantized =
634 static_cast<int64>(Eigen::NumTraits<qint32>::highest());
635
636 for (; i < count; ++i) {
637 const int64 input_value = static_cast<int64>(input[i]);
638 int64 output_value = code_0_int64 + (input_value * mult_int32);
639 output_value = std::max(output_value, lowest_quantized);
640 output_value = std::min(output_value, highest_quantized);
641 output[i] = static_cast<int32>(output_value);
642 }
643 }
644
645 #else
646
647 // If SIMD implementations aren't available, then use these default reference
648 // versions.
649 template <>
650 inline void RequantizeManyInNewRange<qint32, quint8>(
651 const qint32* input, int64 count, float min_input, float max_input,
652 float min_output, float max_output, quint8* output) {
653 RequantizeManyInNewRangeReference(input, count, min_input, max_input,
654 min_output, max_output, output);
655 }
656
657 template <>
658 inline void RequantizeManyInNewRange<quint8, qint32>(
659 const quint8* input, int64 count, float min_input, float max_input,
660 float min_output, float max_output, qint32* output) {
661 RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input,
662 min_output, max_output, output);
663 }
664
665 #endif
666
667 template <int shift>
668 struct int64_right_shift_op {
EIGEN_EMPTY_STRUCT_CTORint64_right_shift_op669 EIGEN_EMPTY_STRUCT_CTOR(int64_right_shift_op)
670 EIGEN_DEVICE_FUNC
671 EIGEN_STRONG_INLINE const int64 operator()(const int64& a) const {
672 return a >> shift;
673 }
674 };
675
676 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
677 template <class T1, class T2>
RequantizeManyInNewRangeUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min_input,float max_input,float min_output,float max_output,Tensor * output)678 inline void RequantizeManyInNewRangeUsingEigen(
679 const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
680 float max_input, float min_output, float max_output, Tensor* output) {
681 auto input_array = input.flat<T1>();
682 QuantizedToFloatStruct<T1> q2f(min_input, max_input);
683 auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
684 FloatToQuantizedStruct<T2> f2q(min_output, max_output);
685 auto input_requantized = QUANTIZE_WITH_EIGEN(input_float, f2q, T2);
686
687 output->flat<T2>().device(device) = input_requantized;
688 }
689
690 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
691 //
692 // Because converting 32-bit accumulated results down to eight bit is a common
693 // case, we have a specialized code path to handle it as efficiently as
694 // possible using only fixed-point math for the inner loop.
695 template <>
696 inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
697 const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
698 float max_input, float min_output, float max_output, Tensor* output) {
699 // Initially we calculate all the constants we need once, before we go into
700 // the inner loop. If this is updated, also update the non-Eigen version.
701 const int fp_shift = 16;
702 const float input_range = max_input - min_input;
703 const float output_range = max_output - min_output;
704 const float recip_output_range =
705 output_range == 0.0 ? 0.0 : (255.0 / output_range);
706 const float input_rezero = (min_input + max_input) / 2.0;
707 const int64 range_scale_fp =
708 output_range == 0.0 ? 0.0
709 : static_cast<int64>(255.0 * (1 << fp_shift) *
710 input_range / output_range);
711 const int64 input_offset_fp =
712 static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
713 const int64 output_offset_fp =
714 output_range == 0.0
715 ? 0
716 : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
717 output_range);
718 const int64 rounding_delta = 1 << (fp_shift - 1);
719
720 // Inside this eigen expression we just do minimal adds, multiplies, and
721 // shifts. It should be possible to perform all the calculations in 32-bit
722 // rather than 64, but that's not been implemented yet.
723 auto input_array = input.flat<qint32>();
724 auto fp_value = ((input_array.template cast<int64>() * range_scale_fp)
725 .unaryExpr(int64_right_shift_op<32>())) +
726 (input_offset_fp - output_offset_fp + rounding_delta);
727 auto intermediate = fp_value.unaryExpr(int64_right_shift_op<fp_shift>());
728 auto input_requantized = intermediate.cwiseMax(int64{0})
729 .cwiseMin(int64{255})
730 .template cast<int32>()
731 .template cast<quint8>();
732 output->flat<quint8>().device(device) = input_requantized;
733 }
734
735 // REQUIRES: 'result->NumElements() == input.NumElements()'
736 template <class T>
FloatTensorToQuantizedInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)737 void FloatTensorToQuantizedInPlaceUsingEigen(
738 const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
739 float max, Tensor* result) {
740 DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
741 auto flat_input = input.flat<float>();
742 auto flat_result = result->flat<T>();
743 DCHECK_EQ(flat_input.size(), flat_result.size());
744
745 FloatToQuantizedStruct<T> f2q(min, max);
746 flat_result.device(device) = QUANTIZE_WITH_EIGEN(flat_input, f2q, T);
747 }
748
749 template <class T>
FloatTensorToQuantizedInPlace(const Tensor & input,float min,float max,Tensor * result)750 void FloatTensorToQuantizedInPlace(const Tensor& input, float min, float max,
751 Tensor* result) {
752 DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
753 auto flat_input = input.flat<float>();
754 auto flat_result = result->flat<T>();
755 const int data_size = flat_input.size();
756 DCHECK(data_size == flat_result.size());
757 for (int i = 0; i < data_size; ++i) {
758 flat_result(i) = FloatToQuantized<T>(flat_input(i), min, max);
759 }
760 }
761
762 template <class T>
FloatTensorToQuantized(const Tensor & input,float min,float max)763 Tensor FloatTensorToQuantized(const Tensor& input, float min, float max) {
764 Tensor result(DataTypeToEnum<T>::v(), input.shape());
765 FloatTensorToQuantizedInPlace<T>(input, min, max, &result);
766 return result;
767 }
768
769 // REQUIRES: 'result->NumElements() == input.NumElements()'
770 template <class T>
QuantizedTensorToFloatInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)771 void QuantizedTensorToFloatInPlaceUsingEigen(
772 const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
773 float max, Tensor* result) {
774 DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
775 auto flat_input = input.flat<T>();
776 auto flat_result = result->flat<float>();
777 const int data_size = flat_input.size();
778 DCHECK(data_size == flat_result.size());
779
780 QuantizedToFloatStruct<T> q2f(min, max);
781 flat_result.device(device) = DEQUANTIZE_WITH_EIGEN(flat_input, q2f);
782 }
783
784 // REQUIRES: 'result->NumElements() == input.NumElements()'
785 template <class T>
QuantizedTensorToFloatInPlace(const Tensor & input,float min,float max,Tensor * result)786 void QuantizedTensorToFloatInPlace(const Tensor& input, float min, float max,
787 Tensor* result) {
788 DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
789 auto flat_input = input.flat<T>();
790 auto flat_result = result->flat<float>();
791 const int data_size = flat_input.size();
792 DCHECK(data_size == flat_result.size());
793 for (int i = 0; i < data_size; ++i) {
794 flat_result(i) = QuantizedToFloat<T>(flat_input(i), min, max);
795 }
796 }
797
798 template <class T>
QuantizedTensorToFloat(const Tensor & input,float min,float max)799 Tensor QuantizedTensorToFloat(const Tensor& input, float min, float max) {
800 Tensor result(DT_FLOAT, input.shape());
801 QuantizedTensorToFloatInPlace<T>(input, min, max, &result);
802 return result;
803 }
804
805 void GetOutputMinAndMaxForQuantizedAdd(float input_min, float input_max,
806 float smaller_input_min,
807 float smaller_input_max,
808 float* output_min, float* output_max);
809
810 // Add <input> and <smaller_input>. If <smaller_input> has fewer elements than
811 // <input>, then it is broadcast onto <input>.
812 template <typename T1, typename T2, typename T3>
QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)813 void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device,
814 const Tensor& input, float input_min,
815 float input_max, const Tensor& smaller_input,
816 float smaller_input_min, float smaller_input_max,
817 Tensor* output, float* output_min,
818 float* output_max) {
819 const auto& input_flat = input.flat<T1>();
820 const auto& smaller_input_flat = smaller_input.flat<T2>();
821 auto output_flat = output->flat<T3>();
822
823 GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
824 smaller_input_max, output_min, output_max);
825 // To do addition properly, we need to compensate for a possibly unbalanced
826 // zero point in the total representation. The quantized value that
827 // represents the real number zero needs to be subtracted before addition to
828 // make sure that the identity of zero + zero = zero holds.
829 const T3 zero_in_total_space =
830 FloatToQuantized<T3>(0.0f, *output_min, *output_max);
831
832 const int64 input_element_count = input.NumElements();
833 const int64 smaller_input_element_count = smaller_input.NumElements();
834
835 QuantizedToFloatStruct<T1> input_q2f(input_min, input_max);
836 QuantizedToFloatStruct<T2> smaller_input_q2f(smaller_input_min,
837 smaller_input_max);
838 FloatToQuantizedStruct<T3> f2q(*output_min, *output_max);
839
840 auto smaller_input_float =
841 DEQUANTIZE_WITH_EIGEN(smaller_input_flat, smaller_input_q2f);
842 auto smaller_input_in_total_space =
843 QUANTIZE_WITH_EIGEN(smaller_input_float, f2q, T3);
844
845 auto input_float = DEQUANTIZE_WITH_EIGEN(input_flat, input_q2f);
846 auto input_in_total_space = QUANTIZE_WITH_EIGEN(input_float, f2q, T3);
847
848 Eigen::array<Eigen::DenseIndex, 1> bcast;
849 bcast[0] = input_element_count / smaller_input_element_count;
850 output_flat.device(device) =
851 input_in_total_space +
852 (smaller_input_in_total_space.broadcast(bcast) + zero_in_total_space);
853 }
854
855 // This is a reference implementation of the bias addition for quantized
856 // buffers, designed to provide a clear specification for the result we
857 // want. We'll want to specialize this for particular hardware, and
858 // probably even fuse it with matrix multiplications in a lot of cases. It's
859 // important to show the clamping behavior we want in particular.
860 template <typename T1, typename T2, typename T3>
QuantizedAdd(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)861 void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input,
862 float input_min, float input_max, const Tensor& smaller_input,
863 float smaller_input_min, float smaller_input_max,
864 Tensor* output, float* output_min, float* output_max) {
865 const auto& input_flat = input.flat<T1>();
866 const auto& smaller_input_flat = smaller_input.flat<T2>();
867 auto output_flat = output->flat<T3>();
868
869 GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
870 smaller_input_max, output_min, output_max);
871 // To do addition properly, we need to compensate for a possibly unbalanced
872 // zero point in the total representation. The quantized value that
873 // represents the real number zero needs to be subtracted before addition to
874 // make sure that the identity of zero + zero = zero holds.
875 const T3 zero_in_total_space =
876 FloatToQuantized<T3>(0.0f, *output_min, *output_max);
877
878 const int64 input_element_count = input.NumElements();
879 const int64 smaller_input_element_count = smaller_input.NumElements();
880
881 float total_min = *output_min;
882 float total_max = *output_max;
883 const size_t how_many_iterations =
884 (input_element_count / smaller_input_element_count);
885 for (size_t iteration = 0; iteration < how_many_iterations; ++iteration) {
886 const size_t offset = iteration * smaller_input_element_count;
887 for (int c = 0; c < smaller_input_element_count; ++c) {
888 const int index = (offset + c);
889 // The two numbers we're going to add can each be in very different
890 // ranges (e.g. the quantized value '127' may represent very different
891 // real numbers in both) so we need to convert them to a common range
892 // before we sum them.
893 const T1 input_value = input_flat(index);
894 const T3 input_in_total_space = RequantizeInNewRange<T1, T3>(
895 input_value, input_min, input_max, total_min, total_max);
896 const T2 smaller_input_value = smaller_input_flat(c);
897 const T3 smaller_input_in_total_space =
898 RequantizeInNewRange<T2, T3>(smaller_input_value, smaller_input_min,
899 smaller_input_max, total_min, total_max);
900 const T3 total_pre = input_in_total_space + smaller_input_in_total_space;
901 // As noted above, we need to compensate for the offset of the actual
902 // zero point in the space we're operating in.
903 const T3 total = total_pre + zero_in_total_space;
904 output_flat(index) = total;
905 }
906 }
907 }
908
909 // See gemmlowp/internal/multi_thread_gemm.h for the semantics of Execute.
910 class TensorflowGemmlowpWorkersPool {
911 public:
TensorflowGemmlowpWorkersPool(thread::ThreadPool * workers)912 TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers)
913 : workers_(workers) {}
914
~TensorflowGemmlowpWorkersPool()915 ~TensorflowGemmlowpWorkersPool() {
916 // This workaround ensures that all worker tasks have exited methods in the
917 // BlockingCounter. Without this, there is a race where the context is torn
918 // down while the counter is in use.
919 counter_to_decrement_when_ready_.Reset(0);
920 }
921
Execute(const std::vector<gemmlowp::Task * > & tasks)922 void Execute(const std::vector<gemmlowp::Task*>& tasks) {
923 assert(!tasks.empty());
924 assert(workers_ != nullptr);
925 counter_to_decrement_when_ready_.Reset(tasks.size());
926 for (gemmlowp::Task* task : tasks) {
927 workers_->Schedule([this, task]() {
928 // TODO(cwhipkey): get a local_allocator from a thread local storage.
929 gemmlowp::Allocator local_allocator;
930 CHECK(task != nullptr);
931 task->local_allocator = &local_allocator;
932 task->Run();
933 counter_to_decrement_when_ready_.DecrementCount();
934 });
935 }
936 counter_to_decrement_when_ready_.Wait();
937 for (gemmlowp::Task* task : tasks) {
938 delete task;
939 }
940 }
941
942 private:
943 thread::ThreadPool* const workers_;
944
945 // The BlockingCounter used to wait for the workers.
946 gemmlowp::BlockingCounter counter_to_decrement_when_ready_;
947
948 TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmlowpWorkersPool);
949 };
950
951 class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase {
952 public:
TensorflowGemmContext(int num_threads,thread::ThreadPool * workers)953 TensorflowGemmContext(int num_threads, thread::ThreadPool* workers)
954 : workers_pool_(workers) {
955 set_max_num_threads(num_threads);
956 }
957
workers_pool()958 TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; }
959
960 private:
961 TensorflowGemmlowpWorkersPool workers_pool_;
962
963 TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmContext);
964 };
965
966 } // namespace tensorflow
967
968 #endif // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
969