• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <string.h>
16 #include "tensorflow/lite/c/builtin_op_data.h"
17 #include "tensorflow/lite/c/c_api_internal.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/tensor.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 #include "tensorflow/lite/kernels/op_macros.h"
22 #include "tensorflow/lite/string_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace gather {
28 constexpr int kInputTensor = 0;
29 constexpr int kInputPositions = 1;
30 constexpr int kOutputTensor = 0;
31 
Prepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
34   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35 
36   const auto* params =
37       reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
38   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
39   const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
40   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
41 
42   switch (positions->type) {
43     case kTfLiteInt64:
44     case kTfLiteInt32:
45       break;
46     default:
47       context->ReportError(
48           context, "Positions of type '%s' are not supported by gather.",
49           TfLiteTypeGetName(positions->type));
50       return kTfLiteError;
51   }
52 
53   // Assign to output the input type.
54   output->type = input->type;
55 
56   // Check conditions for different types.
57   switch (input->type) {
58     case kTfLiteFloat32:
59     case kTfLiteUInt8:
60     case kTfLiteInt8:
61     case kTfLiteInt64:
62     case kTfLiteInt32:
63       break;
64     case kTfLiteString: {
65       // Only 1D input is supported.
66       TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
67     } break;
68     default:
69       context->ReportError(context, "Type '%s' is not supported by gather.",
70                            TfLiteTypeGetName(input->type));
71       return kTfLiteError;
72   }
73 
74   int axis = params->axis;
75   if (axis < 0) {
76     axis += NumDimensions(input);
77   }
78   TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
79 
80   const int num_dimensions =
81       NumDimensions(input) + NumDimensions(positions) - 1;
82   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
83   int output_index = 0;
84   for (int i = 0; i < axis; ++i) {
85     output_shape->data[output_index++] = input->dims->data[i];
86   }
87   for (int i = 0; i < positions->dims->size; ++i) {
88     output_shape->data[output_index++] = positions->dims->data[i];
89   }
90   for (int i = axis + 1; i < input->dims->size; ++i) {
91     output_shape->data[output_index++] = input->dims->data[i];
92   }
93   return context->ResizeTensor(context, output, output_shape);
94 }
95 
96 template <typename InputT, typename PositionsT>
Gather(const TfLiteGatherParams & params,const TfLiteTensor * input,const TfLiteTensor * positions,TfLiteTensor * output)97 TfLiteStatus Gather(const TfLiteGatherParams& params, const TfLiteTensor* input,
98                     const TfLiteTensor* positions, TfLiteTensor* output) {
99   tflite::GatherParams op_params;
100   op_params.axis = params.axis;
101   optimized_ops::Gather(op_params, GetTensorShape(input),
102                         GetTensorData<InputT>(input), GetTensorShape(positions),
103                         GetTensorData<PositionsT>(positions),
104                         GetTensorShape(output), GetTensorData<InputT>(output));
105   return kTfLiteOk;
106 }
107 
108 template <typename PositionT>
GatherStrings(TfLiteContext * context,const TfLiteTensor * input,const TfLiteTensor * positions,TfLiteTensor * output)109 TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
110                            const TfLiteTensor* positions,
111                            TfLiteTensor* output) {
112   // TODO(mgubin): Currently support only for 1D output tensors.
113   DynamicBuffer buffer;
114   const PositionT* indexes = GetTensorData<PositionT>(positions);
115   const PositionT num_strings = GetStringCount(input);
116   for (int i = 0; i < positions->dims->data[0]; ++i) {
117     const PositionT pos = indexes[i];
118     TF_LITE_ENSURE(context, pos < num_strings);
119     const auto string_ref = GetString(input, pos);
120     buffer.AddString(string_ref.str, string_ref.len);
121   }
122   buffer.WriteToTensorAsVector(output);
123   return kTfLiteOk;
124 }
125 
Eval(TfLiteContext * context,TfLiteNode * node)126 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
127   const auto* params =
128       reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
129   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
130   const TfLiteTensor* positions = GetInput(context, node, kInputPositions);
131   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
132 
133   if (positions->type == kTfLiteInt32) {
134     switch (input->type) {
135       case kTfLiteFloat32:
136         return Gather<float, int32_t>(*params, input, positions, output);
137       case kTfLiteUInt8:
138         return Gather<uint8_t, int32_t>(*params, input, positions, output);
139       case kTfLiteInt8:
140         return Gather<int8_t, int32_t>(*params, input, positions, output);
141       case kTfLiteInt32:
142         return Gather<int32_t, int32_t>(*params, input, positions, output);
143       case kTfLiteInt64:
144         return Gather<int64_t, int32_t>(*params, input, positions, output);
145       case kTfLiteString:
146         return GatherStrings<int32_t>(context, input, positions, output);
147       default:
148         context->ReportError(context, "Type '%s' is not supported by gather.",
149                              TfLiteTypeGetName(input->type));
150         return kTfLiteError;
151     }
152   }
153   if (positions->type == kTfLiteInt64) {
154     switch (input->type) {
155       case kTfLiteFloat32:
156         return Gather<float, int64_t>(*params, input, positions, output);
157       case kTfLiteUInt8:
158         return Gather<uint8_t, int64_t>(*params, input, positions, output);
159       case kTfLiteInt8:
160         return Gather<int8_t, int64_t>(*params, input, positions, output);
161       case kTfLiteInt32:
162         return Gather<int32_t, int64_t>(*params, input, positions, output);
163       case kTfLiteInt64:
164         return Gather<int64_t, int64_t>(*params, input, positions, output);
165       case kTfLiteString:
166         return GatherStrings<int64_t>(context, input, positions, output);
167       default:
168         context->ReportError(context, "Type '%s' is not supported by gather.",
169                              TfLiteTypeGetName(input->type));
170         return kTfLiteError;
171     }
172   }
173   context->ReportError(context,
174                        "Positions of type '%s' are not supported by gather.",
175                        TfLiteTypeGetName(positions->type));
176   return kTfLiteError;
177 }
178 }  // namespace gather
179 
Register_GATHER()180 TfLiteRegistration* Register_GATHER() {
181   static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare,
182                                  gather::Eval};
183   return &r;
184 }
185 
186 }  // namespace builtin
187 }  // namespace ops
188 }  // namespace tflite
189