1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <cmath>
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/custom_ops_register.h"
19 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21
22 namespace tflite {
23 namespace ops {
24 namespace custom {
25 namespace sign {
26
27 // Performs common preparation for pointwise, unary ops, i.e., type checks and
28 // output tensor resizing.
PointwiseUnaryOpPrepare(TfLiteContext * context,TfLiteNode * node)29 TfLiteStatus PointwiseUnaryOpPrepare(TfLiteContext* context, TfLiteNode* node) {
30 TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 1);
31
32 const TfLiteTensor* input = tflite::GetInput(context, node, 0);
33 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
34
35 // Validate size and type constraints
36 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
37 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
38 return context->ResizeTensor(context, output, output_shape);
39 }
40
41 // Applies the operator Op pointwise to data of type T.
42 template <typename Op, typename T>
PointwiseUnaryOpDoEval(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output)43 TfLiteStatus PointwiseUnaryOpDoEval(
44 TfLiteContext* context,
45 const TfLiteTensor* input,
46 TfLiteTensor* output) {
47 const T* data = tflite::GetTensorData<T>(input);
48 T* data_output = tflite::GetTensorData<T>(output);
49
50 const int64_t num_elements = NumElements(input);
51 for (int64_t i = 0; i < num_elements; ++i) {
52 data_output[i] = Op::template Eval<T>(data[i]);
53 }
54
55 return TfLiteStatus::kTfLiteOk;
56 }
57
58 // A generic evaluation function where the actual data processing is handled
59 // by the Op::Eval<T> function.
60 template <typename Op>
PointwiseUnaryOpEval(TfLiteContext * context,TfLiteNode * node)61 TfLiteStatus PointwiseUnaryOpEval(TfLiteContext* context, TfLiteNode* node) {
62 const TfLiteTensor* input = tflite::GetInput(context, node, 0);
63 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
64
65 switch (output->type) {
66 case kTfLiteFloat32:
67 TF_LITE_ENSURE_OK(
68 context,
69 (PointwiseUnaryOpDoEval<Op, float>(context, input, output)));
70 break;
71 case kTfLiteFloat64:
72 TF_LITE_ENSURE_OK(
73 context,
74 (PointwiseUnaryOpDoEval<Op, double>(context, input, output)));
75 break;
76 default:
77 TF_LITE_KERNEL_LOG(
78 context,
79 "Unsupported datatype for atan2 output: %s",
80 TfLiteTypeGetName(output->type));
81 }
82
83 return TfLiteStatus::kTfLiteOk;
84 }
85
86 // Operator that computes the sign function.
87 struct Sign {
88 template <typename T>
Evaltflite::ops::custom::sign::Sign89 static T Eval(T x) {
90 if (x > 0) {
91 return 1;
92 }
93 if (x < 0) {
94 return -1;
95 }
96 return 0;
97 }
98 };
99
100 } // namespace sign
101
Register_SIGN()102 TfLiteRegistration* Register_SIGN() {
103 static TfLiteRegistration r = {nullptr, nullptr,
104 sign::PointwiseUnaryOpPrepare,
105 sign::PointwiseUnaryOpEval<sign::Sign>};
106 return &r;
107 }
108
109 } // namespace custom
110 } // namespace ops
111 } // namespace tflite
112