• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 #include <stdint.h>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 
25 namespace tflite {
26 namespace ops {
27 namespace builtin {
28 namespace pow {
29 namespace {
30 
31 // Input/output tensor index.
32 constexpr int kInputTensor1 = 0;
33 constexpr int kInputTensor2 = 1;
34 constexpr int kOutputTensor = 0;
35 
36 // Op data for pow op.
37 struct OpData {
38   bool requires_broadcast;
39 };
40 
Init(TfLiteContext * context,const char * buffer,size_t length)41 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42   auto* data = new OpData;
43   data->requires_broadcast = false;
44   return data;
45 }
46 
Free(TfLiteContext * context,void * buffer)47 void Free(TfLiteContext* context, void* buffer) {
48   delete reinterpret_cast<OpData*>(buffer);
49 }
50 
Prepare(TfLiteContext * context,TfLiteNode * node)51 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
52   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
53   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
54 
55   OpData* data = reinterpret_cast<OpData*>(node->user_data);
56 
57   const TfLiteTensor* input1;
58   TF_LITE_ENSURE_OK(context,
59                     GetInputSafe(context, node, kInputTensor1, &input1));
60   const TfLiteTensor* input2;
61   TF_LITE_ENSURE_OK(context,
62                     GetInputSafe(context, node, kInputTensor2, &input2));
63   TfLiteTensor* output;
64   TF_LITE_ENSURE_OK(context,
65                     GetOutputSafe(context, node, kOutputTensor, &output));
66 
67   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
68 
69   const TfLiteType type = input1->type;
70   if (type != kTfLiteInt32 && type != kTfLiteFloat32) {
71     TF_LITE_KERNEL_LOG(context, "Unsupported data type %s.",
72                        TfLiteTypeGetName(type));
73     return kTfLiteError;
74   }
75   output->type = type;
76 
77   data->requires_broadcast = !HaveSameShapes(input1, input2);
78 
79   TfLiteIntArray* output_size = nullptr;
80   if (data->requires_broadcast) {
81     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
82                                    context, input1, input2, &output_size));
83   } else {
84     output_size = TfLiteIntArrayCopy(input1->dims);
85   }
86 
87   return context->ResizeTensor(context, output, output_size);
88 }
89 
90 template <typename T>
PowImpl(const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output,bool requires_broadcast)91 void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
92              TfLiteTensor* output, bool requires_broadcast) {
93   if (requires_broadcast) {
94     optimized_ops::BroadcastPow4D(
95         GetTensorShape(input1), GetTensorData<T>(input1),
96         GetTensorShape(input2), GetTensorData<T>(input2),
97         GetTensorShape(output), GetTensorData<T>(output));
98   } else {
99     reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
100                        GetTensorShape(input2), GetTensorData<T>(input2),
101                        GetTensorShape(output), GetTensorData<T>(output));
102   }
103 }
104 
CheckValue(TfLiteContext * context,const TfLiteTensor * input)105 TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
106   const int64_t num_elements = NumElements(input);
107   const int32_t* data = GetTensorData<int32_t>(input);
108   for (int i = 0; i < num_elements; ++i) {
109     if (data[i] < 0) {
110       context->ReportError(context,
111                            "POW does not support negative value for int32.");
112       return kTfLiteError;
113     }
114   }
115   return kTfLiteOk;
116 }
117 
Eval(TfLiteContext * context,TfLiteNode * node)118 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
119   OpData* data = reinterpret_cast<OpData*>(node->user_data);
120 
121   const TfLiteTensor* input1;
122   TF_LITE_ENSURE_OK(context,
123                     GetInputSafe(context, node, kInputTensor1, &input1));
124   const TfLiteTensor* input2;
125   TF_LITE_ENSURE_OK(context,
126                     GetInputSafe(context, node, kInputTensor2, &input2));
127   TfLiteTensor* output;
128   TF_LITE_ENSURE_OK(context,
129                     GetOutputSafe(context, node, kOutputTensor, &output));
130 
131   switch (output->type) {
132     case kTfLiteInt32: {
133       // TensorFlow does not support negative for int32.
134       TF_LITE_ENSURE_OK(context, CheckValue(context, input2));
135       PowImpl<int32_t>(input1, input2, output, data->requires_broadcast);
136       break;
137     }
138     case kTfLiteFloat32: {
139       PowImpl<float>(input1, input2, output, data->requires_broadcast);
140       break;
141     }
142     default: {
143       context->ReportError(context, "Unsupported data type: %d", output->type);
144       return kTfLiteError;
145     }
146   }
147   return kTfLiteOk;
148 }
149 
150 }  // namespace
151 }  // namespace pow
152 
Register_POW()153 TfLiteRegistration* Register_POW() {
154   static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval};
155   return &r;
156 }
157 
158 }  // namespace builtin
159 }  // namespace ops
160 }  // namespace tflite
161