• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23 #include "tensorflow/lite/kernels/internal/types.h"
24 #include "tensorflow/lite/kernels/kernel_util.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace builtin {
29 namespace scatter_nd {
30 constexpr int kIndices = 0;
31 constexpr int kUpdates = 1;
32 constexpr int kShape = 2;
33 constexpr int kOutputTensor = 0;
34 
35 template <typename IndicesT>
ResizeOutputTensor(TfLiteContext * context,const TfLiteTensor * shape,TfLiteTensor * output)36 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
37                                 const TfLiteTensor* shape,
38                                 TfLiteTensor* output) {
39   const int shape_rank = SizeOfDimension(shape, 0);
40   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(shape_rank);
41   const auto* shape_data = GetTensorData<IndicesT>(shape);
42 
43   for (int i = 0; i < shape_rank; i++) {
44     output_shape->data[i] = shape_data[i];
45   }
46   return context->ResizeTensor(context, output, output_shape);
47 }
48 
49 template <typename IndicesT>
CheckShapes(TfLiteContext * context,const RuntimeShape & indices,const RuntimeShape & updates,const RuntimeShape & shape_shape,const IndicesT * shape_data)50 TfLiteStatus CheckShapes(TfLiteContext* context, const RuntimeShape& indices,
51                          const RuntimeShape& updates,
52                          const RuntimeShape& shape_shape,
53                          const IndicesT* shape_data) {
54   TF_LITE_ENSURE(context, (indices.DimensionsCount() >= 1) &&
55                               (updates.DimensionsCount() >= 1) &&
56                               (shape_shape.DimensionsCount() == 1));
57 
58   const int outer_dims = indices.DimensionsCount() - 1;
59   for (int i = 0; i < outer_dims; ++i) {
60     TF_LITE_ENSURE_EQ(context, indices.Dims(i), updates.Dims(i));
61   }
62 
63   const int ix = indices.Dims(outer_dims);
64   TF_LITE_ENSURE_EQ(context, updates.DimensionsCount() - outer_dims,
65                     shape_shape.Dims(0) - ix);
66   for (int i = 0; i + outer_dims < updates.DimensionsCount(); ++i) {
67     TF_LITE_ENSURE_EQ(context, updates.Dims(i + outer_dims),
68                       shape_data[ix + i]);
69   }
70   return kTfLiteOk;
71 }
72 
Prepare(TfLiteContext * context,TfLiteNode * node)73 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
74   TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
75   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
76 
77   const TfLiteTensor* indices;
78   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
79   const TfLiteTensor* updates;
80   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
81   const TfLiteTensor* shape;
82   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
83 
84   switch (updates->type) {
85     case kTfLiteFloat32:
86     case kTfLiteUInt8:
87     case kTfLiteBool:
88     case kTfLiteInt8:
89     case kTfLiteInt64:
90     case kTfLiteInt32:
91       break;
92     default:
93       TF_LITE_KERNEL_LOG(
94           context, "Updates of type '%s' are not supported by scatter_nd.",
95           TfLiteTypeGetName(updates->type));
96       return kTfLiteError;
97   }
98   if (indices->type != shape->type) {
99     TF_LITE_KERNEL_LOG(context, "Indices and shape must have the same type.");
100     return kTfLiteError;
101   }
102 
103   TfLiteTensor* output;
104   TF_LITE_ENSURE_OK(context,
105                     GetOutputSafe(context, node, kOutputTensor, &output));
106   output->type = updates->type;
107 
108   if (IsConstantTensor(shape)) {
109     switch (indices->type) {
110       case kTfLiteInt32:
111         TF_LITE_ENSURE_OK(
112             context,
113             CheckShapes<int32_t>(context, GetTensorShape(indices),
114                                  GetTensorShape(updates), GetTensorShape(shape),
115                                  GetTensorData<int32_t>(shape)));
116         return ResizeOutputTensor<int32_t>(context, shape, output);
117       default:
118         TF_LITE_KERNEL_LOG(
119             context, "Indices of type '%s' are not supported by scatter_nd.",
120             TfLiteTypeGetName(indices->type));
121         return kTfLiteError;
122     }
123   } else {
124     SetTensorToDynamic(output);
125     return kTfLiteOk;
126   }
127 }
128 
129 template <typename IndicesT, typename UpdatesT>
ScatterNd(const TfLiteTensor * indices,const TfLiteTensor * updates,TfLiteTensor * output)130 TfLiteStatus ScatterNd(const TfLiteTensor* indices, const TfLiteTensor* updates,
131                        TfLiteTensor* output) {
132   return reference_ops::ScatterNd(
133       GetTensorShape(indices), GetTensorData<IndicesT>(indices),
134       GetTensorShape(updates), GetTensorData<UpdatesT>(updates),
135       GetTensorShape(output), GetTensorData<UpdatesT>(output));
136 }
137 
138 template <typename IndicesT>
EvalScatterNd(TfLiteContext * context,const TfLiteTensor * indices,const TfLiteTensor * updates,const TfLiteTensor * shape,TfLiteTensor * output)139 TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
140                            const TfLiteTensor* updates,
141                            const TfLiteTensor* shape, TfLiteTensor* output) {
142   if (IsDynamicTensor(output)) {
143     TF_LITE_ENSURE_OK(
144         context, CheckShapes<IndicesT>(
145                      context, GetTensorShape(indices), GetTensorShape(updates),
146                      GetTensorShape(shape), GetTensorData<IndicesT>(shape)));
147     TF_LITE_ENSURE_OK(context,
148                       ResizeOutputTensor<IndicesT>(context, shape, output));
149   }
150 
151   TfLiteStatus status = kTfLiteError;
152   switch (updates->type) {
153     case kTfLiteFloat32:
154       status = ScatterNd<IndicesT, float>(indices, updates, output);
155       break;
156     case kTfLiteUInt8:
157       status = ScatterNd<IndicesT, uint8_t>(indices, updates, output);
158       break;
159     case kTfLiteBool:
160       status = ScatterNd<IndicesT, bool>(indices, updates, output);
161       break;
162     case kTfLiteInt8:
163       status = ScatterNd<IndicesT, int8_t>(indices, updates, output);
164       break;
165     case kTfLiteInt32:
166       status = ScatterNd<IndicesT, int32_t>(indices, updates, output);
167       break;
168     case kTfLiteInt64:
169       status = ScatterNd<IndicesT, int64_t>(indices, updates, output);
170       break;
171     default:
172       TF_LITE_KERNEL_LOG(
173           context, "Updates of type '%s' are not supported by scatter_nd.",
174           TfLiteTypeGetName(updates->type));
175       return kTfLiteError;
176   }
177   if (status != kTfLiteOk) {
178     TF_LITE_KERNEL_LOG(context, "scatter_nd index out of bounds");
179   }
180   return status;
181 }
182 
Eval(TfLiteContext * context,TfLiteNode * node)183 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
184   const TfLiteTensor* indices;
185   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
186   const TfLiteTensor* updates;
187   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kUpdates, &updates));
188   const TfLiteTensor* shape;
189   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShape, &shape));
190   TfLiteTensor* output;
191   TF_LITE_ENSURE_OK(context,
192                     GetOutputSafe(context, node, kOutputTensor, &output));
193 
194   switch (indices->type) {
195     case kTfLiteInt32:
196       return EvalScatterNd<int32_t>(context, indices, updates, shape, output);
197     default:
198       TF_LITE_KERNEL_LOG(
199           context, "Indices of type '%s' are not supported by scatter_nd.",
200           TfLiteTypeGetName(indices->type));
201       return kTfLiteError;
202   }
203 }
204 
205 }  // namespace scatter_nd
206 
Register_SCATTER_ND()207 TfLiteRegistration* Register_SCATTER_ND() {
208   static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
209                                  scatter_nd::Prepare, scatter_nd::Eval};
210   return &r;
211 }
212 }  // namespace builtin
213 }  // namespace ops
214 }  // namespace tflite
215