• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/builtin_op_data.h"
16 #include "tensorflow/lite/c/c_api_internal.h"
17 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
18 #include "tensorflow/lite/kernels/internal/tensor.h"
19 #include "tensorflow/lite/kernels/kernel_util.h"
20 #include "tensorflow/lite/kernels/op_macros.h"
21 
22 namespace tflite {
23 namespace ops {
24 namespace builtin {
25 namespace gather_nd {
26 constexpr int kParams = 0;
27 constexpr int kIndices = 1;
28 constexpr int kOutputTensor = 0;
29 
Prepare(TfLiteContext * context,TfLiteNode * node)30 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
31   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
32   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
33 
34   const TfLiteTensor* params = GetInput(context, node, kParams);
35   const TfLiteTensor* indices = GetInput(context, node, kIndices);
36   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
37 
38   switch (params->type) {
39     case kTfLiteFloat32:
40     case kTfLiteUInt8:
41     case kTfLiteInt8:
42     case kTfLiteInt64:
43     case kTfLiteInt32:
44       break;
45     default:
46       context->ReportError(
47           context, "Params of type '%s' are not supported by gather_nd.",
48           TfLiteTypeGetName(params->type));
49       return kTfLiteError;
50   }
51   switch (indices->type) {
52     case kTfLiteInt64:
53     case kTfLiteInt32:
54       break;
55     default:
56       context->ReportError(
57           context, "Indices of type '%s' are not supported by gather_nd.",
58           TfLiteTypeGetName(indices->type));
59       return kTfLiteError;
60   }
61 
62   const int params_rank = NumDimensions(params);
63   const int indices_rank = NumDimensions(indices);
64   const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
65   if (params_rank < 1) {
66     context->ReportError(context, "Params must be at least a vector.");
67     return kTfLiteError;
68   }
69   if (indices_rank < 1) {
70     context->ReportError(context, "Indices must be at least a vector.");
71     return kTfLiteError;
72   }
73   if (indices_nd > params_rank) {
74     context->ReportError(
75         context, "Index innermost dimension length must be <= params rank.");
76     return kTfLiteError;
77   }
78 
79   // Assign to output the input type.
80   output->type = params->type;
81 
82   // The result shape is
83   // indices.shape[:-1] + params.shape[indices.shape[-1]:]
84   const int output_rank = indices_rank + params_rank - indices_nd - 1;
85   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
86   int output_index = 0;
87   for (int i = 0; i < indices_rank - 1; ++i) {
88     output_shape->data[output_index++] = indices->dims->data[i];
89   }
90   for (int i = indices_nd; i < params_rank; ++i) {
91     output_shape->data[output_index++] = params->dims->data[i];
92   }
93   return context->ResizeTensor(context, output, output_shape);
94 }
95 
96 template <typename ParamsT, typename IndicesT>
GatherNd(const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)97 TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices,
98                       TfLiteTensor* output) {
99   reference_ops::GatherNd(
100       GetTensorShape(params), GetTensorData<ParamsT>(params),
101       GetTensorShape(indices), GetTensorData<IndicesT>(indices),
102       GetTensorShape(output), GetTensorData<ParamsT>(output));
103   return kTfLiteOk;
104 }
105 
106 template <typename IndicesT>
EvalGatherNd(TfLiteContext * context,const TfLiteTensor * params,const TfLiteTensor * indices,TfLiteTensor * output)107 TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
108                           const TfLiteTensor* indices, TfLiteTensor* output) {
109   switch (params->type) {
110     case kTfLiteFloat32:
111       return GatherNd<float, IndicesT>(params, indices, output);
112     case kTfLiteUInt8:
113       return GatherNd<uint8_t, IndicesT>(params, indices, output);
114     case kTfLiteInt8:
115       return GatherNd<int8_t, IndicesT>(params, indices, output);
116     case kTfLiteInt32:
117       return GatherNd<int32_t, IndicesT>(params, indices, output);
118     case kTfLiteInt64:
119       return GatherNd<int64_t, IndicesT>(params, indices, output);
120     default:
121       context->ReportError(context,
122                            "Params type '%s' are not supported by gather_nd.",
123                            TfLiteTypeGetName(params->type));
124       return kTfLiteError;
125   }
126 }
127 
Eval(TfLiteContext * context,TfLiteNode * node)128 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
129   const TfLiteTensor* params = GetInput(context, node, kParams);
130   const TfLiteTensor* indices = GetInput(context, node, kIndices);
131   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
132 
133   switch (indices->type) {
134     case kTfLiteInt32:
135       return EvalGatherNd<int32_t>(context, params, indices, output);
136     case kTfLiteInt64:
137       return EvalGatherNd<int64_t>(context, params, indices, output);
138     default:
139       context->ReportError(
140           context, "Indices of type '%s' are not supported by gather_nd.",
141           TfLiteTypeGetName(indices->type));
142       return kTfLiteError;
143   }
144 }
145 }  // namespace gather_nd
146 
Register_GATHER_ND()147 TfLiteRegistration* Register_GATHER_ND() {
148   static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
149                                  gather_nd::Prepare, gather_nd::Eval};
150   return &r;
151 }
152 }  // namespace builtin
153 }  // namespace ops
154 }  // namespace tflite
155