1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 #include <stdint.h>
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace pow {
29 namespace {
30
31 // Input/output tensor index.
32 constexpr int kInputTensor1 = 0;
33 constexpr int kInputTensor2 = 1;
34 constexpr int kOutputTensor = 0;
35
36 // Op data for pow op.
37 struct OpData {
38 bool requires_broadcast;
39 };
40
Init(TfLiteContext * context,const char * buffer,size_t length)41 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42 auto* data = new OpData;
43 data->requires_broadcast = false;
44 return data;
45 }
46
Free(TfLiteContext * context,void * buffer)47 void Free(TfLiteContext* context, void* buffer) {
48 delete reinterpret_cast<OpData*>(buffer);
49 }
50
Prepare(TfLiteContext * context,TfLiteNode * node)51 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
52 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
53 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
54
55 OpData* data = reinterpret_cast<OpData*>(node->user_data);
56
57 const TfLiteTensor* input1;
58 TF_LITE_ENSURE_OK(context,
59 GetInputSafe(context, node, kInputTensor1, &input1));
60 const TfLiteTensor* input2;
61 TF_LITE_ENSURE_OK(context,
62 GetInputSafe(context, node, kInputTensor2, &input2));
63 TfLiteTensor* output;
64 TF_LITE_ENSURE_OK(context,
65 GetOutputSafe(context, node, kOutputTensor, &output));
66
67 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
68
69 const TfLiteType type = input1->type;
70 if (type != kTfLiteInt32 && type != kTfLiteFloat32) {
71 TF_LITE_KERNEL_LOG(context, "Unsupported data type %s.",
72 TfLiteTypeGetName(type));
73 return kTfLiteError;
74 }
75 output->type = type;
76
77 data->requires_broadcast = !HaveSameShapes(input1, input2);
78
79 TfLiteIntArray* output_size = nullptr;
80 if (data->requires_broadcast) {
81 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
82 context, input1, input2, &output_size));
83 } else {
84 output_size = TfLiteIntArrayCopy(input1->dims);
85 }
86
87 return context->ResizeTensor(context, output, output_size);
88 }
89
90 template <typename T>
PowImpl(const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,bool requires_broadcast)91 void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
92 TfLiteTensor* output, bool requires_broadcast) {
93 if (requires_broadcast) {
94 optimized_ops::BroadcastPow4D(
95 GetTensorShape(input1), GetTensorData<T>(input1),
96 GetTensorShape(input2), GetTensorData<T>(input2),
97 GetTensorShape(output), GetTensorData<T>(output));
98 } else {
99 reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
100 GetTensorShape(input2), GetTensorData<T>(input2),
101 GetTensorShape(output), GetTensorData<T>(output));
102 }
103 }
104
CheckValue(TfLiteContext * context,const TfLiteTensor * input)105 TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
106 const int64_t num_elements = NumElements(input);
107 const int32_t* data = GetTensorData<int32_t>(input);
108 for (int i = 0; i < num_elements; ++i) {
109 if (data[i] < 0) {
110 context->ReportError(context,
111 "POW does not support negative value for int32.");
112 return kTfLiteError;
113 }
114 }
115 return kTfLiteOk;
116 }
117
Eval(TfLiteContext * context,TfLiteNode * node)118 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
119 OpData* data = reinterpret_cast<OpData*>(node->user_data);
120
121 const TfLiteTensor* input1;
122 TF_LITE_ENSURE_OK(context,
123 GetInputSafe(context, node, kInputTensor1, &input1));
124 const TfLiteTensor* input2;
125 TF_LITE_ENSURE_OK(context,
126 GetInputSafe(context, node, kInputTensor2, &input2));
127 TfLiteTensor* output;
128 TF_LITE_ENSURE_OK(context,
129 GetOutputSafe(context, node, kOutputTensor, &output));
130
131 switch (output->type) {
132 case kTfLiteInt32: {
133 // TensorFlow does not support negative for int32.
134 TF_LITE_ENSURE_OK(context, CheckValue(context, input2));
135 PowImpl<int32_t>(input1, input2, output, data->requires_broadcast);
136 break;
137 }
138 case kTfLiteFloat32: {
139 PowImpl<float>(input1, input2, output, data->requires_broadcast);
140 break;
141 }
142 default: {
143 context->ReportError(context, "Unsupported data type: %d", output->type);
144 return kTfLiteError;
145 }
146 }
147 return kTfLiteOk;
148 }
149
150 } // namespace
151 } // namespace pow
152
Register_POW()153 TfLiteRegistration* Register_POW() {
154 static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval};
155 return &r;
156 }
157
158 } // namespace builtin
159 } // namespace ops
160 } // namespace tflite
161