• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <cmath>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/custom_ops_register.h"
19 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 
22 namespace tflite {
23 namespace ops {
24 namespace custom {
25 namespace sign {
26 
27 // Performs common preparation for pointwise, unary ops, i.e., type checks and
28 // output tensor resizing.
PointwiseUnaryOpPrepare(TfLiteContext * context,TfLiteNode * node)29 TfLiteStatus PointwiseUnaryOpPrepare(TfLiteContext* context, TfLiteNode* node) {
30   TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 1);
31 
32   const TfLiteTensor* input = tflite::GetInput(context, node, 0);
33   TfLiteTensor* output = tflite::GetOutput(context, node, 0);
34 
35   // Validate size and type constraints
36   TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
37   TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
38   return context->ResizeTensor(context, output, output_shape);
39 }
40 
41 // Applies the operator Op pointwise to data of type T.
42 template <typename Op, typename T>
PointwiseUnaryOpDoEval(TfLiteContext * context,const TfLiteTensor * input,TfLiteTensor * output)43 TfLiteStatus PointwiseUnaryOpDoEval(
44     TfLiteContext* context,
45     const TfLiteTensor* input,
46     TfLiteTensor* output) {
47   const T* data = tflite::GetTensorData<T>(input);
48   T* data_output = tflite::GetTensorData<T>(output);
49 
50   const int64_t num_elements = NumElements(input);
51   for (int64_t i = 0; i < num_elements; ++i) {
52     data_output[i] = Op::template Eval<T>(data[i]);
53   }
54 
55   return TfLiteStatus::kTfLiteOk;
56 }
57 
58 // A generic evaluation function where the actual data processing is handled
59 // by the Op::Eval<T> function.
60 template <typename Op>
PointwiseUnaryOpEval(TfLiteContext * context,TfLiteNode * node)61 TfLiteStatus PointwiseUnaryOpEval(TfLiteContext* context, TfLiteNode* node) {
62   const TfLiteTensor* input = tflite::GetInput(context, node, 0);
63   TfLiteTensor* output = tflite::GetOutput(context, node, 0);
64 
65   switch (output->type) {
66     case kTfLiteFloat32:
67       TF_LITE_ENSURE_OK(
68           context,
69           (PointwiseUnaryOpDoEval<Op, float>(context, input, output)));
70       break;
71     case kTfLiteFloat64:
72       TF_LITE_ENSURE_OK(
73           context,
74           (PointwiseUnaryOpDoEval<Op, double>(context, input, output)));
75       break;
76     default:
77       TF_LITE_KERNEL_LOG(
78           context,
79           "Unsupported datatype for atan2 output: %s",
80           TfLiteTypeGetName(output->type));
81   }
82 
83   return TfLiteStatus::kTfLiteOk;
84 }
85 
86 // Operator that computes the sign function.
87 struct Sign {
88   template <typename T>
Evaltflite::ops::custom::sign::Sign89   static T Eval(T x) {
90     if (x > 0) {
91       return 1;
92     }
93     if (x < 0) {
94       return -1;
95     }
96     return 0;
97   }
98 };
99 
100 }  // namespace sign
101 
Register_SIGN()102 TfLiteRegistration* Register_SIGN() {
103   static TfLiteRegistration r = {nullptr, nullptr,
104                                  sign::PointwiseUnaryOpPrepare,
105                                  sign::PointwiseUnaryOpEval<sign::Sign>};
106   return &r;
107 }
108 
109 }  // namespace custom
110 }  // namespace ops
111 }  // namespace tflite
112