1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <math.h>
17 #include <stdint.h>
18 #include <stdlib.h>
19
20 #include <functional>
21 #include <type_traits>
22
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace range {
33 namespace {
34
35 constexpr int kStartTensor = 0;
36 constexpr int kLimitTensor = 1;
37 constexpr int kDeltaTensor = 2;
38 constexpr int kOutputTensor = 0;
39
40 template <typename T>
GetSize(TfLiteContext * context,T start,T limit,T delta,int * size)41 TfLiteStatus GetSize(TfLiteContext* context, T start, T limit, T delta,
42 int* size) {
43 TF_LITE_ENSURE(context, !std::equal_to<T>()(delta, 0));
44 TF_LITE_ENSURE(
45 context, (start >= limit && delta < 0) || (start <= limit && delta > 0));
46 *size =
47 (std::is_integral<T>::value
48 ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
49 : std::ceil(std::abs((limit - start) / delta)));
50 return kTfLiteOk;
51 }
52
ResizeOutput(TfLiteContext * context,const TfLiteTensor * start,const TfLiteTensor * limit,const TfLiteTensor * delta,TfLiteTensor * output)53 TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* start,
54 const TfLiteTensor* limit, const TfLiteTensor* delta,
55 TfLiteTensor* output) {
56 // The output will always be a 1-d array.
57 int size = 0;
58 switch (start->type) {
59 case kTfLiteInt32: {
60 TF_LITE_ENSURE_OK(context,
61 GetSize(context, *GetTensorData<int32_t>(start),
62 *GetTensorData<int32_t>(limit),
63 *GetTensorData<int32_t>(delta), &size));
64 break;
65 }
66 case kTfLiteFloat32: {
67 TF_LITE_ENSURE_OK(context, GetSize(context, *GetTensorData<float>(start),
68 *GetTensorData<float>(limit),
69 *GetTensorData<float>(delta), &size));
70 break;
71 }
72 default: {
73 TF_LITE_KERNEL_LOG(context, "Unknown data type: %d", start->type);
74 return kTfLiteError;
75 }
76 }
77 TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(1);
78 output_shape_array->data[0] = size;
79 return context->ResizeTensor(context, output, output_shape_array);
80 }
81
Prepare(TfLiteContext * context,TfLiteNode * node)82 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
83 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
84 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
85
86 const TfLiteTensor* start;
87 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
88 const TfLiteTensor* limit;
89 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
90 const TfLiteTensor* delta;
91 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
92 // Make sure all the inputs are scalars.
93 TF_LITE_ENSURE_EQ(context, NumDimensions(start), 0);
94 TF_LITE_ENSURE_EQ(context, NumDimensions(limit), 0);
95 TF_LITE_ENSURE_EQ(context, NumDimensions(delta), 0);
96
97 // Currently only supports int32 and float.
98 // TODO(b/117912892): Support quantization as well.
99 const auto dtype = start->type;
100 if (dtype != kTfLiteFloat32 && dtype != kTfLiteInt32) {
101 TF_LITE_KERNEL_LOG(context, "Unknown index output data type: %s",
102 TfLiteTypeGetName(dtype));
103 return kTfLiteError;
104 }
105
106 TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype);
107 TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype);
108
109 TfLiteTensor* output;
110 TF_LITE_ENSURE_OK(context,
111 GetOutputSafe(context, node, kOutputTensor, &output));
112 output->type = dtype;
113
114 if (IsConstantTensor(start) && IsConstantTensor(limit) &&
115 IsConstantTensor(delta)) {
116 return ResizeOutput(context, start, limit, delta, output);
117 }
118
119 SetTensorToDynamic(output);
120 return kTfLiteOk;
121 }
122
123 template <typename T>
EvalImpl(const TfLiteTensor * start,const TfLiteTensor * delta,TfLiteTensor * output)124 void EvalImpl(const TfLiteTensor* start, const TfLiteTensor* delta,
125 TfLiteTensor* output) {
126 const T start_value = *GetTensorData<T>(start);
127 const T delta_value = *GetTensorData<T>(delta);
128 T* output_data = GetTensorData<T>(output);
129 const int num_elements = NumElements(output);
130 T value = start_value;
131 for (int i = 0; i < num_elements; ++i) {
132 output_data[i] = value;
133 value += delta_value;
134 }
135 }
136
Eval(TfLiteContext * context,TfLiteNode * node)137 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
138 const TfLiteTensor* start;
139 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
140 const TfLiteTensor* limit;
141 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
142 const TfLiteTensor* delta;
143 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
144
145 TfLiteTensor* output;
146 TF_LITE_ENSURE_OK(context,
147 GetOutputSafe(context, node, kOutputTensor, &output));
148
149 if (IsDynamicTensor(output)) {
150 TF_LITE_ENSURE_OK(context,
151 ResizeOutput(context, start, limit, delta, output));
152 }
153
154 switch (output->type) {
155 case kTfLiteInt32: {
156 EvalImpl<int32_t>(start, delta, output);
157 break;
158 }
159 case kTfLiteFloat32: {
160 EvalImpl<float>(start, delta, output);
161 break;
162 }
163 default: {
164 TF_LITE_KERNEL_LOG(context, "Unsupported data type: %d", output->type);
165 return kTfLiteError;
166 }
167 }
168 return kTfLiteOk;
169 }
170
171 } // namespace
172 } // namespace range
173
Register_RANGE()174 TfLiteRegistration* Register_RANGE() {
175 static TfLiteRegistration r = {nullptr, nullptr, range::Prepare, range::Eval};
176 return &r;
177 }
178
179 } // namespace builtin
180 } // namespace ops
181 } // namespace tflite
182