1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/string_util.h"
26
27 namespace tflite {
28 namespace ops {
29 namespace builtin {
30 namespace gather {
31 constexpr int kInputTensor = 0;
32 constexpr int kInputPositions = 1;
33 constexpr int kOutputTensor = 0;
34
Prepare(TfLiteContext * context,TfLiteNode * node)35 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
36 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
37 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
38
39 const auto* params =
40 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
41 const TfLiteTensor* input;
42 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
43 const TfLiteTensor* positions;
44 TF_LITE_ENSURE_OK(context,
45 GetInputSafe(context, node, kInputPositions, &positions));
46 TfLiteTensor* output;
47 TF_LITE_ENSURE_OK(context,
48 GetOutputSafe(context, node, kOutputTensor, &output));
49
50 switch (positions->type) {
51 case kTfLiteInt64:
52 case kTfLiteInt32:
53 break;
54 default:
55 TF_LITE_KERNEL_LOG(context,
56 "Positions of type '%s' are not supported by gather.",
57 TfLiteTypeGetName(positions->type));
58 return kTfLiteError;
59 }
60
61 // Assign to output the input type.
62 output->type = input->type;
63
64 // Check conditions for different types.
65 switch (input->type) {
66 case kTfLiteFloat32:
67 case kTfLiteUInt8:
68 case kTfLiteInt8:
69 case kTfLiteInt16:
70 case kTfLiteInt64:
71 case kTfLiteInt32:
72 case kTfLiteBool:
73 break;
74 case kTfLiteString: {
75 // Only 1D input is supported.
76 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
77 } break;
78 default:
79 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
80 TfLiteTypeGetName(input->type));
81 return kTfLiteError;
82 }
83
84 int axis = params->axis;
85 if (axis < 0) {
86 axis += NumDimensions(input);
87 }
88 TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
89
90 int batch_dims = params->batch_dims;
91 // batch_dims should be in range: [-rank(positions), rank(positions)].
92 // Negative batch_dims is added with rank of positions.
93 if (batch_dims < 0) {
94 batch_dims += NumDimensions(positions);
95 }
96 TF_LITE_ENSURE(context, batch_dims <= axis);
97 TF_LITE_ENSURE(context, 0 <= batch_dims && batch_dims < NumDimensions(input));
98 TF_LITE_ENSURE(context, batch_dims <= NumDimensions(positions));
99 for (int i = 0; i < batch_dims; ++i) {
100 TF_LITE_ENSURE_EQ(context, input->dims->data[i], positions->dims->data[i]);
101 }
102
103 const int num_dimensions =
104 NumDimensions(input) + NumDimensions(positions) - 1 - batch_dims;
105 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
106 int output_index = 0;
107 for (int i = 0; i < axis; ++i) {
108 output_shape->data[output_index++] = input->dims->data[i];
109 }
110 for (int i = batch_dims; i < positions->dims->size; ++i) {
111 output_shape->data[output_index++] = positions->dims->data[i];
112 }
113 for (int i = axis + 1; i < input->dims->size; ++i) {
114 output_shape->data[output_index++] = input->dims->data[i];
115 }
116 return context->ResizeTensor(context, output, output_shape);
117 }
118
119 template <typename InputT, typename PositionsT>
Gather(TfLiteContext * context,const TfLiteGatherParams & params,const TfLiteTensor * input,const TfLiteTensor * positions,TfLiteTensor * output)120 TfLiteStatus Gather(TfLiteContext* context, const TfLiteGatherParams& params,
121 const TfLiteTensor* input, const TfLiteTensor* positions,
122 TfLiteTensor* output) {
123 const PositionsT* indexes = GetTensorData<PositionsT>(positions);
124 bool indices_has_only_positive_elements = true;
125 const size_t num_indices = positions->bytes / sizeof(PositionsT);
126 for (size_t i = 0; i < num_indices; i++) {
127 if (indexes[i] < 0) {
128 indices_has_only_positive_elements = false;
129 break;
130 }
131 }
132 TF_LITE_ENSURE(context, indices_has_only_positive_elements);
133
134 tflite::GatherParams op_params;
135 op_params.axis = params.axis;
136 op_params.batch_dims = params.batch_dims;
137 optimized_ops::Gather(op_params, GetTensorShape(input),
138 GetTensorData<InputT>(input), GetTensorShape(positions),
139 GetTensorData<PositionsT>(positions),
140 GetTensorShape(output), GetTensorData<InputT>(output));
141 return kTfLiteOk;
142 }
143
144 template <typename PositionT>
GatherStrings(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * positions,TfLiteTensor * output)145 TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
146 const TfLiteTensor* positions,
147 TfLiteTensor* output) {
148 DynamicBuffer buffer;
149
150 const PositionT* indexes = GetTensorData<PositionT>(positions);
151 bool indices_has_only_positive_elements = true;
152 const size_t num_indices = positions->bytes / sizeof(PositionT);
153 for (size_t i = 0; i < num_indices; i++) {
154 if (indexes[i] < 0) {
155 indices_has_only_positive_elements = false;
156 break;
157 }
158 }
159 TF_LITE_ENSURE(context, indices_has_only_positive_elements);
160
161 const PositionT num_strings = GetStringCount(input);
162 const int num_indexes = NumElements(positions);
163
164 for (int i = 0; i < num_indexes; ++i) {
165 const PositionT pos = indexes[i];
166 TF_LITE_ENSURE(context, pos < num_strings);
167 const auto string_ref = GetString(input, pos);
168 buffer.AddString(string_ref.str, string_ref.len);
169 }
170 buffer.WriteToTensor(output, /*new_shape=*/nullptr);
171 return kTfLiteOk;
172 }
173
Eval(TfLiteContext * context,TfLiteNode * node)174 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
175 const auto* params =
176 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
177 const TfLiteTensor* input;
178 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
179 const TfLiteTensor* positions;
180 TF_LITE_ENSURE_OK(context,
181 GetInputSafe(context, node, kInputPositions, &positions));
182 TfLiteTensor* output;
183 TF_LITE_ENSURE_OK(context,
184 GetOutputSafe(context, node, kOutputTensor, &output));
185
186 if (positions->type == kTfLiteInt32) {
187 switch (input->type) {
188 case kTfLiteFloat32:
189 return Gather<float, int32_t>(context, *params, input, positions,
190 output);
191 case kTfLiteUInt8:
192 return Gather<uint8_t, int32_t>(context, *params, input, positions,
193 output);
194 case kTfLiteInt8:
195 return Gather<int8_t, int32_t>(context, *params, input, positions,
196 output);
197 case kTfLiteInt16:
198 return Gather<int16_t, int32_t>(context, *params, input, positions,
199 output);
200 case kTfLiteInt32:
201 return Gather<int32_t, int32_t>(context, *params, input, positions,
202 output);
203 case kTfLiteInt64:
204 return Gather<int64_t, int32_t>(context, *params, input, positions,
205 output);
206 case kTfLiteBool:
207 return Gather<bool, int32_t>(context, *params, input, positions,
208 output);
209 case kTfLiteString:
210 return GatherStrings<int32_t>(context, input, positions, output);
211 default:
212 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
213 TfLiteTypeGetName(input->type));
214 return kTfLiteError;
215 }
216 }
217 if (positions->type == kTfLiteInt64) {
218 switch (input->type) {
219 case kTfLiteFloat32:
220 return Gather<float, int64_t>(context, *params, input, positions,
221 output);
222 case kTfLiteUInt8:
223 return Gather<uint8_t, int64_t>(context, *params, input, positions,
224 output);
225 case kTfLiteInt8:
226 return Gather<int8_t, int64_t>(context, *params, input, positions,
227 output);
228 case kTfLiteInt16:
229 return Gather<int16_t, int64_t>(context, *params, input, positions,
230 output);
231 case kTfLiteInt32:
232 return Gather<int32_t, int64_t>(context, *params, input, positions,
233 output);
234 case kTfLiteInt64:
235 return Gather<int64_t, int64_t>(context, *params, input, positions,
236 output);
237 case kTfLiteBool:
238 return Gather<bool, int64_t>(context, *params, input, positions,
239 output);
240 case kTfLiteString:
241 return GatherStrings<int64_t>(context, input, positions, output);
242 default:
243 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
244 TfLiteTypeGetName(input->type));
245 return kTfLiteError;
246 }
247 }
248 TF_LITE_KERNEL_LOG(context,
249 "Positions of type '%s' are not supported by gather.",
250 TfLiteTypeGetName(positions->type));
251 return kTfLiteError;
252 }
253 } // namespace gather
254
Register_GATHER()255 TfLiteRegistration* Register_GATHER() {
256 static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare,
257 gather::Eval};
258 return &r;
259 }
260
261 } // namespace builtin
262 } // namespace ops
263 } // namespace tflite
264