1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <vector>
19
20 #include <gtest/gtest.h>
21 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
22 #include "tensorflow/lite/kernels/test_util.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24
25 namespace tflite {
26 namespace {
27
28 using ::testing::ElementsAreArray;
29
30 class ScatterNdOpModel : public SingleOpModel {
31 public:
ScatterNdOpModel(const TensorData & indices,const TensorData & updates,const TensorData & shape)32 ScatterNdOpModel(const TensorData& indices, const TensorData& updates,
33 const TensorData& shape) {
34 indices_ = AddInput(indices);
35 updates_ = AddInput(updates);
36 shape_ = AddInput(shape);
37 output_ = AddOutput(updates.type);
38 SetBuiltinOp(BuiltinOperator_SCATTER_ND, BuiltinOptions_ScatterNdOptions,
39 CreateScatterNdOptions(builder_).Union());
40 BuildInterpreter(
41 {GetShape(indices_), GetShape(updates_), GetShape(shape_)});
42 }
43
44 template <typename T>
SetIndices(std::initializer_list<T> data)45 void SetIndices(std::initializer_list<T> data) {
46 PopulateTensor<T>(indices_, data);
47 }
48
49 template <typename T>
SetUpdates(std::initializer_list<T> data)50 void SetUpdates(std::initializer_list<T> data) {
51 PopulateTensor<T>(updates_, data);
52 }
53
54 template <typename T>
SetShape(std::initializer_list<T> data)55 void SetShape(std::initializer_list<T> data) {
56 PopulateTensor<T>(shape_, data);
57 }
58
59 template <typename T>
GetOutput()60 std::vector<T> GetOutput() {
61 return ExtractVector<T>(output_);
62 }
63
GetOutputShape()64 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
65
66 protected:
67 int indices_;
68 int updates_;
69 int shape_;
70 int output_;
71 };
72
TEST(ScatterNdOpTest,ScatterElementIntoVector)73 TEST(ScatterNdOpTest, ScatterElementIntoVector) {
74 ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4}},
75 {TensorType_INT32, {1}});
76 m.SetIndices<int32_t>({4, 3, 1, 7});
77 m.SetUpdates<float>({9, 10, 11, 12});
78 m.SetShape<int32_t>({8});
79 m.Invoke();
80
81 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8}));
82 EXPECT_THAT(m.GetOutput<float>(),
83 ElementsAreArray({0, 11, 0, 10, 9, 0, 0, 12}));
84 }
85
TEST(ScatterNdOpTest,ScatterMatrixIntoRank3Tensor)86 TEST(ScatterNdOpTest, ScatterMatrixIntoRank3Tensor) {
87 ScatterNdOpModel m({TensorType_INT32, {2, 1}},
88 {TensorType_FLOAT32, {2, 4, 4}}, {TensorType_INT32, {3}});
89 m.SetIndices<int32_t>({0, 2});
90 m.SetUpdates<float>({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
91 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8});
92 m.SetShape<int32_t>({4, 4, 4});
93 m.Invoke();
94
95 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 4, 4}));
96 EXPECT_THAT(
97 m.GetOutput<float>(),
98 ElementsAreArray({5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
99 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
100 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8,
101 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
102 }
103
TEST(ScatterNdOpTest,ScatterVectorIntoMatrix)104 TEST(ScatterNdOpTest, ScatterVectorIntoMatrix) {
105 ScatterNdOpModel m({TensorType_INT32, {4, 1}}, {TensorType_FLOAT32, {4, 4}},
106 {TensorType_INT32, {2}});
107 m.SetIndices<int32_t>({/*0*/ 9, /*1*/ 8, /*2*/ 0, /*3*/ 1});
108 m.SetUpdates<float>({/*0*/ 1, 2, 3, 4,
109 /*1*/ 5, 6, 7, 8,
110 /*2*/ 9, 10, 11, 12,
111 /*3*/ 13, 14, 15, 16});
112 m.SetShape<int32_t>({10, 4});
113 m.Invoke();
114
115 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({10, 4}));
116 EXPECT_THAT(m.GetOutput<float>(),
117 ElementsAreArray({/*0*/ 9, 10, 11, 12,
118 /*1*/ 13, 14, 15, 16,
119 /*2*/ 0, 0, 0, 0,
120 /*3*/ 0, 0, 0, 0,
121 /*4*/ 0, 0, 0, 0,
122 /*5*/ 0, 0, 0, 0,
123 /*6*/ 0, 0, 0, 0,
124 /*7*/ 0, 0, 0, 0,
125 /*8*/ 5, 6, 7, 8,
126 /*9*/ 1, 2, 3, 4}));
127 }
128
TEST(ScatterNdOpTest,ScatterMatricesIntoRank4Tensor)129 TEST(ScatterNdOpTest, ScatterMatricesIntoRank4Tensor) {
130 ScatterNdOpModel m({TensorType_INT32, {2, 2, 2}},
131 {TensorType_FLOAT32, {2, 2, 2, 2}},
132 {TensorType_INT32, {4}});
133 m.SetIndices<int32_t>(
134 {/*0,0*/ 1, 1, /*0,1*/ 0, 1, /*1,0*/ 0, 0, /*1,1*/ 1, 0});
135 m.SetUpdates<float>({/*0,0*/ 1, 2, 3, 4, /*0,1*/ 5, 6, 7, 8,
136 /*1,0*/ 9, 10, 11, 12, /*1,1*/ 13, 14, 15, 16});
137 m.SetShape<int32_t>({2, 2, 2, 2});
138 m.Invoke();
139
140 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2}));
141 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({/*0, 0*/ 9, 10, 11, 12,
142 /*0, 1*/ 5, 6, 7, 8,
143 /*1, 0*/ 13, 14, 15, 16,
144 /*1, 1*/ 1, 2, 3, 4}));
145 }
146
TEST(ScatterNdOpTest,ScatterVectorIntoRank4Tensor)147 TEST(ScatterNdOpTest, ScatterVectorIntoRank4Tensor) {
148 ScatterNdOpModel m({TensorType_INT32, {2, 2, 3}},
149 {TensorType_FLOAT32, {2, 2, 5}}, {TensorType_INT32, {4}});
150 m.SetIndices<int32_t>(
151 {/*0,0*/ 2, 2, 2, /*0,1*/ 1, 0, 1, /*1,0*/ 0, 2, 0, /*1,0*/ 2, 2, 0});
152 m.SetUpdates<float>(
153 {/*0,0*/ 1, 2, 3, 4, 5, /*0,1*/ 6, 7, 8, 9, 10,
154 /*1,0*/ 11, 12, 13, 14, 15, /*1,1*/ 16, 17, 18, 19, 20});
155 m.SetShape<int32_t>({3, 3, 3, 5});
156 m.Invoke();
157
158 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3, 5}));
159 EXPECT_THAT(m.GetOutput<float>(),
160 ElementsAreArray({
161 /*0, 0, 0*/ 0, 0, 0, 0, 0,
162 /*0, 0, 1*/ 0, 0, 0, 0, 0,
163 /*0, 0, 2*/ 0, 0, 0, 0, 0,
164 /*0, 1, 0*/ 0, 0, 0, 0, 0,
165 /*0, 1, 1*/ 0, 0, 0, 0, 0,
166 /*0, 1, 2*/ 0, 0, 0, 0, 0,
167 /*0, 2, 0*/ 11, 12, 13, 14, 15,
168 /*0, 2, 1*/ 0, 0, 0, 0, 0,
169 /*0, 2, 2*/ 0, 0, 0, 0, 0,
170 /*1, 0, 0*/ 0, 0, 0, 0, 0,
171 /*1, 0, 1*/ 6, 7, 8, 9, 10,
172 /*1, 0, 2*/ 0, 0, 0, 0, 0,
173 /*1, 1, 0*/ 0, 0, 0, 0, 0,
174 /*1, 1, 1*/ 0, 0, 0, 0, 0,
175 /*1, 1, 2*/ 0, 0, 0, 0, 0,
176 /*1, 2, 0*/ 0, 0, 0, 0, 0,
177 /*1, 2, 1*/ 0, 0, 0, 0, 0,
178 /*1, 2, 2*/ 0, 0, 0, 0, 0,
179 /*2, 0, 0*/ 0, 0, 0, 0, 0,
180 /*2, 0, 1*/ 0, 0, 0, 0, 0,
181 /*2, 0, 2*/ 0, 0, 0, 0, 0,
182 /*2, 1, 0*/ 0, 0, 0, 0, 0,
183 /*2, 1, 1*/ 0, 0, 0, 0, 0,
184 /*2, 1, 2*/ 0, 0, 0, 0, 0,
185 /*2, 2, 0*/ 16, 17, 18, 19, 20,
186 /*2, 2, 1*/ 0, 0, 0, 0, 0,
187 /*2, 2, 2*/ 1, 2, 3, 4, 5,
188 }));
189 }
190
TEST(ScatterNdOpTest,ScatterVectorIntoRank3Tensor)191 TEST(ScatterNdOpTest, ScatterVectorIntoRank3Tensor) {
192 ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
193 {TensorType_INT32, {3}});
194 m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
195 m.SetUpdates<float>(
196 {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
197 /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
198 m.SetShape<int32_t>({2, 3, 5});
199 m.Invoke();
200
201 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
202 EXPECT_THAT(m.GetOutput<float>(),
203 ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
204 /*0, 1*/ 0, 0, 0, 0, 0,
205 /*0, 2*/ 11, 12, 13, 14, 15,
206 /*1, 0*/ 6, 7, 8, 9, 10,
207 /*1, 1*/ 0, 0, 0, 0, 0,
208 /*1, 2*/ 16, 17, 18, 19, 20}));
209 }
210
TEST(ScatterNdOpTest,OverlappedIndicesSummed)211 TEST(ScatterNdOpTest, OverlappedIndicesSummed) {
212 ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_FLOAT32, {4, 5}},
213 {TensorType_INT32, {3}});
214 m.SetIndices<int32_t>({/*0*/ 1, 0, /*1*/ 0, 2, /*2*/ 0, 2, /*3*/ 1, 0});
215 m.SetUpdates<float>(
216 {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
217 /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
218 m.SetShape<int32_t>({2, 3, 5});
219 m.Invoke();
220
221 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
222 EXPECT_THAT(m.GetOutput<float>(),
223 ElementsAreArray({/*0, 0*/ 0, 0, 0, 0, 0,
224 /*0, 1*/ 0, 0, 0, 0, 0,
225 /*0, 2*/ 17, 19, 21, 23, 25,
226 /*1, 0*/ 17, 19, 21, 23, 25,
227 /*1, 1*/ 0, 0, 0, 0, 0,
228 /*1, 2*/ 0, 0, 0, 0, 0}));
229 }
230
TEST(ScatterNdOpTest,Int32IndicesUint8Updates)231 TEST(ScatterNdOpTest, Int32IndicesUint8Updates) {
232 ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_UINT8, {4, 5}},
233 {TensorType_INT32, {3}});
234 m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
235 m.SetUpdates<uint8_t>(
236 {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
237 /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
238 m.SetShape<int32_t>({2, 3, 5});
239 m.Invoke();
240
241 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
242 EXPECT_THAT(m.GetOutput<uint8_t>(),
243 ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
244 /*0, 1*/ 0, 0, 0, 0, 0,
245 /*0, 2*/ 11, 12, 13, 14, 15,
246 /*1, 0*/ 6, 7, 8, 9, 10,
247 /*1, 1*/ 0, 0, 0, 0, 0,
248 /*1, 2*/ 16, 17, 18, 19, 20}));
249 }
250
TEST(ScatterNdOpTest,Int32IndicesInt8Updates)251 TEST(ScatterNdOpTest, Int32IndicesInt8Updates) {
252 ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT8, {4, 5}},
253 {TensorType_INT32, {3}});
254 m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
255 m.SetUpdates<int8_t>(
256 {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
257 /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
258 m.SetShape<int32_t>({2, 3, 5});
259 m.Invoke();
260
261 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
262 EXPECT_THAT(m.GetOutput<int8_t>(),
263 ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
264 /*0, 1*/ 0, 0, 0, 0, 0,
265 /*0, 2*/ 11, 12, 13, 14, 15,
266 /*1, 0*/ 6, 7, 8, 9, 10,
267 /*1, 1*/ 0, 0, 0, 0, 0,
268 /*1, 2*/ 16, 17, 18, 19, 20}));
269 }
270
TEST(ScatterNdOpTest,Int32IndicesInt32Updates)271 TEST(ScatterNdOpTest, Int32IndicesInt32Updates) {
272 ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT32, {4, 5}},
273 {TensorType_INT32, {3}});
274 m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
275 m.SetUpdates<int32_t>(
276 {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
277 /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
278 m.SetShape<int32_t>({2, 3, 5});
279 m.Invoke();
280
281 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
282 EXPECT_THAT(m.GetOutput<int32_t>(),
283 ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
284 /*0, 1*/ 0, 0, 0, 0, 0,
285 /*0, 2*/ 11, 12, 13, 14, 15,
286 /*1, 0*/ 6, 7, 8, 9, 10,
287 /*1, 1*/ 0, 0, 0, 0, 0,
288 /*1, 2*/ 16, 17, 18, 19, 20}));
289 }
290
TEST(ScatterNdOpTest,Int32IndicesInt64Updates)291 TEST(ScatterNdOpTest, Int32IndicesInt64Updates) {
292 ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
293 {TensorType_INT32, {3}});
294 m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
295 m.SetUpdates<int64_t>(
296 {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
297 /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
298 m.SetShape<int32_t>({2, 3, 5});
299 m.Invoke();
300
301 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
302 EXPECT_THAT(m.GetOutput<int64_t>(),
303 ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
304 /*0, 1*/ 0, 0, 0, 0, 0,
305 /*0, 2*/ 11, 12, 13, 14, 15,
306 /*1, 0*/ 6, 7, 8, 9, 10,
307 /*1, 1*/ 0, 0, 0, 0, 0,
308 /*1, 2*/ 16, 17, 18, 19, 20}));
309 }
310
TEST(ScatterNdOpTest,DynamicShape)311 TEST(ScatterNdOpTest, DynamicShape) {
312 ScatterNdOpModel m({TensorType_INT32, {4, 2}}, {TensorType_INT64, {4, 5}},
313 {TensorType_INT32, {3}});
314 m.SetIndices<int32_t>({/*0*/ 0, 0, /*1*/ 1, 0, /*2*/ 0, 2, /*3*/ 1, 2});
315 m.SetUpdates<int64_t>(
316 {/*0*/ 1, 2, 3, 4, 5, /*1*/ 6, 7, 8, 9, 10,
317 /*2*/ 11, 12, 13, 14, 15, /*3*/ 16, 17, 18, 19, 20});
318 m.SetShape<int32_t>({2, 3, 5});
319 m.Invoke();
320
321 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 5}));
322 EXPECT_THAT(m.GetOutput<int64_t>(),
323 ElementsAreArray({/*0, 0*/ 1, 2, 3, 4, 5,
324 /*0, 1*/ 0, 0, 0, 0, 0,
325 /*0, 2*/ 11, 12, 13, 14, 15,
326 /*1, 0*/ 6, 7, 8, 9, 10,
327 /*1, 1*/ 0, 0, 0, 0, 0,
328 /*1, 2*/ 16, 17, 18, 19, 20}));
329
330 m.SetIndices<int32_t>({/*0*/ 2, 3, /*1*/ 1, 0, /*2*/ 2, 0, /*3*/ 1, 2});
331 m.SetShape<int32_t>({3, 4, 5});
332 m.Invoke();
333
334 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 4, 5}));
335 EXPECT_THAT(m.GetOutput<int64_t>(),
336 ElementsAreArray({/*0, 0*/ 0, 0, 0, 0, 0,
337 /*0, 1*/ 0, 0, 0, 0, 0,
338 /*0, 2*/ 0, 0, 0, 0, 0,
339 /*0, 3*/ 0, 0, 0, 0, 0,
340 /*1, 0*/ 6, 7, 8, 9, 10,
341 /*1, 1*/ 0, 0, 0, 0, 0,
342 /*1, 2*/ 16, 17, 18, 19, 20,
343 /*1, 3*/ 0, 0, 0, 0, 0,
344 /*2, 0*/ 11, 12, 13, 14, 15,
345 /*2, 1*/ 0, 0, 0, 0, 0,
346 /*2, 2*/ 0, 0, 0, 0, 0,
347 /*2, 3*/ 1, 2, 3, 4, 5}));
348 }
349
350 } // namespace
351 } // namespace tflite
352