• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <gtest/gtest.h>
16 #include "tensorflow/contrib/lite/interpreter.h"
17 #include "tensorflow/contrib/lite/kernels/register.h"
18 #include "tensorflow/contrib/lite/kernels/test_util.h"
19 #include "tensorflow/contrib/lite/model.h"
20 
21 namespace tflite {
22 namespace {
23 
24 using ::testing::ElementsAreArray;
25 
26 class PadOpModel : public SingleOpModel {
27  public:
SetInput(std::initializer_list<float> data)28   void SetInput(std::initializer_list<float> data) {
29     PopulateTensor<float>(input_, data);
30   }
31 
SetPaddings(std::initializer_list<int> paddings)32   void SetPaddings(std::initializer_list<int> paddings) {
33     PopulateTensor<int>(paddings_, paddings);
34   }
35 
GetOutput()36   std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
GetOutputShape()37   std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
38 
39  protected:
40   int input_;
41   int output_;
42   int paddings_;
43 };
44 
45 // Tests case where paddings is a const tensor.
46 //
47 // Example usage is as follows:
48 //    PadOpDynamicModel m(input_shape, paddings_shape, paddings_data);
49 //    m.SetInput(input_data);
50 //    m.Invoke();
51 class PadOpConstModel : public PadOpModel {
52  public:
PadOpConstModel(std::initializer_list<int> input_shape,std::initializer_list<int> paddings_shape,std::initializer_list<int> paddings)53   PadOpConstModel(std::initializer_list<int> input_shape,
54                   std::initializer_list<int> paddings_shape,
55                   std::initializer_list<int> paddings) {
56     input_ = AddInput(TensorType_FLOAT32);
57     paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape);
58     output_ = AddOutput(TensorType_FLOAT32);
59 
60     SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
61                  CreatePadOptions(builder_).Union());
62     BuildInterpreter({input_shape});
63   }
64 };
65 
66 // Test case where paddings is a non-const tensor.
67 //
68 // Example usage is as follows:
69 //    PadOpDynamicModel m(input_shape, paddings_shape);
70 //    m.SetInput(input_data);
71 //    m.SetPaddings(paddings_data);
72 //    m.Invoke();
73 class PadOpDynamicModel : public PadOpModel {
74  public:
PadOpDynamicModel(std::initializer_list<int> input_shape,std::initializer_list<int> paddings_shape)75   PadOpDynamicModel(std::initializer_list<int> input_shape,
76                     std::initializer_list<int> paddings_shape) {
77     input_ = AddInput(TensorType_FLOAT32);
78     paddings_ = AddInput(TensorType_INT32);
79     output_ = AddOutput(TensorType_FLOAT32);
80 
81     SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions,
82                  CreatePadOptions(builder_).Union());
83     BuildInterpreter({input_shape, paddings_shape});
84   }
85 };
86 
TEST(PadOpTest,TooManyDimensions)87 TEST(PadOpTest, TooManyDimensions) {
88   EXPECT_DEATH(
89       PadOpConstModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9, 2},
90                       {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}),
91       "dims != 4");
92 }
93 
TEST(PadOpTest,UnequalDimensions)94 TEST(PadOpTest, UnequalDimensions) {
95   EXPECT_DEATH(PadOpConstModel({1, 1, 2, 1}, {3, 2}, {1, 1, 2, 2, 3, 3}),
96                "3 != 4");
97 }
98 
TEST(PadOpTest,InvalidPadValue)99 TEST(PadOpTest, InvalidPadValue) {
100   EXPECT_DEATH(
101       PadOpConstModel({1, 1, 2, 1}, {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}),
102       "Pad value has to be greater than equal to 0.");
103 }
104 
TEST(PadOpTest,SimpleConstTest)105 TEST(PadOpTest, SimpleConstTest) {
106   // Padding is represented as four 2-D lists representing above padding and
107   // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}).
108   PadOpConstModel m({1, 2, 2, 1}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0});
109   m.SetInput({1, 2, 3, 4});
110   m.Invoke();
111   EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
112                                                0, 0, 0, 0, 0}));
113   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
114 }
115 
TEST(PadOpTest,SimpleDynamicTest)116 TEST(PadOpTest, SimpleDynamicTest) {
117   PadOpDynamicModel m({1, 2, 2, 1}, {4, 2});
118   m.SetInput({1, 2, 3, 4});
119   m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0});
120   m.Invoke();
121   EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4,
122                                                0, 0, 0, 0, 0}));
123   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
124 }
125 
TEST(PadOpTest,AdvancedConstTest)126 TEST(PadOpTest, AdvancedConstTest) {
127   PadOpConstModel m({1, 2, 3, 1}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0});
128   m.SetInput({1, 2, 3, 4, 5, 6});
129   m.Invoke();
130   EXPECT_THAT(m.GetOutput(),
131               ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
132                                 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
133   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
134 }
135 
TEST(PadOpTest,AdvancedDynamicTest)136 TEST(PadOpTest, AdvancedDynamicTest) {
137   PadOpDynamicModel m({1, 2, 3, 1}, {4, 2});
138   m.SetInput({1, 2, 3, 4, 5, 6});
139   m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0});
140   m.Invoke();
141   EXPECT_THAT(m.GetOutput(),
142               ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0,
143                                 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
144   EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1}));
145 }
146 
147 }  // namespace
148 }  // namespace tflite
149 
main(int argc,char ** argv)150 int main(int argc, char** argv) {
151   ::tflite::LogToStderr();
152   ::testing::InitGoogleTest(&argc, argv);
153   return RUN_ALL_TESTS();
154 }
155