• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/builtin_op_data.h"
16 #include "tensorflow/lite/c/c_api_internal.h"
17 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
18 #include "tensorflow/lite/kernels/internal/quantization_util.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace div {
28 
29 // This file has three implementation of Div.
30 enum KernelType {
31   kReference,
32   kGenericOptimized,  // Neon-free
33   kNeonOptimized,
34 };
35 
36 constexpr int kInputTensor1 = 0;
37 constexpr int kInputTensor2 = 1;
38 constexpr int kOutputTensor = 0;
39 
40 struct OpData {
41   bool requires_broadcast;
42 };
43 
Init(TfLiteContext * context,const char * buffer,size_t length)44 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
45   auto* data = new OpData;
46   data->requires_broadcast = false;
47   return data;
48 }
49 
Free(TfLiteContext * context,void * buffer)50 void Free(TfLiteContext* context, void* buffer) {
51   delete reinterpret_cast<OpData*>(buffer);
52 }
53 
Prepare(TfLiteContext * context,TfLiteNode * node)54 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
55   OpData* data = reinterpret_cast<OpData*>(node->user_data);
56 
57   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
58   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
59 
60   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
61   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
62   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
63 
64   TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
65   output->type = input2->type;
66 
67   data->requires_broadcast = !HaveSameShapes(input1, input2);
68 
69   TfLiteIntArray* output_size = nullptr;
70   if (data->requires_broadcast) {
71     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
72                                    context, input1, input2, &output_size));
73   } else {
74     output_size = TfLiteIntArrayCopy(input1->dims);
75   }
76 
77   return context->ResizeTensor(context, output, output_size);
78 }
79 
80 template <KernelType kernel_type>
EvalDiv(TfLiteContext * context,TfLiteNode * node,TfLiteDivParams * params,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)81 void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
82              const OpData* data, const TfLiteTensor* input1,
83              const TfLiteTensor* input2, TfLiteTensor* output) {
84 #define TF_LITE_DIV(type, opname, data_type)                             \
85   tflite::ArithmeticParams op_params;                                    \
86   data_type output_activation_min, output_activation_max;                \
87   CalculateActivationRange(params->activation, &output_activation_min,   \
88                            &output_activation_max);                      \
89   SetActivationParams(output_activation_min, output_activation_max,      \
90                       &op_params);                                       \
91   type::opname(op_params, GetTensorShape(input1),                        \
92                GetTensorData<data_type>(input1), GetTensorShape(input2), \
93                GetTensorData<data_type>(input2), GetTensorShape(output), \
94                GetTensorData<data_type>(output))
95   if (output->type == kTfLiteInt32) {
96     if (kernel_type == kReference) {
97       if (data->requires_broadcast) {
98         TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, int32_t);
99       } else {
100         TF_LITE_DIV(reference_ops, Div, int32_t);
101       }
102     } else {
103       if (data->requires_broadcast) {
104         TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, int32_t);
105       } else {
106         TF_LITE_DIV(optimized_ops, Div, int32_t);
107       }
108     }
109   } else if (output->type == kTfLiteFloat32) {
110     if (kernel_type == kReference) {
111       if (data->requires_broadcast) {
112         TF_LITE_DIV(reference_ops, BroadcastDiv4DSlow, float);
113       } else {
114         TF_LITE_DIV(reference_ops, Div, float);
115       }
116     } else {
117       if (data->requires_broadcast) {
118         TF_LITE_DIV(optimized_ops, BroadcastDiv4DSlow, float);
119       } else {
120         TF_LITE_DIV(optimized_ops, Div, float);
121       }
122     }
123   }
124 #undef TF_LITE_DIV
125 }
126 
127 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)128 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
129   auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
130   OpData* data = reinterpret_cast<OpData*>(node->user_data);
131 
132   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
133   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
134   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
135 
136   if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
137     EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
138   } else {
139     context->ReportError(
140         context,
141         "Div only supports FLOAT32, INT32 and quantized UINT8 now, got %d.",
142         output->type);
143     return kTfLiteError;
144   }
145 
146   return kTfLiteOk;
147 }
148 
149 }  // namespace div
150 
Register_DIV_REF()151 TfLiteRegistration* Register_DIV_REF() {
152   static TfLiteRegistration r = {div::Init, div::Free, div::Prepare,
153                                  div::Eval<div::kReference>};
154   return &r;
155 }
156 
Register_DIV_GENERIC_OPT()157 TfLiteRegistration* Register_DIV_GENERIC_OPT() {
158   static TfLiteRegistration r = {div::Init, div::Free, div::Prepare,
159                                  div::Eval<div::kGenericOptimized>};
160   return &r;
161 }
162 
Register_DIV_NEON_OPT()163 TfLiteRegistration* Register_DIV_NEON_OPT() {
164   static TfLiteRegistration r = {div::Init, div::Free, div::Prepare,
165                                  div::Eval<div::kNeonOptimized>};
166   return &r;
167 }
168 
Register_DIV()169 TfLiteRegistration* Register_DIV() {
170 #ifdef USE_NEON
171   return Register_DIV_NEON_OPT();
172 #else
173   return Register_DIV_GENERIC_OPT();
174 #endif
175 }
176 
177 }  // namespace builtin
178 }  // namespace ops
179 }  // namespace tflite
180