• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace maximum_minimum {
28 
29 // This file has a reference implementation of TFMaximum/TFMinimum.
30 enum KernelType {
31   kReference,
32 };
33 
34 constexpr int kInputTensor1 = 0;
35 constexpr int kInputTensor2 = 1;
36 constexpr int kOutputTensor = 0;
37 
38 struct OpContext {
OpContexttflite::ops::builtin::maximum_minimum::OpContext39   OpContext(TfLiteContext* context, TfLiteNode* node) {
40     input1 = GetInput(context, node, kInputTensor1);
41     input2 = GetInput(context, node, kInputTensor2);
42     output = GetOutput(context, node, kOutputTensor);
43   }
44   const TfLiteTensor* input1;
45   const TfLiteTensor* input2;
46   TfLiteTensor* output;
47 };
48 
Prepare(TfLiteContext * context,TfLiteNode * node)49 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
50   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
51   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
52 
53   OpContext op_context(context, node);
54   TF_LITE_ENSURE_EQ(context, op_context.input1->type, op_context.input2->type);
55   op_context.output->type = op_context.input1->type;
56 
57   bool requires_broadcast =
58       !HaveSameShapes(op_context.input1, op_context.input2);
59 
60   TfLiteIntArray* output_size = nullptr;
61   if (requires_broadcast) {
62     TF_LITE_ENSURE_OK(
63         context, CalculateShapeForBroadcast(context, op_context.input1,
64                                             op_context.input2, &output_size));
65   } else {
66     output_size = TfLiteIntArrayCopy(op_context.input1->dims);
67   }
68 
69   return context->ResizeTensor(context, op_context.output, output_size);
70 }
71 
72 struct MaximumOp {
73   template <typename data_type>
optflite::ops::builtin::maximum_minimum::MaximumOp74   static data_type op(data_type el1, data_type el2) {
75     return el1 > el2 ? el1 : el2;
76   }
77 };
78 
79 struct MinimumOp {
80   template <typename data_type>
optflite::ops::builtin::maximum_minimum::MinimumOp81   static data_type op(data_type el1, data_type el2) {
82     return el1 < el2 ? el1 : el2;
83   }
84 };
85 
86 template <typename data_type, typename op_type>
TFLiteOperation(TfLiteContext * context,TfLiteNode * node,const OpContext & op_context)87 void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
88                       const OpContext& op_context) {
89   reference_ops::MaximumMinimumBroadcast4DSlow(
90       GetTensorShape(op_context.input1),
91       GetTensorData<data_type>(op_context.input1),
92       GetTensorShape(op_context.input2),
93       GetTensorData<data_type>(op_context.input2),
94       GetTensorShape(op_context.output),
95       GetTensorData<data_type>(op_context.output),
96       op_type::template op<data_type>);
97 }
98 
99 template <KernelType kernel_type, typename OpType>
Eval(TfLiteContext * context,TfLiteNode * node)100 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
101   OpContext op_context(context, node);
102 
103   if (kernel_type == kReference) {
104     switch (op_context.output->type) {
105       case kTfLiteFloat32:
106         TFLiteOperation<float, OpType>(context, node, op_context);
107         break;
108       case kTfLiteUInt8:
109         TFLiteOperation<uint8_t, OpType>(context, node, op_context);
110         break;
111       case kTfLiteInt8:
112         TFLiteOperation<int8_t, OpType>(context, node, op_context);
113         break;
114       case kTfLiteInt32:
115        TFLiteOperation<int32_t, OpType>(context, node, op_context);
116         break;
117       case kTfLiteInt64:
118         TFLiteOperation<int64_t, OpType>(context, node, op_context);
119         break;
120       default:
121         context->ReportError(context,
122                              "Type %d is currently not supported by Maximum.",
123                              op_context.output->type);
124         return kTfLiteError;
125     }
126   } else {
127     context->ReportError(context,
128                          "Type %d is currently not supported by Maximum.",
129                          op_context.output->type);
130     return kTfLiteError;
131   }
132   return kTfLiteOk;
133 }
134 
135 }  // namespace maximum_minimum
136 
Register_MAXIMUM_REF()137 TfLiteRegistration* Register_MAXIMUM_REF() {
138   static TfLiteRegistration r = {
139       nullptr, nullptr, maximum_minimum::Prepare,
140       maximum_minimum::Eval<maximum_minimum::kReference,
141                             maximum_minimum::MaximumOp>};
142   return &r;
143 }
144 
Register_MINIMUM_REF()145 TfLiteRegistration* Register_MINIMUM_REF() {
146   static TfLiteRegistration r = {
147       nullptr, nullptr, maximum_minimum::Prepare,
148       maximum_minimum::Eval<maximum_minimum::kReference,
149                             maximum_minimum::MinimumOp>};
150   return &r;
151 }
Register_MAXIMUM()152 TfLiteRegistration* Register_MAXIMUM() { return Register_MAXIMUM_REF(); }
Register_MINIMUM()153 TfLiteRegistration* Register_MINIMUM() { return Register_MINIMUM_REF(); }
154 
155 }  // namespace builtin
156 }  // namespace ops
157 }  // namespace tflite
158