• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 
22 namespace tflite {
23 namespace ops {
24 namespace builtin {
25 namespace elementwise {
26 namespace {
27 
IsNumericSupportedType(const TfLiteType type)28 bool IsNumericSupportedType(const TfLiteType type) {
29   return type == kTfLiteFloat32;
30 }
31 
IsLogicalSupportedType(const TfLiteType type)32 bool IsLogicalSupportedType(const TfLiteType type) {
33   return type == kTfLiteBool;
34 }
35 
36 typedef bool (*IsSupportedType)(TfLiteType);
37 template <IsSupportedType>
GenericPrepare(TfLiteContext * context,TfLiteNode * node)38 TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
39   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
40   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
41   const TfLiteTensor* input = GetInput(context, node, 0);
42   TfLiteTensor* output = GetOutput(context, node, 0);
43   TF_LITE_ENSURE_EQ(context, input->type, output->type);
44   if (!IsSupportedType(input->type)) {
45     context->ReportError(context, "Current data type %d is not supported.",
46                          input->type);
47     return kTfLiteError;
48   }
49   return context->ResizeTensor(context, output,
50                                TfLiteIntArrayCopy(input->dims));
51 }
52 
53 template <typename T>
EvalImpl(TfLiteContext * context,TfLiteNode * node,T func (T),TfLiteType expected_type)54 inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
55                              T func(T), TfLiteType expected_type) {
56   const TfLiteTensor* input = GetInput(context, node, 0);
57   TfLiteTensor* output = GetOutput(context, node, 0);
58   TF_LITE_ENSURE_EQ(context, input->type, expected_type);
59   const int64_t num_elements = NumElements(input);
60   const T* in_data = GetTensorData<T>(input);
61   T* out_data = GetTensorData<T>(output);
62   for (int64_t i = 0; i < num_elements; ++i) {
63     out_data[i] = func(in_data[i]);
64   }
65   return kTfLiteOk;
66 }
67 
EvalNumeric(TfLiteContext * context,TfLiteNode * node,float float_func (float))68 inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
69                                 float float_func(float)) {
70   return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
71 }
72 
EvalLogical(TfLiteContext * context,TfLiteNode * node,bool bool_func (bool))73 inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
74                                 bool bool_func(bool)) {
75   return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
76 }
77 
AbsEval(TfLiteContext * context,TfLiteNode * node)78 TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
79   return EvalNumeric(context, node, std::abs);
80 }
81 
SinEval(TfLiteContext * context,TfLiteNode * node)82 TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
83   return EvalNumeric(context, node, std::sin);
84 }
85 
CosEval(TfLiteContext * context,TfLiteNode * node)86 TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
87   return EvalNumeric(context, node, std::cos);
88 }
89 
LogEval(TfLiteContext * context,TfLiteNode * node)90 TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
91   return EvalNumeric(context, node, std::log);
92 }
93 
SqrtEval(TfLiteContext * context,TfLiteNode * node)94 TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
95   return EvalNumeric(context, node, std::sqrt);
96 }
97 
RsqrtEval(TfLiteContext * context,TfLiteNode * node)98 TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
99   return EvalNumeric(context, node, [](float f) { return 1.f / std::sqrt(f); });
100 }
101 
SquareEval(TfLiteContext * context,TfLiteNode * node)102 TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
103   return EvalNumeric(context, node, [](float f) { return f * f; });
104 }
105 
LogicalNotEval(TfLiteContext * context,TfLiteNode * node)106 TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
107   return EvalLogical(context, node, [](bool v) { return !v; });
108 }
109 
110 }  // namespace
111 }  // namespace elementwise
112 
Register_ABS()113 TfLiteRegistration* Register_ABS() {
114   static TfLiteRegistration r = {
115       /*init=*/nullptr, /*free=*/nullptr,
116       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
117       elementwise::AbsEval};
118   return &r;
119 }
120 
Register_SIN()121 TfLiteRegistration* Register_SIN() {
122   static TfLiteRegistration r = {
123       /*init=*/nullptr, /*free=*/nullptr,
124       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
125       elementwise::SinEval};
126   return &r;
127 }
128 
Register_COS()129 TfLiteRegistration* Register_COS() {
130   static TfLiteRegistration r = {
131       /*init=*/nullptr, /*free=*/nullptr,
132       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
133       elementwise::CosEval};
134   return &r;
135 }
136 
Register_LOG()137 TfLiteRegistration* Register_LOG() {
138   static TfLiteRegistration r = {
139       /*init=*/nullptr, /*free=*/nullptr,
140       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
141       elementwise::LogEval};
142   return &r;
143 }
144 
Register_SQRT()145 TfLiteRegistration* Register_SQRT() {
146   static TfLiteRegistration r = {
147       /*init=*/nullptr, /*free=*/nullptr,
148       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
149       elementwise::SqrtEval};
150   return &r;
151 }
152 
Register_RSQRT()153 TfLiteRegistration* Register_RSQRT() {
154   static TfLiteRegistration r = {
155       /*init=*/nullptr, /*free=*/nullptr,
156       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
157       elementwise::RsqrtEval};
158   return &r;
159 }
160 
Register_SQUARE()161 TfLiteRegistration* Register_SQUARE() {
162   static TfLiteRegistration r = {
163       /*init=*/nullptr, /*free=*/nullptr,
164       elementwise::GenericPrepare<elementwise::IsNumericSupportedType>,
165       elementwise::SquareEval};
166   return &r;
167 }
168 
Register_LOGICAL_NOT()169 TfLiteRegistration* Register_LOGICAL_NOT() {
170   static TfLiteRegistration r = {
171       /*init=*/nullptr, /*free=*/nullptr,
172       elementwise::GenericPrepare<elementwise::IsLogicalSupportedType>,
173       elementwise::LogicalNotEval};
174   return &r;
175 }
176 
177 }  // namespace builtin
178 }  // namespace ops
179 }  // namespace tflite
180