• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_SUB_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_SUB_H_
17 
18 #include <algorithm>
19 
20 #include "ruy/profiler/instrumentation.h"  // from @ruy
21 #include "tensorflow/lite/kernels/internal/common.h"
22 #include "tensorflow/lite/kernels/internal/compatibility.h"
23 #include "tensorflow/lite/kernels/internal/optimized/avx2_quantization_utils.h"
24 #include "tensorflow/lite/kernels/internal/reference/sub.h"
25 #include "tensorflow/lite/kernels/internal/types.h"
26 
27 namespace tflite {
28 namespace optimized_integer_ops {
29 
SubElementwiseInt16(int size,const ArithmeticParams & params,const int16 * input1_data,const int16 * input2_data,int16 * output_data)30 inline void SubElementwiseInt16(int size, const ArithmeticParams& params,
31                                 const int16* input1_data,
32                                 const int16* input2_data, int16* output_data) {
33   ruy::profiler::ScopeLabel label("SubElementwiseInt16/16bit");
34   int i = 0;
35   TFLITE_DCHECK_GT(params.input1_offset, -32768);
36   TFLITE_DCHECK_GT(params.input2_offset, -32768);
37   TFLITE_DCHECK_LT(params.input1_offset, 32768);
38   TFLITE_DCHECK_LT(params.input2_offset, 32768);
39 
40 #ifdef __AVX2__
41   const int32_t input1_left_shift = params.left_shift + params.input1_shift;
42   const int32_t input2_left_shift = params.left_shift + params.input2_shift;
43   const __m256i input1_offset = _mm256_set1_epi32(params.input1_offset);
44   const __m256i input2_offset = _mm256_set1_epi32(params.input2_offset);
45   const __m256i output_offset = _mm256_set1_epi32(params.output_offset);
46   const __m256i clamp_max_v =
47       _mm256_set1_epi32(params.quantized_activation_max);
48   const __m256i clamp_min_v =
49       _mm256_set1_epi32(params.quantized_activation_min);
50 
51   for (; i <= size - 16; i += 16) {
52     const __m256i input1_val_original =
53         _mm256_loadu_si256(reinterpret_cast<__m256i const*>(input1_data + i));
54     const __m256i input2_val_original =
55         _mm256_loadu_si256(reinterpret_cast<__m256i const*>(input2_data + i));
56 
57     __m256i s11 =
58         _mm256_cvtepi16_epi32(_mm256_castsi256_si128(input1_val_original));
59     __m256i s12 =
60         _mm256_cvtepi16_epi32(_mm256_extracti128_si256(input1_val_original, 1));
61     __m256i s21 =
62         _mm256_cvtepi16_epi32(_mm256_castsi256_si128(input2_val_original));
63     __m256i s22 =
64         _mm256_cvtepi16_epi32(_mm256_extracti128_si256(input2_val_original, 1));
65 
66     s11 = _mm256_add_epi32(s11, input1_offset);
67     s12 = _mm256_add_epi32(s12, input1_offset);
68     s21 = _mm256_add_epi32(s21, input2_offset);
69     s22 = _mm256_add_epi32(s22, input2_offset);
70 
71     s11 = avx2_utils::MultiplyByQuantizedMultiplier(
72         s11, params.input1_multiplier, input1_left_shift);
73     s12 = avx2_utils::MultiplyByQuantizedMultiplier(
74         s12, params.input1_multiplier, input1_left_shift);
75     s21 = avx2_utils::MultiplyByQuantizedMultiplier(
76         s21, params.input2_multiplier, input2_left_shift);
77     s22 = avx2_utils::MultiplyByQuantizedMultiplier(
78         s22, params.input2_multiplier, input2_left_shift);
79 
80     __m256i s1 = _mm256_sub_epi32(s11, s21);
81     __m256i s2 = _mm256_sub_epi32(s12, s22);
82 
83     s1 = avx2_utils::MultiplyByQuantizedMultiplier(s1, params.output_multiplier,
84                                                    params.output_shift);
85     s2 = avx2_utils::MultiplyByQuantizedMultiplier(s2, params.output_multiplier,
86                                                    params.output_shift);
87 
88     s1 = _mm256_add_epi32(s1, output_offset);
89     s2 = _mm256_add_epi32(s2, output_offset);
90 
91     s1 = _mm256_min_epi32(s1, clamp_max_v);
92     s1 = _mm256_max_epi32(s1, clamp_min_v);
93     s2 = _mm256_min_epi32(s2, clamp_max_v);
94     s2 = _mm256_max_epi32(s2, clamp_min_v);
95 
96     avx2_utils::CastInt32ToInt16AndStore(output_data + i, s1);
97     avx2_utils::CastInt32ToInt16AndStore(output_data + i + 8, s2);
98   }
99 #endif  // __AVX2__
100 
101   for (; i < size; ++i) {
102     const int32_t input1_val = params.input1_offset + input1_data[i];
103     const int32_t input2_val = params.input2_offset + input2_data[i];
104     const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
105     const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
106     const int32_t scaled_input1_val =
107         MultiplyByQuantizedMultiplierSmallerThanOneExp(
108             shifted_input1_val, params.input1_multiplier, params.input1_shift);
109     const int32_t scaled_input2_val =
110         MultiplyByQuantizedMultiplierSmallerThanOneExp(
111             shifted_input2_val, params.input2_multiplier, params.input2_shift);
112     const int32_t raw_sum = scaled_input1_val - scaled_input2_val;
113     const int32_t raw_output =
114         MultiplyByQuantizedMultiplierSmallerThanOneExp(
115             raw_sum, params.output_multiplier, params.output_shift) +
116         params.output_offset;
117     const int32_t clamped_output =
118         std::min(params.quantized_activation_max,
119                  std::max(params.quantized_activation_min, raw_output));
120     output_data[i] = static_cast<int16>(clamped_output);
121   }
122 }
123 
BroadcastSubFiveFold(const ArithmeticParams & unswitched_params,const RuntimeShape & input1_shape,const int16 * unswitched_input1_data,const RuntimeShape & input2_shape,const int16 * unswitched_input2_data,const RuntimeShape & output_shape,int16 * output_data)124 inline void BroadcastSubFiveFold(const ArithmeticParams& unswitched_params,
125                                  const RuntimeShape& input1_shape,
126                                  const int16* unswitched_input1_data,
127                                  const RuntimeShape& input2_shape,
128                                  const int16* unswitched_input2_data,
129                                  const RuntimeShape& output_shape,
130                                  int16* output_data) {
131   ruy::profiler::ScopeLabel label("BroadcastSubFiveFold/16bit");
132 
133   ArithmeticParams switched_params = unswitched_params;
134   switched_params.input1_offset = unswitched_params.input2_offset;
135   switched_params.input1_multiplier = unswitched_params.input2_multiplier;
136   switched_params.input1_shift = unswitched_params.input2_shift;
137   switched_params.input2_offset = unswitched_params.input1_offset;
138   switched_params.input2_multiplier = unswitched_params.input1_multiplier;
139   switched_params.input2_shift = unswitched_params.input1_shift;
140 
141   const bool use_unswitched =
142       unswitched_params.broadcast_category ==
143       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
144 
145   const ArithmeticParams& params =
146       use_unswitched ? unswitched_params : switched_params;
147   const int16_t* input1_data =
148       use_unswitched ? unswitched_input1_data : unswitched_input2_data;
149   const int16_t* input2_data =
150       use_unswitched ? unswitched_input2_data : unswitched_input1_data;
151 
152   int16_t* output_data_ptr = output_data;
153   const int16_t* input1_data_ptr = input1_data;
154   const int16_t* input2_data_reset = input2_data;
155   // In the fivefold pattern, y0, y2 and y4 are not broadcast, and so shared
156   // between input shapes. y3 for input 1 is always broadcast, and so the
157   // dimension there is 1, whereas optionally y1 might be broadcast for input 2.
158   // The flatsize for each inputs are as below.
159   // input1.shape.FlatSize = y0 * y1 * y2 * y4,
160   // input2.shape.FlatSize = y0 * y2 * y3 * y4.
161   const int y0 = params.broadcast_shape[0];
162   const int y1 = params.broadcast_shape[1];
163   const int y2 = params.broadcast_shape[2];
164   const int y3 = params.broadcast_shape[3];
165   const int y4 = params.broadcast_shape[4];
166   for (int i0 = 0; i0 < y0; ++i0) {
167     const int16_t* input2_data_ptr = nullptr;
168     for (int i1 = 0; i1 < y1; ++i1) {
169       input2_data_ptr = input2_data_reset;
170       for (int i2 = 0; i2 < y2; ++i2) {
171         for (int i3 = 0; i3 < y3; ++i3) {
172           if (use_unswitched) {
173             SubElementwiseInt16(y4, params, input1_data_ptr, input2_data_ptr,
174                                 output_data_ptr);
175           } else {
176             // When input1 and input2 are switched, calculate (input2 - input1)
177             // and use unswitched_params as we switch the switched input here.
178             SubElementwiseInt16(y4, unswitched_params, input2_data_ptr,
179                                 input1_data_ptr, output_data_ptr);
180           }
181           input2_data_ptr += y4;
182           output_data_ptr += y4;
183         }
184         // We have broadcast y4 of input1 data y3 times, and now move on.
185         input1_data_ptr += y4;
186       }
187     }
188     // We have broadcast y2*y3*y4 of input2 data y1 times, and now move on.
189     input2_data_reset = input2_data_ptr;
190   }
191 }
192 
Sub(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)193 inline void Sub(const ArithmeticParams& params,
194                 const RuntimeShape& input1_shape, const int16* input1_data,
195                 const RuntimeShape& input2_shape, const int16* input2_data,
196                 const RuntimeShape& output_shape, int16* output_data) {
197   ruy::profiler::ScopeLabel label("SubInt16/16bit");
198   TFLITE_DCHECK_LE(params.quantized_activation_min,
199                    params.quantized_activation_max);
200   TFLITE_DCHECK_GT(params.input1_offset, -32768);
201   TFLITE_DCHECK_GT(params.input2_offset, -32768);
202   TFLITE_DCHECK_LT(params.input1_offset, 32768);
203   TFLITE_DCHECK_LT(params.input2_offset, 32768);
204 
205   const int flat_size =
206       MatchingElementsSize(input1_shape, input2_shape, output_shape);
207   SubElementwiseInt16(flat_size, params, input1_data, input2_data, output_data);
208 }
209 
BroadcastSubDispatch(const ArithmeticParams & params,const RuntimeShape & input1_shape,const int16 * input1_data,const RuntimeShape & input2_shape,const int16 * input2_data,const RuntimeShape & output_shape,int16 * output_data)210 inline void BroadcastSubDispatch(const ArithmeticParams& params,
211                                  const RuntimeShape& input1_shape,
212                                  const int16* input1_data,
213                                  const RuntimeShape& input2_shape,
214                                  const int16* input2_data,
215                                  const RuntimeShape& output_shape,
216                                  int16* output_data) {
217   ruy::profiler::ScopeLabel label("BroadcastSubDispatchInt16/16bit");
218   TFLITE_DCHECK_LE(params.quantized_activation_min,
219                    params.quantized_activation_max);
220   TFLITE_DCHECK_GT(params.input1_offset, -32768);
221   TFLITE_DCHECK_GT(params.input2_offset, -32768);
222   TFLITE_DCHECK_LT(params.input1_offset, 32768);
223   TFLITE_DCHECK_LT(params.input2_offset, 32768);
224 
225   if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
226     return reference_ops::BroadcastQuantSubSlow(
227         params, input1_shape, input1_data, input2_shape, input2_data,
228         output_shape, output_data);
229   }
230 
231   BroadcastSubFiveFold(params, input1_shape, input1_data, input2_shape,
232                        input2_data, output_shape, output_data);
233 }
234 }  // namespace optimized_integer_ops
235 }  // namespace tflite
236 
237 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_INTEGER_OPS_SUB_H_
238