1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16
17 #include <initializer_list>
18 #include <string>
19 #include <vector>
20
21 #include <gmock/gmock.h>
22 #include <gtest/gtest.h>
23 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
24 #include "tensorflow/lite/kernels/custom_ops_register.h"
25 #include "tensorflow/lite/kernels/test_util.h"
26
27 namespace tflite {
28
29 using ::testing::ElementsAreArray;
30
31 enum PoolType {
32 kAverage,
33 kMax,
34 };
35
36 template <typename T>
37 class BasePoolingOpModel : public SingleOpModel {
38 public:
BasePoolingOpModel(PoolType pool_type,TensorData input,int filter_d,int filter_h,int filter_w,TensorData output,TfLitePadding padding=kTfLitePaddingValid,int stride_d=2,int stride_h=2,int stride_w=2)39 BasePoolingOpModel(PoolType pool_type, TensorData input, int filter_d,
40 int filter_h, int filter_w, TensorData output,
41 TfLitePadding padding = kTfLitePaddingValid,
42 int stride_d = 2, int stride_h = 2, int stride_w = 2) {
43 if (input.type == TensorType_FLOAT32) {
44 // Clear quantization params.
45 input.min = input.max = 0.f;
46 output.min = output.max = 0.f;
47 }
48 input_ = AddInput(input);
49 output_ = AddOutput(output);
50
51 std::vector<uint8_t> custom_option = CreateCustomOptions(
52 stride_d, stride_h, stride_w, filter_d, filter_h, filter_w, padding);
53 if (pool_type == kAverage) {
54 SetCustomOp("AveragePool3D", custom_option,
55 ops::custom::Register_AVG_POOL_3D);
56 } else {
57 SetCustomOp("MaxPool3D", custom_option,
58 ops::custom::Register_MAX_POOL_3D);
59 }
60 BuildInterpreter({GetShape(input_)});
61 }
62
SetInput(const std::vector<float> & data)63 void SetInput(const std::vector<float>& data) {
64 QuantizeAndPopulate<T>(input_, data);
65 }
66
GetOutput()67 std::vector<float> GetOutput() {
68 return Dequantize<T>(ExtractVector<T>(output_), GetScale(output_),
69 GetZeroPoint(output_));
70 }
71
72 protected:
73 int input_;
74 int output_;
75
76 private:
CreateCustomOptions(int stride_depth,int stride_height,int stride_width,int filter_depth,int filter_height,int filter_width,TfLitePadding padding)77 std::vector<uint8_t> CreateCustomOptions(int stride_depth, int stride_height,
78 int stride_width, int filter_depth,
79 int filter_height, int filter_width,
80 TfLitePadding padding) {
81 auto flex_builder = std::make_unique<flexbuffers::Builder>();
82 size_t map_start = flex_builder->StartMap();
83 flex_builder->String("data_format", "NDHWC");
84 if (padding == kTfLitePaddingValid) {
85 flex_builder->String("padding", "VALID");
86 } else {
87 flex_builder->String("padding", "SAME");
88 }
89
90 auto start = flex_builder->StartVector("ksize");
91 flex_builder->Add(1);
92 flex_builder->Add(filter_depth);
93 flex_builder->Add(filter_height);
94 flex_builder->Add(filter_width);
95 flex_builder->Add(1);
96 flex_builder->EndVector(start, /*typed=*/true, /*fixed=*/false);
97
98 auto strides_start = flex_builder->StartVector("strides");
99 flex_builder->Add(1);
100 flex_builder->Add(stride_depth);
101 flex_builder->Add(stride_height);
102 flex_builder->Add(stride_width);
103 flex_builder->Add(1);
104 flex_builder->EndVector(strides_start, /*typed=*/true, /*fixed=*/false);
105
106 flex_builder->EndMap(map_start);
107 flex_builder->Finish();
108 return flex_builder->GetBuffer();
109 }
110 };
111
112 template <>
SetInput(const std::vector<float> & data)113 void BasePoolingOpModel<float>::SetInput(const std::vector<float>& data) {
114 PopulateTensor(input_, data);
115 }
116
117 template <>
GetOutput()118 std::vector<float> BasePoolingOpModel<float>::GetOutput() {
119 return ExtractVector<float>(output_);
120 }
121
122 #ifdef GTEST_HAS_DEATH_TEST
TEST(AveragePoolingOpTest,InvalidDimSize)123 TEST(AveragePoolingOpTest, InvalidDimSize) {
124 EXPECT_DEATH(BasePoolingOpModel<float> m(
125 kAverage,
126 /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}},
127 /*filter_d=*/2,
128 /*filter_h=*/2, /*filter_w=*/2,
129 /*output=*/{TensorType_FLOAT32, {}},
130 /*padding=*/kTfLitePaddingValid, /*stride_d=*/1,
131 /*stride_h=*/1, /*stride_w=*/1),
132 "NumDimensions.input. != 5 .4 != 5.");
133 }
134
TEST(AveragePoolingOpTest,ZeroStride)135 TEST(AveragePoolingOpTest, ZeroStride) {
136 EXPECT_DEATH(BasePoolingOpModel<float> m(
137 kAverage,
138 /*input=*/{TensorType_FLOAT32, {1, 2, 2, 4, 1}},
139 /*filter_d=*/2,
140 /*filter_h=*/2, /*filter_w=*/2,
141 /*output=*/{TensorType_FLOAT32, {}},
142 /*padding=*/kTfLitePaddingValid, /*stride_d=*/0,
143 /*stride_h=*/0, /*stride_w=*/0),
144 "Cannot allocate tensors");
145 }
146 #endif
147
148 template <typename T>
149 class AveragePoolingOpTest : public ::testing::Test {};
150
151 template <typename T>
152 class MaxPoolingOpTest : public ::testing::Test {};
153
154 using DataTypes = ::testing::Types<float, int8_t, int16_t>;
155 TYPED_TEST_SUITE(AveragePoolingOpTest, DataTypes);
156 TYPED_TEST_SUITE(MaxPoolingOpTest, DataTypes);
157
TYPED_TEST(AveragePoolingOpTest,AveragePool)158 TYPED_TEST(AveragePoolingOpTest, AveragePool) {
159 BasePoolingOpModel<TypeParam> m(
160 kAverage,
161 /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
162 /*filter_d=*/2,
163 /*filter_h=*/2, /*filter_w=*/2,
164 /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375});
165 m.SetInput({0, 6, 2, 4, 4, 5, 1, 4, 3, 2, 10, 7, 2, 3, 5, 1});
166 ASSERT_EQ(m.Invoke(), kTfLiteOk);
167 EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.125, 4.25}));
168 }
169
TYPED_TEST(AveragePoolingOpTest,AveragePoolFilterH1)170 TYPED_TEST(AveragePoolingOpTest, AveragePoolFilterH1) {
171 BasePoolingOpModel<TypeParam> m(
172 kAverage,
173 /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
174 /*filter_d=*/2,
175 /*filter_h=*/1, /*filter_w=*/2,
176 /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375});
177 m.SetInput({0, 6, 2, 4, 4, 5, 1, 4, 3, 2, 10, 7, 2, 3, 5, 1});
178 ASSERT_EQ(m.Invoke(), kTfLiteOk);
179 EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75}));
180 }
181
TYPED_TEST(AveragePoolingOpTest,AveragePoolPaddingSameStride1)182 TYPED_TEST(AveragePoolingOpTest, AveragePoolPaddingSameStride1) {
183 BasePoolingOpModel<TypeParam> m(
184 kAverage,
185 /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
186 /*filter_d=*/2,
187 /*filter_h=*/2, /*filter_w=*/2,
188 /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375},
189 kTfLitePaddingSame,
190 /*stride_d=*/1, /*stride_h=*/1,
191 /*stride_w=*/1);
192 m.SetInput({0, 6, 2, 4, 2, 5, 4, 3, 3, 2, 10, 7, 3, 2, 2, 4});
193 ASSERT_EQ(m.Invoke(), kTfLiteOk);
194 EXPECT_THAT(m.GetOutput(),
195 ElementsAreArray({2.875, 4.125, 4.5, 4.5, 3.0, 3.25, 3.25, 3.5,
196 2.5, 4.0, 5.75, 5.5, 2.5, 2.0, 3.0, 4.0}));
197 }
198
TYPED_TEST(AveragePoolingOpTest,AveragePoolPaddingValidStride1)199 TYPED_TEST(AveragePoolingOpTest, AveragePoolPaddingValidStride1) {
200 BasePoolingOpModel<TypeParam> m(
201 kAverage,
202 /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
203 /*filter_d=*/2,
204 /*filter_h=*/2, /*filter_w=*/2,
205 /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375},
206 kTfLitePaddingValid,
207 /*stride_d=*/1, /*stride_h=*/1,
208 /*stride_w=*/1);
209 m.SetInput({0, 6, 2, 4, 2, 5, 4, 3, 3, 2, 10, 7, 3, 2, 2, 4});
210 ASSERT_EQ(m.Invoke(), kTfLiteOk);
211 EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.875, 4.125, 4.5}));
212 }
213
TYPED_TEST(MaxPoolingOpTest,MaxPool)214 TYPED_TEST(MaxPoolingOpTest, MaxPool) {
215 BasePoolingOpModel<TypeParam> m(
216 kMax,
217 /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
218 /*filter_d=*/2,
219 /*filter_h=*/2, /*filter_w=*/2,
220 /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375});
221 m.SetInput({0, 6, 2, 4, 4, 5, 1, 4, 3, 2, 10, 7, 2, 3, 5, 1});
222 ASSERT_EQ(m.Invoke(), kTfLiteOk);
223 EXPECT_THAT(m.GetOutput(), ElementsAreArray({6.0, 10.0}));
224 }
225
TYPED_TEST(MaxPoolingOpTest,MaxPoolFilterH1)226 TYPED_TEST(MaxPoolingOpTest, MaxPoolFilterH1) {
227 BasePoolingOpModel<TypeParam> m(
228 kMax,
229 /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
230 /*filter_d=*/2,
231 /*filter_h=*/1, /*filter_w=*/2,
232 /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375});
233 m.SetInput({0, 6, 2, 4, 4, 5, 1, 4, 3, 2, 10, 7, 2, 3, 5, 1});
234 ASSERT_EQ(m.Invoke(), kTfLiteOk);
235 EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10}));
236 }
237
TYPED_TEST(MaxPoolingOpTest,MaxPoolPaddingSameStride1)238 TYPED_TEST(MaxPoolingOpTest, MaxPoolPaddingSameStride1) {
239 BasePoolingOpModel<TypeParam> m(
240 kMax,
241 /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
242 /*filter_d=*/2,
243 /*filter_h=*/2, /*filter_w=*/2,
244 /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375},
245 kTfLitePaddingSame,
246 /*stride_d=*/1, /*stride_h=*/1,
247 /*stride_w=*/1);
248 m.SetInput({0, 6, 2, 4, 2, 5, 4, 3, 3, 2, 10, 7, 3, 2, 2, 4});
249 ASSERT_EQ(m.Invoke(), kTfLiteOk);
250 EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10, 10, 7, 5, 5, 4, 4, 3, 10,
251 10, 7, 3, 2, 4, 4}));
252 }
253
TYPED_TEST(MaxPoolingOpTest,MaxPoolPaddingValidStride1)254 TYPED_TEST(MaxPoolingOpTest, MaxPoolPaddingValidStride1) {
255 BasePoolingOpModel<TypeParam> m(
256 kMax,
257 /*input=*/{GetTensorType<TypeParam>(), {1, 2, 2, 4, 1}, 0, 15.9375},
258 /*filter_d=*/2,
259 /*filter_h=*/2, /*filter_w=*/2,
260 /*output=*/{GetTensorType<TypeParam>(), {}, 0, 15.9375},
261 kTfLitePaddingValid,
262 /*stride_d=*/1, /*stride_h=*/1,
263 /*stride_w=*/1);
264 m.SetInput({0, 6, 2, 4, 2, 5, 4, 3, 3, 2, 10, 7, 3, 2, 2, 4});
265 ASSERT_EQ(m.Invoke(), kTfLiteOk);
266 EXPECT_THAT(m.GetOutput(), ElementsAreArray({6.0, 10.0, 10.0}));
267 }
268
269 } // namespace tflite
270