1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include "tensorflow/lite/c/common.h"
18 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace gather_nd {
28 constexpr int kParams = 0;
29 constexpr int kIndices = 1;
30 constexpr int kOutputTensor = 0;
31
Prepare(TfLiteContext * context,TfLiteNode * node)32 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
33 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
34 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
35
36 const TfLiteTensor* params;
37 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, ¶ms));
38 const TfLiteTensor* indices;
39 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
40 TfLiteTensor* output;
41 TF_LITE_ENSURE_OK(context,
42 GetOutputSafe(context, node, kOutputTensor, &output));
43
44 switch (params->type) {
45 case kTfLiteFloat32:
46 case kTfLiteUInt8:
47 case kTfLiteInt8:
48 case kTfLiteInt16:
49 case kTfLiteInt64:
50 case kTfLiteInt32:
51 case kTfLiteString:
52 break;
53 default:
54 context->ReportError(
55 context, "Params of type '%s' are not supported by gather_nd.",
56 TfLiteTypeGetName(params->type));
57 return kTfLiteError;
58 }
59 switch (indices->type) {
60 case kTfLiteInt64:
61 case kTfLiteInt32:
62 break;
63 default:
64 context->ReportError(
65 context, "Indices of type '%s' are not supported by gather_nd.",
66 TfLiteTypeGetName(indices->type));
67 return kTfLiteError;
68 }
69
70 const int params_rank = NumDimensions(params);
71 const int indices_rank = NumDimensions(indices);
72 const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
73 if (params_rank < 1) {
74 context->ReportError(context, "Params must be at least a vector.");
75 return kTfLiteError;
76 }
77 if (indices_rank < 1) {
78 context->ReportError(context, "Indices must be at least a vector.");
79 return kTfLiteError;
80 }
81 if (indices_nd > params_rank) {
82 context->ReportError(
83 context, "Index innermost dimension length must be <= params rank.");
84 return kTfLiteError;
85 }
86
87 // Assign to output the input type.
88 output->type = params->type;
89
90 // The result shape is
91 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
92 const int output_rank = indices_rank + params_rank - indices_nd - 1;
93 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
94 int output_index = 0;
95 for (int i = 0; i < indices_rank - 1; ++i) {
96 output_shape->data[output_index++] = indices->dims->data[i];
97 }
98 for (int i = indices_nd; i < params_rank; ++i) {
99 output_shape->data[output_index++] = params->dims->data[i];
100 }
101 return context->ResizeTensor(context, output, output_shape);
102 }
103
104 template <typename ParamsT, typename IndicesT>
GatherNd(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)105 TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices,
106 TfLiteTensor* output) {
107 reference_ops::GatherNd(
108 GetTensorShape(params), GetTensorData<ParamsT>(params),
109 GetTensorShape(indices), GetTensorData<IndicesT>(indices),
110 GetTensorShape(output), GetTensorData<ParamsT>(output));
111 return kTfLiteOk;
112 }
113
114 template <typename IndicesT>
GatherNdString(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)115 TfLiteStatus GatherNdString(const TfLiteTensor* params,
116 const TfLiteTensor* indices, TfLiteTensor* output) {
117 reference_ops::GatherNdString(
118 GetTensorShape(params), params, GetTensorShape(indices),
119 GetTensorData<IndicesT>(indices), GetTensorShape(output), output);
120 return kTfLiteOk;
121 }
122
123 template <typename IndicesT>
EvalGatherNd(TfLiteContext * context,const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)124 TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
125 const TfLiteTensor* indices, TfLiteTensor* output) {
126 switch (params->type) {
127 case kTfLiteFloat32:
128 return GatherNd<float, IndicesT>(params, indices, output);
129 case kTfLiteUInt8:
130 return GatherNd<uint8_t, IndicesT>(params, indices, output);
131 case kTfLiteInt8:
132 return GatherNd<int8_t, IndicesT>(params, indices, output);
133 case kTfLiteInt16:
134 return GatherNd<int16_t, IndicesT>(params, indices, output);
135 case kTfLiteInt32:
136 return GatherNd<int32_t, IndicesT>(params, indices, output);
137 case kTfLiteInt64:
138 return GatherNd<int64_t, IndicesT>(params, indices, output);
139 case kTfLiteString:
140 return GatherNdString<IndicesT>(params, indices, output);
141 default:
142 context->ReportError(context,
143 "Params type '%s' are not supported by gather_nd.",
144 TfLiteTypeGetName(params->type));
145 return kTfLiteError;
146 }
147 }
148
Eval(TfLiteContext * context,TfLiteNode * node)149 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
150 const TfLiteTensor* params;
151 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, ¶ms));
152 const TfLiteTensor* indices;
153 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
154 TfLiteTensor* output;
155 TF_LITE_ENSURE_OK(context,
156 GetOutputSafe(context, node, kOutputTensor, &output));
157
158 switch (indices->type) {
159 case kTfLiteInt32:
160 return EvalGatherNd<int32_t>(context, params, indices, output);
161 case kTfLiteInt64:
162 return EvalGatherNd<int64_t>(context, params, indices, output);
163 default:
164 context->ReportError(
165 context, "Indices of type '%s' are not supported by gather_nd.",
166 TfLiteTypeGetName(indices->type));
167 return kTfLiteError;
168 }
169 }
170 } // namespace gather_nd
171
Register_GATHER_ND()172 TfLiteRegistration* Register_GATHER_ND() {
173 static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
174 gather_nd::Prepare, gather_nd::Eval};
175 return &r;
176 }
177 } // namespace builtin
178 } // namespace ops
179 } // namespace tflite
180