1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/c_api_internal.h"
16 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
17 #include "tensorflow/lite/kernels/internal/tensor.h"
18 #include "tensorflow/lite/kernels/kernel_util.h"
19 #include "tensorflow/lite/kernels/op_macros.h"
20
21 namespace tflite {
22 namespace ops {
23 namespace builtin {
24 namespace floor_div {
25 namespace {
26
27 // Input/output tensor index.
28 constexpr int kInputTensor1 = 0;
29 constexpr int kInputTensor2 = 1;
30 constexpr int kOutputTensor = 0;
31
32 // Op data for floor_div op.
33 struct OpData {
34 bool requires_broadcast;
35 };
36
37 template <typename T>
FloorDiv(T input1,T input2)38 T FloorDiv(T input1, T input2) {
39 return std::floor(std::divides<double>()(static_cast<double>(input1),
40 static_cast<double>(input2)));
41 }
42
Init(TfLiteContext * context,const char * buffer,size_t length)43 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
44 auto* data = new OpData;
45 data->requires_broadcast = false;
46 return data;
47 }
48
Free(TfLiteContext * context,void * buffer)49 void Free(TfLiteContext* context, void* buffer) {
50 delete reinterpret_cast<OpData*>(buffer);
51 }
52
Prepare(TfLiteContext * context,TfLiteNode * node)53 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
54 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
55 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
56
57 // Reinterprete the opaque data provided by user.
58 OpData* data = reinterpret_cast<OpData*>(node->user_data);
59
60 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
61 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
62 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
63
64 TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
65
66 const TfLiteType type = input1->type;
67 if (type != kTfLiteInt32) {
68 context->ReportError(context, "Currently floor_div only supports int32.");
69 return kTfLiteError;
70 }
71 output->type = type;
72
73 data->requires_broadcast = !HaveSameShapes(input1, input2);
74
75 TfLiteIntArray* output_size = nullptr;
76 if (data->requires_broadcast) {
77 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
78 context, input1, input2, &output_size));
79 } else {
80 output_size = TfLiteIntArrayCopy(input1->dims);
81 }
82
83 return context->ResizeTensor(context, output, output_size);
84 }
85
86 template <typename T>
EvalImpl(TfLiteContext * context,bool requires_broadcast,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)87 TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
88 const TfLiteTensor* input1, const TfLiteTensor* input2,
89 TfLiteTensor* output) {
90 const T* denominator_data = GetTensorData<T>(input2);
91
92 // Validate the denominator.
93 for (int i = 0; i < NumElements(input2); ++i) {
94 if (std::equal_to<T>()(denominator_data[i], 0)) {
95 context->ReportError(context, "Division by 0");
96 return kTfLiteError;
97 }
98 }
99 if (requires_broadcast) {
100 reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
101 GetTensorShape(input1), GetTensorData<T>(input1),
102 GetTensorShape(input2), denominator_data, GetTensorShape(output),
103 GetTensorData<T>(output), FloorDiv<T>);
104 } else {
105 reference_ops::BinaryFunction<T, T, T>(
106 GetTensorShape(input1), GetTensorData<T>(input1),
107 GetTensorShape(input2), GetTensorData<T>(input2),
108 GetTensorShape(output), GetTensorData<T>(output), FloorDiv<T>);
109 }
110
111 return kTfLiteOk;
112 }
113
Eval(TfLiteContext * context,TfLiteNode * node)114 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
115 OpData* data = reinterpret_cast<OpData*>(node->user_data);
116
117 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
118 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
119 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
120
121 switch (input1->type) {
122 case kTfLiteInt32: {
123 return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
124 input2, output);
125 }
126 default: {
127 context->ReportError(context, "Currently floor_div only supports int32.");
128 return kTfLiteError;
129 }
130 }
131 }
132
133 } // namespace
134 } // namespace floor_div
135
Register_FLOOR_DIV()136 TfLiteRegistration* Register_FLOOR_DIV() {
137 // Init, Free, Prepare, Eval are satisfying the Interface required by
138 // TfLiteRegistration.
139 static TfLiteRegistration r = {floor_div::Init, floor_div::Free,
140 floor_div::Prepare, floor_div::Eval};
141 return &r;
142 }
143
144 } // namespace builtin
145 } // namespace ops
146 } // namespace tflite
147