1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22 #include "tensorflow/lite/string_util.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace gather {
28 constexpr int kInputTensor = 0;
29 constexpr int kInputPositions = 1;
30 constexpr int kOutputTensor = 0;
31
Prepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
34 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35
36 const auto* params =
37 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
38 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
39 const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
40 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
41
42 switch (positions->type) {
43 case kTfLiteInt64:
44 case kTfLiteInt32:
45 break;
46 default:
47 context->ReportError(
48 context, "Positions of type '%s' are not supported by gather.",
49 TfLiteTypeGetName(positions->type));
50 return kTfLiteError;
51 }
52
53 // Assign to output the input type.
54 output->type = input->type;
55
56 // Check conditions for different types.
57 switch (input->type) {
58 case kTfLiteFloat32:
59 case kTfLiteUInt8:
60 case kTfLiteInt8:
61 case kTfLiteInt64:
62 case kTfLiteInt32:
63 break;
64 case kTfLiteString: {
65 // Only 1D input is supported.
66 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
67 } break;
68 default:
69 context->ReportError(context, "Type '%s' is not supported by gather.",
70 TfLiteTypeGetName(input->type));
71 return kTfLiteError;
72 }
73
74 int axis = params->axis;
75 if (axis < 0) {
76 axis += NumDimensions(input);
77 }
78 TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
79
80 const int num_dimensions =
81 NumDimensions(input) + NumDimensions(positions) - 1;
82 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
83 int output_index = 0;
84 for (int i = 0; i < axis; ++i) {
85 output_shape->data[output_index++] = input->dims->data[i];
86 }
87 for (int i = 0; i < positions->dims->size; ++i) {
88 output_shape->data[output_index++] = positions->dims->data[i];
89 }
90 for (int i = axis + 1; i < input->dims->size; ++i) {
91 output_shape->data[output_index++] = input->dims->data[i];
92 }
93 return context->ResizeTensor(context, output, output_shape);
94 }
95
96 template <typename InputT, typename PositionsT>
Gather(const TfLiteGatherParams & params,const TfLiteTensor * input,const TfLiteTensor * positions,TfLiteTensor * output)97 TfLiteStatus Gather(const TfLiteGatherParams& params, const TfLiteTensor* input,
98 const TfLiteTensor* positions, TfLiteTensor* output) {
99 tflite::GatherParams op_params;
100 op_params.axis = params.axis;
101 optimized_ops::Gather(op_params, GetTensorShape(input),
102 GetTensorData<InputT>(input), GetTensorShape(positions),
103 GetTensorData<PositionsT>(positions),
104 GetTensorShape(output), GetTensorData<InputT>(output));
105 return kTfLiteOk;
106 }
107
108 template <typename PositionT>
GatherStrings(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * positions,TfLiteTensor * output)109 TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
110 const TfLiteTensor* positions,
111 TfLiteTensor* output) {
112 // TODO(mgubin): Currently support only for 1D output tensors.
113 DynamicBuffer buffer;
114 const PositionT* indexes = GetTensorData<PositionT>(positions);
115 const PositionT num_strings = GetStringCount(input);
116 for (int i = 0; i < positions->dims->data[0]; ++i) {
117 const PositionT pos = indexes[i];
118 TF_LITE_ENSURE(context, pos < num_strings);
119 const auto string_ref = GetString(input, pos);
120 buffer.AddString(string_ref.str, string_ref.len);
121 }
122 buffer.WriteToTensorAsVector(output);
123 return kTfLiteOk;
124 }
125
Eval(TfLiteContext * context,TfLiteNode * node)126 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
127 const auto* params =
128 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
129 const TfLiteTensor* input = GetInput(context, node, kInputTensor);
130 const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
131 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
132
133 if (positions->type == kTfLiteInt32) {
134 switch (input->type) {
135 case kTfLiteFloat32:
136 return Gather<float, int32_t>(*params, input, positions, output);
137 case kTfLiteUInt8:
138 return Gather<uint8_t, int32_t>(*params, input, positions, output);
139 case kTfLiteInt8:
140 return Gather<int8_t, int32_t>(*params, input, positions, output);
141 case kTfLiteInt32:
142 return Gather<int32_t, int32_t>(*params, input, positions, output);
143 case kTfLiteInt64:
144 return Gather<int64_t, int32_t>(*params, input, positions, output);
145 case kTfLiteString:
146 return GatherStrings<int32_t>(context, input, positions, output);
147 default:
148 context->ReportError(context, "Type '%s' is not supported by gather.",
149 TfLiteTypeGetName(input->type));
150 return kTfLiteError;
151 }
152 }
153 if (positions->type == kTfLiteInt64) {
154 switch (input->type) {
155 case kTfLiteFloat32:
156 return Gather<float, int64_t>(*params, input, positions, output);
157 case kTfLiteUInt8:
158 return Gather<uint8_t, int64_t>(*params, input, positions, output);
159 case kTfLiteInt8:
160 return Gather<int8_t, int64_t>(*params, input, positions, output);
161 case kTfLiteInt32:
162 return Gather<int32_t, int64_t>(*params, input, positions, output);
163 case kTfLiteInt64:
164 return Gather<int64_t, int64_t>(*params, input, positions, output);
165 case kTfLiteString:
166 return GatherStrings<int64_t>(context, input, positions, output);
167 default:
168 context->ReportError(context, "Type '%s' is not supported by gather.",
169 TfLiteTypeGetName(input->type));
170 return kTfLiteError;
171 }
172 }
173 context->ReportError(context,
174 "Positions of type '%s' are not supported by gather.",
175 TfLiteTypeGetName(positions->type));
176 return kTfLiteError;
177 }
178 } // namespace gather
179
Register_GATHER()180 TfLiteRegistration* Register_GATHER() {
181 static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare,
182 gather::Eval};
183 return &r;
184 }
185
186 } // namespace builtin
187 } // namespace ops
188 } // namespace tflite
189