• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/quantization_util.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 #include "tensorflow/lite/kernels/op_macros.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace mul {
29 
30 // This file has three implementation of Mul.
31 enum KernelType {
32   kReference,
33   kGenericOptimized,  // Neon-free
34   kNeonOptimized,
35 };
36 
37 constexpr int kInputTensor1 = 0;
38 constexpr int kInputTensor2 = 1;
39 constexpr int kOutputTensor = 0;
40 
41 struct OpData {
42   bool requires_broadcast;
43 
44   // Parameters used in the quantized paths where the output is 8bit
45   int32 output_activation_min;
46   int32 output_activation_max;
47 
48   // Parameters used in all quantized paths
49   int32_t output_multiplier;
50   int output_shift;
51 };
52 
Init(TfLiteContext * context,const char * buffer,size_t length)53 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
54   auto* data = new OpData;
55   data->requires_broadcast = false;
56   return data;
57 }
58 
Free(TfLiteContext * context,void * buffer)59 void Free(TfLiteContext* context, void* buffer) {
60   delete reinterpret_cast<OpData*>(buffer);
61 }
62 
Prepare(TfLiteContext * context,TfLiteNode * node)63 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
64   auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
65   OpData* data = reinterpret_cast<OpData*>(node->user_data);
66 
67   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
68   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
69 
70   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
71   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
72   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
73 
74   TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
75 
76   data->requires_broadcast = !HaveSameShapes(input1, input2);
77 
78   TfLiteIntArray* output_size = nullptr;
79   if (data->requires_broadcast) {
80     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
81                                    context, input1, input2, &output_size));
82   } else {
83     output_size = TfLiteIntArrayCopy(input1->dims);
84   }
85 
86   if (output->type == kTfLiteUInt8) {
87     CalculateActivationRangeUint8(params->activation, output,
88                                   &data->output_activation_min,
89                                   &data->output_activation_max);
90   }
91   if (output->type == kTfLiteInt8) {
92     CalculateActivationRangeInt8(params->activation, output,
93                                  &data->output_activation_min,
94                                  &data->output_activation_max);
95   }
96 
97   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
98       output->type == kTfLiteInt16) {
99     double real_multiplier =
100         input1->params.scale * input2->params.scale / output->params.scale;
101     QuantizeMultiplierSmallerThanOneExp(
102         real_multiplier, &data->output_multiplier, &data->output_shift);
103   }
104 
105   return context->ResizeTensor(context, output, output_size);
106 }
107 
108 template <KernelType kernel_type>
EvalMul(TfLiteContext * context,TfLiteNode * node,TfLiteMulParams * params,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)109 void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
110              const OpData* data, const TfLiteTensor* input1,
111              const TfLiteTensor* input2, TfLiteTensor* output) {
112 #define TF_LITE_MUL(type, opname, data_type)                             \
113   data_type output_activation_min, output_activation_max;                \
114   CalculateActivationRange(params->activation, &output_activation_min,   \
115                            &output_activation_max);                      \
116   tflite::ArithmeticParams op_params;                                    \
117   SetActivationParams(output_activation_min, output_activation_max,      \
118                       &op_params);                                       \
119   type::opname(op_params, GetTensorShape(input1),                        \
120                GetTensorData<data_type>(input1), GetTensorShape(input2), \
121                GetTensorData<data_type>(input2), GetTensorShape(output), \
122                GetTensorData<data_type>(output))
123 
124   if (output->type == kTfLiteInt32) {
125     if (kernel_type == kReference) {
126       if (data->requires_broadcast) {
127         TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
128       } else {
129         TF_LITE_MUL(reference_ops, Mul, int32_t);
130       }
131     } else {
132       if (data->requires_broadcast) {
133         TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
134       } else {
135         TF_LITE_MUL(optimized_ops, Mul, int32_t);
136       }
137     }
138   } else if (output->type == kTfLiteFloat32) {
139     if (kernel_type == kReference) {
140       if (data->requires_broadcast) {
141         TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
142       } else {
143         TF_LITE_MUL(reference_ops, Mul, float);
144       }
145     } else {
146       if (data->requires_broadcast) {
147         TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
148       } else {
149         TF_LITE_MUL(optimized_ops, Mul, float);
150       }
151     }
152   }
153 #undef TF_LITE_MUL
154 }
155 
156 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteMulParams * params,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)157 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
158                            TfLiteMulParams* params, const OpData* data,
159                            const TfLiteTensor* input1,
160                            const TfLiteTensor* input2, TfLiteTensor* output) {
161   if (input1->type == input2->type && input1->type == output->type &&
162       (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8)) {
163     tflite::ArithmeticParams op_params;
164     SetActivationParams(data->output_activation_min,
165                         data->output_activation_max, &op_params);
166     op_params.input1_offset = -input1->params.zero_point;
167     op_params.input2_offset = -input2->params.zero_point;
168     op_params.output_offset = output->params.zero_point;
169     op_params.output_multiplier = data->output_multiplier;
170     op_params.output_shift = data->output_shift;
171     bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
172         GetTensorShape(input1), GetTensorShape(input2), &op_params);
173 #define TF_LITE_MUL(type, opname, dtype)                             \
174   type::opname(op_params, GetTensorShape(input1),                    \
175                GetTensorData<dtype>(input1), GetTensorShape(input2), \
176                GetTensorData<dtype>(input2), GetTensorShape(output), \
177                GetTensorData<dtype>(output))
178     if (input1->type == kTfLiteInt8) {
179       if (need_broadcast) {
180         TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int8_t);
181       } else {
182         TF_LITE_MUL(reference_integer_ops, Mul, int8_t);
183       }
184     } else {
185       // type == kTfLiteUInt8
186       if (kernel_type == kReference) {
187         if (need_broadcast) {
188           TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, uint8_t);
189         } else {
190           TF_LITE_MUL(reference_ops, Mul, uint8_t);
191         }
192       } else {
193         if (need_broadcast) {
194           TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, uint8_t);
195         } else {
196           TF_LITE_MUL(optimized_ops, Mul, uint8_t);
197         }
198       }
199     }
200 #undef TF_LITE_MUL
201   } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
202              output->type == kTfLiteInt16) {
203 #define TF_LITE_MUL(type, opname)                                      \
204   tflite::ArithmeticParams op_params;                                  \
205   type::opname(op_params, GetTensorShape(input1),                      \
206                GetTensorData<int16_t>(input1), GetTensorShape(input2), \
207                GetTensorData<int16_t>(input2), GetTensorShape(output), \
208                GetTensorData<int16_t>(output))
209     if (kernel_type == kReference) {
210       TF_LITE_MUL(reference_ops, Mul);
211     } else {
212       TF_LITE_MUL(optimized_ops, Mul);
213     }
214 #undef TF_LITE_MUL
215   } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
216              (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8)) {
217 #define TF_LITE_MUL(type, opname, output_dtype)                        \
218   tflite::ArithmeticParams op_params;                                  \
219   SetActivationParams(data->output_activation_min,                     \
220                       data->output_activation_max, &op_params);        \
221   op_params.output_offset = output->params.zero_point;                 \
222   type::opname(op_params, GetTensorShape(input1),                      \
223                GetTensorData<int16_t>(input1), GetTensorShape(input2), \
224                GetTensorData<int16_t>(input2), GetTensorShape(output), \
225                GetTensorData<output_dtype>(output))
226     if (output->type == kTfLiteInt8) {
227       TF_LITE_MUL(reference_integer_ops, Mul, int8_t);
228     } else {
229       if (kernel_type == kReference) {
230         TF_LITE_MUL(reference_ops, Mul, uint8_t);
231       } else {
232         TF_LITE_MUL(optimized_ops, Mul, uint8_t);
233       }
234     }
235 #undef TF_LITE_MUL
236   } else {
237     context->ReportError(
238         context, "Unsupported combination of input and output types in Mul.");
239     return kTfLiteError;
240   }
241   return kTfLiteOk;
242 }
243 
244 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)245 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
246   auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
247   OpData* data = reinterpret_cast<OpData*>(node->user_data);
248 
249   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
250   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
251   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
252 
253   if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
254     EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
255   } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
256              output->type == kTfLiteInt16) {
257     TF_LITE_ENSURE_OK(
258         context, EvalQuantized<kernel_type>(context, node, params, data, input1,
259                                             input2, output));
260   } else {
261     context->ReportError(context,
262                          "Mul only supports FLOAT32, INT32 and quantized UINT8,"
263                          " INT8 and INT16 now, got %d.",
264                          output->type);
265     return kTfLiteError;
266   }
267 
268   return kTfLiteOk;
269 }
270 
271 }  // namespace mul
272 
Register_MUL_REF()273 TfLiteRegistration* Register_MUL_REF() {
274   static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
275                                  mul::Eval<mul::kReference>};
276   return &r;
277 }
278 
Register_MUL_GENERIC_OPT()279 TfLiteRegistration* Register_MUL_GENERIC_OPT() {
280   static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
281                                  mul::Eval<mul::kGenericOptimized>};
282   return &r;
283 }
284 
Register_MUL_NEON_OPT()285 TfLiteRegistration* Register_MUL_NEON_OPT() {
286   static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
287                                  mul::Eval<mul::kNeonOptimized>};
288   return &r;
289 }
290 
Register_MUL()291 TfLiteRegistration* Register_MUL() {
292 #ifdef USE_NEON
293   return Register_MUL_NEON_OPT();
294 #else
295   return Register_MUL_GENERIC_OPT();
296 #endif
297 }
298 
299 }  // namespace builtin
300 }  // namespace ops
301 }  // namespace tflite
302