• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <vector>
18 
19 #include <gtest/gtest.h>
20 #include "tensorflow/lite/kernels/test_util.h"
21 #include "tensorflow/lite/schema/schema_generated.h"
22 
23 namespace tflite {
24 namespace {
25 
26 using ::testing::ElementsAre;
27 using ::testing::ElementsAreArray;
28 
29 template <typename T>
30 class ReverseOpModel : public SingleOpModel {
31  public:
ReverseOpModel(const TensorData & input,const TensorData & axis)32   ReverseOpModel(const TensorData& input, const TensorData& axis) {
33     input_ = AddInput(input);
34     axis_ = AddInput(axis);
35 
36     output_ = AddOutput({input.type, {}});
37 
38     SetBuiltinOp(BuiltinOperator_REVERSE_V2, BuiltinOptions_ReverseV2Options,
39                  CreateReverseV2Options(builder_).Union());
40     BuildInterpreter({GetShape(input_)});
41   }
42 
input()43   int input() { return input_; }
axis()44   int axis() { return axis_; }
45 
GetOutput()46   std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
GetOutputShape()47   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
48 
49  private:
50   int input_;
51   int axis_;
52   int output_;
53 };
54 
55 // float32 tests.
TEST(ReverseOpTest,FloatOneDimension)56 TEST(ReverseOpTest, FloatOneDimension) {
57   ReverseOpModel<float> model({TensorType_FLOAT32, {4}},
58                               {TensorType_INT32, {1}});
59   model.PopulateTensor<float>(model.input(), {1, 2, 3, 4});
60   model.PopulateTensor<int32_t>(model.axis(), {0});
61   ASSERT_EQ(model.Invoke(), kTfLiteOk);
62 
63   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
64   EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
65 }
66 
TEST(ReverseOpTest,FloatMultiDimensions)67 TEST(ReverseOpTest, FloatMultiDimensions) {
68   ReverseOpModel<float> model({TensorType_FLOAT32, {4, 3, 2}},
69                               {TensorType_INT32, {1}});
70   model.PopulateTensor<float>(model.input(),
71                               {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
72                                13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
73   model.PopulateTensor<int32_t>(model.axis(), {1});
74   ASSERT_EQ(model.Invoke(), kTfLiteOk);
75 
76   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
77   EXPECT_THAT(
78       model.GetOutput(),
79       ElementsAreArray({5,  6,  3,  4,  1,  2,  11, 12, 9,  10, 7,  8,
80                         17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
81 }
82 
83 // int32 tests
TEST(ReverseOpTest,Int32OneDimension)84 TEST(ReverseOpTest, Int32OneDimension) {
85   ReverseOpModel<int32_t> model({TensorType_INT32, {4}},
86                                 {TensorType_INT32, {1}});
87   model.PopulateTensor<int32_t>(model.input(), {1, 2, 3, 4});
88   model.PopulateTensor<int32_t>(model.axis(), {0});
89   ASSERT_EQ(model.Invoke(), kTfLiteOk);
90 
91   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
92   EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
93 }
94 
TEST(ReverseOpTest,Int32MultiDimensions)95 TEST(ReverseOpTest, Int32MultiDimensions) {
96   ReverseOpModel<int32_t> model({TensorType_INT32, {4, 3, 2}},
97                                 {TensorType_INT32, {1}});
98   model.PopulateTensor<int32_t>(
99       model.input(), {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
100                       13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
101   model.PopulateTensor<int32_t>(model.axis(), {1});
102   ASSERT_EQ(model.Invoke(), kTfLiteOk);
103 
104   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
105   EXPECT_THAT(
106       model.GetOutput(),
107       ElementsAreArray({5,  6,  3,  4,  1,  2,  11, 12, 9,  10, 7,  8,
108                         17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
109 }
110 
111 // int64 tests
TEST(ReverseOpTest,Int64OneDimension)112 TEST(ReverseOpTest, Int64OneDimension) {
113   ReverseOpModel<int64_t> model({TensorType_INT64, {4}},
114                                 {TensorType_INT32, {1}});
115   model.PopulateTensor<int64_t>(model.input(), {1, 2, 3, 4});
116   model.PopulateTensor<int32_t>(model.axis(), {0});
117   ASSERT_EQ(model.Invoke(), kTfLiteOk);
118 
119   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
120   EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
121 }
122 
TEST(ReverseOpTest,Int64MultiDimensions)123 TEST(ReverseOpTest, Int64MultiDimensions) {
124   ReverseOpModel<int64_t> model({TensorType_INT64, {4, 3, 2}},
125                                 {TensorType_INT32, {1}});
126   model.PopulateTensor<int64_t>(
127       model.input(), {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
128                       13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
129   model.PopulateTensor<int32_t>(model.axis(), {1});
130   ASSERT_EQ(model.Invoke(), kTfLiteOk);
131 
132   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
133   EXPECT_THAT(
134       model.GetOutput(),
135       ElementsAreArray({5,  6,  3,  4,  1,  2,  11, 12, 9,  10, 7,  8,
136                         17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
137 }
138 
139 // uint8 tests
TEST(ReverseOpTest,Uint8OneDimension)140 TEST(ReverseOpTest, Uint8OneDimension) {
141   ReverseOpModel<uint8_t> model({TensorType_UINT8, {4}},
142                                 {TensorType_INT32, {1}});
143   model.PopulateTensor<uint8_t>(model.input(), {1, 2, 3, 4});
144   model.PopulateTensor<int32_t>(model.axis(), {0});
145   ASSERT_EQ(model.Invoke(), kTfLiteOk);
146 
147   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
148   EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
149 }
150 
TEST(ReverseOpTest,Uint8MultiDimensions)151 TEST(ReverseOpTest, Uint8MultiDimensions) {
152   ReverseOpModel<uint8_t> model({TensorType_UINT8, {4, 3, 2}},
153                                 {TensorType_INT32, {1}});
154   model.PopulateTensor<uint8_t>(
155       model.input(), {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
156                       13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
157   model.PopulateTensor<int32_t>(model.axis(), {1});
158   ASSERT_EQ(model.Invoke(), kTfLiteOk);
159 
160   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
161   EXPECT_THAT(
162       model.GetOutput(),
163       ElementsAreArray({5,  6,  3,  4,  1,  2,  11, 12, 9,  10, 7,  8,
164                         17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
165 }
166 
167 // int8 tests
TEST(ReverseOpTest,Int8OneDimension)168 TEST(ReverseOpTest, Int8OneDimension) {
169   ReverseOpModel<int8_t> model({TensorType_INT8, {4}}, {TensorType_INT32, {1}});
170   model.PopulateTensor<int8_t>(model.input(), {1, 2, -1, -2});
171   model.PopulateTensor<int32_t>(model.axis(), {0});
172   ASSERT_EQ(model.Invoke(), kTfLiteOk);
173 
174   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
175   EXPECT_THAT(model.GetOutput(), ElementsAreArray({-2, -1, 2, 1}));
176 }
177 
TEST(ReverseOpTest,Int8MultiDimensions)178 TEST(ReverseOpTest, Int8MultiDimensions) {
179   ReverseOpModel<int8_t> model({TensorType_INT8, {4, 3, 2}},
180                                {TensorType_INT32, {1}});
181   model.PopulateTensor<int8_t>(
182       model.input(), {-1, -2, -3, -4, 5,  6,  7,  8,  9,   10,  11,  12,
183                       13, 14, 15, 16, 17, 18, 19, 20, -21, -22, -23, -24});
184   model.PopulateTensor<int32_t>(model.axis(), {1});
185   ASSERT_EQ(model.Invoke(), kTfLiteOk);
186 
187   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
188   EXPECT_THAT(
189       model.GetOutput(),
190       ElementsAreArray({5,  6,  -3, -4, -1, -2, 11,  12,  9,   10,  7,  8,
191                         17, 18, 15, 16, 13, 14, -23, -24, -21, -22, 19, 20}));
192 }
193 
194 // int16 tests
TEST(ReverseOpTest,Int16OneDimension)195 TEST(ReverseOpTest, Int16OneDimension) {
196   ReverseOpModel<int16_t> model({TensorType_INT16, {4}},
197                                 {TensorType_INT32, {1}});
198   model.PopulateTensor<int16_t>(model.input(), {1, 2, 3, 4});
199   model.PopulateTensor<int32_t>(model.axis(), {0});
200   ASSERT_EQ(model.Invoke(), kTfLiteOk);
201 
202   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4));
203   EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 3, 2, 1}));
204 }
205 
TEST(ReverseOpTest,Int16MultiDimensions)206 TEST(ReverseOpTest, Int16MultiDimensions) {
207   ReverseOpModel<int16_t> model({TensorType_INT16, {4, 3, 2}},
208                                 {TensorType_INT32, {1}});
209   model.PopulateTensor<int16_t>(
210       model.input(), {1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
211                       13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
212   model.PopulateTensor<int32_t>(model.axis(), {1});
213   ASSERT_EQ(model.Invoke(), kTfLiteOk);
214 
215   EXPECT_THAT(model.GetOutputShape(), ElementsAre(4, 3, 2));
216   EXPECT_THAT(
217       model.GetOutput(),
218       ElementsAreArray({5,  6,  3,  4,  1,  2,  11, 12, 9,  10, 7,  8,
219                         17, 18, 15, 16, 13, 14, 23, 24, 21, 22, 19, 20}));
220 }
221 
222 }  // namespace
223 }  // namespace tflite
224