1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <vector>
18
19 #include <gtest/gtest.h>
20 #include "tensorflow/lite/kernels/test_util.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22
23 namespace tflite {
24 namespace {
25
26 using ::testing::ElementsAre;
27 using ::testing::ElementsAreArray;
28
29 template <typename T>
30 class ReverseOpModel : public SingleOpModel {
31 public:
ReverseOpModel(const TensorData & input,const TensorData & axis)32 ReverseOpModel(const TensorData& input, const TensorData& axis) {
33 input_ = AddInput(input);
34 axis_ = AddInput(axis);
35
36 output_ = AddOutput({input.type, {}});
37
38 SetBuiltinOp(BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options,
39 CreateReverseV2Options(builder_).Union());
40 BuildInterpreter({GetShape(input_)});
41 }
42
input()43 int input() { return input_; }
axis()44 int axis() { return axis_; }
45
GetOutput()46 std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()47 std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
48
49 private:
50 int input_;
51 int axis_;
52 int output_;
53 };
54
55 // float32 tests.
TEST(ReverseOpTest,FloatOneDimension)56 TEST(ReverseOpTest, FloatOneDimension) {
57 ReverseOpModel<float> model({TensorType_FLOAT32, {4}},
58 {TensorType_INT32, {1}});
59 model.PopulateTensor<float>(model.input(), {1, 2, 3, 4});
60 model.PopulateTensor<int32_t>(model.axis(), {0});
61 ASSERT_EQ(model.Invoke(), kTfLiteOk);
62
63 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
64 EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
65 }
66
TEST(ReverseOpTest,FloatMultiDimensions)67 TEST(ReverseOpTest, FloatMultiDimensions) {
68 ReverseOpModel<float> model({TensorType_FLOAT32, {4, 3, 2}},
69 {TensorType_INT32, {1}});
70 model.PopulateTensor<float>(model.input(),
71 {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
72 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
73 model.PopulateTensor<int32_t>(model.axis(), {1});
74 ASSERT_EQ(model.Invoke(), kTfLiteOk);
75
76 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
77 EXPECT_THAT(
78 model.GetOutput(),
79 ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
80 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
81 }
82
83 // int32 tests
TEST(ReverseOpTest,Int32OneDimension)84 TEST(ReverseOpTest, Int32OneDimension) {
85 ReverseOpModel<int32_t> model({TensorType_INT32, {4}},
86 {TensorType_INT32, {1}});
87 model.PopulateTensor<int32_t>(model.input(), {1, 2, 3, 4});
88 model.PopulateTensor<int32_t>(model.axis(), {0});
89 ASSERT_EQ(model.Invoke(), kTfLiteOk);
90
91 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
92 EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
93 }
94
TEST(ReverseOpTest,Int32MultiDimensions)95 TEST(ReverseOpTest, Int32MultiDimensions) {
96 ReverseOpModel<int32_t> model({TensorType_INT32, {4, 3, 2}},
97 {TensorType_INT32, {1}});
98 model.PopulateTensor<int32_t>(
99 model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
100 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
101 model.PopulateTensor<int32_t>(model.axis(), {1});
102 ASSERT_EQ(model.Invoke(), kTfLiteOk);
103
104 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
105 EXPECT_THAT(
106 model.GetOutput(),
107 ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
108 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
109 }
110
111 // int64 tests
TEST(ReverseOpTest,Int64OneDimension)112 TEST(ReverseOpTest, Int64OneDimension) {
113 ReverseOpModel<int64_t> model({TensorType_INT64, {4}},
114 {TensorType_INT32, {1}});
115 model.PopulateTensor<int64_t>(model.input(), {1, 2, 3, 4});
116 model.PopulateTensor<int32_t>(model.axis(), {0});
117 ASSERT_EQ(model.Invoke(), kTfLiteOk);
118
119 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
120 EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
121 }
122
TEST(ReverseOpTest,Int64MultiDimensions)123 TEST(ReverseOpTest, Int64MultiDimensions) {
124 ReverseOpModel<int64_t> model({TensorType_INT64, {4, 3, 2}},
125 {TensorType_INT32, {1}});
126 model.PopulateTensor<int64_t>(
127 model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
128 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
129 model.PopulateTensor<int32_t>(model.axis(), {1});
130 ASSERT_EQ(model.Invoke(), kTfLiteOk);
131
132 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
133 EXPECT_THAT(
134 model.GetOutput(),
135 ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
136 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
137 }
138
139 // uint8 tests
TEST(ReverseOpTest,Uint8OneDimension)140 TEST(ReverseOpTest, Uint8OneDimension) {
141 ReverseOpModel<uint8_t> model({TensorType_UINT8, {4}},
142 {TensorType_INT32, {1}});
143 model.PopulateTensor<uint8_t>(model.input(), {1, 2, 3, 4});
144 model.PopulateTensor<int32_t>(model.axis(), {0});
145 ASSERT_EQ(model.Invoke(), kTfLiteOk);
146
147 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
148 EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
149 }
150
TEST(ReverseOpTest,Uint8MultiDimensions)151 TEST(ReverseOpTest, Uint8MultiDimensions) {
152 ReverseOpModel<uint8_t> model({TensorType_UINT8, {4, 3, 2}},
153 {TensorType_INT32, {1}});
154 model.PopulateTensor<uint8_t>(
155 model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
156 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
157 model.PopulateTensor<int32_t>(model.axis(), {1});
158 ASSERT_EQ(model.Invoke(), kTfLiteOk);
159
160 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
161 EXPECT_THAT(
162 model.GetOutput(),
163 ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
164 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
165 }
166
167 // int8 tests
TEST(ReverseOpTest,Int8OneDimension)168 TEST(ReverseOpTest, Int8OneDimension) {
169 ReverseOpModel<int8_t> model({TensorType_INT8, {4}}, {TensorType_INT32, {1}});
170 model.PopulateTensor<int8_t>(model.input(), {1, 2, -1, -2});
171 model.PopulateTensor<int32_t>(model.axis(), {0});
172 ASSERT_EQ(model.Invoke(), kTfLiteOk);
173
174 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
175 EXPECT_THAT(model.GetOutput(), ElementsAreArray({-2, -1, 2, 1}));
176 }
177
TEST(ReverseOpTest,Int8MultiDimensions)178 TEST(ReverseOpTest, Int8MultiDimensions) {
179 ReverseOpModel<int8_t> model({TensorType_INT8, {4, 3, 2}},
180 {TensorType_INT32, {1}});
181 model.PopulateTensor<int8_t>(
182 model.input(), {-1, -2, -3, -4, 5, 6, 7, 8, 9, 10, 11, 12,
183 13, 14, 15, 16, 17, 18, 19, 20, -21, -22, -23, -24});
184 model.PopulateTensor<int32_t>(model.axis(), {1});
185 ASSERT_EQ(model.Invoke(), kTfLiteOk);
186
187 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
188 EXPECT_THAT(
189 model.GetOutput(),
190 ElementsAreArray({5, 6, -3, -4, -1, -2, 11, 12, 9, 10, 7, 8,
191 17, 18, 15, 16, 13, 14, -23, -24, -21, -22, 19, 20}));
192 }
193
194 // int16 tests
TEST(ReverseOpTest,Int16OneDimension)195 TEST(ReverseOpTest, Int16OneDimension) {
196 ReverseOpModel<int16_t> model({TensorType_INT16, {4}},
197 {TensorType_INT32, {1}});
198 model.PopulateTensor<int16_t>(model.input(), {1, 2, 3, 4});
199 model.PopulateTensor<int32_t>(model.axis(), {0});
200 ASSERT_EQ(model.Invoke(), kTfLiteOk);
201
202 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
203 EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
204 }
205
TEST(ReverseOpTest,Int16MultiDimensions)206 TEST(ReverseOpTest, Int16MultiDimensions) {
207 ReverseOpModel<int16_t> model({TensorType_INT16, {4, 3, 2}},
208 {TensorType_INT32, {1}});
209 model.PopulateTensor<int16_t>(
210 model.input(), {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
211 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
212 model.PopulateTensor<int32_t>(model.axis(), {1});
213 ASSERT_EQ(model.Invoke(), kTfLiteOk);
214
215 EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
216 EXPECT_THAT(
217 model.GetOutput(),
218 ElementsAreArray({5, 6, 3, 4, 1, 2, 11, 12, 9, 10, 7, 8,
219 17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
220 }
221
222 } // namespace
223 } // namespace tflite
224