1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
17 #define TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
18
19 #define EIGEN_USE_THREADS
20
21 // This is a set of functions that standardizes how quantized values are
22 // interpreted as float numbers.
23 // All of the current implementations are for reference and have not been
24 // optimized. They should be implementable using fixed point representations
25 // to avoid a dependency on floating-point hardware.
26
27 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
28 #define QUANTIZATION_UTILS_USE_NEON
29 #include <arm_neon.h>
30 #endif
31
32 #include <array>
33
34 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
35 #define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
36 #include "public/gemmlowp.h"
37 #include "tensorflow/core/framework/tensor.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39
40 namespace tensorflow {
41
42 // We have to be able to detect and handle overflows in int32, so this function
43 // uses doubles and int64's to make sure we have enough room.
44 template <class T>
FloatToQuantizedUnclamped(float input,float range_min,float range_max)45 int64 FloatToQuantizedUnclamped(float input, float range_min, float range_max) {
46 const int64 lowest_quantized =
47 static_cast<double>(Eigen::NumTraits<T>::lowest());
48 if (range_min == range_max) {
49 return lowest_quantized;
50 }
51 const int number_of_bits = sizeof(T) * 8;
52 const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
53 const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
54 const double range = ((range_max - range_min) * range_adjust);
55 const double range_scale = (number_of_steps / range);
56 int64 quantized =
57 (round(input * range_scale) - round(range_min * range_scale));
58 quantized += lowest_quantized;
59 return quantized;
60 }
61
62 // This converts the float into the final quantized type, clamping/saturating
63 // any over or underflows.
64 template <class T>
FloatToQuantized(float input,float range_min,float range_max)65 T FloatToQuantized(float input, float range_min, float range_max) {
66 if (std::is_same<T, float>::value) {
67 // Specialization for float. This is used in reference implementation
68 // for float which is useful to compare performance between float
69 // and quantized type.
70 return input;
71 }
72 int64 quantized = FloatToQuantizedUnclamped<T>(input, range_min, range_max);
73 const int64 lowest_quantized =
74 static_cast<int64>(Eigen::NumTraits<T>::lowest());
75 const int64 highest_quantized =
76 static_cast<int64>(Eigen::NumTraits<T>::highest());
77 quantized = std::max(quantized, lowest_quantized);
78 quantized = std::min(quantized, highest_quantized);
79 return static_cast<T>(static_cast<int32>(quantized));
80 }
81
82 template <class T>
QuantizedToFloat(T input,float range_min,float range_max)83 float QuantizedToFloat(T input, float range_min, float range_max) {
84 if (std::is_same<T, float>::value) {
85 // Specialization for float. This is used in reference implementation
86 // for float which is useful to compare performance between float
87 // and quantized type.
88 return input;
89 }
90 if (range_min == range_max) {
91 return range_min;
92 }
93 const int number_of_bits = sizeof(T) * 8;
94 const int64 number_of_steps = static_cast<int64>(1) << number_of_bits;
95 const double range_adjust = (number_of_steps / (number_of_steps - 1.0));
96 const double range = ((range_max - range_min) * range_adjust);
97 const double range_scale = (range / number_of_steps);
98 const int64 lowest_quantized =
99 static_cast<int64>(Eigen::NumTraits<T>::lowest());
100 const double offset_input = static_cast<double>(input) - lowest_quantized;
101 // For compatibility with DEQUANTIZE_WITH_EIGEN, we should convert
102 // range_scale to a float, otherwise range_min_rounded might be slightly
103 // different.
104 const double range_min_rounded =
105 round(range_min / static_cast<float>(range_scale)) *
106 static_cast<float>(range_scale);
107 const double result = range_min_rounded + (offset_input * range_scale);
108 return static_cast<float>(result);
109 }
110
111 template <class T>
FloatForOneQuantizedLevel(float range_min,float range_max)112 float FloatForOneQuantizedLevel(float range_min, float range_max) {
113 const int64 highest = static_cast<int64>(Eigen::NumTraits<T>::highest());
114 const int64 lowest = static_cast<int64>(Eigen::NumTraits<T>::lowest());
115 const float float_for_one_quantized_level =
116 (range_max - range_min) / (highest - lowest);
117 return float_for_one_quantized_level;
118 }
119
120 template <class T1, class T2, class T3>
QuantizationRangeForMultiplication(float min_a,float max_a,float min_b,float max_b,float * min_c,float * max_c)121 void QuantizationRangeForMultiplication(float min_a, float max_a, float min_b,
122 float max_b, float* min_c,
123 float* max_c) {
124 const float a_float_for_one_quant_level =
125 FloatForOneQuantizedLevel<T1>(min_a, max_a);
126 const float b_float_for_one_quant_level =
127 FloatForOneQuantizedLevel<T2>(min_b, max_b);
128
129 const int64 c_highest = static_cast<int64>(Eigen::NumTraits<T3>::highest());
130 const int64 c_lowest = static_cast<int64>(Eigen::NumTraits<T3>::lowest());
131 const float c_float_for_one_quant_level =
132 a_float_for_one_quant_level * b_float_for_one_quant_level;
133
134 *min_c = c_float_for_one_quant_level * c_lowest;
135 *max_c = c_float_for_one_quant_level * c_highest;
136 }
137
138 // input_array is an eigen Tensor. q2f is a QuantizedToFloatStruct.
139 // This evaluates to an eigen tensor expression, to be used like:
140 // auto tensor = DEQUANTIZE_WITH_EIGEN(input_tensor, q2f);
141 #define DEQUANTIZE_WITH_EIGEN(input_array, q2f) \
142 ((q2f.range_min_rounded - q2f.lowest_quantized() * q2f.range_scale) + \
143 input_array.template cast<float>() * q2f.range_scale)
144
145 // input_array is an eigen Tensor. f2q is a FloatToQuantizedStruct.
146 // OutputType is the type of output (e.g. quint8).
147 // This evaluates to an eigen tensor expression, to be used like:
148 // auto tensor = QUANTIZE_WITH_EIGEN(input_tensor, f2q, T);
149 #define QUANTIZE_WITH_EIGEN(input_array, f2q, OutputType) \
150 ((input_array * f2q.range_scale).round() - \
151 (f2q.range_min_scaled - f2q.lowest_quantized())) \
152 .cwiseMax(f2q.lower_bound_float()) \
153 .cwiseMin(f2q.upper_bound_float()) \
154 .template cast<int32>() \
155 .template cast<OutputType>()
156
157 // For use with DEQUANTIZE_WITH_EIGEN.
158 template <typename T>
159 struct QuantizedToFloatStruct {
160 static constexpr int number_of_bits = sizeof(T) * 8;
161 static constexpr int64 number_of_steps = static_cast<int64>(1)
162 << number_of_bits;
163
lowest_quantizedQuantizedToFloatStruct164 static float lowest_quantized() {
165 return static_cast<float>(Eigen::NumTraits<T>::lowest());
166 }
167
QuantizedToFloatStructQuantizedToFloatStruct168 QuantizedToFloatStruct(float range_min, float range_max)
169 : range_min(range_min),
170 range_scale((range_max - range_min) / (number_of_steps - 1.0)),
171 range_min_rounded(range_max == range_min
172 ? range_min
173 : round(range_min / range_scale) * range_scale) {}
174
175 const float range_min;
176 const float range_scale;
177 const float range_min_rounded;
178 };
179
180 // For use with QUANTIZE_WITH_EIGEN.
181 template <typename T>
182 struct FloatToQuantizedStruct {
183 static constexpr int number_of_bits = sizeof(T) * 8;
184 static constexpr int64 number_of_steps = static_cast<int64>(1)
185 << number_of_bits;
186 static constexpr double range_adjust =
187 (number_of_steps / (number_of_steps - 1.0));
188
189 // Casting QInt32's lowest or highest to a float gives a float that can't be
190 // cast back to int32 or QInt32. Instead, use bounds that can be converted
191 // back to int32 without going outside the range of an int32.
lower_bound_floatFloatToQuantizedStruct192 static float lower_bound_float() {
193 return Eigen::numext::maxi(
194 static_cast<float>(Eigen::NumTraits<T>::lowest()), -2.147483648e+09f);
195 }
upper_bound_floatFloatToQuantizedStruct196 static float upper_bound_float() {
197 return Eigen::numext::mini(
198 static_cast<float>(Eigen::NumTraits<T>::highest()), +2.147483520e+09f);
199 }
200
lowest_quantizedFloatToQuantizedStruct201 static float lowest_quantized() {
202 return static_cast<float>(Eigen::NumTraits<T>::lowest());
203 }
204
FloatToQuantizedStructFloatToQuantizedStruct205 FloatToQuantizedStruct(float range_min, float range_max)
206 : range_min(range_min),
207 range_scale(range_max == range_min
208 ? 0.0
209 : (number_of_steps - 1.0) / (range_max - range_min)),
210 range_min_scaled(round(range_min * range_scale)) {}
211
212 const float range_min;
213 const float range_scale;
214 const float range_min_scaled;
215 };
216
217 template <class T1, class T2>
RequantizeInNewRange(T1 input,float min_input,float max_input,float min_new,float max_new)218 inline T2 RequantizeInNewRange(T1 input, float min_input, float max_input,
219 float min_new, float max_new) {
220 const float input_float = QuantizedToFloat<T1>(input, min_input, max_input);
221 return FloatToQuantized<T2>(input_float, min_new, max_new);
222 }
223
224 template <class T1, class T2>
RequantizeManyInNewRange(const T1 * input,int64 count,float min_input,float max_input,float min_output,float max_output,T2 * output)225 inline void RequantizeManyInNewRange(const T1* input, int64 count,
226 float min_input, float max_input,
227 float min_output, float max_output,
228 T2* output) {
229 for (size_t index = 0; index < count; ++index) {
230 const float input_float =
231 QuantizedToFloat<T1>(input[index], min_input, max_input);
232 output[index] = FloatToQuantized<T2>(input_float, min_output, max_output);
233 }
234 }
235
236 // Because converting 32-bit accumulated results down to eight bit is a common
237 // case, we have a specialized code path to handle it as efficiently as
238 // possible using only fixed-point math for the inner loop.
RequantizeManyInNewRangeReference(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)239 inline void RequantizeManyInNewRangeReference(const qint32* input, int64 count,
240 float min_input, float max_input,
241 float min_output,
242 float max_output,
243 quint8* output) {
244 // Initially we calculate all the constants we need once, before we go into
245 // the inner loop. If this is updated, also update the Eigen version.
246 const int fp_shift = 16;
247 const float input_range = max_input - min_input;
248 const float output_range = max_output - min_output;
249 const float recip_output_range =
250 output_range == 0.0 ? 0.0 : (255.0 / output_range);
251 const float input_rezero = (min_input + max_input) / 2.0;
252 const int64 range_scale_fp =
253 output_range == 0.0 ? 0.0
254 : static_cast<int64>(255.0 * (1 << fp_shift) *
255 input_range / output_range);
256 const int64 input_offset_fp =
257 static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
258 const int64 output_offset_fp =
259 output_range == 0.0
260 ? 0
261 : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
262 output_range);
263 const int64 rounding_delta = 1 << (fp_shift - 1);
264
265 // Inside this loop we just do minimal adds, multiplies, and shifts, in a way
266 // that could be easily adapted for a SIMD implementation. It should also be
267 // possible to perform all the calculations in 32-bit rather than 64, but
268 // that's not been implemented yet.
269 for (size_t index = 0; index < count; ++index) {
270 const int64 input_value = static_cast<int64>(input[index]);
271 const int64 fp_value =
272 ((input_value * range_scale_fp) >> 32) + input_offset_fp;
273 const int64 offset_intermediate = fp_value - output_offset_fp;
274 const int64 round_intermediate = offset_intermediate + rounding_delta;
275 int64 quantized_int64 = round_intermediate >> fp_shift;
276 quantized_int64 = std::max(quantized_int64, 0LL);
277 quantized_int64 = std::min(quantized_int64, 255LL);
278 output[index] = static_cast<quint8>(static_cast<int32>(quantized_int64));
279 }
280 }
281
282 // Another common case is converting eight bit inputs up to thirty two bits, so
283 // we have specialized fixed-point code to accelerate that. There is also a NEON
284 // version for ARM devices below.
RequantizeManyInNewRange8To32BitReference(const quint8 * input,int64 count,float min_input,float max_input,float min_output,float max_output,qint32 * output)285 inline void RequantizeManyInNewRange8To32BitReference(
286 const quint8* input, int64 count, float min_input, float max_input,
287 float min_output, float max_output, qint32* output) {
288 const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
289 const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
290 const int64 code_0_int64 =
291 FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
292 const int64 code_1_int64 =
293 FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
294 const int32 mult_int32 = code_1_int64 - code_0_int64;
295 const int64 lowest_quantized =
296 static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
297 const int64 highest_quantized =
298 static_cast<int64>(Eigen::NumTraits<qint32>::highest());
299 for (int64 i = 0; i < count; ++i) {
300 const int64 input_value = static_cast<int64>(input[i]);
301 int64 output_value = code_0_int64 + (input_value * mult_int32);
302 output_value = std::max(output_value, lowest_quantized);
303 output_value = std::min(output_value, highest_quantized);
304 output[i] = static_cast<int32>(output_value);
305 }
306 }
307
308 #ifdef QUANTIZATION_UTILS_USE_NEON
309 // Speeds up the 32->8bit conversion using fixed-point arithmetic and NEON SIMD
310 // intrinsics for ARM platforms.
RequantizeManyInNewRangeNeon(const qint32 * input,int64 count,float min_input,float max_input,float min_output,float max_output,quint8 * output)311 inline void RequantizeManyInNewRangeNeon(const qint32* input, int64 count,
312 float min_input, float max_input,
313 float min_output, float max_output,
314 quint8* output) {
315 // Initially we calculate all the constants we need once, before we go into
316 // the inner loop. If this is updated, also update the Eigen version.
317 const int fp_shift = 16;
318
319 // Calculate range variables in advance.
320 // Input range.
321 const float input_range = max_input - min_input;
322 // Output range.
323 const float output_range = max_output - min_output;
324 // Ratio of output range.
325 const float recip_output_range =
326 output_range == 0.0 ? 0.0 : (255.0 / output_range);
327 // Average of input range as zero position of input.
328 const float input_rezero = (min_input + max_input) / 2.0;
329 // In-out range scale.
330 const int32 range_scale_fp =
331 output_range == 0.0 ? 0.0
332 : static_cast<int32>(255.0 * (1 << (fp_shift - 16)) *
333 input_range / output_range);
334 // Input zero position offset to output.
335 const int32 input_offset_fp =
336 static_cast<int32>(input_rezero * recip_output_range * (1 << fp_shift));
337 // Output min offset.
338 const int32 output_offset_fp =
339 output_range == 0.0
340 ? 0
341 : static_cast<int32>((1 << fp_shift) * (min_output * 255.0) /
342 output_range);
343 const int32 rounding_delta = 1 << (fp_shift - 1);
344
345 // broadcast range to each lane
346 const int32x4_t range_scale_fp_32x4 = vmovq_n_s32(range_scale_fp);
347 const int32x4_t input_offset_fp_32x4 = vmovq_n_s32(input_offset_fp);
348 const int32x4_t output_offset_fp_32x4 = vmovq_n_s32(output_offset_fp);
349 const int32x4_t rounding_delta_32x4 = vmovq_n_s32(rounding_delta);
350
351 int64 index = 0;
352 // Use SIMD to requantize.
353 for (; index < (count - 7); index += 8) {
354 const int32* input_ptr = &(input->value) + index;
355 const int32x4_t input_value_low_32x4 = vld1q_s32(input_ptr);
356 const int32x4_t input_value_high_32x4 = vld1q_s32(input_ptr + 4);
357 const int32x4_t fp_value_low_32x4 = vaddq_s32(
358 input_offset_fp_32x4,
359 vmulq_s32(vshrq_n_s32(input_value_low_32x4, 16), range_scale_fp_32x4));
360 const int32x4_t fp_value_high_32x4 = vaddq_s32(
361 input_offset_fp_32x4,
362 vmulq_s32(vshrq_n_s32(input_value_high_32x4, 16), range_scale_fp_32x4));
363 const int32x4_t offset_intermediate_low_32x4 =
364 vsubq_s32(fp_value_low_32x4, output_offset_fp_32x4);
365 const int32x4_t offset_intermediate_high_32x4 =
366 vsubq_s32(fp_value_high_32x4, output_offset_fp_32x4);
367 const int32x4_t round_intermediate_low_32x4 =
368 vaddq_s32(offset_intermediate_low_32x4, rounding_delta_32x4);
369 const int32x4_t round_intermediate_high_32x4 =
370 vaddq_s32(offset_intermediate_high_32x4, rounding_delta_32x4);
371 const int16x4_t quantized_low_16x4 =
372 vqmovn_s32(vshrq_n_s32(round_intermediate_low_32x4, fp_shift));
373 const int16x4_t quantized_high_16x4 =
374 vqmovn_s32(vshrq_n_s32(round_intermediate_high_32x4, fp_shift));
375 const uint8x8_t quantized_8x8 =
376 vqmovun_s16(vcombine_s16(quantized_low_16x4, quantized_high_16x4));
377 uint8* output_ptr = &(output->value) + index;
378 vst1_u8(output_ptr, quantized_8x8);
379 }
380
381 // Requantize remaining elements in array without SIMD.
382 for (; index < count; ++index) {
383 const int32 input_value = static_cast<int32>(input[index]);
384 const int32 fp_value =
385 static_cast<int32>(
386 (static_cast<int32>(input_value >> 16) * (range_scale_fp))) +
387 input_offset_fp;
388 const int32 offset_intermediate = fp_value - output_offset_fp;
389 const int32 round_intermediate = offset_intermediate + rounding_delta;
390 int32 quantized_int32 = round_intermediate >> fp_shift;
391 quantized_int32 = std::max(quantized_int32, 0);
392 quantized_int32 = std::min(quantized_int32, 255);
393 output[index] = static_cast<quint8>(static_cast<int32>(quantized_int32));
394 }
395 }
396
397 template <>
398 inline void RequantizeManyInNewRange<qint32, quint8>(
399 const qint32* input, int64 count, float min_input, float max_input,
400 float min_output, float max_output, quint8* output) {
401 const float input_range = max_input - min_input;
402 const float output_range = max_output - min_output;
403 if ((input_range / output_range) > 16384.0f) {
404 // Our NEON implementation uses 32-bit math and can't handle very
405 // large ranges, so fall back to the reference implementation. We don't
406 // expect these to be common in models, so this shouldn't be a performance
407 // problem in practice.
408 RequantizeManyInNewRangeReference(input, count, min_input, max_input,
409 min_output, max_output, output);
410 } else {
411 RequantizeManyInNewRangeNeon(input, count, min_input, max_input, min_output,
412 max_output, output);
413 }
414 }
415
416 // NEON accelerated 16bit rounded division by 2^n.
417 template <int POW>
Divide16x8PowRound(const int16x8_t val)418 inline int16x8_t Divide16x8PowRound(const int16x8_t val) {
419 const int16x8_t val_sign = vshrq_n_s16(val, 15);
420 const int16x8_t val_xor = veorq_s16(val, val_sign);
421 const int16x8_t val_pos = vsubq_s16(val_xor, val_sign);
422 const int16x8_t shifted_val_pos = vrshrq_n_s16(val_pos, POW);
423 const int16x8_t shifted_val_pos_xor = veorq_s16(shifted_val_pos, val_sign);
424 const int16x8_t shifted_val = vsubq_s16(shifted_val_pos_xor, val_sign);
425 return shifted_val;
426 }
427
428 // NEON accelerated 64bit rounded division by 2^n.
429 template <int POW>
Divide64x2PowRound(const int64x2_t val)430 inline int64x2_t Divide64x2PowRound(const int64x2_t val) {
431 const int64x2_t val_sign = vshrq_n_s64(val, 63);
432 const int64x2_t val_xor = veorq_s64(val, val_sign);
433 const int64x2_t val_pos = vsubq_s64(val_xor, val_sign);
434 const int64x2_t shifted_val_pos = vrshrq_n_s64(val_pos, POW);
435 const int64x2_t shifted_val_pos_xor = veorq_s64(shifted_val_pos, val_sign);
436 const int64x2_t shifted_val = vsubq_s64(shifted_val_pos_xor, val_sign);
437 return shifted_val;
438 }
439
440 // NEON accelerated 16bit division by 2^n.
441 // CAVEAT: The input must be greater than min-int16 to avoid underflow.
442 template <int POW>
Divide16x8Pow(const int16x8_t val)443 inline int16x8_t Divide16x8Pow(const int16x8_t val) {
444 static constexpr int16 FIRST_BIT_VAL = 0x0000000000000001;
445 static const int16x8_t FIRST_BIT = vmovq_n_s16(FIRST_BIT_VAL);
446 const int16x8_t val_sign = vshrq_n_s16(val, 15);
447 const int16x8_t neg_offset = vandq_s16(val_sign, FIRST_BIT);
448 const int16x8_t val_with_offset = vsubq_s16(val, neg_offset);
449 const int16x8_t shifted_wo_offset =
450 vsraq_n_s16(neg_offset, val_with_offset, POW);
451 return shifted_wo_offset;
452 }
453
454 // NEON accelerated 64bit division by 2^n.
455 // CAVEAT: The input must be greater than min-int64 to avoid underflow.
456 template <int POW>
Divide64x2Pow(const int64x2_t val)457 inline int64x2_t Divide64x2Pow(const int64x2_t val) {
458 static constexpr int64 FIRST_BIT_VAL = 0x0000000000000001;
459 static const int64x2_t FIRST_BIT = vmovq_n_s64(FIRST_BIT_VAL);
460 const int64x2_t val_sign = vshrq_n_s64(val, 63);
461 const int64x2_t neg_offset = vandq_s64(val_sign, FIRST_BIT);
462 const int64x2_t val_with_offset = vsubq_s64(val, neg_offset);
463 const int64x2_t shifted_wo_offset =
464 vsraq_n_s64(neg_offset, val_with_offset, POW);
465 return shifted_wo_offset;
466 }
467
468 // 32bit x 2 NEON accelerated lerp computation.
469 template <int RESOLUTION>
ComputeLerp32x2(const int32x2_t top_left,const int32x2_t top_right,const int32x2_t bottom_left,const int32x2_t bottom_right,const int32x2_t x_lerp,const int32x2_t y_lerp)470 inline int32x2_t ComputeLerp32x2(const int32x2_t top_left,
471 const int32x2_t top_right,
472 const int32x2_t bottom_left,
473 const int32x2_t bottom_right,
474 const int32x2_t x_lerp,
475 const int32x2_t y_lerp) {
476 static_assert(RESOLUTION < 31, "RESOLUTION must be less than 31");
477 constexpr int32 RESOLUTION_MULT32 = (1 << RESOLUTION);
478 static const int32x2_t RESOLUTION_MULT32x2 = vmov_n_s32(RESOLUTION_MULT32);
479
480 const int64x2_t top_left_x_res = vmull_s32(top_left, RESOLUTION_MULT32x2);
481 const int64x2_t bottom_left_x_res =
482 vmull_s32(bottom_left, RESOLUTION_MULT32x2);
483
484 const int32x2_t top_right_sub_top_left = vsub_s32(top_right, top_left);
485 const int64x2_t top_x_res =
486 vmlal_s32(top_left_x_res, top_right_sub_top_left, x_lerp);
487 const int32x2_t bottom_right_sub_bottom_left =
488 vsub_s32(bottom_right, bottom_left);
489 const int64x2_t bottom_x_res =
490 vmlal_s32(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
491
492 const int64x2_t bottom_sub_top_x_res = vsubq_s64(bottom_x_res, top_x_res);
493 const int64x2_t bottom_sub_top =
494 Divide64x2Pow<RESOLUTION>(bottom_sub_top_x_res);
495 const int32x2_t bottom_sub_top_32 = vqmovn_s64(bottom_sub_top);
496 const int64x2_t top_add_bottom_sub_top_mul_ylerp_x_res =
497 vmlal_s32(top_x_res, bottom_sub_top_32, y_lerp);
498 const int64x2_t retval =
499 Divide64x2PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
500 const int32x2_t retval32 = vqmovn_s64(retval);
501 return retval32;
502 }
503
504 // 8bit x 8 NEON accelerated lerp computation.
505 template <int RESOLUTION>
ComputeLerp8x8(const uint8x8_t top_left8x8,const uint8x8_t top_right8x8,const uint8x8_t bottom_left8x8,const uint8x8_t bottom_right8x8,const int16x8_t x_lerp,const int16x8_t y_lerp)506 inline uint8x8_t ComputeLerp8x8(const uint8x8_t top_left8x8,
507 const uint8x8_t top_right8x8,
508 const uint8x8_t bottom_left8x8,
509 const uint8x8_t bottom_right8x8,
510 const int16x8_t x_lerp,
511 const int16x8_t y_lerp) {
512 static_assert(RESOLUTION < 8, "RESOLUTION must be less than 8");
513 constexpr uint8 RESOLUTION_MULT_VAL = (1 << RESOLUTION);
514 static const uint8x8_t RESOLUTION_MULT = vdup_n_u8(RESOLUTION_MULT_VAL);
515
516 const int16x8_t top_left_x_res =
517 vreinterpretq_s16_u16(vmull_u8(top_left8x8, RESOLUTION_MULT));
518 const int16x8_t bottom_left_x_res =
519 vreinterpretq_s16_u16(vmull_u8(bottom_left8x8, RESOLUTION_MULT));
520
521 const int16x8_t top_right_sub_top_left =
522 vreinterpretq_s16_u16(vsubl_u8(top_right8x8, top_left8x8));
523 const int16x8_t top_x_res =
524 vmlaq_s16(top_left_x_res, top_right_sub_top_left, x_lerp);
525
526 const int16x8_t bottom_right_sub_bottom_left =
527 vreinterpretq_s16_u16(vsubl_u8(bottom_right8x8, bottom_left8x8));
528 const int16x8_t bottom_x_res =
529 vmlaq_s16(bottom_left_x_res, bottom_right_sub_bottom_left, x_lerp);
530
531 const int16x8_t bottom_sub_top_x_res = vsubq_s16(bottom_x_res, top_x_res);
532 const int16x8_t bottom_sub_top =
533 Divide16x8Pow<RESOLUTION>(bottom_sub_top_x_res);
534 const int16x8_t top_add_bottom_sub_top_mul_ylerp_x_res =
535 vmlaq_s16(top_x_res, bottom_sub_top, y_lerp);
536 const int16x8_t retval16 =
537 Divide16x8PowRound<RESOLUTION>(top_add_bottom_sub_top_mul_ylerp_x_res);
538 const uint8x8_t retval = vmovn_u16(vreinterpretq_u16_s16(retval16));
539 return retval;
540 }
541
542 // Requantize 8 x 8 quints to 8 x 32 qints in parallel by neon
543 // Return std::array instead of pointer to leverage return value optimization
Requantize8x8To32Neon(const uint8 * input_ptr,const int64x2_t input_0_64x2,const int32x2_t input_mult_32x2)544 inline std::array<int32x4_t, 2> Requantize8x8To32Neon(
545 const uint8* input_ptr, const int64x2_t input_0_64x2,
546 const int32x2_t input_mult_32x2) {
547 const uint8x8_t input_value_8x8 = vld1_u8(input_ptr);
548 const int16x8_t input_value_16x8 =
549 vreinterpretq_s16_u16(vmovl_u8(input_value_8x8));
550 const int16x4_t input_value_low_16x4 = vget_low_s16(input_value_16x8);
551 const int16x4_t input_value_high_16x4 = vget_high_s16(input_value_16x8);
552 const int32x4_t input_value_low_32x4 = vmovl_s16(input_value_low_16x4);
553 const int32x4_t input_value_high_32x4 = vmovl_s16(input_value_high_16x4);
554 const int32x2_t input_value_low_low_32x2 = vget_low_s32(input_value_low_32x4);
555 const int32x2_t input_value_low_high_32x2 =
556 vget_high_s32(input_value_low_32x4);
557 const int32x2_t input_value_high_low_32x2 =
558 vget_low_s32(input_value_high_32x4);
559 const int32x2_t input_value_high_high_32x2 =
560 vget_high_s32(input_value_high_32x4);
561 const int64x2_t mult_result_low_low_64x2 =
562 vmlal_s32(input_0_64x2, input_value_low_low_32x2, input_mult_32x2);
563 const int64x2_t mult_result_low_high_64x2 =
564 vmlal_s32(input_0_64x2, input_value_low_high_32x2, input_mult_32x2);
565 const int64x2_t mult_result_high_low_64x2 =
566 vmlal_s32(input_0_64x2, input_value_high_low_32x2, input_mult_32x2);
567 const int64x2_t mult_result_high_high_64x2 =
568 vmlal_s32(input_0_64x2, input_value_high_high_32x2, input_mult_32x2);
569 const int32x2_t output_value_low_low_32x2 =
570 vqmovn_s64(mult_result_low_low_64x2);
571 const int32x2_t output_value_low_high_32x2 =
572 vqmovn_s64(mult_result_low_high_64x2);
573 const int32x2_t output_value_high_low_32x2 =
574 vqmovn_s64(mult_result_high_low_64x2);
575 const int32x2_t output_value_high_high_32x2 =
576 vqmovn_s64(mult_result_high_high_64x2);
577 const int32x4_t output_value_low_32x4 =
578 vcombine_s32(output_value_low_low_32x2, output_value_low_high_32x2);
579 const int32x4_t output_value_high_32x4 =
580 vcombine_s32(output_value_high_low_32x2, output_value_high_high_32x2);
581 return std::array<int32x4_t, 2>{
582 {output_value_low_32x4, output_value_high_32x4}};
583 }
584
585 // Speeds up the 8->32bit conversion using fixed-point arithmetic and NEON SIMD
586 // intrinsics for ARM platforms.
587 template <>
588 inline void RequantizeManyInNewRange<quint8, qint32>(
589 const quint8* input, int64 count, float min_input, float max_input,
590 float min_output, float max_output, qint32* output) {
591 // Pre-calculate zero position and multiplier.
592 // Calculate 0 and 1 value in float.
593 const float code_0_float = QuantizedToFloat<quint8>(0, min_input, max_input);
594 const float code_1_float = QuantizedToFloat<quint8>(1, min_input, max_input);
595
596 // Cast 0 and 1 value in int64.
597 const int64 code_0_int64 =
598 FloatToQuantizedUnclamped<qint32>(code_0_float, min_output, max_output);
599 const int64 code_1_int64 =
600 FloatToQuantizedUnclamped<qint32>(code_1_float, min_output, max_output);
601
602 // Calculate multiplier.
603 const int32 mult_int32 = static_cast<int32>(code_1_int64 - code_0_int64);
604
605 // Broadcast 0 position and multiplier to lanes
606 const int64x2_t code_0_64x2 = vmovq_n_s64(code_0_int64);
607 const int32x2_t mult_32x2 = vmov_n_s32(mult_int32);
608
609 int64 i = 0;
610
611 // Use SIMD to requantize array.
612 for (; i < (count - 7); i += 8) {
613 const uint8* input_ptr = &(input->value) + i;
614 int32* output_ptr = &(output->value) + i;
615 const std::array<int32x4_t, 2> output_value =
616 Requantize8x8To32Neon(input_ptr, code_0_64x2, mult_32x2);
617 vst1q_s32(output_ptr + 0, output_value[0]);
618 vst1q_s32(output_ptr + 4, output_value[1]);
619 }
620
621 // Requantize remaining elements in array without SIMD.
622 const int64 lowest_quantized =
623 static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
624 const int64 highest_quantized =
625 static_cast<int64>(Eigen::NumTraits<qint32>::highest());
626
627 for (; i < count; ++i) {
628 const int64 input_value = static_cast<int64>(input[i]);
629 int64 output_value = code_0_int64 + (input_value * mult_int32);
630 output_value = std::max(output_value, lowest_quantized);
631 output_value = std::min(output_value, highest_quantized);
632 output[i] = static_cast<int32>(output_value);
633 }
634 }
635
636 #else
637
638 // If SIMD implementations aren't available, then use these default reference
639 // versions.
640 template <>
641 inline void RequantizeManyInNewRange<qint32, quint8>(
642 const qint32* input, int64 count, float min_input, float max_input,
643 float min_output, float max_output, quint8* output) {
644 RequantizeManyInNewRangeReference(input, count, min_input, max_input,
645 min_output, max_output, output);
646 }
647
648 template <>
649 inline void RequantizeManyInNewRange<quint8, qint32>(
650 const quint8* input, int64 count, float min_input, float max_input,
651 float min_output, float max_output, qint32* output) {
652 RequantizeManyInNewRange8To32BitReference(input, count, min_input, max_input,
653 min_output, max_output, output);
654 }
655
656 #endif
657
658 template <int shift>
659 struct int64_right_shift_op {
EIGEN_EMPTY_STRUCT_CTORint64_right_shift_op660 EIGEN_EMPTY_STRUCT_CTOR(int64_right_shift_op)
661 EIGEN_DEVICE_FUNC
662 EIGEN_STRONG_INLINE const int64 operator()(const int64& a) const {
663 return a >> shift;
664 }
665 };
666
667 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
668 template <class T1, class T2>
RequantizeManyInNewRangeUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min_input,float max_input,float min_output,float max_output,Tensor * output)669 inline void RequantizeManyInNewRangeUsingEigen(
670 const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
671 float max_input, float min_output, float max_output, Tensor* output) {
672 auto input_array = input.flat<T1>();
673 QuantizedToFloatStruct<T1> q2f(min_input, max_input);
674 auto input_float = DEQUANTIZE_WITH_EIGEN(input_array, q2f);
675 FloatToQuantizedStruct<T2> f2q(min_output, max_output);
676 auto input_requantized = QUANTIZE_WITH_EIGEN(input_float, f2q, T2);
677
678 output->flat<T2>().device(device) = input_requantized;
679 }
680
681 // See RequantizeManyInNewRange() for a non-eigen reference implementation.
682 //
683 // Because converting 32-bit accumulated results down to eight bit is a common
684 // case, we have a specialized code path to handle it as efficiently as
685 // possible using only fixed-point math for the inner loop.
686 template <>
687 inline void RequantizeManyInNewRangeUsingEigen<qint32, quint8>(
688 const Eigen::ThreadPoolDevice& device, const Tensor& input, float min_input,
689 float max_input, float min_output, float max_output, Tensor* output) {
690 // Initially we calculate all the constants we need once, before we go into
691 // the inner loop. If this is updated, also update the non-Eigen version.
692 const int fp_shift = 16;
693 const float input_range = max_input - min_input;
694 const float output_range = max_output - min_output;
695 const float recip_output_range =
696 output_range == 0.0 ? 0.0 : (255.0 / output_range);
697 const float input_rezero = (min_input + max_input) / 2.0;
698 const int64 range_scale_fp =
699 output_range == 0.0 ? 0.0
700 : static_cast<int64>(255.0 * (1 << fp_shift) *
701 input_range / output_range);
702 const int64 input_offset_fp =
703 static_cast<int64>(input_rezero * recip_output_range * (1 << fp_shift));
704 const int64 output_offset_fp =
705 output_range == 0.0
706 ? 0
707 : static_cast<int64>((1 << fp_shift) * (min_output * 255.0) /
708 output_range);
709 const int64 rounding_delta = 1 << (fp_shift - 1);
710
711 // Inside this eigen expression we just do minimal adds, multiplies, and
712 // shifts. It should be possible to perform all the calculations in 32-bit
713 // rather than 64, but that's not been implemented yet.
714 auto input_array = input.flat<qint32>();
715 auto fp_value = ((input_array.template cast<int64>() * range_scale_fp)
716 .unaryExpr(int64_right_shift_op<32>())) +
717 (input_offset_fp - output_offset_fp + rounding_delta);
718 auto intermediate = fp_value.unaryExpr(int64_right_shift_op<fp_shift>());
719 auto input_requantized = intermediate.cwiseMax(0LL)
720 .cwiseMin(255LL)
721 .template cast<int32>()
722 .template cast<quint8>();
723 output->flat<quint8>().device(device) = input_requantized;
724 }
725
726 // REQUIRES: 'result->NumElements() == input.NumElements()'
727 template <class T>
FloatTensorToQuantizedInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)728 void FloatTensorToQuantizedInPlaceUsingEigen(
729 const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
730 float max, Tensor* result) {
731 DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
732 auto flat_input = input.flat<float>();
733 auto flat_result = result->flat<T>();
734 DCHECK_EQ(flat_input.size(), flat_result.size());
735
736 FloatToQuantizedStruct<T> f2q(min, max);
737 flat_result.device(device) = QUANTIZE_WITH_EIGEN(flat_input, f2q, T);
738 }
739
740 template <class T>
FloatTensorToQuantizedInPlace(const Tensor & input,float min,float max,Tensor * result)741 void FloatTensorToQuantizedInPlace(const Tensor& input, float min, float max,
742 Tensor* result) {
743 DCHECK_EQ(DataTypeToEnum<T>::v(), result->dtype());
744 auto flat_input = input.flat<float>();
745 auto flat_result = result->flat<T>();
746 const int data_size = flat_input.size();
747 DCHECK(data_size == flat_result.size());
748 for (int i = 0; i < data_size; ++i) {
749 flat_result(i) = FloatToQuantized<T>(flat_input(i), min, max);
750 }
751 }
752
753 template <class T>
FloatTensorToQuantized(const Tensor & input,float min,float max)754 Tensor FloatTensorToQuantized(const Tensor& input, float min, float max) {
755 Tensor result(DataTypeToEnum<T>::v(), input.shape());
756 FloatTensorToQuantizedInPlace<T>(input, min, max, &result);
757 return result;
758 }
759
760 // REQUIRES: 'result->NumElements() == input.NumElements()'
761 template <class T>
QuantizedTensorToFloatInPlaceUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float min,float max,Tensor * result)762 void QuantizedTensorToFloatInPlaceUsingEigen(
763 const Eigen::ThreadPoolDevice& device, const Tensor& input, float min,
764 float max, Tensor* result) {
765 DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
766 auto flat_input = input.flat<T>();
767 auto flat_result = result->flat<float>();
768 const int data_size = flat_input.size();
769 DCHECK(data_size == flat_result.size());
770
771 QuantizedToFloatStruct<T> q2f(min, max);
772 flat_result.device(device) = DEQUANTIZE_WITH_EIGEN(flat_input, q2f);
773 }
774
775 // REQUIRES: 'result->NumElements() == input.NumElements()'
776 template <class T>
QuantizedTensorToFloatInPlace(const Tensor & input,float min,float max,Tensor * result)777 void QuantizedTensorToFloatInPlace(const Tensor& input, float min, float max,
778 Tensor* result) {
779 DCHECK_EQ(DataTypeToEnum<T>::v(), input.dtype());
780 auto flat_input = input.flat<T>();
781 auto flat_result = result->flat<float>();
782 const int data_size = flat_input.size();
783 DCHECK(data_size == flat_result.size());
784 for (int i = 0; i < data_size; ++i) {
785 flat_result(i) = QuantizedToFloat<T>(flat_input(i), min, max);
786 }
787 }
788
789 template <class T>
QuantizedTensorToFloat(const Tensor & input,float min,float max)790 Tensor QuantizedTensorToFloat(const Tensor& input, float min, float max) {
791 Tensor result(DT_FLOAT, input.shape());
792 QuantizedTensorToFloatInPlace<T>(input, min, max, &result);
793 return result;
794 }
795
796 void GetOutputMinAndMaxForQuantizedAdd(float input_min, float input_max,
797 float smaller_input_min,
798 float smaller_input_max,
799 float* output_min, float* output_max);
800
801 // Add <input> and <smaller_input>. If <smaller_input> has fewer elements than
802 // <input>, then it is broadcast onto <input>.
803 template <typename T1, typename T2, typename T3>
QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)804 void QuantizedAddUsingEigen(const Eigen::ThreadPoolDevice& device,
805 const Tensor& input, float input_min,
806 float input_max, const Tensor& smaller_input,
807 float smaller_input_min, float smaller_input_max,
808 Tensor* output, float* output_min,
809 float* output_max) {
810 const auto& input_flat = input.flat<T1>();
811 const auto& smaller_input_flat = smaller_input.flat<T2>();
812 auto output_flat = output->flat<T3>();
813
814 GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
815 smaller_input_max, output_min, output_max);
816 // To do addition properly, we need to compensate for a possibly unbalanced
817 // zero point in the total representation. The quantized value that
818 // represents the real number zero needs to be subtracted before addition to
819 // make sure that the identity of zero + zero = zero holds.
820 const T3 zero_in_total_space =
821 FloatToQuantized<T3>(0.0f, *output_min, *output_max);
822
823 const int64 input_element_count = input.NumElements();
824 const int64 smaller_input_element_count = smaller_input.NumElements();
825
826 QuantizedToFloatStruct<T1> input_q2f(input_min, input_max);
827 QuantizedToFloatStruct<T2> smaller_input_q2f(smaller_input_min,
828 smaller_input_max);
829 FloatToQuantizedStruct<T3> f2q(*output_min, *output_max);
830
831 auto smaller_input_float =
832 DEQUANTIZE_WITH_EIGEN(smaller_input_flat, smaller_input_q2f);
833 auto smaller_input_in_total_space =
834 QUANTIZE_WITH_EIGEN(smaller_input_float, f2q, T3);
835
836 auto input_float = DEQUANTIZE_WITH_EIGEN(input_flat, input_q2f);
837 auto input_in_total_space = QUANTIZE_WITH_EIGEN(input_float, f2q, T3);
838
839 Eigen::array<Eigen::DenseIndex, 1> bcast;
840 bcast[0] = input_element_count / smaller_input_element_count;
841 output_flat.device(device) =
842 input_in_total_space +
843 (smaller_input_in_total_space.broadcast(bcast) + zero_in_total_space);
844 }
845
846 // This is a reference implementation of the bias addition for quantized
847 // buffers, designed to provide a clear specification for the result we
848 // want. We'll want to specialize this for particular hardware, and
849 // probably even fuse it with matrix multiplications in a lot of cases. It's
850 // important to show the clamping behavior we want in particular.
851 template <typename T1, typename T2, typename T3>
QuantizedAdd(const Eigen::ThreadPoolDevice & device,const Tensor & input,float input_min,float input_max,const Tensor & smaller_input,float smaller_input_min,float smaller_input_max,Tensor * output,float * output_min,float * output_max)852 void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input,
853 float input_min, float input_max, const Tensor& smaller_input,
854 float smaller_input_min, float smaller_input_max,
855 Tensor* output, float* output_min, float* output_max) {
856 const auto& input_flat = input.flat<T1>();
857 const auto& smaller_input_flat = smaller_input.flat<T2>();
858 auto output_flat = output->flat<T3>();
859
860 GetOutputMinAndMaxForQuantizedAdd(input_min, input_max, smaller_input_min,
861 smaller_input_max, output_min, output_max);
862 // To do addition properly, we need to compensate for a possibly unbalanced
863 // zero point in the total representation. The quantized value that
864 // represents the real number zero needs to be subtracted before addition to
865 // make sure that the identity of zero + zero = zero holds.
866 const T3 zero_in_total_space =
867 FloatToQuantized<T3>(0.0f, *output_min, *output_max);
868
869 const int64 input_element_count = input.NumElements();
870 const int64 smaller_input_element_count = smaller_input.NumElements();
871
872 float total_min = *output_min;
873 float total_max = *output_max;
874 const size_t how_many_iterations =
875 (input_element_count / smaller_input_element_count);
876 for (size_t iteration = 0; iteration < how_many_iterations; ++iteration) {
877 const size_t offset = iteration * smaller_input_element_count;
878 for (int c = 0; c < smaller_input_element_count; ++c) {
879 const int index = (offset + c);
880 // The two numbers we're going to add can each be in very different
881 // ranges (e.g. the quantized value '127' may represent very different
882 // real numbers in both) so we need to convert them to a common range
883 // before we sum them.
884 const T1 input_value = input_flat(index);
885 const T3 input_in_total_space = RequantizeInNewRange<T1, T3>(
886 input_value, input_min, input_max, total_min, total_max);
887 const T2 smaller_input_value = smaller_input_flat(c);
888 const T3 smaller_input_in_total_space =
889 RequantizeInNewRange<T2, T3>(smaller_input_value, smaller_input_min,
890 smaller_input_max, total_min, total_max);
891 const T3 total_pre = input_in_total_space + smaller_input_in_total_space;
892 // As noted above, we need to compensate for the offset of the actual
893 // zero point in the space we're operating in.
894 const T3 total = total_pre + zero_in_total_space;
895 output_flat(index) = total;
896 }
897 }
898 }
899
900 // See gemmlowp/internal/multi_thread_gemm.h for the semantics of Execute.
901 class TensorflowGemmlowpWorkersPool {
902 public:
TensorflowGemmlowpWorkersPool(thread::ThreadPool * workers)903 TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers)
904 : workers_(workers) {}
905
~TensorflowGemmlowpWorkersPool()906 ~TensorflowGemmlowpWorkersPool() {
907 // This workaround ensures that all worker tasks have exited methods in the
908 // BlockingCounter. Without this, there is a race where the context is torn
909 // down while the counter is in use.
910 counter_to_decrement_when_ready_.Reset(0);
911 }
912
Execute(const std::vector<gemmlowp::Task * > & tasks)913 void Execute(const std::vector<gemmlowp::Task*>& tasks) {
914 assert(!tasks.empty());
915 assert(workers_ != nullptr);
916 counter_to_decrement_when_ready_.Reset(tasks.size());
917 for (gemmlowp::Task* task : tasks) {
918 workers_->Schedule([this, task]() {
919 // TODO(cwhipkey): get a local_allocator from a thread local storage.
920 gemmlowp::Allocator local_allocator;
921 CHECK(task != nullptr);
922 task->local_allocator = &local_allocator;
923 task->Run();
924 counter_to_decrement_when_ready_.DecrementCount();
925 });
926 }
927 counter_to_decrement_when_ready_.Wait();
928 for (gemmlowp::Task* task : tasks) {
929 delete task;
930 }
931 }
932
933 private:
934 thread::ThreadPool* const workers_;
935
936 // The BlockingCounter used to wait for the workers.
937 gemmlowp::BlockingCounter counter_to_decrement_when_ready_;
938
939 TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmlowpWorkersPool);
940 };
941
942 class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase {
943 public:
TensorflowGemmContext(int num_threads,thread::ThreadPool * workers)944 TensorflowGemmContext(int num_threads, thread::ThreadPool* workers)
945 : workers_pool_(workers) {
946 set_max_num_threads(num_threads);
947 }
948
workers_pool()949 TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; }
950
951 private:
952 TensorflowGemmlowpWorkersPool workers_pool_;
953
954 TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmContext);
955 };
956
957 } // namespace tensorflow
958
959 #endif // TENSORFLOW_CORE_KERNELS_QUANTIZATION_UTILS_H_
960