• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <math.h>
17 #include <stdint.h>
18 #include <stdlib.h>
19 
20 #include <functional>
21 #include <type_traits>
22 
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25 #include "tensorflow/lite/kernels/internal/tensor.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28 
29 namespace tflite {
30 namespace ops {
31 namespace builtin {
32 namespace range {
33 namespace {
34 
35 constexpr int kStartTensor = 0;
36 constexpr int kLimitTensor = 1;
37 constexpr int kDeltaTensor = 2;
38 constexpr int kOutputTensor = 0;
39 
40 template <typename T>
GetSize(TfLiteContext * context,T start,T limit,T delta,int * size)41 TfLiteStatus GetSize(TfLiteContext* context, T start, T limit, T delta,
42                      int* size) {
43   TF_LITE_ENSURE(context, !std::equal_to<T>()(delta, 0));
44   TF_LITE_ENSURE(
45       context, (start >= limit && delta < 0) || (start <= limit && delta > 0));
46   *size =
47       (std::is_integral<T>::value
48            ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
49            : std::ceil(std::abs((limit - start) / delta)));
50   return kTfLiteOk;
51 }
52 
ResizeOutput(TfLiteContext * context,const TfLiteTensor * start,const TfLiteTensor * limit,const TfLiteTensor * delta,TfLiteTensor * output)53 TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* start,
54                           const TfLiteTensor* limit, const TfLiteTensor* delta,
55                           TfLiteTensor* output) {
56   // The output will always be a 1-d array.
57   int size = 0;
58   switch (start->type) {
59     case kTfLiteInt32: {
60       TF_LITE_ENSURE_OK(context,
61                         GetSize(context, *GetTensorData<int32_t>(start),
62                                 *GetTensorData<int32_t>(limit),
63                                 *GetTensorData<int32_t>(delta), &size));
64       break;
65     }
66     case kTfLiteFloat32: {
67       TF_LITE_ENSURE_OK(context, GetSize(context, *GetTensorData<float>(start),
68                                          *GetTensorData<float>(limit),
69                                          *GetTensorData<float>(delta), &size));
70       break;
71     }
72     default: {
73       TF_LITE_KERNEL_LOG(context, "Unknown data type: %d", start->type);
74       return kTfLiteError;
75     }
76   }
77   TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(1);
78   output_shape_array->data[0] = size;
79   return context->ResizeTensor(context, output, output_shape_array);
80 }
81 
Prepare(TfLiteContext * context,TfLiteNode * node)82 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
83   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
84   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
85 
86   const TfLiteTensor* start;
87   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
88   const TfLiteTensor* limit;
89   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
90   const TfLiteTensor* delta;
91   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
92   // Make sure all the inputs are scalars.
93   TF_LITE_ENSURE_EQ(context, NumDimensions(start), 0);
94   TF_LITE_ENSURE_EQ(context, NumDimensions(limit), 0);
95   TF_LITE_ENSURE_EQ(context, NumDimensions(delta), 0);
96 
97   // Currently only supports int32 and float.
98   // TODO(b/117912892): Support quantization as well.
99   const auto dtype = start->type;
100   if (dtype != kTfLiteFloat32 && dtype != kTfLiteInt32) {
101     TF_LITE_KERNEL_LOG(context, "Unknown index output data type: %s",
102                        TfLiteTypeGetName(dtype));
103     return kTfLiteError;
104   }
105 
106   TF_LITE_ENSURE_TYPES_EQ(context, limit->type, dtype);
107   TF_LITE_ENSURE_TYPES_EQ(context, delta->type, dtype);
108 
109   TfLiteTensor* output;
110   TF_LITE_ENSURE_OK(context,
111                     GetOutputSafe(context, node, kOutputTensor, &output));
112   output->type = dtype;
113 
114   if (IsConstantTensor(start) && IsConstantTensor(limit) &&
115       IsConstantTensor(delta)) {
116     return ResizeOutput(context, start, limit, delta, output);
117   }
118 
119   SetTensorToDynamic(output);
120   return kTfLiteOk;
121 }
122 
123 template <typename T>
EvalImpl(const TfLiteTensor * start,const TfLiteTensor * delta,TfLiteTensor * output)124 void EvalImpl(const TfLiteTensor* start, const TfLiteTensor* delta,
125               TfLiteTensor* output) {
126   const T start_value = *GetTensorData<T>(start);
127   const T delta_value = *GetTensorData<T>(delta);
128   T* output_data = GetTensorData<T>(output);
129   const int num_elements = NumElements(output);
130   T value = start_value;
131   for (int i = 0; i < num_elements; ++i) {
132     output_data[i] = value;
133     value += delta_value;
134   }
135 }
136 
Eval(TfLiteContext * context,TfLiteNode * node)137 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
138   const TfLiteTensor* start;
139   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStartTensor, &start));
140   const TfLiteTensor* limit;
141   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kLimitTensor, &limit));
142   const TfLiteTensor* delta;
143   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kDeltaTensor, &delta));
144 
145   TfLiteTensor* output;
146   TF_LITE_ENSURE_OK(context,
147                     GetOutputSafe(context, node, kOutputTensor, &output));
148 
149   if (IsDynamicTensor(output)) {
150     TF_LITE_ENSURE_OK(context,
151                       ResizeOutput(context, start, limit, delta, output));
152   }
153 
154   switch (output->type) {
155     case kTfLiteInt32: {
156       EvalImpl<int32_t>(start, delta, output);
157       break;
158     }
159     case kTfLiteFloat32: {
160       EvalImpl<float>(start, delta, output);
161       break;
162     }
163     default: {
164       TF_LITE_KERNEL_LOG(context, "Unsupported data type: %d", output->type);
165       return kTfLiteError;
166     }
167   }
168   return kTfLiteOk;
169 }
170 
171 }  // namespace
172 }  // namespace range
173 
Register_RANGE()174 TfLiteRegistration* Register_RANGE() {
175   static TfLiteRegistration r = {nullptr, nullptr, range::Prepare, range::Eval};
176   return &r;
177 }
178 
179 }  // namespace builtin
180 }  // namespace ops
181 }  // namespace tflite
182