• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <string>
19 #include <vector>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
24 #include "tensorflow/lite/kernels/test_util.h"
25 #include "tensorflow/lite/schema/schema_generated.h"
26 #include "tensorflow/lite/string_type.h"
27 
28 namespace tflite {
29 namespace {
30 
31 using ::testing::ElementsAreArray;
32 
33 class GatherOpModel : public SingleOpModel {
34  public:
GatherOpModel(const TensorData & input,const TensorData & positions,int axis=0,int batch_dims=0)35   GatherOpModel(const TensorData& input, const TensorData& positions,
36                 int axis = 0, int batch_dims = 0) {
37     input_ = AddInput(input);
38     positions_ = AddInput(positions);
39     output_ = AddOutput(input.type);
40     SetBuiltinOp(BuiltinOperator_GATHER, BuiltinOptions_GatherOptions,
41                  CreateGatherOptions(builder_, axis, batch_dims).Union());
42     BuildInterpreter({GetShape(input_), GetShape(positions_)});
43   }
44 
45   template <typename T>
SetInput(std::initializer_list<T> data)46   void SetInput(std::initializer_list<T> data) {
47     PopulateTensor<T>(input_, data);
48   }
49 
SetStringInput(std::initializer_list<string> data)50   void SetStringInput(std::initializer_list<string> data) {
51     PopulateStringTensor(input_, data);
52   }
53 
54   template <typename T>
SetPositions(std::initializer_list<T> data)55   void SetPositions(std::initializer_list<T> data) {
56     PopulateTensor<T>(positions_, data);
57   }
58 
59   template <typename T>
GetOutput()60   std::vector<T> GetOutput() {
61     return ExtractVector<T>(output_);
62   }
63 
GetStringOutput()64   std::vector<string> GetStringOutput() {
65     return ExtractVector<string>(output_);
66   }
GetOutputShape()67   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
68 
69  protected:
70   int input_;
71   int positions_;
72   int output_;
73 };
74 
TEST(GatherOpTest,Shuffle)75 TEST(GatherOpTest, Shuffle) {
76   GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2}});
77   m.SetInput<float>({-2.0, 0.2, 0.7, 0.8});
78   m.SetPositions<int32_t>({1, 0});
79   ASSERT_EQ(m.Invoke(), kTfLiteOk);
80   EXPECT_THAT(m.GetOutput<float>(),
81               ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2})));
82 }
83 
TEST(GatherOpTest,Test0DIndex)84 TEST(GatherOpTest, Test0DIndex) {
85   GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {}});
86   m.SetInput<float>({-2.0, 0.2, 0.7, 0.8});
87   m.SetPositions<int32_t>({1});
88   ASSERT_EQ(m.Invoke(), kTfLiteOk);
89   EXPECT_THAT(m.GetOutput<float>(),
90               ElementsAreArray(ArrayFloatNear({0.7, 0.8})));
91   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
92 }
93 
TEST(GatherOpTest,Test0DIndexWith0DResult)94 TEST(GatherOpTest, Test0DIndexWith0DResult) {
95   // 0D tensor is special case in current TFLite. Test it once to make sure
96   // existing workarounds are fine with it.
97   GatherOpModel m({TensorType_FLOAT32, {3}}, {TensorType_INT32, {}});
98   m.SetInput<float>({1.0, 2.0, 3.0});
99   m.SetPositions<int32_t>({1});
100   ASSERT_EQ(m.Invoke(), kTfLiteOk);
101   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({2.0})));
102   EXPECT_TRUE(m.GetOutputShape().empty());
103 }
104 
TEST(GatherOpTest,Test1DInput1DIndex)105 TEST(GatherOpTest, Test1DInput1DIndex) {
106   GatherOpModel m({TensorType_FLOAT32, {3}}, {TensorType_INT32, {1}});
107   m.SetInput<float>({1.0, 3.0, 5.0});
108   m.SetPositions<int32_t>({1});
109   ASSERT_EQ(m.Invoke(), kTfLiteOk);
110   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3.0})));
111   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
112 }
113 
TEST(GatherOpTest,Test2DIndexWith2DResult)114 TEST(GatherOpTest, Test2DIndexWith2DResult) {
115   GatherOpModel m({TensorType_FLOAT32, {3}}, {TensorType_INT32, {1, 2}});
116   m.SetInput<float>({1.0, 2.0, 3.0});
117   m.SetPositions<int32_t>({1, 0});
118   ASSERT_EQ(m.Invoke(), kTfLiteOk);
119   EXPECT_THAT(m.GetOutput<float>(),
120               ElementsAreArray(ArrayFloatNear({2.0, 1.0})));
121   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
122 }
123 
TEST(FloatGatherOpTest,Duplicate)124 TEST(FloatGatherOpTest, Duplicate) {
125   GatherOpModel m({TensorType_FLOAT32, {1, 2, 2}}, {TensorType_INT32, {2}});
126   m.SetInput<float>({-2.0, 0.2, 0.7, 0.8});
127   m.SetPositions<int32_t>({0, 0});
128   ASSERT_EQ(m.Invoke(), kTfLiteOk);
129   EXPECT_THAT(
130       m.GetOutput<float>(),
131       ElementsAreArray(ArrayFloatNear({-2, 0.2, 0.7, 0.8, -2, 0.2, 0.7, 0.8})));
132 }
133 
TEST(FloatGatherOpTest,Slice)134 TEST(FloatGatherOpTest, Slice) {
135   GatherOpModel m({TensorType_FLOAT32, {4, 1}}, {TensorType_INT32, {2}});
136   m.SetInput<float>({-2.0, 0.2, 0.7, 0.8});
137   m.SetPositions<int32_t>({1, 3});
138   ASSERT_EQ(m.Invoke(), kTfLiteOk);
139   EXPECT_THAT(m.GetOutput<float>(),
140               ElementsAreArray(ArrayFloatNear({0.2, 0.8})));
141 }
142 
TEST(FloatGatherOpTest,Axis1)143 TEST(FloatGatherOpTest, Axis1) {
144   const int axis = 1;
145   GatherOpModel m({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_INT32, {2}},
146                   axis);
147   m.SetInput<float>({1, 2, 3, 4, 5, 6});
148   m.SetPositions<int32_t>({1, 0});
149   ASSERT_EQ(m.Invoke(), kTfLiteOk);
150   EXPECT_THAT(m.GetOutput<float>(),
151               ElementsAreArray(ArrayFloatNear({4, 5, 6, 1, 2, 3})));
152   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3}));
153 }
154 
TEST(FloatGatherOpTest,Axis10DIndex)155 TEST(FloatGatherOpTest, Axis10DIndex) {
156   const int axis = 1;
157   GatherOpModel m({TensorType_FLOAT32, {1, 3, 2}}, {TensorType_INT32, {}},
158                   axis);
159   m.SetInput<float>({1, 2, 3, 4, 5, 6});
160   m.SetPositions<int32_t>({1});
161   ASSERT_EQ(m.Invoke(), kTfLiteOk);
162   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3, 4})));
163   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
164 }
165 
TEST(FloatGatherOpTest,Axis1Slice)166 TEST(FloatGatherOpTest, Axis1Slice) {
167   const int axis = 1;
168   GatherOpModel m({TensorType_FLOAT32, {1, 4, 2}}, {TensorType_INT32, {2}},
169                   axis);
170   m.SetInput<float>({1, 2, 3, 4, 5, 6, 7, 8});
171   m.SetPositions<int32_t>({3, 1});
172   ASSERT_EQ(m.Invoke(), kTfLiteOk);
173   EXPECT_THAT(m.GetOutput<float>(),
174               ElementsAreArray(ArrayFloatNear({7, 8, 3, 4})));
175   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2}));
176 }
177 
TEST(FloatGatherOpTest,LastAxis)178 TEST(FloatGatherOpTest, LastAxis) {
179   const int axis = -1;
180   GatherOpModel m({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_INT32, {2}},
181                   axis);
182   m.SetInput<float>({1, 2, 3, 4, 5, 6});
183   m.SetPositions<int32_t>({2, 0});
184   ASSERT_EQ(m.Invoke(), kTfLiteOk);
185   EXPECT_THAT(m.GetOutput<float>(),
186               ElementsAreArray(ArrayFloatNear({3, 1, 6, 4})));
187   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 2}));
188 }
189 
TEST(FloatGatherOpTest,LastAxis0DIndex)190 TEST(FloatGatherOpTest, LastAxis0DIndex) {
191   const int axis = -1;
192   GatherOpModel m({TensorType_FLOAT32, {1, 2, 3}}, {TensorType_INT32, {}},
193                   axis);
194   m.SetInput<float>({1, 2, 3, 4, 5, 6});
195   m.SetPositions<int32_t>({2});
196   ASSERT_EQ(m.Invoke(), kTfLiteOk);
197   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({3, 6})));
198   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2}));
199 }
200 
TEST(TypesGatherOpTest,Float32Int32)201 TEST(TypesGatherOpTest, Float32Int32) {
202   GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2}});
203   m.SetInput<float>({13.3, -13.4, -1.4, 1.5});
204   m.SetPositions<int32_t>({1, 0});
205   ASSERT_EQ(m.Invoke(), kTfLiteOk);
206 
207   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({-1.4, 1.5, 13.3, -13.4}));
208 }
209 
TEST(TypesGatherOpTest,Float32Int64)210 TEST(TypesGatherOpTest, Float32Int64) {
211   GatherOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT64, {2}});
212   m.SetInput<float>({13.3, -13.4, -1.4, 1.5});
213   m.SetPositions<int64_t>({1LL, 0LL});
214   ASSERT_EQ(m.Invoke(), kTfLiteOk);
215 
216   EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({-1.4, 1.5, 13.3, -13.4}));
217 }
218 
TEST(TypesGatherOpTest,Int32Int32)219 TEST(TypesGatherOpTest, Int32Int32) {
220   GatherOpModel m({TensorType_INT32, {2, 2}}, {TensorType_INT32, {2}});
221   m.SetInput<int32_t>({-1330, 1340, 140, -150});
222   m.SetPositions<int32_t>({1, 0});
223   ASSERT_EQ(m.Invoke(), kTfLiteOk);
224 
225   EXPECT_THAT(m.GetOutput<int32_t>(),
226               ElementsAreArray({140, -150, -1330, 1340}));
227 }
228 
TEST(TypesGatherOpTest,Int32Int64)229 TEST(TypesGatherOpTest, Int32Int64) {
230   GatherOpModel m({TensorType_INT32, {2, 2}}, {TensorType_INT64, {2}});
231   m.SetInput<int32_t>({-1330, 1340, 140, -150});
232   m.SetPositions<int64_t>({1LL, 0LL});
233   ASSERT_EQ(m.Invoke(), kTfLiteOk);
234 
235   EXPECT_THAT(m.GetOutput<int32_t>(),
236               ElementsAreArray({140, -150, -1330, 1340}));
237 }
238 
TEST(TypesGatherOpTest,Uint8Int32)239 TEST(TypesGatherOpTest, Uint8Int32) {
240   GatherOpModel m({TensorType_UINT8, {2, 2}}, {TensorType_INT32, {2}});
241   m.SetInput<uint8_t>({133, 134, 14, 15});
242   m.SetPositions<int32_t>({1, 0});
243   ASSERT_EQ(m.Invoke(), kTfLiteOk);
244 
245   EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({14, 15, 133, 134}));
246 }
247 
TEST(TypesGatherOpTest,Uint8Int64)248 TEST(TypesGatherOpTest, Uint8Int64) {
249   GatherOpModel m({TensorType_UINT8, {2, 2}}, {TensorType_INT64, {2}});
250   m.SetInput<uint8_t>({133, 134, 14, 15});
251   m.SetPositions<int64_t>({1LL, 0LL});
252   ASSERT_EQ(m.Invoke(), kTfLiteOk);
253 
254   EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({14, 15, 133, 134}));
255 }
256 
TEST(TypesGatherOpTest,Int8Int32)257 TEST(TypesGatherOpTest, Int8Int32) {
258   GatherOpModel m({TensorType_INT8, {2, 2}}, {TensorType_INT32, {2}});
259   m.SetInput<int8_t>({-13, -120, 14, 15});
260   m.SetPositions<int32_t>({1, 0});
261   ASSERT_EQ(m.Invoke(), kTfLiteOk);
262 
263   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({14, 15, -13, -120}));
264 }
265 
TEST(TypesGatherOpTest,Int8Int64)266 TEST(TypesGatherOpTest, Int8Int64) {
267   GatherOpModel m({TensorType_INT8, {2, 2}}, {TensorType_INT64, {2}});
268   m.SetInput<int8_t>({-13, -120, 14, 15});
269   m.SetPositions<int64_t>({1LL, 0LL});
270   ASSERT_EQ(m.Invoke(), kTfLiteOk);
271 
272   EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({14, 15, -13, -120}));
273 }
274 
TEST(TypesGatherOpTest,Int16Int32)275 TEST(TypesGatherOpTest, Int16Int32) {
276   GatherOpModel m({TensorType_INT16, {2, 2}}, {TensorType_INT32, {2}});
277   m.SetInput<int16_t>({-13, -32000, 0, 32500});
278   m.SetPositions<int32_t>({1, 0});
279   ASSERT_EQ(m.Invoke(), kTfLiteOk);
280 
281   EXPECT_THAT(m.GetOutput<int16_t>(),
282               ElementsAreArray({0, 32500, -13, -32000}));
283 }
284 
TEST(TypesGatherOpTest,Int16Int64)285 TEST(TypesGatherOpTest, Int16Int64) {
286   GatherOpModel m({TensorType_INT16, {2, 2}}, {TensorType_INT64, {2}});
287   m.SetInput<int16_t>({-13, -32000, 0, 32500});
288   m.SetPositions<int64_t>({1LL, 0LL});
289   ASSERT_EQ(m.Invoke(), kTfLiteOk);
290 
291   EXPECT_THAT(m.GetOutput<int16_t>(),
292               ElementsAreArray({0, 32500, -13, -32000}));
293 }
294 
TEST(TypesGatherOpTest,Int64Int32)295 TEST(TypesGatherOpTest, Int64Int32) {
296   GatherOpModel m({TensorType_INT64, {2, 2}}, {TensorType_INT32, {2}});
297   m.SetInput<int64_t>({-(1LL << 34), 134LL, 14LL, 15LL});
298   m.SetPositions<int32_t>({1, 0});
299   ASSERT_EQ(m.Invoke(), kTfLiteOk);
300 
301   EXPECT_THAT(m.GetOutput<int64_t>(),
302               ElementsAreArray({14LL, 15LL, -(1LL << 34), 134LL}));
303 }
304 
TEST(TypesGatherOpTest,Int64Int64)305 TEST(TypesGatherOpTest, Int64Int64) {
306   GatherOpModel m({TensorType_INT64, {2, 2}}, {TensorType_INT64, {2}});
307   m.SetInput<int64_t>({-(1LL << 34), 134LL, 14LL, 15LL});
308   m.SetPositions<int64_t>({1LL, 0LL});
309   ASSERT_EQ(m.Invoke(), kTfLiteOk);
310 
311   EXPECT_THAT(m.GetOutput<int64_t>(),
312               ElementsAreArray({14LL, 15LL, -(1LL << 34), 134LL}));
313 }
314 
TEST(GatherOpTest,SimpleString)315 TEST(GatherOpTest, SimpleString) {
316   GatherOpModel m({TensorType_STRING, {3}}, {TensorType_INT32, {2}});
317   m.SetStringInput({"A", "B", "C"});
318   m.SetPositions<int32_t>({0, 2});
319   ASSERT_EQ(m.Invoke(), kTfLiteOk);
320   ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
321   EXPECT_THAT(m.GetStringOutput(), ElementsAreArray({"A", "C"}));
322 }
323 
324 TEST(GatherOpTest, 2DIndexString) {
325   GatherOpModel m({TensorType_STRING, {3}}, {TensorType_INT32, {2, 3}});
326   m.SetStringInput({"A", "B", "C"});
327   m.SetPositions<int32_t>({0, 2, 1, 1, 0, 2});
328   ASSERT_EQ(m.Invoke(), kTfLiteOk);
329   ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
330   EXPECT_THAT(m.GetStringOutput(),
331               ElementsAreArray({"A", "C", "B", "B", "A", "C"}));
332 }
333 
TEST(TypesGatherOpTest,BatchDims2)334 TEST(TypesGatherOpTest, BatchDims2) {
335   GatherOpModel m({TensorType_INT32, {2, 2, 3, 5}},
336                   {TensorType_INT32, {2, 2, 2}}, /*axis=*/2, /*batch_dims=*/2);
337   m.SetInput<int32_t>({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
338                        12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
339                        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
340                        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
341                        48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59});
342   m.SetPositions<int32_t>({1, 0, 0, 1, 1, 0, 0, 1});
343   ASSERT_EQ(m.Invoke(), kTfLiteOk);
344 
345   ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 5}));
346   EXPECT_THAT(
347       m.GetOutput<int32_t>(),
348       ElementsAreArray({5,  6,  7,  8,  9,  0,  1,  2,  3,  4,  15, 16, 17, 18,
349                         19, 20, 21, 22, 23, 24, 35, 36, 37, 38, 39, 30, 31, 32,
350                         33, 34, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54}));
351 }
352 
TEST(TypesGatherOpTest,BatchDims1)353 TEST(TypesGatherOpTest, BatchDims1) {
354   GatherOpModel m({TensorType_INT8, {2, 2, 3, 5}},
355                   {TensorType_INT32, {2, 2, 2}}, /*axis=*/2, /*batch_dims=*/1);
356   m.SetInput<int8_t>({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
357                       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
358                       24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
359                       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
360                       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59});
361   m.SetPositions<int32_t>({1, 0, 0, 1, 1, 0, 0, 1});
362   ASSERT_EQ(m.Invoke(), kTfLiteOk);
363 
364   ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2, 5}));
365   EXPECT_THAT(
366       m.GetOutput<int8_t>(),
367       ElementsAreArray({5,  6,  7,  8,  9,  0,  1,  2,  3,  4,  0,  1,  2,  3,
368                         4,  5,  6,  7,  8,  9,  20, 21, 22, 23, 24, 15, 16, 17,
369                         18, 19, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 35, 36,
370                         37, 38, 39, 30, 31, 32, 33, 34, 30, 31, 32, 33, 34, 35,
371                         36, 37, 38, 39, 50, 51, 52, 53, 54, 45, 46, 47, 48, 49,
372                         45, 46, 47, 48, 49, 50, 51, 52, 53, 54}));
373 }
374 
TEST(TypesGatherOpTest,NegativeBatchDims)375 TEST(TypesGatherOpTest, NegativeBatchDims) {
376   GatherOpModel m({TensorType_INT8, {2, 2, 3, 5}},
377                   {TensorType_INT32, {2, 2, 2}}, /*axis=*/2, /*batch_dims=*/-2);
378   m.SetInput<int8_t>({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
379                       12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
380                       24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
381                       36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
382                       48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59});
383   m.SetPositions<int32_t>({1, 0, 0, 1, 1, 0, 0, 1});
384   ASSERT_EQ(m.Invoke(), kTfLiteOk);
385 
386   ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2, 2, 5}));
387   EXPECT_THAT(
388       m.GetOutput<int8_t>(),
389       ElementsAreArray({5,  6,  7,  8,  9,  0,  1,  2,  3,  4,  0,  1,  2,  3,
390                         4,  5,  6,  7,  8,  9,  20, 21, 22, 23, 24, 15, 16, 17,
391                         18, 19, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 35, 36,
392                         37, 38, 39, 30, 31, 32, 33, 34, 30, 31, 32, 33, 34, 35,
393                         36, 37, 38, 39, 50, 51, 52, 53, 54, 45, 46, 47, 48, 49,
394                         45, 46, 47, 48, 49, 50, 51, 52, 53, 54}));
395 }
396 
TEST(TypesGatherOpTest,BatchDimsEqualIndiceDims)397 TEST(TypesGatherOpTest, BatchDimsEqualIndiceDims) {
398   GatherOpModel m({TensorType_INT8, {2, 2, 2, 5}},
399                   {TensorType_INT32, {2, 2, 2}}, /*axis=*/3, /*batch_dims=*/3);
400   m.SetInput<int8_t>({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13,
401                       14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
402                       28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39});
403   m.SetPositions<int32_t>({1, 0, 0, 1, 1, 0, 0, 1});
404   ASSERT_EQ(m.Invoke(), kTfLiteOk);
405 
406   ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
407   EXPECT_THAT(m.GetOutput<int8_t>(),
408               ElementsAreArray({1, 5, 10, 16, 21, 25, 30, 36}));
409 }
410 
411 }  // namespace
412 }  // namespace tflite
413