• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/builtin_op_data.h"
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 #include "tensorflow/lite/string_util.h"
26 
27 namespace tflite {
28 namespace ops {
29 namespace builtin {
30 namespace gather {
31 constexpr int kInputTensor = 0;
32 constexpr int kInputPositions = 1;
33 constexpr int kOutputTensor = 0;
34 
Prepare(TfLiteContext * context,TfLiteNode * node)35 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
36   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
37   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
38 
39   const auto* params =
40       reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
41   const TfLiteTensor* input;
42   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
43   const TfLiteTensor* positions;
44   TF_LITE_ENSURE_OK(context,
45                     GetInputSafe(context, node, kInputPositions, &positions));
46   TfLiteTensor* output;
47   TF_LITE_ENSURE_OK(context,
48                     GetOutputSafe(context, node, kOutputTensor, &output));
49 
50   switch (positions->type) {
51     case kTfLiteInt64:
52     case kTfLiteInt32:
53       break;
54     default:
55       TF_LITE_KERNEL_LOG(context,
56                          "Positions of type '%s' are not supported by gather.",
57                          TfLiteTypeGetName(positions->type));
58       return kTfLiteError;
59   }
60 
61   // Assign to output the input type.
62   output->type = input->type;
63 
64   // Check conditions for different types.
65   switch (input->type) {
66     case kTfLiteFloat32:
67     case kTfLiteUInt8:
68     case kTfLiteInt8:
69     case kTfLiteInt16:
70     case kTfLiteInt64:
71     case kTfLiteInt32:
72     case kTfLiteBool:
73       break;
74     case kTfLiteString: {
75       // Only 1D input is supported.
76       TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
77     } break;
78     default:
79       TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
80                          TfLiteTypeGetName(input->type));
81       return kTfLiteError;
82   }
83 
84   int axis = params->axis;
85   if (axis < 0) {
86     axis += NumDimensions(input);
87   }
88   TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
89 
90   int batch_dims = params->batch_dims;
91   // batch_dims should be in range: [-rank(positions), rank(positions)].
92   // Negative batch_dims is added with rank of positions.
93   if (batch_dims < 0) {
94     batch_dims += NumDimensions(positions);
95   }
96   TF_LITE_ENSURE(context, batch_dims <= axis);
97   TF_LITE_ENSURE(context, 0 <= batch_dims && batch_dims < NumDimensions(input));
98   TF_LITE_ENSURE(context, batch_dims <= NumDimensions(positions));
99   for (int i = 0; i < batch_dims; ++i) {
100     TF_LITE_ENSURE_EQ(context, input->dims->data[i], positions->dims->data[i]);
101   }
102 
103   const int num_dimensions =
104       NumDimensions(input) + NumDimensions(positions) - 1 - batch_dims;
105   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
106   int output_index = 0;
107   for (int i = 0; i < axis; ++i) {
108     output_shape->data[output_index++] = input->dims->data[i];
109   }
110   for (int i = batch_dims; i < positions->dims->size; ++i) {
111     output_shape->data[output_index++] = positions->dims->data[i];
112   }
113   for (int i = axis + 1; i < input->dims->size; ++i) {
114     output_shape->data[output_index++] = input->dims->data[i];
115   }
116   return context->ResizeTensor(context, output, output_shape);
117 }
118 
119 template <typename InputT, typename PositionsT>
Gather(TfLiteContext * context,const TfLiteGatherParams & params,const TfLiteTensor * input,const TfLiteTensor * positions,TfLiteTensor * output)120 TfLiteStatus Gather(TfLiteContext* context, const TfLiteGatherParams& params,
121                     const TfLiteTensor* input, const TfLiteTensor* positions,
122                     TfLiteTensor* output) {
123   const PositionsT* indexes = GetTensorData<PositionsT>(positions);
124   bool indices_has_only_positive_elements = true;
125   const size_t num_indices = positions->bytes / sizeof(PositionsT);
126   for (size_t i = 0; i < num_indices; i++) {
127     if (indexes[i] < 0) {
128       indices_has_only_positive_elements = false;
129       break;
130     }
131   }
132   TF_LITE_ENSURE(context, indices_has_only_positive_elements);
133 
134   tflite::GatherParams op_params;
135   op_params.axis = params.axis;
136   op_params.batch_dims = params.batch_dims;
137   optimized_ops::Gather(op_params, GetTensorShape(input),
138                         GetTensorData<InputT>(input), GetTensorShape(positions),
139                         GetTensorData<PositionsT>(positions),
140                         GetTensorShape(output), GetTensorData<InputT>(output));
141   return kTfLiteOk;
142 }
143 
144 template <typename PositionT>
GatherStrings(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * positions,TfLiteTensor * output)145 TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
146                            const TfLiteTensor* positions,
147                            TfLiteTensor* output) {
148   DynamicBuffer buffer;
149 
150   const PositionT* indexes = GetTensorData<PositionT>(positions);
151   bool indices_has_only_positive_elements = true;
152   const size_t num_indices = positions->bytes / sizeof(PositionT);
153   for (size_t i = 0; i < num_indices; i++) {
154     if (indexes[i] < 0) {
155       indices_has_only_positive_elements = false;
156       break;
157     }
158   }
159   TF_LITE_ENSURE(context, indices_has_only_positive_elements);
160 
161   const PositionT num_strings = GetStringCount(input);
162   const int num_indexes = NumElements(positions);
163 
164   for (int i = 0; i < num_indexes; ++i) {
165     const PositionT pos = indexes[i];
166     TF_LITE_ENSURE(context, pos < num_strings);
167     const auto string_ref = GetString(input, pos);
168     buffer.AddString(string_ref.str, string_ref.len);
169   }
170   buffer.WriteToTensor(output, /*new_shape=*/nullptr);
171   return kTfLiteOk;
172 }
173 
Eval(TfLiteContext * context,TfLiteNode * node)174 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
175   const auto* params =
176       reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
177   const TfLiteTensor* input;
178   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
179   const TfLiteTensor* positions;
180   TF_LITE_ENSURE_OK(context,
181                     GetInputSafe(context, node, kInputPositions, &positions));
182   TfLiteTensor* output;
183   TF_LITE_ENSURE_OK(context,
184                     GetOutputSafe(context, node, kOutputTensor, &output));
185 
186   if (positions->type == kTfLiteInt32) {
187     switch (input->type) {
188       case kTfLiteFloat32:
189         return Gather<float, int32_t>(context, *params, input, positions,
190                                       output);
191       case kTfLiteUInt8:
192         return Gather<uint8_t, int32_t>(context, *params, input, positions,
193                                         output);
194       case kTfLiteInt8:
195         return Gather<int8_t, int32_t>(context, *params, input, positions,
196                                        output);
197       case kTfLiteInt16:
198         return Gather<int16_t, int32_t>(context, *params, input, positions,
199                                         output);
200       case kTfLiteInt32:
201         return Gather<int32_t, int32_t>(context, *params, input, positions,
202                                         output);
203       case kTfLiteInt64:
204         return Gather<int64_t, int32_t>(context, *params, input, positions,
205                                         output);
206       case kTfLiteBool:
207         return Gather<bool, int32_t>(context, *params, input, positions,
208                                      output);
209       case kTfLiteString:
210         return GatherStrings<int32_t>(context, input, positions, output);
211       default:
212         TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
213                            TfLiteTypeGetName(input->type));
214         return kTfLiteError;
215     }
216   }
217   if (positions->type == kTfLiteInt64) {
218     switch (input->type) {
219       case kTfLiteFloat32:
220         return Gather<float, int64_t>(context, *params, input, positions,
221                                       output);
222       case kTfLiteUInt8:
223         return Gather<uint8_t, int64_t>(context, *params, input, positions,
224                                         output);
225       case kTfLiteInt8:
226         return Gather<int8_t, int64_t>(context, *params, input, positions,
227                                        output);
228       case kTfLiteInt16:
229         return Gather<int16_t, int64_t>(context, *params, input, positions,
230                                         output);
231       case kTfLiteInt32:
232         return Gather<int32_t, int64_t>(context, *params, input, positions,
233                                         output);
234       case kTfLiteInt64:
235         return Gather<int64_t, int64_t>(context, *params, input, positions,
236                                         output);
237       case kTfLiteBool:
238         return Gather<bool, int64_t>(context, *params, input, positions,
239                                      output);
240       case kTfLiteString:
241         return GatherStrings<int64_t>(context, input, positions, output);
242       default:
243         TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
244                            TfLiteTypeGetName(input->type));
245         return kTfLiteError;
246     }
247   }
248   TF_LITE_KERNEL_LOG(context,
249                      "Positions of type '%s' are not supported by gather.",
250                      TfLiteTypeGetName(positions->type));
251   return kTfLiteError;
252 }
253 }  // namespace gather
254 
Register_GATHER()255 TfLiteRegistration* Register_GATHER() {
256   static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare,
257                                  gather::Eval};
258   return &r;
259 }
260 
261 }  // namespace builtin
262 }  // namespace ops
263 }  // namespace tflite
264