• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <vector>
19 
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
22 #include "tensorflow/lite/kernels/test_util.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 
25 namespace tflite {
26 namespace {
27 
28 using ::testing::ElementsAreArray;
29 
30 class ScatterNdOpModel : public SingleOpModel {
31  public:
ScatterNdOpModel(const TensorData & indices,const TensorData & updates,const TensorData & shape)32   ScatterNdOpModel(const TensorData& indices, const TensorData& updates,
33                    const TensorData& shape) {
34     indices_ = AddInput(indices);
35     updates_ = AddInput(updates);
36     shape_ = AddInput(shape);
37     output_ = AddOutput(updates.type);
38     SetBuiltinOp(BuiltinOperator_SCATTER_ND, BuiltinOptions_ScatterNdOptions,
39                  CreateScatterNdOptions(builder_).Union());
40     BuildInterpreter(
41         {GetShape(indices_), GetShape(updates_), GetShape(shape_)});
42   }
43 
44   template <typename T>
SetIndices(std::initializer_list<T> data)45   void SetIndices(std::initializer_list<T> data) {
46     PopulateTensor<T>(indices_, data);
47   }
48 
49   template <typename T>
SetUpdates(std::initializer_list<T> data)50   void SetUpdates(std::initializer_list<T> data) {
51     PopulateTensor<T>(updates_, data);
52   }
53 
54   template <typename T>
SetShape(std::initializer_list<T> data)55   void SetShape(std::initializer_list<T> data) {
56     PopulateTensor<T>(shape_, data);
57   }
58 
59   template <typename T>
GetOutput()60   std::vector<T> GetOutput() {
61     return ExtractVector<T>(output_);
62   }
63 
GetOutputShape()64   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
65 
66  protected:
67   int indices_;
68   int updates_;
69   int shape_;
70   int output_;
71 };
72 
TEST(ScatterNdOpTest,ScatterElementIntoVector)73 TEST(ScatterNdOpTest, ScatterElementIntoVector) {
74   ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4}},
75                      {TensorType_INT32, {1}});
76   m.SetIndices<int32_t>({4, 3, 1, 7});
77   m.SetUpdates<float>({9, 10, 11, 12});
78   m.SetShape<int32_t>({8});
79   m.Invoke();
80 
81   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8}));
82   EXPECT_THAT(m.GetOutput<float>(),
83               ElementsAreArray({0, 11, 0, 10, 9, 0, 0, 12}));
84 }
85 
TEST(ScatterNdOpTest,ScatterMatrixIntoRank3Tensor)86 TEST(ScatterNdOpTest, ScatterMatrixIntoRank3Tensor) {
87   ScatterNdOpModel m({TensorType_INT32, {2, 1}},
88                      {TensorType_FLOAT32, {2, 4, 4}}, {TensorType_INT32, {3}});
89   m.SetIndices<int32_t>({0, 2});
90   m.SetUpdates<float>({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
91                        5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8});
92   m.SetShape<int32_t>({4, 4, 4});
93   m.Invoke();
94 
95   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 4, 4}));
96   EXPECT_THAT(
97       m.GetOutput<float>(),
98       ElementsAreArray({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
99                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
100                         5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
101                         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
102 }
103 
TEST(ScatterNdOpTest,ScatterVectorIntoMatrix)104 TEST(ScatterNdOpTest, ScatterVectorIntoMatrix) {
105   ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4, 4}},
106                      {TensorType_INT32, {2}});
107   m.SetIndices<int32_t>({/*0*/ 9, /*1*/ 8, /*2*/ 0, /*3*/ 1});
108   m.SetUpdates<float>({/*0*/ 1, 2, 3, 4,
109                        /*1*/ 5, 6, 7, 8,
110                        /*2*/ 9, 10, 11, 12,
111                        /*3*/ 13, 14, 15, 16});
112   m.SetShape<int32_t>({10, 4});
113   m.Invoke();
114 
115   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({10, 4}));
116   EXPECT_THAT(m.GetOutput<float>(),
117               ElementsAreArray({/*0*/ 9,  10, 11, 12,
118                                 /*1*/ 13, 14, 15, 16,
119                                 /*2*/ 0,  0,  0,  0,
120                                 /*3*/ 0,  0,  0,  0,
121                                 /*4*/ 0,  0,  0,  0,
122                                 /*5*/ 0,  0,  0,  0,
123                                 /*6*/ 0,  0,  0,  0,
124                                 /*7*/ 0,  0,  0,  0,
125                                 /*8*/ 5,  6,  7,  8,
126                                 /*9*/ 1,  2,  3,  4}));
127 }
128 
TEST(ScatterNdOpTest,ScatterMatricesIntoRank4Tensor)129 TEST(ScatterNdOpTest, ScatterMatricesIntoRank4Tensor) {
130   ScatterNdOpModel m({TensorType_INT32, {2, 2, 2}},
131                      {TensorType_FLOAT32, {2, 2, 2, 2}},
132                      {TensorType_INT32, {4}});
133   m.SetIndices<int32_t>(
134       {/*0,0*/ 1, 1, /*0,1*/ 0, 1, /*1,0*/ 0, 0, /*1,1*/ 1, 0});
135   m.SetUpdates<float>({/*0,0*/ 1, 2, 3, 4, /*0,1*/ 5, 6, 7, 8,
136                        /*1,0*/ 9, 10, 11, 12, /*1,1*/ 13, 14, 15, 16});
137   m.SetShape<int32_t>({2, 2, 2, 2});
138   m.Invoke();
139 
140   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
141   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({/*0, 0*/ 9, 10, 11, 12,
142                                                       /*0, 1*/ 5, 6, 7, 8,
143                                                       /*1, 0*/ 13, 14, 15, 16,
144                                                       /*1, 1*/ 1, 2, 3, 4}));
145 }
146 
TEST(ScatterNdOpTest,ScatterVectorIntoRank4Tensor)147 TEST(ScatterNdOpTest, ScatterVectorIntoRank4Tensor) {
148   ScatterNdOpModel m({TensorType_INT32, {2, 2, 3}},
149                      {TensorType_FLOAT32, {2, 2, 5}}, {TensorType_INT32, {4}});
150   m.SetIndices<int32_t>(
151       {/*0,0*/ 2, 2, 2, /*0,1*/ 1, 0, 1, /*1,0*/ 0, 2, 0, /*1,0*/ 2, 2, 0});
152   m.SetUpdates<float>(
153       {/*0,0*/ 1,  2,  3,  4,  5,  /*0,1*/ 6,  7,  8,  9,  10,
154        /*1,0*/ 11, 12, 13, 14, 15, /*1,1*/ 16, 17, 18, 19, 20});
155   m.SetShape<int32_t>({3, 3, 3, 5});
156   m.Invoke();
157 
158   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3, 5}));
159   EXPECT_THAT(m.GetOutput<float>(),
160               ElementsAreArray({
161                   /*0, 0, 0*/ 0,  0,  0,  0,  0,
162                   /*0, 0, 1*/ 0,  0,  0,  0,  0,
163                   /*0, 0, 2*/ 0,  0,  0,  0,  0,
164                   /*0, 1, 0*/ 0,  0,  0,  0,  0,
165                   /*0, 1, 1*/ 0,  0,  0,  0,  0,
166                   /*0, 1, 2*/ 0,  0,  0,  0,  0,
167                   /*0, 2, 0*/ 11, 12, 13, 14, 15,
168                   /*0, 2, 1*/ 0,  0,  0,  0,  0,
169                   /*0, 2, 2*/ 0,  0,  0,  0,  0,
170                   /*1, 0, 0*/ 0,  0,  0,  0,  0,
171                   /*1, 0, 1*/ 6,  7,  8,  9,  10,
172                   /*1, 0, 2*/ 0,  0,  0,  0,  0,
173                   /*1, 1, 0*/ 0,  0,  0,  0,  0,
174                   /*1, 1, 1*/ 0,  0,  0,  0,  0,
175                   /*1, 1, 2*/ 0,  0,  0,  0,  0,
176                   /*1, 2, 0*/ 0,  0,  0,  0,  0,
177                   /*1, 2, 1*/ 0,  0,  0,  0,  0,
178                   /*1, 2, 2*/ 0,  0,  0,  0,  0,
179                   /*2, 0, 0*/ 0,  0,  0,  0,  0,
180                   /*2, 0, 1*/ 0,  0,  0,  0,  0,
181                   /*2, 0, 2*/ 0,  0,  0,  0,  0,
182                   /*2, 1, 0*/ 0,  0,  0,  0,  0,
183                   /*2, 1, 1*/ 0,  0,  0,  0,  0,
184                   /*2, 1, 2*/ 0,  0,  0,  0,  0,
185                   /*2, 2, 0*/ 16, 17, 18, 19, 20,
186                   /*2, 2, 1*/ 0,  0,  0,  0,  0,
187                   /*2, 2, 2*/ 1,  2,  3,  4,  5,
188               }));
189 }
190 
TEST(ScatterNdOpTest,ScatterVectorIntoRank3Tensor)191 TEST(ScatterNdOpTest, ScatterVectorIntoRank3Tensor) {
192   ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
193                      {TensorType_INT32, {3}});
194   m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
195   m.SetUpdates<float>(
196       {/*0*/ 1,  2,  3,  4,  5,  /*1*/ 6,  7,  8,  9,  10,
197        /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
198   m.SetShape<int32_t>({2, 3, 5});
199   m.Invoke();
200 
201   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
202   EXPECT_THAT(m.GetOutput<float>(),
203               ElementsAreArray({/*0, 0*/ 1,  2,  3,  4,  5,
204                                 /*0, 1*/ 0,  0,  0,  0,  0,
205                                 /*0, 2*/ 11, 12, 13, 14, 15,
206                                 /*1, 0*/ 6,  7,  8,  9,  10,
207                                 /*1, 1*/ 0,  0,  0,  0,  0,
208                                 /*1, 2*/ 16, 17, 18, 19, 20}));
209 }
210 
TEST(ScatterNdOpTest,OverlappedIndicesSummed)211 TEST(ScatterNdOpTest, OverlappedIndicesSummed) {
212   ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
213                      {TensorType_INT32, {3}});
214   m.SetIndices<int32_t>({/*0*/ 1, 0, /*1*/ 0, 2, /*2*/ 0, 2, /*3*/ 1, 0});
215   m.SetUpdates<float>(
216       {/*0*/ 1,  2,  3,  4,  5,  /*1*/ 6,  7,  8,  9,  10,
217        /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
218   m.SetShape<int32_t>({2, 3, 5});
219   m.Invoke();
220 
221   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
222   EXPECT_THAT(m.GetOutput<float>(),
223               ElementsAreArray({/*0, 0*/ 0,  0,  0,  0,  0,
224                                 /*0, 1*/ 0,  0,  0,  0,  0,
225                                 /*0, 2*/ 17, 19, 21, 23, 25,
226                                 /*1, 0*/ 17, 19, 21, 23, 25,
227                                 /*1, 1*/ 0,  0,  0,  0,  0,
228                                 /*1, 2*/ 0,  0,  0,  0,  0}));
229 }
230 
TEST(ScatterNdOpTest,Int32IndicesUint8Updates)231 TEST(ScatterNdOpTest, Int32IndicesUint8Updates) {
232   ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_UINT8, {4, 5}},
233                      {TensorType_INT32, {3}});
234   m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
235   m.SetUpdates<uint8_t>(
236       {/*0*/ 1,  2,  3,  4,  5,  /*1*/ 6,  7,  8,  9,  10,
237        /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
238   m.SetShape<int32_t>({2, 3, 5});
239   m.Invoke();
240 
241   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
242   EXPECT_THAT(m.GetOutput<uint8_t>(),
243               ElementsAreArray({/*0, 0*/ 1,  2,  3,  4,  5,
244                                 /*0, 1*/ 0,  0,  0,  0,  0,
245                                 /*0, 2*/ 11, 12, 13, 14, 15,
246                                 /*1, 0*/ 6,  7,  8,  9,  10,
247                                 /*1, 1*/ 0,  0,  0,  0,  0,
248                                 /*1, 2*/ 16, 17, 18, 19, 20}));
249 }
250 
TEST(ScatterNdOpTest,Int32IndicesInt8Updates)251 TEST(ScatterNdOpTest, Int32IndicesInt8Updates) {
252   ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT8, {4, 5}},
253                      {TensorType_INT32, {3}});
254   m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
255   m.SetUpdates<int8_t>(
256       {/*0*/ 1,  2,  3,  4,  5,  /*1*/ 6,  7,  8,  9,  10,
257        /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
258   m.SetShape<int32_t>({2, 3, 5});
259   m.Invoke();
260 
261   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
262   EXPECT_THAT(m.GetOutput<int8_t>(),
263               ElementsAreArray({/*0, 0*/ 1,  2,  3,  4,  5,
264                                 /*0, 1*/ 0,  0,  0,  0,  0,
265                                 /*0, 2*/ 11, 12, 13, 14, 15,
266                                 /*1, 0*/ 6,  7,  8,  9,  10,
267                                 /*1, 1*/ 0,  0,  0,  0,  0,
268                                 /*1, 2*/ 16, 17, 18, 19, 20}));
269 }
270 
TEST(ScatterNdOpTest,Int32IndicesInt32Updates)271 TEST(ScatterNdOpTest, Int32IndicesInt32Updates) {
272   ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT32, {4, 5}},
273                      {TensorType_INT32, {3}});
274   m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
275   m.SetUpdates<int32_t>(
276       {/*0*/ 1,  2,  3,  4,  5,  /*1*/ 6,  7,  8,  9,  10,
277        /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
278   m.SetShape<int32_t>({2, 3, 5});
279   m.Invoke();
280 
281   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
282   EXPECT_THAT(m.GetOutput<int32_t>(),
283               ElementsAreArray({/*0, 0*/ 1,  2,  3,  4,  5,
284                                 /*0, 1*/ 0,  0,  0,  0,  0,
285                                 /*0, 2*/ 11, 12, 13, 14, 15,
286                                 /*1, 0*/ 6,  7,  8,  9,  10,
287                                 /*1, 1*/ 0,  0,  0,  0,  0,
288                                 /*1, 2*/ 16, 17, 18, 19, 20}));
289 }
290 
TEST(ScatterNdOpTest,Int32IndicesInt64Updates)291 TEST(ScatterNdOpTest, Int32IndicesInt64Updates) {
292   ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
293                      {TensorType_INT32, {3}});
294   m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
295   m.SetUpdates<int64_t>(
296       {/*0*/ 1,  2,  3,  4,  5,  /*1*/ 6,  7,  8,  9,  10,
297        /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
298   m.SetShape<int32_t>({2, 3, 5});
299   m.Invoke();
300 
301   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
302   EXPECT_THAT(m.GetOutput<int64_t>(),
303               ElementsAreArray({/*0, 0*/ 1,  2,  3,  4,  5,
304                                 /*0, 1*/ 0,  0,  0,  0,  0,
305                                 /*0, 2*/ 11, 12, 13, 14, 15,
306                                 /*1, 0*/ 6,  7,  8,  9,  10,
307                                 /*1, 1*/ 0,  0,  0,  0,  0,
308                                 /*1, 2*/ 16, 17, 18, 19, 20}));
309 }
310 
TEST(ScatterNdOpTest,DynamicShape)311 TEST(ScatterNdOpTest, DynamicShape) {
312   ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
313                      {TensorType_INT32, {3}});
314   m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
315   m.SetUpdates<int64_t>(
316       {/*0*/ 1,  2,  3,  4,  5,  /*1*/ 6,  7,  8,  9,  10,
317        /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
318   m.SetShape<int32_t>({2, 3, 5});
319   m.Invoke();
320 
321   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
322   EXPECT_THAT(m.GetOutput<int64_t>(),
323               ElementsAreArray({/*0, 0*/ 1,  2,  3,  4,  5,
324                                 /*0, 1*/ 0,  0,  0,  0,  0,
325                                 /*0, 2*/ 11, 12, 13, 14, 15,
326                                 /*1, 0*/ 6,  7,  8,  9,  10,
327                                 /*1, 1*/ 0,  0,  0,  0,  0,
328                                 /*1, 2*/ 16, 17, 18, 19, 20}));
329 
330   m.SetIndices<int32_t>({/*0*/ 2, 3, /*1*/ 1, 0, /*2*/ 2, 0, /*3*/ 1, 2});
331   m.SetShape<int32_t>({3, 4, 5});
332   m.Invoke();
333 
334   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 4, 5}));
335   EXPECT_THAT(m.GetOutput<int64_t>(),
336               ElementsAreArray({/*0, 0*/ 0,  0,  0,  0,  0,
337                                 /*0, 1*/ 0,  0,  0,  0,  0,
338                                 /*0, 2*/ 0,  0,  0,  0,  0,
339                                 /*0, 3*/ 0,  0,  0,  0,  0,
340                                 /*1, 0*/ 6,  7,  8,  9,  10,
341                                 /*1, 1*/ 0,  0,  0,  0,  0,
342                                 /*1, 2*/ 16, 17, 18, 19, 20,
343                                 /*1, 3*/ 0,  0,  0,  0,  0,
344                                 /*2, 0*/ 11, 12, 13, 14, 15,
345                                 /*2, 1*/ 0,  0,  0,  0,  0,
346                                 /*2, 2*/ 0,  0,  0,  0,  0,
347                                 /*2, 3*/ 1,  2,  3,  4,  5}));
348 }
349 
350 }  // namespace
351 }  // namespace tflite
352