• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <cmath>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 
21 namespace tflite {
22 namespace ops {
23 namespace builtin {
24 namespace atan2 {
25 
EnsureSameShape(TfLiteContext * context,const TfLiteTensor * a,const TfLiteTensor * b)26 TfLiteStatus EnsureSameShape(
27     TfLiteContext* context,
28     const TfLiteTensor* a, const TfLiteTensor* b) {
29   TF_LITE_ENSURE_EQ(context,
30                     tflite::NumDimensions(a),
31                     tflite::NumDimensions(b));
32 
33   return TfLiteStatus::kTfLiteOk;
34 }
35 
Atan2Prepare(TfLiteContext * context,TfLiteNode * node)36 TfLiteStatus Atan2Prepare(TfLiteContext* context, TfLiteNode* node) {
37   TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
38   TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
39 
40   const TfLiteTensor* input_y = tflite::GetInput(context, node, 0);
41   const TfLiteTensor* input_x = tflite::GetInput(context, node, 1);
42   TfLiteTensor* output = tflite::GetOutput(context, node, 0);
43 
44   // Validate size and type constraints
45   TF_LITE_ENSURE_OK(context, EnsureSameShape(context, input_y, input_x));
46   TF_LITE_ENSURE_TYPES_EQ(context, input_y->type, input_x->type);
47   TF_LITE_ENSURE_TYPES_EQ(context, input_y->type, output->type);
48   TF_LITE_ENSURE(context,
49                  input_y->type == kTfLiteFloat32 ||
50                  input_y->type == kTfLiteFloat64);
51 
52   TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input_y->dims);
53 
54   return context->ResizeTensor(context, output, output_shape);
55 }
56 
57 template<typename Float>
Atan2(const TfLiteTensor * input_y,const TfLiteTensor * input_x,TfLiteTensor * output)58 TfLiteStatus Atan2(const TfLiteTensor* input_y,
59                    const TfLiteTensor* input_x,
60                    TfLiteTensor* output) {
61   const Float* data_y = tflite::GetTensorData<Float>(input_y);
62   const Float* data_x = tflite::GetTensorData<Float>(input_x);
63   Float* data_output = tflite::GetTensorData<Float>(output);
64 
65   const int64_t num_elements = NumElements(input_y);
66   for (int64_t i = 0; i < num_elements; ++i) {
67     data_output[i] = std::atan2(data_y[i], data_x[i]);
68   }
69 
70   return TfLiteStatus::kTfLiteOk;
71 }
72 
Atan2Eval(TfLiteContext * context,TfLiteNode * node)73 TfLiteStatus Atan2Eval(TfLiteContext* context, TfLiteNode* node) {
74   const TfLiteTensor* input_y = tflite::GetInput(context, node, 0);
75   const TfLiteTensor* input_x = tflite::GetInput(context, node, 1);
76   TfLiteTensor* output = tflite::GetOutput(context, node, 0);
77 
78   switch (output->type) {
79     case kTfLiteFloat32:
80       TF_LITE_ENSURE_OK(context, Atan2<float>(input_y, input_x, output));
81       break;
82     case kTfLiteFloat64:
83       TF_LITE_ENSURE_OK(context, Atan2<double>(input_y, input_x, output));
84       break;
85     default:
86       TF_LITE_KERNEL_LOG(
87           context,
88           "Unsupported datatype for atan2 output: %s",
89           TfLiteTypeGetName(output->type));
90   }
91 
92   return TfLiteStatus::kTfLiteOk;
93 }
94 
95 }  // namespace atan2
96 
Register_ATAN2()97 TfLiteRegistration* Register_ATAN2() {
98   static TfLiteRegistration r = {
99     nullptr, nullptr, atan2::Atan2Prepare, atan2::Atan2Eval};
100   return &r;
101 }
102 
103 }  // namespace builtin
104 }  // namespace ops
105 }  // namespace tflite
106