1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <functional>
16 #include <type_traits>
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22
23 // TODO(b/117523611): We should factor out a binary_op and put binary ops there.
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace floor_mod {
28 namespace {
29
30 // Input/output tensor index.
31 constexpr int kInputTensor1 = 0;
32 constexpr int kInputTensor2 = 1;
33 constexpr int kOutputTensor = 0;
34
35 // Op data for floor_mod op.
36 struct OpData {
37 bool requires_broadcast;
38 };
39
40 struct FloatMod {
operator ()tflite::ops::builtin::floor_mod::__anonaf2dfab30111::FloatMod41 float operator()(const float lhs, const float rhs) const {
42 return std::fmod(lhs, rhs);
43 }
44 };
45
46 // TODO(b/117912007): Move the implementation to reference_ops.h
47 // TODO(b/117912880): Support quantization.
48 template <typename T>
FloorMod(T input1,T input2)49 T FloorMod(T input1, T input2) {
50 using ModFunc = typename std::conditional<std::is_integral<T>::value,
51 std::modulus<T>, FloatMod>::type;
52
53 ModFunc mod_func;
54 T trunc_mod = mod_func(input1, input2);
55 return (input1 < T(0)) == (input2 < T(0))
56 ? trunc_mod
57 : mod_func(trunc_mod + input2, input2);
58 }
59
Init(TfLiteContext * context,const char * buffer,size_t length)60 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
61 auto* data = new OpData;
62 data->requires_broadcast = false;
63 return data;
64 }
65
Free(TfLiteContext * context,void * buffer)66 void Free(TfLiteContext* context, void* buffer) {
67 delete reinterpret_cast<OpData*>(buffer);
68 }
69
Prepare(TfLiteContext * context,TfLiteNode * node)70 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
71 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
72 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
73
74 // Reinterprete the opaque data provided by user.
75 OpData* data = reinterpret_cast<OpData*>(node->user_data);
76
77 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
78 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
79 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
80
81 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
82
83 const TfLiteType type = input1->type;
84 if (type != kTfLiteInt32 && type != kTfLiteFloat32 && type != kTfLiteInt64) {
85 context->ReportError(context, "Type '%s' is not supported by floor_mod.",
86 TfLiteTypeGetName(type));
87 return kTfLiteError;
88 }
89 output->type = type;
90
91 data->requires_broadcast = !HaveSameShapes(input1, input2);
92
93 TfLiteIntArray* output_size = nullptr;
94 if (data->requires_broadcast) {
95 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
96 context, input1, input2, &output_size));
97 } else {
98 output_size = TfLiteIntArrayCopy(input1->dims);
99 }
100
101 return context->ResizeTensor(context, output, output_size);
102 }
103
104 template <typename T>
EvalImpl(TfLiteContext * context,bool requires_broadcast,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)105 TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
106 const TfLiteTensor* input1, const TfLiteTensor* input2,
107 TfLiteTensor* output) {
108 const T* denominator_data = GetTensorData<T>(input2);
109
110 if (input2->type == kTfLiteInt32 || input2->type == kTfLiteInt64) {
111 // Validate the denominator only for integer.
112 const int num_elements = NumElements(input2);
113 for (int i = 0; i < num_elements; ++i) {
114 if (denominator_data[i] == 0) {
115 context->ReportError(context, "Division by 0");
116 return kTfLiteError;
117 }
118 }
119 }
120 if (requires_broadcast) {
121 reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
122 GetTensorShape(input1), GetTensorData<T>(input1),
123 GetTensorShape(input2), denominator_data, GetTensorShape(output),
124 GetTensorData<T>(output), FloorMod<T>);
125 } else {
126 reference_ops::BinaryFunction<T, T, T>(
127 GetTensorShape(input1), GetTensorData<T>(input1),
128 GetTensorShape(input2), GetTensorData<T>(input2),
129 GetTensorShape(output), GetTensorData<T>(output), FloorMod<T>);
130 }
131
132 return kTfLiteOk;
133 }
134
Eval(TfLiteContext * context,TfLiteNode * node)135 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
136 OpData* data = reinterpret_cast<OpData*>(node->user_data);
137
138 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
139 const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
140 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
141
142 switch (input1->type) {
143 case kTfLiteInt32: {
144 return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
145 input2, output);
146 }
147 case kTfLiteInt64: {
148 return EvalImpl<int64_t>(context, data->requires_broadcast, input1,
149 input2, output);
150 }
151 case kTfLiteFloat32: {
152 return EvalImpl<float>(context, data->requires_broadcast, input1, input2,
153 output);
154 }
155 default: {
156 context->ReportError(context, "Type '%s' is not supported by floor_mod.",
157 TfLiteTypeGetName(input1->type));
158 return kTfLiteError;
159 }
160 }
161 }
162
163 } // namespace
164 } // namespace floor_mod
165
Register_FLOOR_MOD()166 TfLiteRegistration* Register_FLOOR_MOD() {
167 // Init, Free, Prepare, Eval are satisfying the Interface required by
168 // TfLiteRegistration.
169 static TfLiteRegistration r = {floor_mod::Init, floor_mod::Free,
170 floor_mod::Prepare, floor_mod::Eval};
171 return &r;
172 }
173
174 } // namespace builtin
175 } // namespace ops
176 } // namespace tflite
177