• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 
17 #include <initializer_list>
18 #include <string>
19 #include <vector>
20 
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
24 #include "tensorflow/lite/kernels/custom_ops_register.h"
25 #include "tensorflow/lite/kernels/test_util.h"
26 
27 namespace tflite {
28 
29 using ::testing::ElementsAreArray;
30 
31 enum PoolType {
32   kAverage,
33   kMax,
34 };
35 
36 template <typename T>
37 class BasePoolingOpModel : public SingleOpModel {
38  public:
BasePoolingOpModel(PoolType pool_type,TensorData input,int filter_d,int filter_h,int filter_w,TensorData output,TfLitePadding padding=kTfLitePaddingValid,int stride_d=2,int stride_h=2,int stride_w=2)39   BasePoolingOpModel(PoolType pool_type, TensorData input, int filter_d,
40                      int filter_h, int filter_w, TensorData output,
41                      TfLitePadding padding = kTfLitePaddingValid,
42                      int stride_d = 2, int stride_h = 2, int stride_w = 2) {
43     if (input.type == TensorType_FLOAT32) {
44       // Clear quantization params.
45       input.min = input.max = 0.f;
46       output.min = output.max = 0.f;
47     }
48     input_ = AddInput(input);
49     output_ = AddOutput(output);
50 
51     std::vector<uint8_t> custom_option = CreateCustomOptions(
52         stride_d, stride_h, stride_w, filter_d, filter_h, filter_w, padding);
53     if (pool_type == kAverage) {
54       SetCustomOp("AveragePool3D", custom_option,
55                   ops::custom::Register_AVG_POOL_3D);
56     } else {
57       SetCustomOp("MaxPool3D", custom_option,
58                   ops::custom::Register_MAX_POOL_3D);
59     }
60     BuildInterpreter({GetShape(input_)});
61   }
62 
SetInput(const std::vector<float> & data)63   void SetInput(const std::vector<float>& data) {
64     QuantizeAndPopulate<T>(input_, data);
65   }
66 
GetOutput()67   std::vector<float> GetOutput() {
68     return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
69                          GetZeroPoint(output_));
70   }
71 
72  protected:
73   int input_;
74   int output_;
75 
76  private:
CreateCustomOptions(int stride_depth,int stride_height,int stride_width,int filter_depth,int filter_height,int filter_width,TfLitePadding padding)77   std::vector<uint8_t> CreateCustomOptions(int stride_depth, int stride_height,
78                                            int stride_width, int filter_depth,
79                                            int filter_height, int filter_width,
80                                            TfLitePadding padding) {
81     auto flex_builder = std::make_unique<flexbuffers::Builder>();
82     size_t map_start = flex_builder->StartMap();
83     flex_builder->String("data_format", "NDHWC");
84     if (padding == kTfLitePaddingValid) {
85       flex_builder->String("padding", "VALID");
86     } else {
87       flex_builder->String("padding", "SAME");
88     }
89 
90     auto start = flex_builder->StartVector("ksize");
91     flex_builder->Add(1);
92     flex_builder->Add(filter_depth);
93     flex_builder->Add(filter_height);
94     flex_builder->Add(filter_width);
95     flex_builder->Add(1);
96     flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
97 
98     auto strides_start = flex_builder->StartVector("strides");
99     flex_builder->Add(1);
100     flex_builder->Add(stride_depth);
101     flex_builder->Add(stride_height);
102     flex_builder->Add(stride_width);
103     flex_builder->Add(1);
104     flex_builder->EndVector(strides_start, /*typed=*/true, /*fixed=*/false);
105 
106     flex_builder->EndMap(map_start);
107     flex_builder->Finish();
108     return flex_builder->GetBuffer();
109   }
110 };
111 
112 template <>
SetInput(const std::vector<float> & data)113 void BasePoolingOpModel<float>::SetInput(const std::vector<float>& data) {
114   PopulateTensor(input_, data);
115 }
116 
117 template <>
GetOutput()118 std::vector<float> BasePoolingOpModel<float>::GetOutput() {
119   return ExtractVector<float>(output_);
120 }
121 
122 #ifdef GTEST_HAS_DEATH_TEST
TEST(AveragePoolingOpTest,InvalidDimSize)123 TEST(AveragePoolingOpTest, InvalidDimSize) {
124   EXPECT_DEATH(BasePoolingOpModel<float> m(
125                    kAverage,
126                    /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
127                    /*filter_d=*/2,
128                    /*filter_h=*/2, /*filter_w=*/2,
129                    /*output=*/{TensorType_FLOAT32, {}},
130                    /*padding=*/kTfLitePaddingValid, /*stride_d=*/1,
131                    /*stride_h=*/1, /*stride_w=*/1),
132                "NumDimensions.input. != 5 .4 != 5.");
133 }
134 
TEST(AveragePoolingOpTest,ZeroStride)135 TEST(AveragePoolingOpTest, ZeroStride) {
136   EXPECT_DEATH(BasePoolingOpModel<float> m(
137                    kAverage,
138                    /*input=*/{TensorType_FLOAT32, {1, 2, 2, 4, 1}},
139                    /*filter_d=*/2,
140                    /*filter_h=*/2, /*filter_w=*/2,
141                    /*output=*/{TensorType_FLOAT32, {}},
142                    /*padding=*/kTfLitePaddingValid, /*stride_d=*/0,
143                    /*stride_h=*/0, /*stride_w=*/0),
144                "Cannot allocate tensors");
145 }
146 #endif
147 
148 template <typename T>
149 class AveragePoolingOpTest : public ::testing::Test {};
150 
151 template <typename T>
152 class MaxPoolingOpTest : public ::testing::Test {};
153 
154 using DataTypes = ::testing::Types<float, int8_t, int16_t>;
155 TYPED_TEST_SUITE(AveragePoolingOpTest, DataTypes);
156 TYPED_TEST_SUITE(MaxPoolingOpTest, DataTypes);
157 
TYPED_TEST(AveragePoolingOpTest,AveragePool)158 TYPED_TEST(AveragePoolingOpTest, AveragePool) {
159   BasePoolingOpModel<TypeParam> m(
160       kAverage,
161       /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
162       /*filter_d=*/2,
163       /*filter_h=*/2, /*filter_w=*/2,
164       /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375});
165   m.SetInput({0, 6, 2, 4, 4, 5, 1, 4, 3, 2, 10, 7, 2, 3, 5, 1});
166   ASSERT_EQ(m.Invoke(), kTfLiteOk);
167   EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.125, 4.25}));
168 }
169 
TYPED_TEST(AveragePoolingOpTest,AveragePoolFilterH1)170 TYPED_TEST(AveragePoolingOpTest, AveragePoolFilterH1) {
171   BasePoolingOpModel<TypeParam> m(
172       kAverage,
173       /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
174       /*filter_d=*/2,
175       /*filter_h=*/1, /*filter_w=*/2,
176       /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375});
177   m.SetInput({0, 6, 2, 4, 4, 5, 1, 4, 3, 2, 10, 7, 2, 3, 5, 1});
178   ASSERT_EQ(m.Invoke(), kTfLiteOk);
179   EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75}));
180 }
181 
TYPED_TEST(AveragePoolingOpTest,AveragePoolPaddingSameStride1)182 TYPED_TEST(AveragePoolingOpTest, AveragePoolPaddingSameStride1) {
183   BasePoolingOpModel<TypeParam> m(
184       kAverage,
185       /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
186       /*filter_d=*/2,
187       /*filter_h=*/2, /*filter_w=*/2,
188       /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375},
189       kTfLitePaddingSame,
190       /*stride_d=*/1, /*stride_h=*/1,
191       /*stride_w=*/1);
192   m.SetInput({0, 6, 2, 4, 2, 5, 4, 3, 3, 2, 10, 7, 3, 2, 2, 4});
193   ASSERT_EQ(m.Invoke(), kTfLiteOk);
194   EXPECT_THAT(m.GetOutput(),
195               ElementsAreArray({2.875, 4.125, 4.5, 4.5, 3.0, 3.25, 3.25, 3.5,
196                                 2.5, 4.0, 5.75, 5.5, 2.5, 2.0, 3.0, 4.0}));
197 }
198 
TYPED_TEST(AveragePoolingOpTest,AveragePoolPaddingValidStride1)199 TYPED_TEST(AveragePoolingOpTest, AveragePoolPaddingValidStride1) {
200   BasePoolingOpModel<TypeParam> m(
201       kAverage,
202       /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
203       /*filter_d=*/2,
204       /*filter_h=*/2, /*filter_w=*/2,
205       /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375},
206       kTfLitePaddingValid,
207       /*stride_d=*/1, /*stride_h=*/1,
208       /*stride_w=*/1);
209   m.SetInput({0, 6, 2, 4, 2, 5, 4, 3, 3, 2, 10, 7, 3, 2, 2, 4});
210   ASSERT_EQ(m.Invoke(), kTfLiteOk);
211   EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.875, 4.125, 4.5}));
212 }
213 
TYPED_TEST(MaxPoolingOpTest,MaxPool)214 TYPED_TEST(MaxPoolingOpTest, MaxPool) {
215   BasePoolingOpModel<TypeParam> m(
216       kMax,
217       /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
218       /*filter_d=*/2,
219       /*filter_h=*/2, /*filter_w=*/2,
220       /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375});
221   m.SetInput({0, 6, 2, 4, 4, 5, 1, 4, 3, 2, 10, 7, 2, 3, 5, 1});
222   ASSERT_EQ(m.Invoke(), kTfLiteOk);
223   EXPECT_THAT(m.GetOutput(), ElementsAreArray({6.0, 10.0}));
224 }
225 
TYPED_TEST(MaxPoolingOpTest,MaxPoolFilterH1)226 TYPED_TEST(MaxPoolingOpTest, MaxPoolFilterH1) {
227   BasePoolingOpModel<TypeParam> m(
228       kMax,
229       /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
230       /*filter_d=*/2,
231       /*filter_h=*/1, /*filter_w=*/2,
232       /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375});
233   m.SetInput({0, 6, 2, 4, 4, 5, 1, 4, 3, 2, 10, 7, 2, 3, 5, 1});
234   ASSERT_EQ(m.Invoke(), kTfLiteOk);
235   EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10}));
236 }
237 
TYPED_TEST(MaxPoolingOpTest,MaxPoolPaddingSameStride1)238 TYPED_TEST(MaxPoolingOpTest, MaxPoolPaddingSameStride1) {
239   BasePoolingOpModel<TypeParam> m(
240       kMax,
241       /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
242       /*filter_d=*/2,
243       /*filter_h=*/2, /*filter_w=*/2,
244       /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375},
245       kTfLitePaddingSame,
246       /*stride_d=*/1, /*stride_h=*/1,
247       /*stride_w=*/1);
248   m.SetInput({0, 6, 2, 4, 2, 5, 4, 3, 3, 2, 10, 7, 3, 2, 2, 4});
249   ASSERT_EQ(m.Invoke(), kTfLiteOk);
250   EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10, 10, 7, 5, 5, 4, 4, 3, 10,
251                                                10, 7, 3, 2, 4, 4}));
252 }
253 
TYPED_TEST(MaxPoolingOpTest,MaxPoolPaddingValidStride1)254 TYPED_TEST(MaxPoolingOpTest, MaxPoolPaddingValidStride1) {
255   BasePoolingOpModel<TypeParam> m(
256       kMax,
257       /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
258       /*filter_d=*/2,
259       /*filter_h=*/2, /*filter_w=*/2,
260       /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375},
261       kTfLitePaddingValid,
262       /*stride_d=*/1, /*stride_h=*/1,
263       /*stride_w=*/1);
264   m.SetInput({0, 6, 2, 4, 2, 5, 4, 3, 3, 2, 10, 7, 3, 2, 2, 4});
265   ASSERT_EQ(m.Invoke(), kTfLiteOk);
266   EXPECT_THAT(m.GetOutput(), ElementsAreArray({6.0, 10.0, 10.0}));
267 }
268 
269 }  // namespace tflite
270