1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include <vector>
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/c_api_internal.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/kernel_util.h"
22 #include "tensorflow/lite/kernels/op_macros.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace maximum_minimum {
28
29 // This file has a reference implementation of TFMaximum/TFMinimum.
30 enum KernelType {
31 kReference,
32 };
33
34 constexpr int kInputTensor1 = 0;
35 constexpr int kInputTensor2 = 1;
36 constexpr int kOutputTensor = 0;
37
38 struct OpContext {
OpContexttflite::ops::builtin::maximum_minimum::OpContext39 OpContext(TfLiteContext* context, TfLiteNode* node) {
40 input1 = GetInput(context, node, kInputTensor1);
41 input2 = GetInput(context, node, kInputTensor2);
42 output = GetOutput(context, node, kOutputTensor);
43 }
44 const TfLiteTensor* input1;
45 const TfLiteTensor* input2;
46 TfLiteTensor* output;
47 };
48
Prepare(TfLiteContext * context,TfLiteNode * node)49 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
50 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
51 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
52
53 OpContext op_context(context, node);
54 TF_LITE_ENSURE_EQ(context, op_context.input1->type, op_context.input2->type);
55 op_context.output->type = op_context.input1->type;
56
57 bool requires_broadcast =
58 !HaveSameShapes(op_context.input1, op_context.input2);
59
60 TfLiteIntArray* output_size = nullptr;
61 if (requires_broadcast) {
62 TF_LITE_ENSURE_OK(
63 context, CalculateShapeForBroadcast(context, op_context.input1,
64 op_context.input2, &output_size));
65 } else {
66 output_size = TfLiteIntArrayCopy(op_context.input1->dims);
67 }
68
69 return context->ResizeTensor(context, op_context.output, output_size);
70 }
71
72 struct MaximumOp {
73 template <typename data_type>
optflite::ops::builtin::maximum_minimum::MaximumOp74 static data_type op(data_type el1, data_type el2) {
75 return el1 > el2 ? el1 : el2;
76 }
77 };
78
79 struct MinimumOp {
80 template <typename data_type>
optflite::ops::builtin::maximum_minimum::MinimumOp81 static data_type op(data_type el1, data_type el2) {
82 return el1 < el2 ? el1 : el2;
83 }
84 };
85
86 template <typename data_type, typename op_type>
TFLiteOperation(TfLiteContext * context,TfLiteNode * node,const OpContext & op_context)87 void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
88 const OpContext& op_context) {
89 reference_ops::MaximumMinimumBroadcast4DSlow(
90 GetTensorShape(op_context.input1),
91 GetTensorData<data_type>(op_context.input1),
92 GetTensorShape(op_context.input2),
93 GetTensorData<data_type>(op_context.input2),
94 GetTensorShape(op_context.output),
95 GetTensorData<data_type>(op_context.output),
96 op_type::template op<data_type>);
97 }
98
99 template <KernelType kernel_type, typename OpType>
Eval(TfLiteContext * context,TfLiteNode * node)100 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
101 OpContext op_context(context, node);
102
103 if (kernel_type == kReference) {
104 switch (op_context.output->type) {
105 case kTfLiteFloat32:
106 TFLiteOperation<float, OpType>(context, node, op_context);
107 break;
108 case kTfLiteUInt8:
109 TFLiteOperation<uint8_t, OpType>(context, node, op_context);
110 break;
111 case kTfLiteInt8:
112 TFLiteOperation<int8_t, OpType>(context, node, op_context);
113 break;
114 case kTfLiteInt32:
115 TFLiteOperation<int32_t, OpType>(context, node, op_context);
116 break;
117 case kTfLiteInt64:
118 TFLiteOperation<int64_t, OpType>(context, node, op_context);
119 break;
120 default:
121 context->ReportError(context,
122 "Type %d is currently not supported by Maximum.",
123 op_context.output->type);
124 return kTfLiteError;
125 }
126 } else {
127 context->ReportError(context,
128 "Type %d is currently not supported by Maximum.",
129 op_context.output->type);
130 return kTfLiteError;
131 }
132 return kTfLiteOk;
133 }
134
135 } // namespace maximum_minimum
136
Register_MAXIMUM_REF()137 TfLiteRegistration* Register_MAXIMUM_REF() {
138 static TfLiteRegistration r = {
139 nullptr, nullptr, maximum_minimum::Prepare,
140 maximum_minimum::Eval<maximum_minimum::kReference,
141 maximum_minimum::MaximumOp>};
142 return &r;
143 }
144
Register_MINIMUM_REF()145 TfLiteRegistration* Register_MINIMUM_REF() {
146 static TfLiteRegistration r = {
147 nullptr, nullptr, maximum_minimum::Prepare,
148 maximum_minimum::Eval<maximum_minimum::kReference,
149 maximum_minimum::MinimumOp>};
150 return &r;
151 }
Register_MAXIMUM()152 TfLiteRegistration* Register_MAXIMUM() { return Register_MAXIMUM_REF(); }
Register_MINIMUM()153 TfLiteRegistration* Register_MINIMUM() { return Register_MINIMUM_REF(); }
154
155 } // namespace builtin
156 } // namespace ops
157 } // namespace tflite
158