• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/c_api_internal.h"
16 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
17 #include "tensorflow/lite/kernels/internal/tensor.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 #include "tensorflow/lite/kernels/op_macros.h"
20 
21 namespace tflite {
22 namespace ops {
23 namespace builtin {
24 namespace floor_div {
25 namespace {
26 
27 // Input/output tensor index.
28 constexpr int kInputTensor1 = 0;
29 constexpr int kInputTensor2 = 1;
30 constexpr int kOutputTensor = 0;
31 
32 // Op data for floor_div op.
33 struct OpData {
34   bool requires_broadcast;
35 };
36 
37 template <typename T>
FloorDiv(T input1,T input2)38 T FloorDiv(T input1, T input2) {
39   return std::floor(std::divides<double>()(static_cast<double>(input1),
40                                            static_cast<double>(input2)));
41 }
42 
Init(TfLiteContext * context,const char * buffer,size_t length)43 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
44   auto* data = new OpData;
45   data->requires_broadcast = false;
46   return data;
47 }
48 
Free(TfLiteContext * context,void * buffer)49 void Free(TfLiteContext* context, void* buffer) {
50   delete reinterpret_cast<OpData*>(buffer);
51 }
52 
Prepare(TfLiteContext * context,TfLiteNode * node)53 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
54   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
55   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
56 
57   // Reinterprete the opaque data provided by user.
58   OpData* data = reinterpret_cast<OpData*>(node->user_data);
59 
60   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
61   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
62   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
63 
64   TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
65 
66   const TfLiteType type = input1->type;
67   if (type != kTfLiteInt32) {
68     context->ReportError(context, "Currently floor_div only supports int32.");
69     return kTfLiteError;
70   }
71   output->type = type;
72 
73   data->requires_broadcast = !HaveSameShapes(input1, input2);
74 
75   TfLiteIntArray* output_size = nullptr;
76   if (data->requires_broadcast) {
77     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
78                                    context, input1, input2, &output_size));
79   } else {
80     output_size = TfLiteIntArrayCopy(input1->dims);
81   }
82 
83   return context->ResizeTensor(context, output, output_size);
84 }
85 
86 template <typename T>
EvalImpl(TfLiteContext * context,bool requires_broadcast,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)87 TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
88                       const TfLiteTensor* input1, const TfLiteTensor* input2,
89                       TfLiteTensor* output) {
90   const T* denominator_data = GetTensorData<T>(input2);
91 
92   // Validate the denominator.
93   for (int i = 0; i < NumElements(input2); ++i) {
94     if (std::equal_to<T>()(denominator_data[i], 0)) {
95       context->ReportError(context, "Division by 0");
96       return kTfLiteError;
97     }
98   }
99   if (requires_broadcast) {
100     reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
101         GetTensorShape(input1), GetTensorData<T>(input1),
102         GetTensorShape(input2), denominator_data, GetTensorShape(output),
103         GetTensorData<T>(output), FloorDiv<T>);
104   } else {
105     reference_ops::BinaryFunction<T, T, T>(
106         GetTensorShape(input1), GetTensorData<T>(input1),
107         GetTensorShape(input2), GetTensorData<T>(input2),
108         GetTensorShape(output), GetTensorData<T>(output), FloorDiv<T>);
109   }
110 
111   return kTfLiteOk;
112 }
113 
Eval(TfLiteContext * context,TfLiteNode * node)114 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
115   OpData* data = reinterpret_cast<OpData*>(node->user_data);
116 
117   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
118   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
119   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
120 
121   switch (input1->type) {
122     case kTfLiteInt32: {
123       return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
124                                input2, output);
125     }
126     default: {
127       context->ReportError(context, "Currently floor_div only supports int32.");
128       return kTfLiteError;
129     }
130   }
131 }
132 
133 }  // namespace
134 }  // namespace floor_div
135 
Register_FLOOR_DIV()136 TfLiteRegistration* Register_FLOOR_DIV() {
137   // Init, Free, Prepare, Eval are satisfying the Interface required by
138   // TfLiteRegistration.
139   static TfLiteRegistration r = {floor_div::Init, floor_div::Free,
140                                  floor_div::Prepare, floor_div::Eval};
141   return &r;
142 }
143 
144 }  // namespace builtin
145 }  // namespace ops
146 }  // namespace tflite
147