• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/micro/kernels/kernel_util.h"
22 #include "tensorflow/lite/micro/micro_utils.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace micro {
27 namespace elementwise {
28 namespace {
29 
IsNumericSupportedType(const TfLiteType type)30 bool IsNumericSupportedType(const TfLiteType type) {
31   return type == kTfLiteFloat32;
32 }
33 
IsLogicalSupportedType(const TfLiteType type)34 bool IsLogicalSupportedType(const TfLiteType type) {
35   return type == kTfLiteBool;
36 }
37 
38 typedef bool (*IsSupportedType)(TfLiteType);
39 template <IsSupportedType>
GenericPrepare(TfLiteContext * context,TfLiteNode * node)40 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
41   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
42   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
43   const TfLiteTensor* input = GetInput(context, node, 0);
44   TF_LITE_ENSURE(context, input != nullptr);
45   TfLiteTensor* output = GetOutput(context, node, 0);
46   TF_LITE_ENSURE(context, output != nullptr);
47   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
48   if (!IsSupportedType(input->type)) {
49     TF_LITE_KERNEL_LOG(context, "Input data type %s (%d) is not supported.",
50                        TfLiteTypeGetName(input->type), input->type);
51     return kTfLiteError;
52   }
53   return kTfLiteOk;
54 }
55 
56 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,T func (T),TfLiteType expected_type)57 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
58                              T func(T), TfLiteType expected_type) {
59   const TfLiteEvalTensor* input = tflite::micro::GetEvalInput(context, node, 0);
60   TfLiteEvalTensor* output = tflite::micro::GetEvalOutput(context, node, 0);
61   TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
62   const size_t num_elements = ElementCount(*input->dims);
63   const T* in_data = tflite::micro::GetTensorData<T>(input);
64   T* out_data = tflite::micro::GetTensorData<T>(output);
65   for (size_t i = 0; i < num_elements; ++i) {
66     out_data[i] = func(in_data[i]);
67   }
68   return kTfLiteOk;
69 }
70 
EvalNumeric(TfLiteContext * context,TfLiteNode * node,float float_func (float))71 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
72                                 float float_func(float)) {
73   return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
74 }
75 
EvalLogical(TfLiteContext * context,TfLiteNode * node,bool bool_func (bool))76 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
77                                 bool bool_func(bool)) {
78   return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
79 }
80 
AbsEval(TfLiteContext * context,TfLiteNode * node)81 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
82   return EvalNumeric(context, node, std::abs);
83 }
84 
SinEval(TfLiteContext * context,TfLiteNode * node)85 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
86   return EvalNumeric(context, node, std::sin);
87 }
88 
CosEval(TfLiteContext * context,TfLiteNode * node)89 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
90   return EvalNumeric(context, node, std::cos);
91 }
92 
LogEval(TfLiteContext * context,TfLiteNode * node)93 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
94   return EvalNumeric(context, node, std::log);
95 }
96 
SqrtEval(TfLiteContext * context,TfLiteNode * node)97 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
98   return EvalNumeric(context, node, std::sqrt);
99 }
100 
RsqrtEval(TfLiteContext * context,TfLiteNode * node)101 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
102   return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
103 }
104 
SquareEval(TfLiteContext * context,TfLiteNode * node)105 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
106   return EvalNumeric(context, node, [](float f) { return f * f; });
107 }
108 
LogicalNotEval(TfLiteContext * context,TfLiteNode * node)109 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
110   return EvalLogical(context, node, [](bool v) { return !v; });
111 }
112 
113 }  // namespace
114 }  // namespace elementwise
115 
Register_ABS()116 TfLiteRegistration Register_ABS() {
117   return {/*init=*/nullptr,
118           /*free=*/nullptr,
119           /*prepare=*/
120           elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
121           /*invoke=*/elementwise::AbsEval,
122           /*profiling_string=*/nullptr,
123           /*builtin_code=*/0,
124           /*custom_name=*/nullptr,
125           /*version=*/0};
126 }
127 
Register_SIN()128 TfLiteRegistration Register_SIN() {
129   return {/*init=*/nullptr,
130           /*free=*/nullptr,
131           /*prepare=*/
132           elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
133           /*invoke=*/elementwise::SinEval,
134           /*profiling_string=*/nullptr,
135           /*builtin_code=*/0,
136           /*custom_name=*/nullptr,
137           /*version=*/0};
138 }
139 
Register_COS()140 TfLiteRegistration Register_COS() {
141   return {/*init=*/nullptr,
142           /*free=*/nullptr,
143           /*prepare=*/
144           elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
145           /*invoke=*/elementwise::CosEval,
146           /*profiling_string=*/nullptr,
147           /*builtin_code=*/0,
148           /*custom_name=*/nullptr,
149           /*version=*/0};
150 }
151 
Register_LOG()152 TfLiteRegistration Register_LOG() {
153   return {/*init=*/nullptr,
154           /*free=*/nullptr,
155           /*prepare=*/
156           elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
157           /*invoke=*/elementwise::LogEval,
158           /*profiling_string=*/nullptr,
159           /*builtin_code=*/0,
160           /*custom_name=*/nullptr,
161           /*version=*/0};
162 }
163 
Register_SQRT()164 TfLiteRegistration Register_SQRT() {
165   return {/*init=*/nullptr,
166           /*free=*/nullptr,
167           /*prepare=*/
168           elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
169           /*invoke=*/elementwise::SqrtEval,
170           /*profiling_string=*/nullptr,
171           /*builtin_code=*/0,
172           /*custom_name=*/nullptr,
173           /*version=*/0};
174 }
175 
Register_RSQRT()176 TfLiteRegistration Register_RSQRT() {
177   return {/*init=*/nullptr,
178           /*free=*/nullptr,
179           /*prepare=*/
180           elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
181           /*invoke=*/elementwise::RsqrtEval,
182           /*profiling_string=*/nullptr,
183           /*builtin_code=*/0,
184           /*custom_name=*/nullptr,
185           /*version=*/0};
186 }
187 
Register_SQUARE()188 TfLiteRegistration Register_SQUARE() {
189   return {/*init=*/nullptr,
190           /*free=*/nullptr,
191           /*prepare=*/
192           elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
193           /*invoke=*/elementwise::SquareEval,
194           /*profiling_string=*/nullptr,
195           /*builtin_code=*/0,
196           /*custom_name=*/nullptr,
197           /*version=*/0};
198 }
199 
Register_LOGICAL_NOT()200 TfLiteRegistration Register_LOGICAL_NOT() {
201   return {/*init=*/nullptr,
202           /*free=*/nullptr,
203           /*prepare=*/
204           elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
205           /*invoke=*/elementwise::LogicalNotEval,
206           /*profiling_string=*/nullptr,
207           /*builtin_code=*/0,
208           /*custom_name=*/nullptr,
209           /*version=*/0};
210 }
211 
212 }  // namespace micro
213 }  // namespace ops
214 }  // namespace tflite
215