1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <cmath>
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
19 #include "tensorflow/lite/kernels/kernel_util.h"
20
21 namespace tflite {
22 namespace ops {
23 namespace builtin {
24 namespace atan2 {
25
EnsureSameShape(TfLiteContext * context,const TfLiteTensor * a,const TfLiteTensor * b)26 TfLiteStatus EnsureSameShape(
27 TfLiteContext* context,
28 const TfLiteTensor* a, const TfLiteTensor* b) {
29 TF_LITE_ENSURE_EQ(context,
30 tflite::NumDimensions(a),
31 tflite::NumDimensions(b));
32
33 return TfLiteStatus::kTfLiteOk;
34 }
35
Atan2Prepare(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus Atan2Prepare(TfLiteContext* context, TfLiteNode* node) {
37 TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
38 TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
39
40 const TfLiteTensor* input_y = tflite::GetInput(context, node, 0);
41 const TfLiteTensor* input_x = tflite::GetInput(context, node, 1);
42 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
43
44 // Validate size and type constraints
45 TF_LITE_ENSURE_OK(context, EnsureSameShape(context, input_y, input_x));
46 TF_LITE_ENSURE_TYPES_EQ(context, input_y->type, input_x->type);
47 TF_LITE_ENSURE_TYPES_EQ(context, input_y->type, output->type);
48 TF_LITE_ENSURE(context,
49 input_y->type == kTfLiteFloat32 ||
50 input_y->type == kTfLiteFloat64);
51
52 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input_y->dims);
53
54 return context->ResizeTensor(context, output, output_shape);
55 }
56
57 template<typename Float>
Atan2(const TfLiteTensor * input_y,const TfLiteTensor * input_x,TfLiteTensor * output)58 TfLiteStatus Atan2(const TfLiteTensor* input_y,
59 const TfLiteTensor* input_x,
60 TfLiteTensor* output) {
61 const Float* data_y = tflite::GetTensorData<Float>(input_y);
62 const Float* data_x = tflite::GetTensorData<Float>(input_x);
63 Float* data_output = tflite::GetTensorData<Float>(output);
64
65 const int64_t num_elements = NumElements(input_y);
66 for (int64_t i = 0; i < num_elements; ++i) {
67 data_output[i] = std::atan2(data_y[i], data_x[i]);
68 }
69
70 return TfLiteStatus::kTfLiteOk;
71 }
72
Atan2Eval(TfLiteContext * context,TfLiteNode * node)73 TfLiteStatus Atan2Eval(TfLiteContext* context, TfLiteNode* node) {
74 const TfLiteTensor* input_y = tflite::GetInput(context, node, 0);
75 const TfLiteTensor* input_x = tflite::GetInput(context, node, 1);
76 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
77
78 switch (output->type) {
79 case kTfLiteFloat32:
80 TF_LITE_ENSURE_OK(context, Atan2<float>(input_y, input_x, output));
81 break;
82 case kTfLiteFloat64:
83 TF_LITE_ENSURE_OK(context, Atan2<double>(input_y, input_x, output));
84 break;
85 default:
86 TF_LITE_KERNEL_LOG(
87 context,
88 "Unsupported datatype for atan2 output: %s",
89 TfLiteTypeGetName(output->type));
90 }
91
92 return TfLiteStatus::kTfLiteOk;
93 }
94
95 } // namespace atan2
96
Register_ATAN2()97 TfLiteRegistration* Register_ATAN2() {
98 static TfLiteRegistration r = {
99 nullptr, nullptr, atan2::Atan2Prepare, atan2::Atan2Eval};
100 return &r;
101 }
102
103 } // namespace builtin
104 } // namespace ops
105 } // namespace tflite
106