• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace gather_nd {
28 constexpr int kParams = 0;
29 constexpr int kIndices = 1;
30 constexpr int kOutputTensor = 0;
31 
Prepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
34   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35 
36   const TfLiteTensor* params;
37   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
38   const TfLiteTensor* indices;
39   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
40   TfLiteTensor* output;
41   TF_LITE_ENSURE_OK(context,
42                     GetOutputSafe(context, node, kOutputTensor, &output));
43 
44   switch (params->type) {
45     case kTfLiteFloat32:
46     case kTfLiteUInt8:
47     case kTfLiteInt8:
48     case kTfLiteInt16:
49     case kTfLiteInt64:
50     case kTfLiteInt32:
51     case kTfLiteString:
52       break;
53     default:
54       context->ReportError(
55           context, "Params of type '%s' are not supported by gather_nd.",
56           TfLiteTypeGetName(params->type));
57       return kTfLiteError;
58   }
59   switch (indices->type) {
60     case kTfLiteInt64:
61     case kTfLiteInt32:
62       break;
63     default:
64       context->ReportError(
65           context, "Indices of type '%s' are not supported by gather_nd.",
66           TfLiteTypeGetName(indices->type));
67       return kTfLiteError;
68   }
69 
70   const int params_rank = NumDimensions(params);
71   const int indices_rank = NumDimensions(indices);
72   const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
73   if (params_rank < 1) {
74     context->ReportError(context, "Params must be at least a vector.");
75     return kTfLiteError;
76   }
77   if (indices_rank < 1) {
78     context->ReportError(context, "Indices must be at least a vector.");
79     return kTfLiteError;
80   }
81   if (indices_nd > params_rank) {
82     context->ReportError(
83         context, "Index innermost dimension length must be <= params rank.");
84     return kTfLiteError;
85   }
86 
87   // Assign to output the input type.
88   output->type = params->type;
89 
90   // The result shape is
91   // indices.shape[:-1] + params.shape[indices.shape[-1]:]
92   const int output_rank = indices_rank + params_rank - indices_nd - 1;
93   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
94   int output_index = 0;
95   for (int i = 0; i < indices_rank - 1; ++i) {
96     output_shape->data[output_index++] = indices->dims->data[i];
97   }
98   for (int i = indices_nd; i < params_rank; ++i) {
99     output_shape->data[output_index++] = params->dims->data[i];
100   }
101   return context->ResizeTensor(context, output, output_shape);
102 }
103 
104 template <typename ParamsT, typename IndicesT>
GatherNd(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)105 TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices,
106                       TfLiteTensor* output) {
107   reference_ops::GatherNd(
108       GetTensorShape(params), GetTensorData<ParamsT>(params),
109       GetTensorShape(indices), GetTensorData<IndicesT>(indices),
110       GetTensorShape(output), GetTensorData<ParamsT>(output));
111   return kTfLiteOk;
112 }
113 
114 template <typename IndicesT>
GatherNdString(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)115 TfLiteStatus GatherNdString(const TfLiteTensor* params,
116                             const TfLiteTensor* indices, TfLiteTensor* output) {
117   reference_ops::GatherNdString(
118       GetTensorShape(params), params, GetTensorShape(indices),
119       GetTensorData<IndicesT>(indices), GetTensorShape(output), output);
120   return kTfLiteOk;
121 }
122 
123 template <typename IndicesT>
EvalGatherNd(TfLiteContext * context,const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)124 TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
125                           const TfLiteTensor* indices, TfLiteTensor* output) {
126   switch (params->type) {
127     case kTfLiteFloat32:
128       return GatherNd<float, IndicesT>(params, indices, output);
129     case kTfLiteUInt8:
130       return GatherNd<uint8_t, IndicesT>(params, indices, output);
131     case kTfLiteInt8:
132       return GatherNd<int8_t, IndicesT>(params, indices, output);
133     case kTfLiteInt16:
134       return GatherNd<int16_t, IndicesT>(params, indices, output);
135     case kTfLiteInt32:
136       return GatherNd<int32_t, IndicesT>(params, indices, output);
137     case kTfLiteInt64:
138       return GatherNd<int64_t, IndicesT>(params, indices, output);
139     case kTfLiteString:
140       return GatherNdString<IndicesT>(params, indices, output);
141     default:
142       context->ReportError(context,
143                            "Params type '%s' are not supported by gather_nd.",
144                            TfLiteTypeGetName(params->type));
145       return kTfLiteError;
146   }
147 }
148 
Eval(TfLiteContext * context,TfLiteNode * node)149 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
150   const TfLiteTensor* params;
151   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
152   const TfLiteTensor* indices;
153   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
154   TfLiteTensor* output;
155   TF_LITE_ENSURE_OK(context,
156                     GetOutputSafe(context, node, kOutputTensor, &output));
157 
158   switch (indices->type) {
159     case kTfLiteInt32:
160       return EvalGatherNd<int32_t>(context, params, indices, output);
161     case kTfLiteInt64:
162       return EvalGatherNd<int64_t>(context, params, indices, output);
163     default:
164       context->ReportError(
165           context, "Indices of type '%s' are not supported by gather_nd.",
166           TfLiteTypeGetName(indices->type));
167       return kTfLiteError;
168   }
169 }
170 }  // namespace gather_nd
171 
Register_GATHER_ND()172 TfLiteRegistration* Register_GATHER_ND() {
173   static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
174                                  gather_nd::Prepare, gather_nd::Eval};
175   return &r;
176 }
177 }  // namespace builtin
178 }  // namespace ops
179 }  // namespace tflite
180