1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/micro/kernels/kernel_util.h"
22 #include "tensorflow/lite/micro/micro_utils.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace micro {
27 namespace elementwise {
28 namespace {
29
IsNumericSupportedType(const TfLiteType type)30 bool IsNumericSupportedType(const TfLiteType type) {
31 return type == kTfLiteFloat32;
32 }
33
IsLogicalSupportedType(const TfLiteType type)34 bool IsLogicalSupportedType(const TfLiteType type) {
35 return type == kTfLiteBool;
36 }
37
38 typedef bool (*IsSupportedType)(TfLiteType);
39 template <IsSupportedType>
GenericPrepare(TfLiteContext * context,TfLiteNode * node)40 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
41 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
42 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
43 const TfLiteTensor* input = GetInput(context, node, 0);
44 TF_LITE_ENSURE(context, input != nullptr);
45 TfLiteTensor* output = GetOutput(context, node, 0);
46 TF_LITE_ENSURE(context, output != nullptr);
47 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
48 if (!IsSupportedType(input->type)) {
49 TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
50 TfLiteTypeGetName(input->type), input->type);
51 return kTfLiteError;
52 }
53 return kTfLiteOk;
54 }
55
56 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,T func (T),TfLiteType expected_type)57 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
58 T func(T), TfLiteType expected_type) {
59 const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
60 TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
61 TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
62 const size_t num_elements = ElementCount(*input->dims);
63 const T* in_data = tflite::micro::GetTensorData<T>(input);
64 T* out_data = tflite::micro::GetTensorData<T>(output);
65 for (size_t i = 0; i < num_elements; ++i) {
66 out_data[i] = func(in_data[i]);
67 }
68 return kTfLiteOk;
69 }
70
EvalNumeric(TfLiteContext * context,TfLiteNode * node,float float_func (float))71 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
72 float float_func(float)) {
73 return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
74 }
75
EvalLogical(TfLiteContext * context,TfLiteNode * node,bool bool_func (bool))76 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
77 bool bool_func(bool)) {
78 return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
79 }
80
AbsEval(TfLiteContext * context,TfLiteNode * node)81 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
82 return EvalNumeric(context, node, std::abs);
83 }
84
SinEval(TfLiteContext * context,TfLiteNode * node)85 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
86 return EvalNumeric(context, node, std::sin);
87 }
88
CosEval(TfLiteContext * context,TfLiteNode * node)89 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
90 return EvalNumeric(context, node, std::cos);
91 }
92
LogEval(TfLiteContext * context,TfLiteNode * node)93 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
94 return EvalNumeric(context, node, std::log);
95 }
96
SqrtEval(TfLiteContext * context,TfLiteNode * node)97 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
98 return EvalNumeric(context, node, std::sqrt);
99 }
100
RsqrtEval(TfLiteContext * context,TfLiteNode * node)101 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
102 return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
103 }
104
SquareEval(TfLiteContext * context,TfLiteNode * node)105 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
106 return EvalNumeric(context, node, [](float f) { return f * f; });
107 }
108
LogicalNotEval(TfLiteContext * context,TfLiteNode * node)109 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
110 return EvalLogical(context, node, [](bool v) { return !v; });
111 }
112
113 } // namespace
114 } // namespace elementwise
115
Register_ABS()116 TfLiteRegistration Register_ABS() {
117 return {/*init=*/nullptr,
118 /*free=*/nullptr,
119 /*prepare=*/
120 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
121 /*invoke=*/elementwise::AbsEval,
122 /*profiling_string=*/nullptr,
123 /*builtin_code=*/0,
124 /*custom_name=*/nullptr,
125 /*version=*/0};
126 }
127
Register_SIN()128 TfLiteRegistration Register_SIN() {
129 return {/*init=*/nullptr,
130 /*free=*/nullptr,
131 /*prepare=*/
132 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
133 /*invoke=*/elementwise::SinEval,
134 /*profiling_string=*/nullptr,
135 /*builtin_code=*/0,
136 /*custom_name=*/nullptr,
137 /*version=*/0};
138 }
139
Register_COS()140 TfLiteRegistration Register_COS() {
141 return {/*init=*/nullptr,
142 /*free=*/nullptr,
143 /*prepare=*/
144 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
145 /*invoke=*/elementwise::CosEval,
146 /*profiling_string=*/nullptr,
147 /*builtin_code=*/0,
148 /*custom_name=*/nullptr,
149 /*version=*/0};
150 }
151
Register_LOG()152 TfLiteRegistration Register_LOG() {
153 return {/*init=*/nullptr,
154 /*free=*/nullptr,
155 /*prepare=*/
156 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
157 /*invoke=*/elementwise::LogEval,
158 /*profiling_string=*/nullptr,
159 /*builtin_code=*/0,
160 /*custom_name=*/nullptr,
161 /*version=*/0};
162 }
163
Register_SQRT()164 TfLiteRegistration Register_SQRT() {
165 return {/*init=*/nullptr,
166 /*free=*/nullptr,
167 /*prepare=*/
168 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
169 /*invoke=*/elementwise::SqrtEval,
170 /*profiling_string=*/nullptr,
171 /*builtin_code=*/0,
172 /*custom_name=*/nullptr,
173 /*version=*/0};
174 }
175
Register_RSQRT()176 TfLiteRegistration Register_RSQRT() {
177 return {/*init=*/nullptr,
178 /*free=*/nullptr,
179 /*prepare=*/
180 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
181 /*invoke=*/elementwise::RsqrtEval,
182 /*profiling_string=*/nullptr,
183 /*builtin_code=*/0,
184 /*custom_name=*/nullptr,
185 /*version=*/0};
186 }
187
Register_SQUARE()188 TfLiteRegistration Register_SQUARE() {
189 return {/*init=*/nullptr,
190 /*free=*/nullptr,
191 /*prepare=*/
192 elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
193 /*invoke=*/elementwise::SquareEval,
194 /*profiling_string=*/nullptr,
195 /*builtin_code=*/0,
196 /*custom_name=*/nullptr,
197 /*version=*/0};
198 }
199
Register_LOGICAL_NOT()200 TfLiteRegistration Register_LOGICAL_NOT() {
201 return {/*init=*/nullptr,
202 /*free=*/nullptr,
203 /*prepare=*/
204 elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
205 /*invoke=*/elementwise::LogicalNotEval,
206 /*profiling_string=*/nullptr,
207 /*builtin_code=*/0,
208 /*custom_name=*/nullptr,
209 /*version=*/0};
210 }
211
212 } // namespace micro
213 } // namespace ops
214 } // namespace tflite
215