1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/builtin_op_data.h"
16 #include "tensorflow/lite/c/c_api_internal.h"
17 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
18 #include "tensorflow/lite/kernels/internal/tensor.h"
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/kernels/op_macros.h"
21
22 namespace tflite {
23 namespace ops {
24 namespace builtin {
25 namespace gather_nd {
26 constexpr int kParams = 0;
27 constexpr int kIndices = 1;
28 constexpr int kOutputTensor = 0;
29
Prepare(TfLiteContext * context,TfLiteNode * node)30 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
31 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
32 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
33
34 const TfLiteTensor* params = GetInput(context, node, kParams);
35 const TfLiteTensor* indices = GetInput(context, node, kIndices);
36 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
37
38 switch (params->type) {
39 case kTfLiteFloat32:
40 case kTfLiteUInt8:
41 case kTfLiteInt8:
42 case kTfLiteInt64:
43 case kTfLiteInt32:
44 break;
45 default:
46 context->ReportError(
47 context, "Params of type '%s' are not supported by gather_nd.",
48 TfLiteTypeGetName(params->type));
49 return kTfLiteError;
50 }
51 switch (indices->type) {
52 case kTfLiteInt64:
53 case kTfLiteInt32:
54 break;
55 default:
56 context->ReportError(
57 context, "Indices of type '%s' are not supported by gather_nd.",
58 TfLiteTypeGetName(indices->type));
59 return kTfLiteError;
60 }
61
62 const int params_rank = NumDimensions(params);
63 const int indices_rank = NumDimensions(indices);
64 const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
65 if (params_rank < 1) {
66 context->ReportError(context, "Params must be at least a vector.");
67 return kTfLiteError;
68 }
69 if (indices_rank < 1) {
70 context->ReportError(context, "Indices must be at least a vector.");
71 return kTfLiteError;
72 }
73 if (indices_nd > params_rank) {
74 context->ReportError(
75 context, "Index innermost dimension length must be <= params rank.");
76 return kTfLiteError;
77 }
78
79 // Assign to output the input type.
80 output->type = params->type;
81
82 // The result shape is
83 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
84 const int output_rank = indices_rank + params_rank - indices_nd - 1;
85 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
86 int output_index = 0;
87 for (int i = 0; i < indices_rank - 1; ++i) {
88 output_shape->data[output_index++] = indices->dims->data[i];
89 }
90 for (int i = indices_nd; i < params_rank; ++i) {
91 output_shape->data[output_index++] = params->dims->data[i];
92 }
93 return context->ResizeTensor(context, output, output_shape);
94 }
95
96 template <typename ParamsT, typename IndicesT>
GatherNd(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)97 TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices,
98 TfLiteTensor* output) {
99 reference_ops::GatherNd(
100 GetTensorShape(params), GetTensorData<ParamsT>(params),
101 GetTensorShape(indices), GetTensorData<IndicesT>(indices),
102 GetTensorShape(output), GetTensorData<ParamsT>(output));
103 return kTfLiteOk;
104 }
105
106 template <typename IndicesT>
EvalGatherNd(TfLiteContext * context,const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)107 TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
108 const TfLiteTensor* indices, TfLiteTensor* output) {
109 switch (params->type) {
110 case kTfLiteFloat32:
111 return GatherNd<float, IndicesT>(params, indices, output);
112 case kTfLiteUInt8:
113 return GatherNd<uint8_t, IndicesT>(params, indices, output);
114 case kTfLiteInt8:
115 return GatherNd<int8_t, IndicesT>(params, indices, output);
116 case kTfLiteInt32:
117 return GatherNd<int32_t, IndicesT>(params, indices, output);
118 case kTfLiteInt64:
119 return GatherNd<int64_t, IndicesT>(params, indices, output);
120 default:
121 context->ReportError(context,
122 "Params type '%s' are not supported by gather_nd.",
123 TfLiteTypeGetName(params->type));
124 return kTfLiteError;
125 }
126 }
127
Eval(TfLiteContext * context,TfLiteNode * node)128 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
129 const TfLiteTensor* params = GetInput(context, node, kParams);
130 const TfLiteTensor* indices = GetInput(context, node, kIndices);
131 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
132
133 switch (indices->type) {
134 case kTfLiteInt32:
135 return EvalGatherNd<int32_t>(context, params, indices, output);
136 case kTfLiteInt64:
137 return EvalGatherNd<int64_t>(context, params, indices, output);
138 default:
139 context->ReportError(
140 context, "Indices of type '%s' are not supported by gather_nd.",
141 TfLiteTypeGetName(indices->type));
142 return kTfLiteError;
143 }
144 }
145 } // namespace gather_nd
146
Register_GATHER_ND()147 TfLiteRegistration* Register_GATHER_ND() {
148 static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
149 gather_nd::Prepare, gather_nd::Eval};
150 return &r;
151 }
152 } // namespace builtin
153 } // namespace ops
154 } // namespace tflite
155