1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <functional>
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/context.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/internal/types.h"
23 #include "tensorflow/lite/kernels/kernel_util.h"
24 #include "tensorflow/lite/model.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace custom {
29 namespace ragged {
30 namespace ragged_range {
31 namespace {
32 constexpr int kInputStarts = 0;
33 constexpr int kInputLimits = 1;
34 constexpr int kInputDeltas = 2;
35
36 constexpr int kOutputNestedSplits = 0;
37 constexpr int kOutputDenseValues = 1;
38
IntArrayFromInt(int x)39 TfLiteIntArray* IntArrayFromInt(int x) {
40 TfLiteIntArray* result = TfLiteIntArrayCreate(1);
41 result->data[0] = x;
42 return result;
43 }
44
45 // Returns the number of elements in the specified range.
46 template <typename T, typename SPLITS_TYPE>
RangeSize(T start,T limit,T delta)47 SPLITS_TYPE RangeSize(T start, T limit, T delta) {
48 if (((delta > 0) && (limit < start)) || ((delta < 0) && (limit > start))) {
49 return 0;
50 }
51 // The following is copied from tensorflow::RangeOp::Compute().
52 return (
53 std::is_integral<T>::value
54 ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
55 : std::ceil(std::abs((limit - start) / delta)));
56 }
57
58 template <typename T, typename SPLITS_TYPE>
EvalT(TfLiteContext * context,TfLiteNode * node)59 TfLiteStatus EvalT(TfLiteContext* context, TfLiteNode* node) {
60 TfLiteTensor& input_starts =
61 context->tensors[node->inputs->data[kInputStarts]];
62 TfLiteTensor& input_limits =
63 context->tensors[node->inputs->data[kInputLimits]];
64 TfLiteTensor& input_deltas =
65 context->tensors[node->inputs->data[kInputDeltas]];
66 // Determine which tensors we need to broadcast.
67 const bool broadcast_starts = NumElements(&input_starts) == 1;
68 const bool broadcast_limits = NumElements(&input_limits) == 1;
69 const bool broadcast_deltas = NumElements(&input_deltas) == 1;
70
71 // nrows (number of output rows) is the size of the non-broadcast inputs,
72 // or 1 if all inputs are scalars.
73 std::vector<int> in_sizes;
74 if (!broadcast_starts) in_sizes.push_back(input_starts.dims->data[0]);
75 if (!broadcast_limits) in_sizes.push_back(input_limits.dims->data[0]);
76 if (!broadcast_deltas) in_sizes.push_back(input_deltas.dims->data[0]);
77 if (std::adjacent_find(std::begin(in_sizes), std::end(in_sizes),
78 std::not_equal_to<>()) != std::end(in_sizes)) {
79 context->ReportError(
80 context,
81 "Invalid argument: starts, limits, and deltas must have the "
82 "same shape");
83 return kTfLiteError;
84 }
85
86 const SPLITS_TYPE nrows = in_sizes.empty() ? 1 : in_sizes.front();
87
88 const T* starts = GetTensorData<T>(&input_starts);
89 const T* limits = GetTensorData<T>(&input_limits);
90 const T* deltas = GetTensorData<T>(&input_deltas);
91
92 TfLiteTensor& rt_nested_splits_out =
93 context->tensors[node->outputs->data[kOutputNestedSplits]];
94 TF_LITE_ENSURE_OK(context,
95 context->ResizeTensor(context, &rt_nested_splits_out,
96 IntArrayFromInt(nrows + 1)));
97 SPLITS_TYPE* rt_nested_splits =
98 GetTensorData<SPLITS_TYPE>(&rt_nested_splits_out);
99 rt_nested_splits[0] = 0;
100
101 for (int row = 0; row < nrows; ++row) {
102 const T start = broadcast_starts ? starts[0] : starts[row];
103 const T limit = broadcast_limits ? limits[0] : limits[row];
104 const T delta = broadcast_deltas ? deltas[0] : deltas[row];
105 if (delta == 0) {
106 context->ReportError(context, "Invalid argument: Requires delta != 0");
107 return kTfLiteError;
108 }
109 rt_nested_splits[row + 1] =
110 rt_nested_splits[row] + RangeSize<T, SPLITS_TYPE>(start, limit, delta);
111 }
112 const SPLITS_TYPE nvals = rt_nested_splits[nrows];
113
114 TfLiteTensor& rt_dense_values_out =
115 context->tensors[node->outputs->data[kOutputDenseValues]];
116 TF_LITE_ENSURE_OK(context,
117 context->ResizeTensor(context, &rt_dense_values_out,
118 IntArrayFromInt(nvals)));
119 T* rt_dense_values = GetTensorData<T>(&rt_dense_values_out);
120 int value_index = 0;
121 for (int row = 0; row < nrows; ++row) {
122 const SPLITS_TYPE row_size =
123 rt_nested_splits[row + 1] - rt_nested_splits[row];
124 T value = broadcast_starts ? starts[0] : starts[row];
125 const T delta = broadcast_deltas ? deltas[0] : deltas[row];
126 for (SPLITS_TYPE i = 0; i < row_size; ++i) {
127 rt_dense_values[value_index++] = value;
128 value += delta;
129 }
130 }
131 return kTfLiteOk;
132 }
133
134 template <typename SPLITS_TYPE>
EvalSplitsT(TfLiteContext * context,TfLiteNode * node)135 TfLiteStatus EvalSplitsT(TfLiteContext* context, TfLiteNode* node) {
136 TfLiteTensor& rt_dense_values_out =
137 context->tensors[node->outputs->data[kOutputDenseValues]];
138 switch (rt_dense_values_out.type) {
139 case kTfLiteInt32:
140 return EvalT<int32_t, SPLITS_TYPE>(context, node);
141 case kTfLiteInt64:
142 return EvalT<int64_t, SPLITS_TYPE>(context, node);
143 case kTfLiteFloat32:
144 return EvalT<float, SPLITS_TYPE>(context, node);
145 case kTfLiteFloat64:
146 return EvalT<double, SPLITS_TYPE>(context, node);
147 default:
148 context->ReportError(context,
149 "Invalid argument: Not supported VALUES type");
150 return kTfLiteError;
151 }
152 }
153 } // namespace
154
Prepare(TfLiteContext * context,TfLiteNode * node)155 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
156 // Set outputs dynamic.
157 TfLiteTensor& nested_splits =
158 context->tensors[node->outputs->data[kOutputNestedSplits]];
159 SetTensorToDynamic(&nested_splits);
160 TfLiteTensor& dense_values =
161 context->tensors[node->outputs->data[kOutputDenseValues]];
162 SetTensorToDynamic(&dense_values);
163 return kTfLiteOk;
164 }
165
Eval(TfLiteContext * context,TfLiteNode * node)166 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
167 TfLiteTensor& rt_nested_splits_out =
168 context->tensors[node->outputs->data[kOutputNestedSplits]];
169 switch (rt_nested_splits_out.type) {
170 case kTfLiteInt32:
171 return EvalSplitsT<int32_t>(context, node);
172 case kTfLiteInt64:
173 return EvalSplitsT<int64_t>(context, node);
174 default:
175 context->ReportError(context,
176 "Invalid argument: Not supported ROW_SPLITS type");
177 return kTfLiteError;
178 }
179 }
180
181 } // namespace ragged_range
182 } // namespace ragged
Register_RAGGED_RANGE()183 TfLiteRegistration* Register_RAGGED_RANGE() {
184 static TfLiteRegistration r = {nullptr /*Initialize*/, nullptr /*Free*/,
185 ragged::ragged_range::Prepare,
186 ragged::ragged_range::Eval};
187 return &r;
188 }
189
190 } // namespace custom
191 } // namespace ops
192 } // namespace tflite
193