1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <string>
19 #include <vector>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h" // from @flatbuffers
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26 #include "tensorflow/lite/string_type.h"
27
28 namespace tflite {
29 namespace {
30
31 using ::testing::ElementsAreArray;
32
33 class GatherOpModel : public SingleOpModel {
34 public:
GatherOpModel(const TensorData & input,const TensorData & positions,int axis=0,int batch_dims=0)35 GatherOpModel(const TensorData& input, const TensorData& positions,
36 int axis = 0, int batch_dims = 0) {
37 input_ = AddInput(input);
38 positions_ = AddInput(positions);
39 output_ = AddOutput(input.type);
40 SetBuiltinOp(BuiltinOperator_GATHER, BuiltinOptions_GatherOptions,
41 CreateGatherOptions(builder_, axis, batch_dims).Union());
42 BuildInterpreter({GetShape(input_), GetShape(positions_)});
43 }
44
45 template <typename T>
SetInput(std::initializer_list<T> data)46 void SetInput(std::initializer_list<T> data) {
47 PopulateTensor<T>(input_, data);
48 }
49
SetStringInput(std::initializer_list<string> data)50 void SetStringInput(std::initializer_list<string> data) {
51 PopulateStringTensor(input_, data);
52 }
53
54 template <typename T>
SetPositions(std::initializer_list<T> data)55 void SetPositions(std::initializer_list<T> data) {
56 PopulateTensor<T>(positions_, data);
57 }
58
59 template <typename T>
GetOutput()60 std::vector<T> GetOutput() {
61 return ExtractVector<T>(output_);
62 }
63
GetStringOutput()64 std::vector<string> GetStringOutput() {
65 return ExtractVector<string>(output_);
66 }
GetOutputShape()67 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
68
69 protected:
70 int input_;
71 int positions_;
72 int output_;
73 };
74
TEST(GatherOpTest,Shuffle)75 TEST(GatherOpTest, Shuffle) {
76 GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2}});
77 m.SetInput<float>({-2.0, 0.2, 0.7, 0.8});
78 m.SetPositions<int32_t>({1, 0});
79 ASSERT_EQ(m.Invoke(), kTfLiteOk);
80 EXPECT_THAT(m.GetOutput<float>(),
81 ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2})));
82 }
83
TEST(GatherOpTest,Test0DIndex)84 TEST(GatherOpTest, Test0DIndex) {
85 GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {}});
86 m.SetInput<float>({-2.0, 0.2, 0.7, 0.8});
87 m.SetPositions<int32_t>({1});
88 ASSERT_EQ(m.Invoke(), kTfLiteOk);
89 EXPECT_THAT(m.GetOutput<float>(),
90 ElementsAreArray(ArrayFloatNear({0.7, 0.8})));
91 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
92 }
93
TEST(GatherOpTest,Test0DIndexWith0DResult)94 TEST(GatherOpTest, Test0DIndexWith0DResult) {
95 // 0D tensor is special case in current TFLite. Test it once to make sure
96 // existing workarounds are fine with it.
97 GatherOpModel m({TensorType_FLOAT32, {3}}, {TensorType_INT32, {}});
98 m.SetInput<float>({1.0, 2.0, 3.0});
99 m.SetPositions<int32_t>({1});
100 ASSERT_EQ(m.Invoke(), kTfLiteOk);
101 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({2.0})));
102 EXPECT_TRUE(m.GetOutputShape().empty());
103 }
104
TEST(GatherOpTest,Test1DInput1DIndex)105 TEST(GatherOpTest, Test1DInput1DIndex) {
106 GatherOpModel m({TensorType_FLOAT32, {3}}, {TensorType_INT32, {1}});
107 m.SetInput<float>({1.0, 3.0, 5.0});
108 m.SetPositions<int32_t>({1});
109 ASSERT_EQ(m.Invoke(), kTfLiteOk);
110 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3.0})));
111 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
112 }
113
TEST(GatherOpTest,Test2DIndexWith2DResult)114 TEST(GatherOpTest, Test2DIndexWith2DResult) {
115 GatherOpModel m({TensorType_FLOAT32, {3}}, {TensorType_INT32, {1, 2}});
116 m.SetInput<float>({1.0, 2.0, 3.0});
117 m.SetPositions<int32_t>({1, 0});
118 ASSERT_EQ(m.Invoke(), kTfLiteOk);
119 EXPECT_THAT(m.GetOutput<float>(),
120 ElementsAreArray(ArrayFloatNear({2.0, 1.0})));
121 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
122 }
123
TEST(FloatGatherOpTest,Duplicate)124 TEST(FloatGatherOpTest, Duplicate) {
125 GatherOpModel m({TensorType_FLOAT32, {1, 2, 2}}, {TensorType_INT32, {2}});
126 m.SetInput<float>({-2.0, 0.2, 0.7, 0.8});
127 m.SetPositions<int32_t>({0, 0});
128 ASSERT_EQ(m.Invoke(), kTfLiteOk);
129 EXPECT_THAT(
130 m.GetOutput<float>(),
131 ElementsAreArray(ArrayFloatNear({-2, 0.2, 0.7, 0.8, -2, 0.2, 0.7, 0.8})));
132 }
133
TEST(FloatGatherOpTest,Slice)134 TEST(FloatGatherOpTest, Slice) {
135 GatherOpModel m({TensorType_FLOAT32, {4, 1}}, {TensorType_INT32, {2}});
136 m.SetInput<float>({-2.0, 0.2, 0.7, 0.8});
137 m.SetPositions<int32_t>({1, 3});
138 ASSERT_EQ(m.Invoke(), kTfLiteOk);
139 EXPECT_THAT(m.GetOutput<float>(),
140 ElementsAreArray(ArrayFloatNear({0.2, 0.8})));
141 }
142
TEST(FloatGatherOpTest,Axis1)143 TEST(FloatGatherOpTest, Axis1) {
144 const int axis = 1;
145 GatherOpModel m({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_INT32, {2}},
146 axis);
147 m.SetInput<float>({1, 2, 3, 4, 5, 6});
148 m.SetPositions<int32_t>({1, 0});
149 ASSERT_EQ(m.Invoke(), kTfLiteOk);
150 EXPECT_THAT(m.GetOutput<float>(),
151 ElementsAreArray(ArrayFloatNear({4, 5, 6, 1, 2, 3})));
152 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3}));
153 }
154
TEST(FloatGatherOpTest,Axis10DIndex)155 TEST(FloatGatherOpTest, Axis10DIndex) {
156 const int axis = 1;
157 GatherOpModel m({TensorType_FLOAT32, {1, 3, 2}}, {TensorType_INT32, {}},
158 axis);
159 m.SetInput<float>({1, 2, 3, 4, 5, 6});
160 m.SetPositions<int32_t>({1});
161 ASSERT_EQ(m.Invoke(), kTfLiteOk);
162 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3, 4})));
163 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
164 }
165
TEST(FloatGatherOpTest,Axis1Slice)166 TEST(FloatGatherOpTest, Axis1Slice) {
167 const int axis = 1;
168 GatherOpModel m({TensorType_FLOAT32, {1, 4, 2}}, {TensorType_INT32, {2}},
169 axis);
170 m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8});
171 m.SetPositions<int32_t>({3, 1});
172 ASSERT_EQ(m.Invoke(), kTfLiteOk);
173 EXPECT_THAT(m.GetOutput<float>(),
174 ElementsAreArray(ArrayFloatNear({7, 8, 3, 4})));
175 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2}));
176 }
177
TEST(FloatGatherOpTest,LastAxis)178 TEST(FloatGatherOpTest, LastAxis) {
179 const int axis = -1;
180 GatherOpModel m({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_INT32, {2}},
181 axis);
182 m.SetInput<float>({1, 2, 3, 4, 5, 6});
183 m.SetPositions<int32_t>({2, 0});
184 ASSERT_EQ(m.Invoke(), kTfLiteOk);
185 EXPECT_THAT(m.GetOutput<float>(),
186 ElementsAreArray(ArrayFloatNear({3, 1, 6, 4})));
187 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2}));
188 }
189
TEST(FloatGatherOpTest,LastAxis0DIndex)190 TEST(FloatGatherOpTest, LastAxis0DIndex) {
191 const int axis = -1;
192 GatherOpModel m({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_INT32, {}},
193 axis);
194 m.SetInput<float>({1, 2, 3, 4, 5, 6});
195 m.SetPositions<int32_t>({2});
196 ASSERT_EQ(m.Invoke(), kTfLiteOk);
197 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3, 6})));
198 EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
199 }
200
TEST(TypesGatherOpTest,Float32Int32)201 TEST(TypesGatherOpTest, Float32Int32) {
202 GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2}});
203 m.SetInput<float>({13.3, -13.4, -1.4, 1.5});
204 m.SetPositions<int32_t>({1, 0});
205 ASSERT_EQ(m.Invoke(), kTfLiteOk);
206
207 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({-1.4, 1.5, 13.3, -13.4}));
208 }
209
TEST(TypesGatherOpTest,Float32Int64)210 TEST(TypesGatherOpTest, Float32Int64) {
211 GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT64, {2}});
212 m.SetInput<float>({13.3, -13.4, -1.4, 1.5});
213 m.SetPositions<int64_t>({1LL, 0LL});
214 ASSERT_EQ(m.Invoke(), kTfLiteOk);
215
216 EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({-1.4, 1.5, 13.3, -13.4}));
217 }
218
TEST(TypesGatherOpTest,Int32Int32)219 TEST(TypesGatherOpTest, Int32Int32) {
220 GatherOpModel m({TensorType_INT32, {2, 2}}, {TensorType_INT32, {2}});
221 m.SetInput<int32_t>({-1330, 1340, 140, -150});
222 m.SetPositions<int32_t>({1, 0});
223 ASSERT_EQ(m.Invoke(), kTfLiteOk);
224
225 EXPECT_THAT(m.GetOutput<int32_t>(),
226 ElementsAreArray({140, -150, -1330, 1340}));
227 }
228
TEST(TypesGatherOpTest,Int32Int64)229 TEST(TypesGatherOpTest, Int32Int64) {
230 GatherOpModel m({TensorType_INT32, {2, 2}}, {TensorType_INT64, {2}});
231 m.SetInput<int32_t>({-1330, 1340, 140, -150});
232 m.SetPositions<int64_t>({1LL, 0LL});
233 ASSERT_EQ(m.Invoke(), kTfLiteOk);
234
235 EXPECT_THAT(m.GetOutput<int32_t>(),
236 ElementsAreArray({140, -150, -1330, 1340}));
237 }
238
TEST(TypesGatherOpTest,Uint8Int32)239 TEST(TypesGatherOpTest, Uint8Int32) {
240 GatherOpModel m({TensorType_UINT8, {2, 2}}, {TensorType_INT32, {2}});
241 m.SetInput<uint8_t>({133, 134, 14, 15});
242 m.SetPositions<int32_t>({1, 0});
243 ASSERT_EQ(m.Invoke(), kTfLiteOk);
244
245 EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({14, 15, 133, 134}));
246 }
247
TEST(TypesGatherOpTest,Uint8Int64)248 TEST(TypesGatherOpTest, Uint8Int64) {
249 GatherOpModel m({TensorType_UINT8, {2, 2}}, {TensorType_INT64, {2}});
250 m.SetInput<uint8_t>({133, 134, 14, 15});
251 m.SetPositions<int64_t>({1LL, 0LL});
252 ASSERT_EQ(m.Invoke(), kTfLiteOk);
253
254 EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({14, 15, 133, 134}));
255 }
256
TEST(TypesGatherOpTest,Int8Int32)257 TEST(TypesGatherOpTest, Int8Int32) {
258 GatherOpModel m({TensorType_INT8, {2, 2}}, {TensorType_INT32, {2}});
259 m.SetInput<int8_t>({-13, -120, 14, 15});
260 m.SetPositions<int32_t>({1, 0});
261 ASSERT_EQ(m.Invoke(), kTfLiteOk);
262
263 EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({14, 15, -13, -120}));
264 }
265
TEST(TypesGatherOpTest,Int8Int64)266 TEST(TypesGatherOpTest, Int8Int64) {
267 GatherOpModel m({TensorType_INT8, {2, 2}}, {TensorType_INT64, {2}});
268 m.SetInput<int8_t>({-13, -120, 14, 15});
269 m.SetPositions<int64_t>({1LL, 0LL});
270 ASSERT_EQ(m.Invoke(), kTfLiteOk);
271
272 EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({14, 15, -13, -120}));
273 }
274
TEST(TypesGatherOpTest,Int16Int32)275 TEST(TypesGatherOpTest, Int16Int32) {
276 GatherOpModel m({TensorType_INT16, {2, 2}}, {TensorType_INT32, {2}});
277 m.SetInput<int16_t>({-13, -32000, 0, 32500});
278 m.SetPositions<int32_t>({1, 0});
279 ASSERT_EQ(m.Invoke(), kTfLiteOk);
280
281 EXPECT_THAT(m.GetOutput<int16_t>(),
282 ElementsAreArray({0, 32500, -13, -32000}));
283 }
284
TEST(TypesGatherOpTest,Int16Int64)285 TEST(TypesGatherOpTest, Int16Int64) {
286 GatherOpModel m({TensorType_INT16, {2, 2}}, {TensorType_INT64, {2}});
287 m.SetInput<int16_t>({-13, -32000, 0, 32500});
288 m.SetPositions<int64_t>({1LL, 0LL});
289 ASSERT_EQ(m.Invoke(), kTfLiteOk);
290
291 EXPECT_THAT(m.GetOutput<int16_t>(),
292 ElementsAreArray({0, 32500, -13, -32000}));
293 }
294
TEST(TypesGatherOpTest,Int64Int32)295 TEST(TypesGatherOpTest, Int64Int32) {
296 GatherOpModel m({TensorType_INT64, {2, 2}}, {TensorType_INT32, {2}});
297 m.SetInput<int64_t>({-(1LL << 34), 134LL, 14LL, 15LL});
298 m.SetPositions<int32_t>({1, 0});
299 ASSERT_EQ(m.Invoke(), kTfLiteOk);
300
301 EXPECT_THAT(m.GetOutput<int64_t>(),
302 ElementsAreArray({14LL, 15LL, -(1LL << 34), 134LL}));
303 }
304
TEST(TypesGatherOpTest,Int64Int64)305 TEST(TypesGatherOpTest, Int64Int64) {
306 GatherOpModel m({TensorType_INT64, {2, 2}}, {TensorType_INT64, {2}});
307 m.SetInput<int64_t>({-(1LL << 34), 134LL, 14LL, 15LL});
308 m.SetPositions<int64_t>({1LL, 0LL});
309 ASSERT_EQ(m.Invoke(), kTfLiteOk);
310
311 EXPECT_THAT(m.GetOutput<int64_t>(),
312 ElementsAreArray({14LL, 15LL, -(1LL << 34), 134LL}));
313 }
314
TEST(GatherOpTest,SimpleString)315 TEST(GatherOpTest, SimpleString) {
316 GatherOpModel m({TensorType_STRING, {3}}, {TensorType_INT32, {2}});
317 m.SetStringInput({"A", "B", "C"});
318 m.SetPositions<int32_t>({0, 2});
319 ASSERT_EQ(m.Invoke(), kTfLiteOk);
320 ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
321 EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"A", "C"}));
322 }
323
324 TEST(GatherOpTest, 2DIndexString) {
325 GatherOpModel m({TensorType_STRING, {3}}, {TensorType_INT32, {2, 3}});
326 m.SetStringInput({"A", "B", "C"});
327 m.SetPositions<int32_t>({0, 2, 1, 1, 0, 2});
328 ASSERT_EQ(m.Invoke(), kTfLiteOk);
329 ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
330 EXPECT_THAT(m.GetStringOutput(),
331 ElementsAreArray({"A", "C", "B", "B", "A", "C"}));
332 }
333
TEST(TypesGatherOpTest,BatchDims2)334 TEST(TypesGatherOpTest, BatchDims2) {
335 GatherOpModel m({TensorType_INT32, {2, 2, 3, 5}},
336 {TensorType_INT32, {2, 2, 2}}, /*axis=*/2, /*batch_dims=*/2);
337 m.SetInput<int32_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
338 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
339 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
340 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
341 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59});
342 m.SetPositions<int32_t>({1, 0, 0, 1, 1, 0, 0, 1});
343 ASSERT_EQ(m.Invoke(), kTfLiteOk);
344
345 ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 5}));
346 EXPECT_THAT(
347 m.GetOutput<int32_t>(),
348 ElementsAreArray({5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 15, 16, 17, 18,
349 19, 20, 21, 22, 23, 24, 35, 36, 37, 38, 39, 30, 31, 32,
350 33, 34, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}));
351 }
352
TEST(TypesGatherOpTest,BatchDims1)353 TEST(TypesGatherOpTest, BatchDims1) {
354 GatherOpModel m({TensorType_INT8, {2, 2, 3, 5}},
355 {TensorType_INT32, {2, 2, 2}}, /*axis=*/2, /*batch_dims=*/1);
356 m.SetInput<int8_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
357 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
358 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
359 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
360 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59});
361 m.SetPositions<int32_t>({1, 0, 0, 1, 1, 0, 0, 1});
362 ASSERT_EQ(m.Invoke(), kTfLiteOk);
363
364 ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2, 5}));
365 EXPECT_THAT(
366 m.GetOutput<int8_t>(),
367 ElementsAreArray({5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0, 1, 2, 3,
368 4, 5, 6, 7, 8, 9, 20, 21, 22, 23, 24, 15, 16, 17,
369 18, 19, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 35, 36,
370 37, 38, 39, 30, 31, 32, 33, 34, 30, 31, 32, 33, 34, 35,
371 36, 37, 38, 39, 50, 51, 52, 53, 54, 45, 46, 47, 48, 49,
372 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}));
373 }
374
TEST(TypesGatherOpTest,NegativeBatchDims)375 TEST(TypesGatherOpTest, NegativeBatchDims) {
376 GatherOpModel m({TensorType_INT8, {2, 2, 3, 5}},
377 {TensorType_INT32, {2, 2, 2}}, /*axis=*/2, /*batch_dims=*/-2);
378 m.SetInput<int8_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
379 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
380 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
381 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
382 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59});
383 m.SetPositions<int32_t>({1, 0, 0, 1, 1, 0, 0, 1});
384 ASSERT_EQ(m.Invoke(), kTfLiteOk);
385
386 ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2, 5}));
387 EXPECT_THAT(
388 m.GetOutput<int8_t>(),
389 ElementsAreArray({5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 0, 1, 2, 3,
390 4, 5, 6, 7, 8, 9, 20, 21, 22, 23, 24, 15, 16, 17,
391 18, 19, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 35, 36,
392 37, 38, 39, 30, 31, 32, 33, 34, 30, 31, 32, 33, 34, 35,
393 36, 37, 38, 39, 50, 51, 52, 53, 54, 45, 46, 47, 48, 49,
394 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}));
395 }
396
TEST(TypesGatherOpTest,BatchDimsEqualIndiceDims)397 TEST(TypesGatherOpTest, BatchDimsEqualIndiceDims) {
398 GatherOpModel m({TensorType_INT8, {2, 2, 2, 5}},
399 {TensorType_INT32, {2, 2, 2}}, /*axis=*/3, /*batch_dims=*/3);
400 m.SetInput<int8_t>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
401 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
402 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39});
403 m.SetPositions<int32_t>({1, 0, 0, 1, 1, 0, 0, 1});
404 ASSERT_EQ(m.Invoke(), kTfLiteOk);
405
406 ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
407 EXPECT_THAT(m.GetOutput<int8_t>(),
408 ElementsAreArray({1, 5, 10, 16, 21, 25, 30, 36}));
409 }
410
411 } // namespace
412 } // namespace tflite
413