• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements a quantized eight-bit version of the matmul operation.
17 
18 #define EIGEN_USE_THREADS
19 
20 #if defined(__ARM_NEON__) || defined(__ARM_NEON)
21 #define USE_NEON
22 #define QUANTIZED_ADD_USE_NEON
23 #include <arm_neon.h>
24 #endif
25 
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/kernels/meta_support.h"
29 #include "tensorflow/core/kernels/quantization_utils.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/util/bcast.h"
32 
33 // There are implementations for three broadcast patterns for add:
34 //  - Scalar * Array
35 //  - Array * Array
36 //  - Array * Shorter Array (repeated to match first)
37 //
38 // These handle a lot of common broadcast patterns, and we have NEON SIMD
39 // versions to accelerate performance on ARM platforms.
40 
41 namespace tensorflow {
42 namespace {
43 
44 template <class T, class Toutput>
ScalarAddition(OpKernelContext * context,const T * full_input,float full_input_min,float full_input_max,int64 num_elements,T scalar_input,float scalar_input_min,float scalar_input_max,float output_min,float output_max,Toutput * output)45 void ScalarAddition(OpKernelContext* context, const T* full_input,
46                     float full_input_min, float full_input_max,
47                     int64 num_elements, T scalar_input, float scalar_input_min,
48                     float scalar_input_max, float output_min, float output_max,
49                     Toutput* output) {
50   const Toutput scalar_in_output_range = RequantizeInNewRange<T, Toutput>(
51       scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
52   for (int i = 0; i < num_elements; ++i) {
53     const Toutput full_input_in_output_range = RequantizeInNewRange<T, Toutput>(
54         full_input[i], full_input_min, full_input_max, output_min, output_max);
55     output[i] = full_input_in_output_range + scalar_in_output_range;
56   }
57 }
58 
59 #ifdef QUANTIZED_ADD_USE_NEON
60 
61 template <>
ScalarAddition(OpKernelContext * context,const quint8 * full_input,float full_input_min,float full_input_max,int64 num_elements,quint8 scalar_input,float scalar_input_min,float scalar_input_max,float output_min,float output_max,qint32 * output)62 void ScalarAddition(OpKernelContext* context, const quint8* full_input,
63                     float full_input_min, float full_input_max,
64                     int64 num_elements, quint8 scalar_input,
65                     float scalar_input_min, float scalar_input_max,
66                     float output_min, float output_max, qint32* output) {
67   const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>(
68       scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
69 
70   const float input_0_float =
71       QuantizedToFloat<quint8>(0, full_input_min, full_input_max);
72   const float input_1_float =
73       QuantizedToFloat<quint8>(1, full_input_min, full_input_max);
74   const int64 input_0_int64 =
75       FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max);
76   const int64 input_1_int64 =
77       FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max);
78   const int32 input_mult_int32 = input_1_int64 - input_0_int64;
79 
80   const int64 lowest_quantized =
81       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
82   const int64 highest_quantized =
83       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
84 
85   const int64x2_t input_0_64x2 = vmovq_n_s64(input_0_int64);
86   const int32x2_t input_mult_32x2 = vmov_n_s32(input_mult_int32);
87   const int32x4_t scalar_in_output_range_32x4 =
88       vmovq_n_s32(scalar_in_output_range);
89   int64 i = 0;
90   for (; i < (num_elements - 7); i += 8) {
91     const uint8* full_input_ptr = &(full_input->value) + i;
92     const std::array<int32x4_t, 2> output_value =
93         Requantize8x8To32Neon(full_input_ptr, input_0_64x2, input_mult_32x2);
94     const int32x4_t result_low_32x4 =
95         vaddq_s32(output_value[0], scalar_in_output_range_32x4);
96     const int32x4_t result_high_32x4 =
97         vaddq_s32(output_value[1], scalar_in_output_range_32x4);
98     int32* output_ptr = &(output->value) + i;
99     vst1q_s32(output_ptr + 0, result_low_32x4);
100     vst1q_s32(output_ptr + 4, result_high_32x4);
101   }
102   for (; i < num_elements; ++i) {
103     const int64 full_input_value = static_cast<int64>(full_input[i]);
104     int64 full_input_in_output_range_64 =
105         input_0_int64 + (full_input_value * input_mult_int32);
106     full_input_in_output_range_64 =
107         std::max(full_input_in_output_range_64, lowest_quantized);
108     full_input_in_output_range_64 =
109         std::min(full_input_in_output_range_64, highest_quantized);
110     const int32 full_input_in_output_range =
111         static_cast<int32>(full_input_in_output_range_64);
112     output[i] = full_input_in_output_range + scalar_in_output_range;
113   }
114 }
115 
116 #else  // QUANTIZED_ADD_USE_NEON
117 
118 template <>
ScalarAddition(OpKernelContext * context,const quint8 * full_input,float full_input_min,float full_input_max,int64 num_elements,quint8 scalar_input,float scalar_input_min,float scalar_input_max,float output_min,float output_max,qint32 * output)119 void ScalarAddition(OpKernelContext* context, const quint8* full_input,
120                     float full_input_min, float full_input_max,
121                     int64 num_elements, quint8 scalar_input,
122                     float scalar_input_min, float scalar_input_max,
123                     float output_min, float output_max, qint32* output) {
124   const int32 scalar_in_output_range = RequantizeInNewRange<quint8, qint32>(
125       scalar_input, scalar_input_min, scalar_input_max, output_min, output_max);
126 
127   const float input_0_float =
128       QuantizedToFloat<quint8>(0, full_input_min, full_input_max);
129   const float input_1_float =
130       QuantizedToFloat<quint8>(1, full_input_min, full_input_max);
131   const int64 input_0_int64 =
132       FloatToQuantizedUnclamped<qint32>(input_0_float, output_min, output_max);
133   const int64 input_1_int64 =
134       FloatToQuantizedUnclamped<qint32>(input_1_float, output_min, output_max);
135   const int32 input_mult_int32 = input_1_int64 - input_0_int64;
136 
137   const int64 lowest_quantized =
138       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
139   const int64 highest_quantized =
140       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
141 
142   for (int i = 0; i < num_elements; ++i) {
143     const int64 full_input_value = static_cast<int64>(full_input[i]);
144     int64 full_input_in_output_range_64 =
145         input_0_int64 + (full_input_value * input_mult_int32);
146     full_input_in_output_range_64 =
147         std::max(full_input_in_output_range_64, lowest_quantized);
148     full_input_in_output_range_64 =
149         std::min(full_input_in_output_range_64, highest_quantized);
150     const int32 full_input_in_output_range =
151         static_cast<int32>(full_input_in_output_range_64);
152     output[i] = full_input_in_output_range + scalar_in_output_range;
153   }
154 }
155 
156 #endif  // QUANTIZED_ADD_USE_NEON
157 
158 template <class T, class Toutput>
VectorAddition(OpKernelContext * context,const T * x_data,float min_x,float max_x,const T * y_data,float min_y,float max_y,int64 num_elements,float output_min,float output_max,Toutput * output)159 void VectorAddition(OpKernelContext* context, const T* x_data, float min_x,
160                     float max_x, const T* y_data, float min_y, float max_y,
161                     int64 num_elements, float output_min, float output_max,
162                     Toutput* output) {
163   for (int i = 0; i < num_elements; ++i) {
164     const Toutput x_in_output_range = RequantizeInNewRange<T, Toutput>(
165         x_data[i], min_x, max_x, output_min, output_max);
166     const Toutput y_in_output_range = RequantizeInNewRange<T, Toutput>(
167         y_data[i], min_y, max_y, output_min, output_max);
168     output[i] = x_in_output_range + y_in_output_range;
169   }
170 }
171 
172 #ifdef QUANTIZED_ADD_USE_NEON
173 
174 template <>
VectorAddition(OpKernelContext * context,const quint8 * x_data,float min_x,float max_x,const quint8 * y_data,float min_y,float max_y,int64 num_elements,float output_min,float output_max,qint32 * output)175 void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x,
176                     float max_x, const quint8* y_data, float min_y, float max_y,
177                     int64 num_elements, float output_min, float output_max,
178                     qint32* output) {
179   const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x);
180   const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x);
181   const int64 x_0_int64 =
182       FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max);
183   const int64 x_1_int64 =
184       FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max);
185   const int32 x_mult_int32 = x_1_int64 - x_0_int64;
186 
187   const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y);
188   const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y);
189   const int64 y_0_int64 =
190       FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max);
191   const int64 y_1_int64 =
192       FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max);
193   const int32 y_mult_int32 = y_1_int64 - y_0_int64;
194 
195   const int64 lowest_quantized =
196       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
197   const int64 highest_quantized =
198       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
199 
200   const int64x2_t x_0_64x2 = vmovq_n_s64(x_0_int64);
201   const int32x2_t x_mult_32x2 = vmov_n_s32(x_mult_int32);
202 
203   const int64x2_t y_0_64x2 = vmovq_n_s64(y_0_int64);
204   const int32x2_t y_mult_32x2 = vmov_n_s32(y_mult_int32);
205 
206   int64 i = 0;
207   for (; i < (num_elements - 7); i += 8) {
208     const uint8* x_ptr = &(x_data->value) + i;
209     const std::array<int32x4_t, 2> x_output_value =
210         Requantize8x8To32Neon(x_ptr, x_0_64x2, x_mult_32x2);
211     const uint8* y_ptr = &(y_data->value) + i;
212     const std::array<int32x4_t, 2> y_output_value =
213         Requantize8x8To32Neon(y_ptr, y_0_64x2, y_mult_32x2);
214 
215     const int32x4_t result_low_32x4 =
216         vaddq_s32(x_output_value[0], y_output_value[0]);
217     const int32x4_t result_high_32x4 =
218         vaddq_s32(x_output_value[1], y_output_value[1]);
219     int32* output_ptr = &(output->value) + i;
220     vst1q_s32(output_ptr + 0, result_low_32x4);
221     vst1q_s32(output_ptr + 4, result_high_32x4);
222   }
223 
224   for (; i < num_elements; ++i) {
225     const int64 x_value = static_cast<int64>(x_data[i]);
226     int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32);
227     x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized);
228     x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized);
229     const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64);
230 
231     const int64 y_value = static_cast<int64>(y_data[i]);
232     int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32);
233     y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized);
234     y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized);
235     const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64);
236 
237     output[i] = x_in_output_range + y_in_output_range;
238   }
239 }
240 
241 #else  // QUANTIZED_ADD_USE_NEON
242 
243 template <>
VectorAddition(OpKernelContext * context,const quint8 * x_data,float min_x,float max_x,const quint8 * y_data,float min_y,float max_y,int64 num_elements,float output_min,float output_max,qint32 * output)244 void VectorAddition(OpKernelContext* context, const quint8* x_data, float min_x,
245                     float max_x, const quint8* y_data, float min_y, float max_y,
246                     int64 num_elements, float output_min, float output_max,
247                     qint32* output) {
248   const float x_0_float = QuantizedToFloat<quint8>(0, min_x, max_x);
249   const float x_1_float = QuantizedToFloat<quint8>(1, min_x, max_x);
250   const int64 x_0_int64 =
251       FloatToQuantizedUnclamped<qint32>(x_0_float, output_min, output_max);
252   const int64 x_1_int64 =
253       FloatToQuantizedUnclamped<qint32>(x_1_float, output_min, output_max);
254   const int32 x_mult_int32 = x_1_int64 - x_0_int64;
255 
256   const float y_0_float = QuantizedToFloat<quint8>(0, min_y, max_y);
257   const float y_1_float = QuantizedToFloat<quint8>(1, min_y, max_y);
258   const int64 y_0_int64 =
259       FloatToQuantizedUnclamped<qint32>(y_0_float, output_min, output_max);
260   const int64 y_1_int64 =
261       FloatToQuantizedUnclamped<qint32>(y_1_float, output_min, output_max);
262   const int32 y_mult_int32 = y_1_int64 - y_0_int64;
263 
264   const int64 lowest_quantized =
265       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
266   const int64 highest_quantized =
267       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
268 
269   for (int i = 0; i < num_elements; ++i) {
270     const int64 x_value = static_cast<int64>(x_data[i]);
271     int64 x_in_output_range_64 = x_0_int64 + (x_value * x_mult_int32);
272     x_in_output_range_64 = std::max(x_in_output_range_64, lowest_quantized);
273     x_in_output_range_64 = std::min(x_in_output_range_64, highest_quantized);
274     const int32 x_in_output_range = static_cast<int32>(x_in_output_range_64);
275 
276     const int64 y_value = static_cast<int64>(y_data[i]);
277     int64 y_in_output_range_64 = y_0_int64 + (y_value * y_mult_int32);
278     y_in_output_range_64 = std::max(y_in_output_range_64, lowest_quantized);
279     y_in_output_range_64 = std::min(y_in_output_range_64, highest_quantized);
280     const int32 y_in_output_range = static_cast<int32>(y_in_output_range_64);
281 
282     output[i] = x_in_output_range + y_in_output_range;
283   }
284 }
285 
286 #endif  // QUANTIZED_ADD_USE_NEON
287 
288 template <class T, class Toutput>
VectorTensorAddition(const T * vector_data,float min_vector,float max_vector,int64 vector_num_elements,const T * tensor_data,float min_tensor,float max_tensor,int64 tensor_num_elements,float output_min,float output_max,Toutput * output)289 void VectorTensorAddition(const T* vector_data, float min_vector,
290                           float max_vector, int64 vector_num_elements,
291                           const T* tensor_data, float min_tensor,
292                           float max_tensor, int64 tensor_num_elements,
293                           float output_min, float output_max, Toutput* output) {
294   for (int i = 0; i < tensor_num_elements; ++i) {
295     const int64 vector_i = i % vector_num_elements;
296     const Toutput vector_in_output_range = RequantizeInNewRange<T, Toutput>(
297         vector_data[vector_i], min_vector, max_vector, output_min, output_max);
298     const Toutput tensor_in_output_range = RequantizeInNewRange<T, Toutput>(
299         tensor_data[i], min_tensor, max_tensor, output_min, output_max);
300     output[i] = vector_in_output_range + tensor_in_output_range;
301   }
302 }
303 
304 #ifdef QUANTIZED_ADD_USE_NEON
305 
306 template <>
VectorTensorAddition(const quint8 * vector_data,float min_vector,float max_vector,int64 vector_num_elements,const quint8 * tensor_data,float min_tensor,float max_tensor,int64 tensor_num_elements,float output_min,float output_max,qint32 * output)307 void VectorTensorAddition(const quint8* vector_data, float min_vector,
308                           float max_vector, int64 vector_num_elements,
309                           const quint8* tensor_data, float min_tensor,
310                           float max_tensor, int64 tensor_num_elements,
311                           float output_min, float output_max, qint32* output) {
312   const float vector_0_float =
313       QuantizedToFloat<quint8>(0, min_vector, max_vector);
314   const float vector_1_float =
315       QuantizedToFloat<quint8>(1, min_vector, max_vector);
316   const int64 vector_0_int64 =
317       FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max);
318   const int64 vector_1_int64 =
319       FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max);
320   const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64;
321 
322   const float tensor_0_float =
323       QuantizedToFloat<quint8>(0, min_tensor, max_tensor);
324   const float tensor_1_float =
325       QuantizedToFloat<quint8>(1, min_tensor, max_tensor);
326   const int64 tensor_0_int64 =
327       FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max);
328   const int64 tensor_1_int64 =
329       FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max);
330   const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64;
331 
332   const int64 lowest_quantized =
333       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
334   const int64 highest_quantized =
335       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
336 
337   const int64x2_t vector_0_64x2 = vmovq_n_s64(vector_0_int64);
338   const int32x2_t vector_mult_32x2 = vmov_n_s32(vector_mult_int32);
339 
340   const int64x2_t tensor_0_64x2 = vmovq_n_s64(tensor_0_int64);
341   const int32x2_t tensor_mult_32x2 = vmov_n_s32(tensor_mult_int32);
342 
343   for (int64 base_i = 0; base_i < tensor_num_elements;
344        base_i += vector_num_elements) {
345     int64 i = base_i;
346     int64 vector_i = 0;
347     for (; vector_i < (vector_num_elements - 7); vector_i += 8, i += 8) {
348       const uint8* vector_ptr = &(vector_data->value) + vector_i;
349       const std::array<int32x4_t, 2> vector_output_value =
350           Requantize8x8To32Neon(vector_ptr, vector_0_64x2, vector_mult_32x2);
351       const uint8* tensor_ptr = &(tensor_data->value) + i;
352       const std::array<int32x4_t, 2> tensor_output_value =
353           Requantize8x8To32Neon(tensor_ptr, tensor_0_64x2, tensor_mult_32x2);
354 
355       const int32x4_t result_low_32x4 =
356           vaddq_s32(vector_output_value[0], tensor_output_value[0]);
357       const int32x4_t result_high_32x4 =
358           vaddq_s32(vector_output_value[1], tensor_output_value[1]);
359       int32* output_ptr = &(output->value) + i;
360       vst1q_s32(output_ptr + 0, result_low_32x4);
361       vst1q_s32(output_ptr + 4, result_high_32x4);
362     }
363     for (; vector_i < vector_num_elements; ++vector_i, ++i) {
364       const int64 vector_value = static_cast<int64>(vector_data[vector_i]);
365       int64 vector_in_output_range_64 =
366           vector_0_int64 + (vector_value * vector_mult_int32);
367       vector_in_output_range_64 =
368           std::max(vector_in_output_range_64, lowest_quantized);
369       vector_in_output_range_64 =
370           std::min(vector_in_output_range_64, highest_quantized);
371       const int32 vector_in_output_range =
372           static_cast<int32>(vector_in_output_range_64);
373 
374       const int64 tensor_value = static_cast<int64>(tensor_data[i]);
375       int64 tensor_in_output_range_64 =
376           tensor_0_int64 + (tensor_value * tensor_mult_int32);
377       tensor_in_output_range_64 =
378           std::max(tensor_in_output_range_64, lowest_quantized);
379       tensor_in_output_range_64 =
380           std::min(tensor_in_output_range_64, highest_quantized);
381       const int32 tensor_in_output_range =
382           static_cast<int32>(tensor_in_output_range_64);
383 
384       output[i] = vector_in_output_range + tensor_in_output_range;
385     }
386   }
387 }
388 
389 #else  // QUANTIZED_ADD_USE_NEON
390 
391 template <>
VectorTensorAddition(const quint8 * vector_data,float min_vector,float max_vector,int64 vector_num_elements,const quint8 * tensor_data,float min_tensor,float max_tensor,int64 tensor_num_elements,float output_min,float output_max,qint32 * output)392 void VectorTensorAddition(const quint8* vector_data, float min_vector,
393                           float max_vector, int64 vector_num_elements,
394                           const quint8* tensor_data, float min_tensor,
395                           float max_tensor, int64 tensor_num_elements,
396                           float output_min, float output_max, qint32* output) {
397   const float vector_0_float =
398       QuantizedToFloat<quint8>(0, min_vector, max_vector);
399   const float vector_1_float =
400       QuantizedToFloat<quint8>(1, min_vector, max_vector);
401   const int64 vector_0_int64 =
402       FloatToQuantizedUnclamped<qint32>(vector_0_float, output_min, output_max);
403   const int64 vector_1_int64 =
404       FloatToQuantizedUnclamped<qint32>(vector_1_float, output_min, output_max);
405   const int32 vector_mult_int32 = vector_1_int64 - vector_0_int64;
406 
407   const float tensor_0_float =
408       QuantizedToFloat<quint8>(0, min_tensor, max_tensor);
409   const float tensor_1_float =
410       QuantizedToFloat<quint8>(1, min_tensor, max_tensor);
411   const int64 tensor_0_int64 =
412       FloatToQuantizedUnclamped<qint32>(tensor_0_float, output_min, output_max);
413   const int64 tensor_1_int64 =
414       FloatToQuantizedUnclamped<qint32>(tensor_1_float, output_min, output_max);
415   const int32 tensor_mult_int32 = tensor_1_int64 - tensor_0_int64;
416 
417   const int64 lowest_quantized =
418       static_cast<int64>(Eigen::NumTraits<qint32>::lowest());
419   const int64 highest_quantized =
420       static_cast<int64>(Eigen::NumTraits<qint32>::highest());
421 
422   for (int i = 0; i < tensor_num_elements; ++i) {
423     const int64 vector_i = i % vector_num_elements;
424     const int64 vector_value = static_cast<int64>(vector_data[vector_i]);
425     int64 vector_in_output_range_64 =
426         vector_0_int64 + (vector_value * vector_mult_int32);
427     vector_in_output_range_64 =
428         std::max(vector_in_output_range_64, lowest_quantized);
429     vector_in_output_range_64 =
430         std::min(vector_in_output_range_64, highest_quantized);
431     const int32 vector_in_output_range =
432         static_cast<int32>(vector_in_output_range_64);
433 
434     const int64 tensor_value = static_cast<int64>(tensor_data[i]);
435     int64 tensor_in_output_range_64 =
436         tensor_0_int64 + (tensor_value * tensor_mult_int32);
437     tensor_in_output_range_64 =
438         std::max(tensor_in_output_range_64, lowest_quantized);
439     tensor_in_output_range_64 =
440         std::min(tensor_in_output_range_64, highest_quantized);
441     const int32 tensor_in_output_range =
442         static_cast<int32>(tensor_in_output_range_64);
443 
444     output[i] = vector_in_output_range + tensor_in_output_range;
445   }
446 }
447 
448 #endif  // QUANTIZED_ADD_USE_NEON
449 
450 }  // namespace
451 
452 template <class T, class Toutput>
453 class QuantizedAddOp : public OpKernel {
454  public:
QuantizedAddOp(OpKernelConstruction * context)455   explicit QuantizedAddOp(OpKernelConstruction* context) : OpKernel(context) {}
456 
Compute(OpKernelContext * context)457   void Compute(OpKernelContext* context) override {
458     const Tensor& x = context->input(0);
459     const Tensor& y = context->input(1);
460     const float min_x = context->input(2).flat<float>()(0);
461     const float max_x = context->input(3).flat<float>()(0);
462     const float min_y = context->input(4).flat<float>()(0);
463     const float max_y = context->input(5).flat<float>()(0);
464 
465     BCast bcast(BCast::FromShape(x.shape()), BCast::FromShape(y.shape()));
466     if (!bcast.IsValid()) {
467       context->SetStatus(errors::InvalidArgument(
468           "Incompatible shapes: ", x.shape().DebugString(), " vs. ",
469           y.shape().DebugString()));
470       return;
471     }
472     Tensor* z;
473     OP_REQUIRES_OK(context, context->allocate_output(
474                                 0, BCast::ToShape(bcast.output_shape()), &z));
475 
476     // Make sure that we have valid quantization ranges for the input buffers.
477     // If the difference between the min and max is negative or zero, it makes
478     // it hard to do meaningful intermediate operations on the values.
479     OP_REQUIRES(context, (max_x > min_x),
480                 errors::InvalidArgument("max_x must be larger than min_x."));
481     OP_REQUIRES(context, (max_y > min_y),
482                 errors::InvalidArgument("max_y must be larger than min_y."));
483     const T* x_data = x.flat<T>().data();
484     const T* y_data = y.flat<T>().data();
485     Toutput* z_data = z->flat<Toutput>().data();
486 
487     // We want the range of the output to be symmetrical around zero so that
488     // adding zero leaves the result unchanged, and to contain the largest of
489     // the two input values with some room to spare.
490     const float smallest_min = std::min(min_x, min_y);
491     const float largest_max = std::max(max_x, max_y);
492     const float biggest_range =
493         std::max(std::abs(smallest_min), std::abs(largest_max));
494     const float output_range = (biggest_range * (1 << 14));
495     const float min_z_value = -output_range;
496     const float max_z_value = output_range;
497 
498     const int ndims = bcast.x_reshape().size();
499     if (ndims <= 1) {
500       if (x.NumElements() == 1) {
501         ScalarAddition<T, Toutput>(context, y_data, min_y, max_y,
502                                    y.NumElements(), x_data[0], min_x, max_x,
503                                    min_z_value, max_z_value, z_data);
504       } else if (y.NumElements() == 1) {
505         ScalarAddition<T, Toutput>(context, x_data, min_x, max_x,
506                                    x.NumElements(), y_data[0], min_y, max_y,
507                                    min_z_value, max_z_value, z_data);
508       } else {
509         VectorAddition<T, Toutput>(context, x_data, min_x, max_x, y_data, min_y,
510                                    max_y, x.NumElements(), min_z_value,
511                                    max_z_value, z_data);
512       }
513     } else if (ndims == 2) {
514       const T* vector_data;
515       int64 vector_num_elements;
516       float vector_min;
517       float vector_max;
518       const T* tensor_data;
519       int64 tensor_num_elements;
520       float tensor_min;
521       float tensor_max;
522       if (x.NumElements() < y.NumElements()) {
523         vector_data = x_data;
524         vector_num_elements = x.NumElements();
525         vector_min = min_x;
526         vector_max = max_x;
527         tensor_data = y_data;
528         tensor_num_elements = y.NumElements();
529         tensor_min = min_y;
530         tensor_max = max_y;
531       } else {
532         vector_data = y_data;
533         vector_num_elements = y.NumElements();
534         vector_min = min_y;
535         vector_max = max_y;
536         tensor_data = x_data;
537         tensor_num_elements = x.NumElements();
538         tensor_min = min_x;
539         tensor_max = max_x;
540       }
541       VectorTensorAddition<T, Toutput>(
542           vector_data, vector_min, vector_max, vector_num_elements, tensor_data,
543           tensor_min, tensor_max, tensor_num_elements, min_z_value, max_z_value,
544           z_data);
545     } else {
546       LOG(INFO) << "ndims=" << ndims;
547       LOG(INFO) << "bcast.x_reshape()="
548                 << TensorShape(bcast.x_reshape()).DebugString();
549       LOG(INFO) << "bcast.y_reshape()="
550                 << TensorShape(bcast.y_reshape()).DebugString();
551       LOG(INFO) << "bcast.x_bcast()="
552                 << TensorShape(bcast.x_bcast()).DebugString();
553       LOG(INFO) << "bcast.y_bcast()="
554                 << TensorShape(bcast.y_bcast()).DebugString();
555 
556       context->SetStatus(errors::Unimplemented(
557           "Broadcast between ", context->input(0).shape().DebugString(),
558           " and ", context->input(1).shape().DebugString(),
559           " is not supported yet."));
560       return;
561     }
562 
563     Tensor* z_min = nullptr;
564     OP_REQUIRES_OK(context, context->allocate_output(1, {}, &z_min));
565     z_min->flat<float>()(0) = min_z_value;
566 
567     Tensor* z_max = nullptr;
568     OP_REQUIRES_OK(context, context->allocate_output(2, {}, &z_max));
569     z_max->flat<float>()(0) = max_z_value;
570   }
571 };
572 
573 REGISTER_KERNEL_BUILDER(Name("QuantizedAdd")
574                             .Device(DEVICE_CPU)
575                             .TypeConstraint<quint8>("T1")
576                             .TypeConstraint<quint8>("T2")
577                             .TypeConstraint<qint32>("Toutput"),
578                         QuantizedAddOp<quint8, qint32>);
579 
580 }  // namespace tensorflow
581