• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <functional>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/context.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/model.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace custom {
29 namespace ragged {
30 namespace ragged_range {
31 namespace {
32 constexpr int kInputStarts = 0;
33 constexpr int kInputLimits = 1;
34 constexpr int kInputDeltas = 2;
35 
36 constexpr int kOutputNestedSplits = 0;
37 constexpr int kOutputDenseValues = 1;
38 
IntArrayFromInt(int x)39 TfLiteIntArray* IntArrayFromInt(int x) {
40   TfLiteIntArray* result = TfLiteIntArrayCreate(1);
41   result->data[0] = x;
42   return result;
43 }
44 
45 // Returns the number of elements in the specified range.
46 template <typename T, typename SPLITS_TYPE>
RangeSize(T start,T limit,T delta)47 SPLITS_TYPE RangeSize(T start, T limit, T delta) {
48   if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
49     return 0;
50   }
51   // The following is copied from tensorflow::RangeOp::Compute().
52   return (
53       std::is_integral<T>::value
54           ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
55           : std::ceil(std::abs((limit - start) / delta)));
56 }
57 
58 template <typename T, typename SPLITS_TYPE>
EvalT(TfLiteContext * context,TfLiteNode * node)59 TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) {
60   TfLiteTensor& input_starts =
61       context->tensors[node->inputs->data[kInputStarts]];
62   TfLiteTensor& input_limits =
63       context->tensors[node->inputs->data[kInputLimits]];
64   TfLiteTensor& input_deltas =
65       context->tensors[node->inputs->data[kInputDeltas]];
66   // Determine which tensors we need to broadcast.
67   const bool broadcast_starts = NumElements(&input_starts) == 1;
68   const bool broadcast_limits = NumElements(&input_limits) == 1;
69   const bool broadcast_deltas = NumElements(&input_deltas) == 1;
70 
71   // nrows (number of output rows) is the size of the non-broadcast inputs,
72   // or 1 if all inputs are scalars.
73   std::vector<int> in_sizes;
74   if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]);
75   if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]);
76   if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]);
77   if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes),
78                          std::not_equal_to<>()) != std::end(in_sizes)) {
79     context->ReportError(
80         context,
81         "Invalid argument: starts, limits, and deltas must have the "
82         "same shape");
83     return kTfLiteError;
84   }
85 
86   const SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes.front();
87 
88   const T* starts = GetTensorData<T>(&input_starts);
89   const T* limits = GetTensorData<T>(&input_limits);
90   const T* deltas = GetTensorData<T>(&input_deltas);
91 
92   TfLiteTensor& rt_nested_splits_out =
93       context->tensors[node->outputs->data[kOutputNestedSplits]];
94   TF_LITE_ENSURE_OK(context,
95                     context->ResizeTensor(context, &rt_nested_splits_out,
96                                           IntArrayFromInt(nrows + 1)));
97   SPLITS_TYPE* rt_nested_splits =
98       GetTensorData<SPLITS_TYPE>(&rt_nested_splits_out);
99   rt_nested_splits[0] = 0;
100 
101   for (int row = 0; row < nrows; ++row) {
102     const T start = broadcast_starts ? starts[0] : starts[row];
103     const T limit = broadcast_limits ? limits[0] : limits[row];
104     const T delta = broadcast_deltas ? deltas[0] : deltas[row];
105     if (delta == 0) {
106       context->ReportError(context, "Invalid argument: Requires delta != 0");
107       return kTfLiteError;
108     }
109     rt_nested_splits[row + 1] =
110         rt_nested_splits[row] + RangeSize<T, SPLITS_TYPE>(start, limit, delta);
111   }
112   const SPLITS_TYPE nvals = rt_nested_splits[nrows];
113 
114   TfLiteTensor& rt_dense_values_out =
115       context->tensors[node->outputs->data[kOutputDenseValues]];
116   TF_LITE_ENSURE_OK(context,
117                     context->ResizeTensor(context, &rt_dense_values_out,
118                                           IntArrayFromInt(nvals)));
119   T* rt_dense_values = GetTensorData<T>(&rt_dense_values_out);
120   int value_index = 0;
121   for (int row = 0; row < nrows; ++row) {
122     const SPLITS_TYPE row_size =
123         rt_nested_splits[row + 1] - rt_nested_splits[row];
124     T value = broadcast_starts ? starts[0] : starts[row];
125     const T delta = broadcast_deltas ? deltas[0] : deltas[row];
126     for (SPLITS_TYPE i = 0; i < row_size; ++i) {
127       rt_dense_values[value_index++] = value;
128       value += delta;
129     }
130   }
131   return kTfLiteOk;
132 }
133 
134 template <typename SPLITS_TYPE>
EvalSplitsT(TfLiteContext * context,TfLiteNode * node)135 TfLiteStatus EvalSplitsT(TfLiteContext* context, TfLiteNode* node) {
136   TfLiteTensor& rt_dense_values_out =
137       context->tensors[node->outputs->data[kOutputDenseValues]];
138   switch (rt_dense_values_out.type) {
139     case kTfLiteInt32:
140       return EvalT<int32_t, SPLITS_TYPE>(context, node);
141     case kTfLiteInt64:
142       return EvalT<int64_t, SPLITS_TYPE>(context, node);
143     case kTfLiteFloat32:
144       return EvalT<float, SPLITS_TYPE>(context, node);
145     case kTfLiteFloat64:
146       return EvalT<double, SPLITS_TYPE>(context, node);
147     default:
148       context->ReportError(context,
149                            "Invalid argument: Not supported VALUES type");
150       return kTfLiteError;
151   }
152 }
153 }  // namespace
154 
Prepare(TfLiteContext * context,TfLiteNode * node)155 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
156   // Set outputs dynamic.
157   TfLiteTensor& nested_splits =
158       context->tensors[node->outputs->data[kOutputNestedSplits]];
159   SetTensorToDynamic(&nested_splits);
160   TfLiteTensor& dense_values =
161       context->tensors[node->outputs->data[kOutputDenseValues]];
162   SetTensorToDynamic(&dense_values);
163   return kTfLiteOk;
164 }
165 
Eval(TfLiteContext * context,TfLiteNode * node)166 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
167   TfLiteTensor& rt_nested_splits_out =
168       context->tensors[node->outputs->data[kOutputNestedSplits]];
169   switch (rt_nested_splits_out.type) {
170     case kTfLiteInt32:
171       return EvalSplitsT<int32_t>(context, node);
172     case kTfLiteInt64:
173       return EvalSplitsT<int64_t>(context, node);
174     default:
175       context->ReportError(context,
176                            "Invalid argument: Not supported ROW_SPLITS type");
177       return kTfLiteError;
178   }
179 }
180 
181 }  // namespace ragged_range
182 }  // namespace ragged
Register_RAGGED_RANGE()183 TfLiteRegistration* Register_RAGGED_RANGE() {
184   static TfLiteRegistration r = {nullptr /*Initialize*/, nullptr /*Free*/,
185                                  ragged::ragged_range::Prepare,
186                                  ragged::ragged_range::Eval};
187   return &r;
188 }
189 
190 }  // namespace custom
191 }  // namespace ops
192 }  // namespace tflite
193