• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <functional>
16 #include <type_traits>
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22 
23 // TODO(b/117523611): We should factor out a binary_op and put binary ops there.
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace floor_mod {
28 namespace {
29 
30 // Input/output tensor index.
31 constexpr int kInputTensor1 = 0;
32 constexpr int kInputTensor2 = 1;
33 constexpr int kOutputTensor = 0;
34 
35 // Op data for floor_mod op.
36 struct OpData {
37   bool requires_broadcast;
38 };
39 
40 struct FloatMod {
operator ()tflite::ops::builtin::floor_mod::__anonaf2dfab30111::FloatMod41   float operator()(const float lhs, const float rhs) const {
42     return std::fmod(lhs, rhs);
43   }
44 };
45 
46 // TODO(b/117912007): Move the implementation to reference_ops.h
47 // TODO(b/117912880): Support quantization.
48 template <typename T>
FloorMod(T input1,T input2)49 T FloorMod(T input1, T input2) {
50   using ModFunc = typename std::conditional<std::is_integral<T>::value,
51                                             std::modulus<T>, FloatMod>::type;
52 
53   ModFunc mod_func;
54   T trunc_mod = mod_func(input1, input2);
55   return (input1 < T(0)) == (input2 < T(0))
56              ? trunc_mod
57              : mod_func(trunc_mod + input2, input2);
58 }
59 
Init(TfLiteContext * context,const char * buffer,size_t length)60 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
61   auto* data = new OpData;
62   data->requires_broadcast = false;
63   return data;
64 }
65 
Free(TfLiteContext * context,void * buffer)66 void Free(TfLiteContext* context, void* buffer) {
67   delete reinterpret_cast<OpData*>(buffer);
68 }
69 
Prepare(TfLiteContext * context,TfLiteNode * node)70 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
71   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
72   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
73 
74   // Reinterprete the opaque data provided by user.
75   OpData* data = reinterpret_cast<OpData*>(node->user_data);
76 
77   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
78   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
79   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
80 
81   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
82 
83   const TfLiteType type = input1->type;
84   if (type != kTfLiteInt32 && type != kTfLiteFloat32 && type != kTfLiteInt64) {
85     context->ReportError(context, "Type '%s' is not supported by floor_mod.",
86                          TfLiteTypeGetName(type));
87     return kTfLiteError;
88   }
89   output->type = type;
90 
91   data->requires_broadcast = !HaveSameShapes(input1, input2);
92 
93   TfLiteIntArray* output_size = nullptr;
94   if (data->requires_broadcast) {
95     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
96                                    context, input1, input2, &output_size));
97   } else {
98     output_size = TfLiteIntArrayCopy(input1->dims);
99   }
100 
101   return context->ResizeTensor(context, output, output_size);
102 }
103 
104 template <typename T>
EvalImpl(TfLiteContext * context,bool requires_broadcast,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)105 TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
106                       const TfLiteTensor* input1, const TfLiteTensor* input2,
107                       TfLiteTensor* output) {
108   const T* denominator_data = GetTensorData<T>(input2);
109 
110   if (input2->type == kTfLiteInt32 || input2->type == kTfLiteInt64) {
111     // Validate the denominator only for integer.
112     const int num_elements = NumElements(input2);
113     for (int i = 0; i < num_elements; ++i) {
114       if (denominator_data[i] == 0) {
115         context->ReportError(context, "Division by 0");
116         return kTfLiteError;
117       }
118     }
119   }
120   if (requires_broadcast) {
121     reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
122         GetTensorShape(input1), GetTensorData<T>(input1),
123         GetTensorShape(input2), denominator_data, GetTensorShape(output),
124         GetTensorData<T>(output), FloorMod<T>);
125   } else {
126     reference_ops::BinaryFunction<T, T, T>(
127         GetTensorShape(input1), GetTensorData<T>(input1),
128         GetTensorShape(input2), GetTensorData<T>(input2),
129         GetTensorShape(output), GetTensorData<T>(output), FloorMod<T>);
130   }
131 
132   return kTfLiteOk;
133 }
134 
Eval(TfLiteContext * context,TfLiteNode * node)135 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
136   OpData* data = reinterpret_cast<OpData*>(node->user_data);
137 
138   const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
139   const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
140   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
141 
142   switch (input1->type) {
143     case kTfLiteInt32: {
144       return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
145                                input2, output);
146     }
147     case kTfLiteInt64: {
148       return EvalImpl<int64_t>(context, data->requires_broadcast, input1,
149                                input2, output);
150     }
151     case kTfLiteFloat32: {
152       return EvalImpl<float>(context, data->requires_broadcast, input1, input2,
153                              output);
154     }
155     default: {
156       context->ReportError(context, "Type '%s' is not supported by floor_mod.",
157                            TfLiteTypeGetName(input1->type));
158       return kTfLiteError;
159     }
160   }
161 }
162 
163 }  // namespace
164 }  // namespace floor_mod
165 
Register_FLOOR_MOD()166 TfLiteRegistration* Register_FLOOR_MOD() {
167   // Init, Free, Prepare, Eval are satisfying the Interface required by
168   // TfLiteRegistration.
169   static TfLiteRegistration r = {floor_mod::Init, floor_mod::Free,
170                                  floor_mod::Prepare, floor_mod::Eval};
171   return &r;
172 }
173 
174 }  // namespace builtin
175 }  // namespace ops
176 }  // namespace tflite
177