• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
17 #define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
18 
19 #define EIGEN_USE_THREADS
20 
21 // This is a set of functions that standardizes how quantized values are
22 // interpreted as float numbers.
23 // All of the current implementations are for reference and have not been
24 // optimized. They should be implementable using fixed point representations
25 // to avoid a dependency on floating-point hardware.
26 
27 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
28 #define QUANTIZATION_UTILS_USE_NEON
29 #include <arm_neon.h>
30 #endif
31 
32 #include <array>
33 
34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
35 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
36 #include "public/gemmlowp.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 
40 namespace tensorflow {
41 
42 // We have to be able to detect and handle overflows in int32, so this function
43 // uses doubles and int64's to make sure we have enough room.
44 template <class T>
FloatToQuantizedUnclamped(float input,float range_min,float range_max)45 int64 FloatToQuantizedUnclamped(float input, float range_min, float range_max) {
46   const int64 lowest_quantized =
47       static_cast<double>(Eigen::NumTraits<T>::lowest());
48   if (range_min == range_max) {
49     return lowest_quantized;
50   }
51   const int number_of_bits = sizeof(T) * 8;
52   const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
53   const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
54   const double range = ((range_max - range_min) * range_adjust);
55   const double range_scale = (number_of_steps / range);
56   int64 quantized =
57       (round(input * range_scale) - round(range_min * range_scale));
58   quantized += lowest_quantized;
59   return quantized;
60 }
61 
62 // This converts the float into the final quantized type, clamping/saturating
63 // any over or underflows.
64 template <class T>
FloatToQuantized(float input,float range_min,float range_max)65 T FloatToQuantized(float input, float range_min, float range_max) {
66   if (std::is_same<T, float>::value) {
67     // Specialization for float. This is used in reference implementation
68     // for float which is useful to compare performance between float
69     // and quantized type.
70     return input;
71   }
72   int64 quantized = FloatToQuantizedUnclamped<T>(input, range_min, range_max);
73   const int64 lowest_quantized =
74       static_cast<int64>(Eigen::NumTraits<T>::lowest());
75   const int64 highest_quantized =
76       static_cast<int64>(Eigen::NumTraits<T>::highest());
77   quantized = std::max(quantized, lowest_quantized);
78   quantized = std::min(quantized, highest_quantized);
79   return static_cast<T>(static_cast<int32>(quantized));
80 }
81 
82 template <class T>
QuantizedToFloat(T input,float range_min,float range_max)83 float QuantizedToFloat(T input, float range_min, float range_max) {
84   if (std::is_same<T, float>::value) {
85     // Specialization for float. This is used in reference implementation
86     // for float which is useful to compare performance between float
87     // and quantized type.
88     return input;
89   }
90   if (range_min == range_max) {
91     return range_min;
92   }
93   const int number_of_bits = sizeof(T) * 8;
94   const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
95   const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
96   const double range = ((range_max - range_min) * range_adjust);
97   const double range_scale = (range / number_of_steps);
98   const int64 lowest_quantized =
99       static_cast<int64>(Eigen::NumTraits<T>::lowest());
100   const double offset_input = static_cast<double>(input) - lowest_quantized;
101   // For compatibility with DEQUANTIZE_WITH_EIGEN, we should convert
102   // range_scale to a float, otherwise range_min_rounded might be slightly
103   // different.
104   const double range_min_rounded =
105       round(range_min / static_cast<float>(range_scale)) *
106       static_cast<float>(range_scale);
107   const double result = range_min_rounded + (offset_input * range_scale);
108   return static_cast<float>(result);
109 }
110 
111 template <class T>
FloatForOneQuantizedLevel(float range_min,float range_max)112 float FloatForOneQuantizedLevel(float range_min, float range_max) {
113   const int64 highest = static_cast<int64>(Eigen::NumTraits<T>::highest());
114   const int64 lowest = static_cast<int64>(Eigen::NumTraits<T>::lowest());
115   const float float_for_one_quantized_level =
116       (range_max - range_min) / (highest - lowest);
117   return float_for_one_quantized_level;
118 }
119 
120 template <class T1, class T2, class T3>
QuantizationRangeForMultiplication(float min_a,float max_a,float min_b,float max_b,float * min_c,float * max_c)121 void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b,
122                                         float max_b, float* min_c,
123                                         float* max_c) {
124   const float a_float_for_one_quant_level =
125       FloatForOneQuantizedLevel<T1>(min_a, max_a);
126   const float b_float_for_one_quant_level =
127       FloatForOneQuantizedLevel<T2>(min_b, max_b);
128 
129   const int64 c_highest = static_cast<int64>(Eigen::NumTraits<T3>::highest());
130   const int64 c_lowest = static_cast<int64>(Eigen::NumTraits<T3>::lowest());
131   const float c_float_for_one_quant_level =
132       a_float_for_one_quant_level * b_float_for_one_quant_level;
133 
134   *min_c = c_float_for_one_quant_level * c_lowest;
135   *max_c = c_float_for_one_quant_level * c_highest;
136 }
137 
138 // input_array is an eigen Tensor.  q2f is a QuantizedToFloatStruct.
139 // This evaluates to an eigen tensor expression, to be used like:
140 // auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f);
141 #define DEQUANTIZE_WITH_EIGEN(input_array, q2f)                         \
142   ((q2f.range_min_rounded - q2f.lowest_quantized() * q2f.range_scale) + \
143    input_array.template cast<float>() * q2f.range_scale)
144 
145 // input_array is an eigen Tensor.  f2q is a FloatToQuantizedStruct.
146 // OutputType is the type of output (e.g. quint8).
147 // This evaluates to an eigen tensor expression, to be used like:
148 // auto tensor = QUANTIZE_WITH_EIGEN(input_tensor, f2q, T);
149 #define QUANTIZE_WITH_EIGEN(input_array, f2q, OutputType) \
150   ((input_array * f2q.range_scale).round() -              \
151    (f2q.range_min_scaled - f2q.lowest_quantized()))       \
152       .cwiseMax(f2q.lower_bound_float())                  \
153       .cwiseMin(f2q.upper_bound_float())                  \
154       .template cast<int32>()                             \
155       .template cast<OutputType>()
156 
157 // For use with DEQUANTIZE_WITH_EIGEN.
158 template <typename T>
159 struct QuantizedToFloatStruct {
160   static constexpr int number_of_bits = sizeof(T) * 8;
161   static constexpr int64 number_of_steps = static_cast<int64>(1)
162                                            << number_of_bits;
163 
lowest_quantizedQuantizedToFloatStruct164   static float lowest_quantized() {
165     return static_cast<float>(Eigen::NumTraits<T>::lowest());
166   }
167 
QuantizedToFloatStructQuantizedToFloatStruct168   QuantizedToFloatStruct(float range_min, float range_max)
169       : range_min(range_min),
170         range_scale((range_max - range_min) / (number_of_steps - 1.0)),
171         range_min_rounded(range_max == range_min
172                               ? range_min
173                               : round(range_min / range_scale) * range_scale) {}
174 
175   const float range_min;
176   const float range_scale;
177   const float range_min_rounded;
178 };
179 
180 // For use with QUANTIZE_WITH_EIGEN.
181 template <typename T>
182 struct FloatToQuantizedStruct {
183   static constexpr int number_of_bits = sizeof(T) * 8;
184   static constexpr int64 number_of_steps = static_cast<int64>(1)
185                                            << number_of_bits;
186   static constexpr double range_adjust =
187       (number_of_steps / (number_of_steps - 1.0));
188 
189   // Casting QInt32's lowest or highest to a float gives a float that can't be
190   // cast back to int32 or QInt32.  Instead, use bounds that can be converted
191   // back to int32 without going outside the range of an int32.
lower_bound_floatFloatToQuantizedStruct192   static float lower_bound_float() {
193     return Eigen::numext::maxi(
194         static_cast<float>(Eigen::NumTraits<T>::lowest()), -2.147483648e+09f);
195   }
upper_bound_floatFloatToQuantizedStruct196   static float upper_bound_float() {
197     return Eigen::numext::mini(
198         static_cast<float>(Eigen::NumTraits<T>::highest()), +2.147483520e+09f);
199   }
200 
lowest_quantizedFloatToQuantizedStruct201   static float lowest_quantized() {
202     return static_cast<float>(Eigen::NumTraits<T>::lowest());
203   }
204 
FloatToQuantizedStructFloatToQuantizedStruct205   FloatToQuantizedStruct(float range_min, float range_max)
206       : range_min(range_min),
207         range_scale(range_max == range_min
208                         ? 0.0
209                         : (number_of_steps - 1.0) / (range_max - range_min)),
210         range_min_scaled(round(range_min * range_scale)) {}
211 
212   const float range_min;
213   const float range_scale;
214   const float range_min_scaled;
215 };
216 
217 template <class T1, class T2>
RequantizeInNewRange(T1 input,float min_input,float max_input,float min_new,float max_new)218 inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input,
219                                float min_new, float max_new) {
220   const float input_float = QuantizedToFloat<T1>(input, min_input, max_input);
221   return FloatToQuantized<T2>(input_float, min_new, max_new);
222 }
223 
224 template <class T1, class T2>
RequantizeManyInNewRange(const T1 * input,int64 count,float min_input,float max_input,float min_output,float max_output,T2 * output)225 inline void RequantizeManyInNewRange(const T1* input, int64 count,
226                                      float min_input, float max_input,
227                                      float min_output, float max_output,
228                                      T2* output) {
229   for (size_t index = 0; index < count; ++index) {
230     const float input_float =
231         QuantizedToFloat<T1>(input[index], min_input, max_input);
232     output[index] = FloatToQuantized<T2>(input_float, min_output, max_output);
233   }
234 }
235 
236 // Because converting 32-bit accumulated results down to eight bit is a common
237 // case, we have a specialized code path to handle it as efficiently as
238 // possible using only fixed-point math for the inner loop.
RequantizeManyInNewRangeReference(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)239 inline void RequantizeManyInNewRangeReference(const qint32* input, int64 count,
240                                               float min_input, float max_input,
241                                               float min_output,
242                                               float max_output,
243                                               quint8* output) {
244   // Initially we calculate all the constants we need once, before we go into
245   // the inner loop.  If this is updated, also update the Eigen version.
246   const int fp_shift = 16;
247   const float input_range = max_input - min_input;
248   const float output_range = max_output - min_output;
249   const float recip_output_range =
250       output_range == 0.0 ? 0.0 : (255.0 / output_range);
251   const float input_rezero = (min_input + max_input) / 2.0;
252   const int64 range_scale_fp =
253       output_range == 0.0 ? 0.0
254                           : static_cast<int64>(255.0 * (1 << fp_shift) *
255                                                input_range / output_range);
256   const int64 input_offset_fp =
257       static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
258   const int64 output_offset_fp =
259       output_range == 0.0
260           ? 0
261           : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
262                                output_range);
263   const int64 rounding_delta = 1 << (fp_shift - 1);
264 
265   // Inside this loop we just do minimal adds, multiplies, and shifts, in a way
266   // that could be easily adapted for a SIMD implementation. It should also be
267   // possible to perform all the calculations in 32-bit rather than 64, but
268   // that's not been implemented yet.
269   for (size_t index = 0; index < count; ++index) {
270     const int64 input_value = static_cast<int64>(input[index]);
271     const int64 fp_value =
272         ((input_value * range_scale_fp) >> 32) + input_offset_fp;
273     const int64 offset_intermediate = fp_value - output_offset_fp;
274     const int64 round_intermediate = offset_intermediate + rounding_delta;
275     int64 quantized_int64 = round_intermediate >> fp_shift;
276     quantized_int64 = std::max(quantized_int64, int64{0});
277     quantized_int64 = std::min(quantized_int64, int64{255});
278     output[index] = static_cast<quint8>(static_cast<int32>(quantized_int64));
279   }
280 }
281 
282 // Another common case is converting eight bit inputs up to thirty two bits, so
283 // we have specialized fixed-point code to accelerate that. There is also a NEON
284 // version for ARM devices below.
RequantizeManyInNewRange8To32BitReference(const quint8 * input,int64 count,float min_input,float max_input,float min_output,float max_output,qint32 * output)285 inline void RequantizeManyInNewRange8To32BitReference(
286     const quint8* input, int64 count, float min_input, float max_input,
287     float min_output, float max_output, qint32* output) {
288   const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
289   const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
290   const int64 code_0_int64 =
291       FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
292   const int64 code_1_int64 =
293       FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
294   const int32 mult_int32 = code_1_int64 - code_0_int64;
295   const int64 lowest_quantized =
296       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
297   const int64 highest_quantized =
298       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
299   for (int64 i = 0; i < count; ++i) {
300     const int64 input_value = static_cast<int64>(input[i]);
301     int64 output_value = code_0_int64 + (input_value * mult_int32);
302     output_value = std::max(output_value, lowest_quantized);
303     output_value = std::min(output_value, highest_quantized);
304     output[i] = static_cast<int32>(output_value);
305   }
306 }
307 
308 #ifdef QUANTIZATION_UTILS_USE_NEON
309 // Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD
310 // intrinsics for ARM platforms.
RequantizeManyInNewRangeNeon(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)311 inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count,
312                                          float min_input, float max_input,
313                                          float min_output, float max_output,
314                                          quint8* output) {
315   // Initially we calculate all the constants we need once, before we go into
316   // the inner loop.  If this is updated, also update the Eigen version.
317   const int fp_shift = 16;
318 
319   // Calculate range variables in advance.
320   // Input range.
321   const float input_range = max_input - min_input;
322   // Output range.
323   const float output_range = max_output - min_output;
324   // Ratio of output range.
325   const float recip_output_range =
326       output_range == 0.0 ? 0.0 : (255.0 / output_range);
327   // Average of input range as zero position of input.
328   const float input_rezero = (min_input + max_input) / 2.0;
329   // In-out range scale.
330   const int32 range_scale_fp =
331       output_range == 0.0 ? 0.0
332                           : static_cast<int32>(255.0 * (1 << (fp_shift - 16)) *
333                                                input_range / output_range);
334   // Input zero position offset to output.
335   const int32 input_offset_fp =
336       static_cast<int32>(input_rezero * recip_output_range * (1 << fp_shift));
337   // Output min offset.
338   const int32 output_offset_fp =
339       output_range == 0.0
340           ? 0
341           : static_cast<int32>((1 << fp_shift) * (min_output * 255.0) /
342                                output_range);
343   const int32 rounding_delta = 1 << (fp_shift - 1);
344 
345   // broadcast range to each lane
346   const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp);
347   const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp);
348   const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp);
349   const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta);
350 
351   int64 index = 0;
352   // Use SIMD to requantize.
353   for (; index < (count - 7); index += 8) {
354     const int32* input_ptr = &(input->value) + index;
355     const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr);
356     const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4);
357     const int32x4_t fp_value_low_32x4 = vaddq_s32(
358         input_offset_fp_32x4,
359         vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4));
360     const int32x4_t fp_value_high_32x4 = vaddq_s32(
361         input_offset_fp_32x4,
362         vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4));
363     const int32x4_t offset_intermediate_low_32x4 =
364         vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4);
365     const int32x4_t offset_intermediate_high_32x4 =
366         vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4);
367     const int32x4_t round_intermediate_low_32x4 =
368         vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4);
369     const int32x4_t round_intermediate_high_32x4 =
370         vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4);
371     const int16x4_t quantized_low_16x4 =
372         vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift));
373     const int16x4_t quantized_high_16x4 =
374         vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift));
375     const uint8x8_t quantized_8x8 =
376         vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4));
377     uint8* output_ptr = &(output->value) + index;
378     vst1_u8(output_ptr, quantized_8x8);
379   }
380 
381   // Requantize remaining elements in array without SIMD.
382   for (; index < count; ++index) {
383     const int32 input_value = static_cast<int32>(input[index]);
384     const int32 fp_value =
385         static_cast<int32>(
386             (static_cast<int32>(input_value >> 16) * (range_scale_fp))) +
387         input_offset_fp;
388     const int32 offset_intermediate = fp_value - output_offset_fp;
389     const int32 round_intermediate = offset_intermediate + rounding_delta;
390     int32 quantized_int32 = round_intermediate >> fp_shift;
391     quantized_int32 = std::max(quantized_int32, 0);
392     quantized_int32 = std::min(quantized_int32, 255);
393     output[index] = static_cast<quint8>(static_cast<int32>(quantized_int32));
394   }
395 }
396 
397 template <>
398 inline void RequantizeManyInNewRange<qint32, quint8>(
399     const qint32* input, int64 count, float min_input, float max_input,
400     float min_output, float max_output, quint8* output) {
401   const float input_range = max_input - min_input;
402   const float output_range = max_output - min_output;
403   if ((input_range / output_range) > 16384.0f) {
404     // Our NEON implementation uses 32-bit math and can't handle very
405     // large ranges, so fall back to the reference implementation. We don't
406     // expect these to be common in models, so this shouldn't be a performance
407     // problem in practice.
408     RequantizeManyInNewRangeReference(input, count, min_input, max_input,
409                                       min_output, max_output, output);
410   } else {
411     RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output,
412                                  max_output, output);
413   }
414 }
415 
416 // NEON accelerated 16bit rounded division by 2^n.
417 template <int POW>
Divide16x8PowRound(const int16x8_t val)418 inline int16x8_t Divide16x8PowRound(const int16x8_t val) {
419   const int16x8_t val_sign = vshrq_n_s16(val, 15);
420   const int16x8_t val_xor = veorq_s16(val, val_sign);
421   const int16x8_t val_pos = vsubq_s16(val_xor, val_sign);
422   const int16x8_t shifted_val_pos = vrshrq_n_s16(val_pos, POW);
423   const int16x8_t shifted_val_pos_xor = veorq_s16(shifted_val_pos, val_sign);
424   const int16x8_t shifted_val = vsubq_s16(shifted_val_pos_xor, val_sign);
425   return shifted_val;
426 }
427 
428 // NEON accelerated 64bit rounded division by 2^n.
429 template <int POW>
Divide64x2PowRound(const int64x2_t val)430 inline int64x2_t Divide64x2PowRound(const int64x2_t val) {
431   const int64x2_t val_sign = vshrq_n_s64(val, 63);
432   const int64x2_t val_xor = veorq_s64(val, val_sign);
433   const int64x2_t val_pos = vsubq_s64(val_xor, val_sign);
434   const int64x2_t shifted_val_pos = vrshrq_n_s64(val_pos, POW);
435   const int64x2_t shifted_val_pos_xor = veorq_s64(shifted_val_pos, val_sign);
436   const int64x2_t shifted_val = vsubq_s64(shifted_val_pos_xor, val_sign);
437   return shifted_val;
438 }
439 
440 // NEON accelerated 16bit division by 2^n.
441 // CAVEAT: The input must be greater than min-int16 to avoid underflow.
442 template <int POW>
Divide16x8Pow(const int16x8_t val)443 inline int16x8_t Divide16x8Pow(const int16x8_t val) {
444   static constexpr int16 FIRST_BIT_VAL = 0x0000000000000001;
445   static const int16x8_t FIRST_BIT = vmovq_n_s16(FIRST_BIT_VAL);
446   const int16x8_t val_sign = vshrq_n_s16(val, 15);
447   const int16x8_t neg_offset = vandq_s16(val_sign, FIRST_BIT);
448   const int16x8_t val_with_offset = vsubq_s16(val, neg_offset);
449   const int16x8_t shifted_wo_offset =
450       vsraq_n_s16(neg_offset, val_with_offset, POW);
451   return shifted_wo_offset;
452 }
453 
454 // NEON accelerated 64bit division by 2^n.
455 // CAVEAT: The input must be greater than min-int64 to avoid underflow.
456 template <int POW>
Divide64x2Pow(const int64x2_t val)457 inline int64x2_t Divide64x2Pow(const int64x2_t val) {
458   static constexpr int64 FIRST_BIT_VAL = 0x0000000000000001;
459   static const int64x2_t FIRST_BIT = vmovq_n_s64(FIRST_BIT_VAL);
460   const int64x2_t val_sign = vshrq_n_s64(val, 63);
461   const int64x2_t neg_offset = vandq_s64(val_sign, FIRST_BIT);
462   const int64x2_t val_with_offset = vsubq_s64(val, neg_offset);
463   const int64x2_t shifted_wo_offset =
464       vsraq_n_s64(neg_offset, val_with_offset, POW);
465   return shifted_wo_offset;
466 }
467 
468 // 32bit x 2 NEON accelerated lerp computation.
469 template <int RESOLUTION>
ComputeLerp32x2(const int32x2_t top_left,const int32x2_t top_right,const int32x2_t bottom_left,const int32x2_t bottom_right,const int32x2_t x_lerp,const int32x2_t y_lerp)470 inline int32x2_t ComputeLerp32x2(const int32x2_t top_left,
471                                  const int32x2_t top_right,
472                                  const int32x2_t bottom_left,
473                                  const int32x2_t bottom_right,
474                                  const int32x2_t x_lerp,
475                                  const int32x2_t y_lerp) {
476   static_assert(RESOLUTION < 31, "RESOLUTION must be less than 31");
477   constexpr int32 RESOLUTION_MULT32 = (1 << RESOLUTION);
478   static const int32x2_t RESOLUTION_MULT32x2 = vmov_n_s32(RESOLUTION_MULT32);
479 
480   const int64x2_t top_left_x_res = vmull_s32(top_left, RESOLUTION_MULT32x2);
481   const int64x2_t bottom_left_x_res =
482       vmull_s32(bottom_left, RESOLUTION_MULT32x2);
483 
484   const int32x2_t top_right_sub_top_left = vsub_s32(top_right, top_left);
485   const int64x2_t top_x_res =
486       vmlal_s32(top_left_x_res, top_right_sub_top_left, x_lerp);
487   const int32x2_t bottom_right_sub_bottom_left =
488       vsub_s32(bottom_right, bottom_left);
489   const int64x2_t bottom_x_res =
490       vmlal_s32(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
491 
492   const int64x2_t bottom_sub_top_x_res = vsubq_s64(bottom_x_res, top_x_res);
493   const int64x2_t bottom_sub_top =
494       Divide64x2Pow<RESOLUTION>(bottom_sub_top_x_res);
495   const int32x2_t bottom_sub_top_32 = vqmovn_s64(bottom_sub_top);
496   const int64x2_t top_add_bottom_sub_top_mul_ylerp_x_res =
497       vmlal_s32(top_x_res, bottom_sub_top_32, y_lerp);
498   const int64x2_t retval =
499       Divide64x2PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
500   const int32x2_t retval32 = vqmovn_s64(retval);
501   return retval32;
502 }
503 
504 // 8bit x 8 NEON accelerated lerp computation.
505 template <int RESOLUTION>
ComputeLerp8x8(const uint8x8_t top_left8x8,const uint8x8_t top_right8x8,const uint8x8_t bottom_left8x8,const uint8x8_t bottom_right8x8,const int16x8_t x_lerp,const int16x8_t y_lerp)506 inline uint8x8_t ComputeLerp8x8(const uint8x8_t top_left8x8,
507                                 const uint8x8_t top_right8x8,
508                                 const uint8x8_t bottom_left8x8,
509                                 const uint8x8_t bottom_right8x8,
510                                 const int16x8_t x_lerp,
511                                 const int16x8_t y_lerp) {
512   static_assert(RESOLUTION < 8, "RESOLUTION must be less than 8");
513   constexpr uint8 RESOLUTION_MULT_VAL = (1 << RESOLUTION);
514   static const uint8x8_t RESOLUTION_MULT = vdup_n_u8(RESOLUTION_MULT_VAL);
515 
516   const int16x8_t top_left_x_res =
517       vreinterpretq_s16_u16(vmull_u8(top_left8x8, RESOLUTION_MULT));
518   const int16x8_t bottom_left_x_res =
519       vreinterpretq_s16_u16(vmull_u8(bottom_left8x8, RESOLUTION_MULT));
520 
521   const int16x8_t top_right_sub_top_left =
522       vreinterpretq_s16_u16(vsubl_u8(top_right8x8, top_left8x8));
523   const int16x8_t top_x_res =
524       vmlaq_s16(top_left_x_res, top_right_sub_top_left, x_lerp);
525 
526   const int16x8_t bottom_right_sub_bottom_left =
527       vreinterpretq_s16_u16(vsubl_u8(bottom_right8x8, bottom_left8x8));
528   const int16x8_t bottom_x_res =
529       vmlaq_s16(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
530 
531   const int16x8_t bottom_sub_top_x_res = vsubq_s16(bottom_x_res, top_x_res);
532   const int16x8_t bottom_sub_top =
533       Divide16x8Pow<RESOLUTION>(bottom_sub_top_x_res);
534   const int16x8_t top_add_bottom_sub_top_mul_ylerp_x_res =
535       vmlaq_s16(top_x_res, bottom_sub_top, y_lerp);
536   const int16x8_t retval16 =
537       Divide16x8PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
538   const uint8x8_t retval = vmovn_u16(vreinterpretq_u16_s16(retval16));
539   return retval;
540 }
541 
542 // Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon
543 // Return std::array instead of pointer to leverage return value optimization
Requantize8x8To32Neon(const uint8 * input_ptr,const int64x2_t input_0_64x2,const int32x2_t input_mult_32x2)544 inline std::array<int32x4_t, 2> Requantize8x8To32Neon(
545     const uint8* input_ptr, const int64x2_t input_0_64x2,
546     const int32x2_t input_mult_32x2) {
547   const uint8x8_t input_value_8x8 = vld1_u8(input_ptr);
548   const int16x8_t input_value_16x8 =
549       vreinterpretq_s16_u16(vmovl_u8(input_value_8x8));
550   const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8);
551   const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8);
552   const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4);
553   const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4);
554   const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4);
555   const int32x2_t input_value_low_high_32x2 =
556       vget_high_s32(input_value_low_32x4);
557   const int32x2_t input_value_high_low_32x2 =
558       vget_low_s32(input_value_high_32x4);
559   const int32x2_t input_value_high_high_32x2 =
560       vget_high_s32(input_value_high_32x4);
561   const int64x2_t mult_result_low_low_64x2 =
562       vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2);
563   const int64x2_t mult_result_low_high_64x2 =
564       vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2);
565   const int64x2_t mult_result_high_low_64x2 =
566       vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2);
567   const int64x2_t mult_result_high_high_64x2 =
568       vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2);
569   const int32x2_t output_value_low_low_32x2 =
570       vqmovn_s64(mult_result_low_low_64x2);
571   const int32x2_t output_value_low_high_32x2 =
572       vqmovn_s64(mult_result_low_high_64x2);
573   const int32x2_t output_value_high_low_32x2 =
574       vqmovn_s64(mult_result_high_low_64x2);
575   const int32x2_t output_value_high_high_32x2 =
576       vqmovn_s64(mult_result_high_high_64x2);
577   const int32x4_t output_value_low_32x4 =
578       vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2);
579   const int32x4_t output_value_high_32x4 =
580       vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2);
581   return std::array<int32x4_t, 2>{
582       {output_value_low_32x4, output_value_high_32x4}};
583 }
584 
585 // Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD
586 // intrinsics for ARM platforms.
587 template <>
588 inline void RequantizeManyInNewRange<quint8, qint32>(
589     const quint8* input, int64 count, float min_input, float max_input,
590     float min_output, float max_output, qint32* output) {
591   // Pre-calculate zero position and multiplier.
592   // Calculate 0 and 1 value in float.
593   const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
594   const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
595 
596   // Cast 0 and 1 value in int64.
597   const int64 code_0_int64 =
598       FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
599   const int64 code_1_int64 =
600       FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
601 
602   // Calculate multiplier.
603   const int32 mult_int32 = static_cast<int32>(code_1_int64 - code_0_int64);
604 
605   // Broadcast 0 position and multiplier to lanes
606   const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64);
607   const int32x2_t mult_32x2 = vmov_n_s32(mult_int32);
608 
609   int64 i = 0;
610 
611   // Use SIMD to requantize array.
612   for (; i < (count - 7); i += 8) {
613     const uint8* input_ptr = &(input->value) + i;
614     int32* output_ptr = &(output->value) + i;
615     const std::array<int32x4_t, 2> output_value =
616         Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2);
617     vst1q_s32(output_ptr + 0, output_value[0]);
618     vst1q_s32(output_ptr + 4, output_value[1]);
619   }
620 
621   // Requantize remaining elements in array without SIMD.
622   const int64 lowest_quantized =
623       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
624   const int64 highest_quantized =
625       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
626 
627   for (; i < count; ++i) {
628     const int64 input_value = static_cast<int64>(input[i]);
629     int64 output_value = code_0_int64 + (input_value * mult_int32);
630     output_value = std::max(output_value, lowest_quantized);
631     output_value = std::min(output_value, highest_quantized);
632     output[i] = static_cast<int32>(output_value);
633   }
634 }
635 
636 #else
637 
638 // If SIMD implementations aren't available, then use these default reference
639 // versions.
640 template <>
641 inline void RequantizeManyInNewRange<qint32, quint8>(
642     const qint32* input, int64 count, float min_input, float max_input,
643     float min_output, float max_output, quint8* output) {
644   RequantizeManyInNewRangeReference(input, count, min_input, max_input,
645                                     min_output, max_output, output);
646 }
647 
648 template <>
649 inline void RequantizeManyInNewRange<quint8, qint32>(
650     const quint8* input, int64 count, float min_input, float max_input,
651     float min_output, float max_output, qint32* output) {
652   RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input,
653                                             min_output, max_output, output);
654 }
655 
656 #endif
657 
658 template <int shift>
659 struct int64_right_shift_op {
EIGEN_EMPTY_STRUCT_CTORint64_right_shift_op660   EIGEN_EMPTY_STRUCT_CTOR(int64_right_shift_op)
661   EIGEN_DEVICE_FUNC
662   EIGEN_STRONG_INLINE const int64 operator()(const int64& a) const {
663     return a >> shift;
664   }
665 };
666 
667 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
668 template <class T1, class T2>
RequantizeManyInNewRangeUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min_input,float max_input,float min_output,float max_output,Tensor * output)669 inline void RequantizeManyInNewRangeUsingEigen(
670     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
671     float max_input, float min_output, float max_output, Tensor* output) {
672   auto input_array = input.flat<T1>();
673   QuantizedToFloatStruct<T1> q2f(min_input, max_input);
674   auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
675   FloatToQuantizedStruct<T2> f2q(min_output, max_output);
676   auto input_requantized = QUANTIZE_WITH_EIGEN(input_float, f2q, T2);
677 
678   output->flat<T2>().device(device) = input_requantized;
679 }
680 
681 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
682 //
683 // Because converting 32-bit accumulated results down to eight bit is a common
684 // case, we have a specialized code path to handle it as efficiently as
685 // possible using only fixed-point math for the inner loop.
686 template <>
687 inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
688     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
689     float max_input, float min_output, float max_output, Tensor* output) {
690   // Initially we calculate all the constants we need once, before we go into
691   // the inner loop.  If this is updated, also update the non-Eigen version.
692   const int fp_shift = 16;
693   const float input_range = max_input - min_input;
694   const float output_range = max_output - min_output;
695   const float recip_output_range =
696       output_range == 0.0 ? 0.0 : (255.0 / output_range);
697   const float input_rezero = (min_input + max_input) / 2.0;
698   const int64 range_scale_fp =
699       output_range == 0.0 ? 0.0
700                           : static_cast<int64>(255.0 * (1 << fp_shift) *
701                                                input_range / output_range);
702   const int64 input_offset_fp =
703       static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
704   const int64 output_offset_fp =
705       output_range == 0.0
706           ? 0
707           : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
708                                output_range);
709   const int64 rounding_delta = 1 << (fp_shift - 1);
710 
711   // Inside this eigen expression we just do minimal adds, multiplies, and
712   // shifts. It should be possible to perform all the calculations in 32-bit
713   // rather than 64, but that's not been implemented yet.
714   auto input_array = input.flat<qint32>();
715   auto fp_value = ((input_array.template cast<int64>() * range_scale_fp)
716                        .unaryExpr(int64_right_shift_op<32>())) +
717                   (input_offset_fp - output_offset_fp + rounding_delta);
718   auto intermediate = fp_value.unaryExpr(int64_right_shift_op<fp_shift>());
719   auto input_requantized = intermediate.cwiseMax(0LL)
720                                .cwiseMin(255LL)
721                                .template cast<int32>()
722                                .template cast<quint8>();
723   output->flat<quint8>().device(device) = input_requantized;
724 }
725 
726 // REQUIRES: 'result->NumElements() == input.NumElements()'
727 template <class T>
FloatTensorToQuantizedInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)728 void FloatTensorToQuantizedInPlaceUsingEigen(
729     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
730     float max, Tensor* result) {
731   DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
732   auto flat_input = input.flat<float>();
733   auto flat_result = result->flat<T>();
734   DCHECK_EQ(flat_input.size(), flat_result.size());
735 
736   FloatToQuantizedStruct<T> f2q(min, max);
737   flat_result.device(device) = QUANTIZE_WITH_EIGEN(flat_input, f2q, T);
738 }
739 
740 template <class T>
FloatTensorToQuantizedInPlace(const Tensor & input,float min,float max,Tensor * result)741 void FloatTensorToQuantizedInPlace(const Tensor& input, float min, float max,
742                                    Tensor* result) {
743   DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
744   auto flat_input = input.flat<float>();
745   auto flat_result = result->flat<T>();
746   const int data_size = flat_input.size();
747   DCHECK(data_size == flat_result.size());
748   for (int i = 0; i < data_size; ++i) {
749     flat_result(i) = FloatToQuantized<T>(flat_input(i), min, max);
750   }
751 }
752 
753 template <class T>
FloatTensorToQuantized(const Tensor & input,float min,float max)754 Tensor FloatTensorToQuantized(const Tensor& input, float min, float max) {
755   Tensor result(DataTypeToEnum<T>::v(), input.shape());
756   FloatTensorToQuantizedInPlace<T>(input, min, max, &result);
757   return result;
758 }
759 
760 // REQUIRES: 'result->NumElements() == input.NumElements()'
761 template <class T>
QuantizedTensorToFloatInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)762 void QuantizedTensorToFloatInPlaceUsingEigen(
763     const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
764     float max, Tensor* result) {
765   DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
766   auto flat_input = input.flat<T>();
767   auto flat_result = result->flat<float>();
768   const int data_size = flat_input.size();
769   DCHECK(data_size == flat_result.size());
770 
771   QuantizedToFloatStruct<T> q2f(min, max);
772   flat_result.device(device) = DEQUANTIZE_WITH_EIGEN(flat_input, q2f);
773 }
774 
775 // REQUIRES: 'result->NumElements() == input.NumElements()'
776 template <class T>
QuantizedTensorToFloatInPlace(const Tensor & input,float min,float max,Tensor * result)777 void QuantizedTensorToFloatInPlace(const Tensor& input, float min, float max,
778                                    Tensor* result) {
779   DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
780   auto flat_input = input.flat<T>();
781   auto flat_result = result->flat<float>();
782   const int data_size = flat_input.size();
783   DCHECK(data_size == flat_result.size());
784   for (int i = 0; i < data_size; ++i) {
785     flat_result(i) = QuantizedToFloat<T>(flat_input(i), min, max);
786   }
787 }
788 
789 template <class T>
QuantizedTensorToFloat(const Tensor & input,float min,float max)790 Tensor QuantizedTensorToFloat(const Tensor& input, float min, float max) {
791   Tensor result(DT_FLOAT, input.shape());
792   QuantizedTensorToFloatInPlace<T>(input, min, max, &result);
793   return result;
794 }
795 
796 void GetOutputMinAndMaxForQuantizedAdd(float input_min, float input_max,
797                                        float smaller_input_min,
798                                        float smaller_input_max,
799                                        float* output_min, float* output_max);
800 
801 // Add <input> and <smaller_input>.  If <smaller_input> has fewer elements than
802 // <input>, then it is broadcast onto <input>.
803 template <typename T1, typename T2, typename T3>
QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)804 void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device,
805                             const Tensor& input, float input_min,
806                             float input_max, const Tensor& smaller_input,
807                             float smaller_input_min, float smaller_input_max,
808                             Tensor* output, float* output_min,
809                             float* output_max) {
810   const auto& input_flat = input.flat<T1>();
811   const auto& smaller_input_flat = smaller_input.flat<T2>();
812   auto output_flat = output->flat<T3>();
813 
814   GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
815                                     smaller_input_max, output_min, output_max);
816   // To do addition properly, we need to compensate for a possibly unbalanced
817   // zero point in the total representation. The quantized value that
818   // represents the real number zero needs to be subtracted before addition to
819   // make sure that the identity of zero + zero = zero holds.
820   const T3 zero_in_total_space =
821       FloatToQuantized<T3>(0.0f, *output_min, *output_max);
822 
823   const int64 input_element_count = input.NumElements();
824   const int64 smaller_input_element_count = smaller_input.NumElements();
825 
826   QuantizedToFloatStruct<T1> input_q2f(input_min, input_max);
827   QuantizedToFloatStruct<T2> smaller_input_q2f(smaller_input_min,
828                                                smaller_input_max);
829   FloatToQuantizedStruct<T3> f2q(*output_min, *output_max);
830 
831   auto smaller_input_float =
832       DEQUANTIZE_WITH_EIGEN(smaller_input_flat, smaller_input_q2f);
833   auto smaller_input_in_total_space =
834       QUANTIZE_WITH_EIGEN(smaller_input_float, f2q, T3);
835 
836   auto input_float = DEQUANTIZE_WITH_EIGEN(input_flat, input_q2f);
837   auto input_in_total_space = QUANTIZE_WITH_EIGEN(input_float, f2q, T3);
838 
839   Eigen::array<Eigen::DenseIndex, 1> bcast;
840   bcast[0] = input_element_count / smaller_input_element_count;
841   output_flat.device(device) =
842       input_in_total_space +
843       (smaller_input_in_total_space.broadcast(bcast) + zero_in_total_space);
844 }
845 
846 // This is a reference implementation of the bias addition for quantized
847 // buffers, designed to provide a clear specification for the result we
848 // want. We'll want to specialize this for particular hardware, and
849 // probably even fuse it with matrix multiplications in a lot of cases. It's
850 // important to show the clamping behavior we want in particular.
851 template <typename T1, typename T2, typename T3>
QuantizedAdd(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)852 void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input,
853                   float input_min, float input_max, const Tensor& smaller_input,
854                   float smaller_input_min, float smaller_input_max,
855                   Tensor* output, float* output_min, float* output_max) {
856   const auto& input_flat = input.flat<T1>();
857   const auto& smaller_input_flat = smaller_input.flat<T2>();
858   auto output_flat = output->flat<T3>();
859 
860   GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
861                                     smaller_input_max, output_min, output_max);
862   // To do addition properly, we need to compensate for a possibly unbalanced
863   // zero point in the total representation. The quantized value that
864   // represents the real number zero needs to be subtracted before addition to
865   // make sure that the identity of zero + zero = zero holds.
866   const T3 zero_in_total_space =
867       FloatToQuantized<T3>(0.0f, *output_min, *output_max);
868 
869   const int64 input_element_count = input.NumElements();
870   const int64 smaller_input_element_count = smaller_input.NumElements();
871 
872   float total_min = *output_min;
873   float total_max = *output_max;
874   const size_t how_many_iterations =
875       (input_element_count / smaller_input_element_count);
876   for (size_t iteration = 0; iteration < how_many_iterations; ++iteration) {
877     const size_t offset = iteration * smaller_input_element_count;
878     for (int c = 0; c < smaller_input_element_count; ++c) {
879       const int index = (offset + c);
880       // The two numbers we're going to add can each be in very different
881       // ranges (e.g. the quantized value '127' may represent very different
882       // real numbers in both) so we need to convert them to a common range
883       // before we sum them.
884       const T1 input_value = input_flat(index);
885       const T3 input_in_total_space = RequantizeInNewRange<T1, T3>(
886           input_value, input_min, input_max, total_min, total_max);
887       const T2 smaller_input_value = smaller_input_flat(c);
888       const T3 smaller_input_in_total_space =
889           RequantizeInNewRange<T2, T3>(smaller_input_value, smaller_input_min,
890                                        smaller_input_max, total_min, total_max);
891       const T3 total_pre = input_in_total_space + smaller_input_in_total_space;
892       // As noted above, we need to compensate for the offset of the actual
893       // zero point in the space we're operating in.
894       const T3 total = total_pre + zero_in_total_space;
895       output_flat(index) = total;
896     }
897   }
898 }
899 
900 // See gemmlowp/internal/multi_thread_gemm.h for the semantics of Execute.
901 class TensorflowGemmlowpWorkersPool {
902  public:
TensorflowGemmlowpWorkersPool(thread::ThreadPool * workers)903   TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers)
904       : workers_(workers) {}
905 
~TensorflowGemmlowpWorkersPool()906   ~TensorflowGemmlowpWorkersPool() {
907     // This workaround ensures that all worker tasks have exited methods in the
908     // BlockingCounter. Without this, there is a race where the context is torn
909     // down while the counter is in use.
910     counter_to_decrement_when_ready_.Reset(0);
911   }
912 
Execute(const std::vector<gemmlowp::Task * > & tasks)913   void Execute(const std::vector<gemmlowp::Task*>& tasks) {
914     assert(!tasks.empty());
915     assert(workers_ != nullptr);
916     counter_to_decrement_when_ready_.Reset(tasks.size());
917     for (gemmlowp::Task* task : tasks) {
918       workers_->Schedule([this, task]() {
919         // TODO(cwhipkey): get a local_allocator from a thread local storage.
920         gemmlowp::Allocator local_allocator;
921         CHECK(task != nullptr);
922         task->local_allocator = &local_allocator;
923         task->Run();
924         counter_to_decrement_when_ready_.DecrementCount();
925       });
926     }
927     counter_to_decrement_when_ready_.Wait();
928     for (gemmlowp::Task* task : tasks) {
929       delete task;
930     }
931   }
932 
933  private:
934   thread::ThreadPool* const workers_;
935 
936   // The BlockingCounter used to wait for the workers.
937   gemmlowp::BlockingCounter counter_to_decrement_when_ready_;
938 
939   TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmlowpWorkersPool);
940 };
941 
942 class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase {
943  public:
TensorflowGemmContext(int num_threads,thread::ThreadPool * workers)944   TensorflowGemmContext(int num_threads, thread::ThreadPool* workers)
945       : workers_pool_(workers) {
946     set_max_num_threads(num_threads);
947   }
948 
workers_pool()949   TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; }
950 
951  private:
952   TensorflowGemmlowpWorkersPool workers_pool_;
953 
954   TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmContext);
955 };
956 
957 }  // namespace tensorflow
958 
959 #endif  // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
960