1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19
20 class OpAvgPool2DOutTest : public OperatorTest {
21 protected:
op_avg_pool2d_out(const exec_aten::Tensor & self,exec_aten::ArrayRef<int64_t> kernel_size,exec_aten::ArrayRef<int64_t> stride,exec_aten::ArrayRef<int64_t> padding,bool ceil_mode,bool count_include_pad,exec_aten::optional<int64_t> divisor_override,exec_aten::Tensor & out)22 exec_aten::Tensor& op_avg_pool2d_out(
23 const exec_aten::Tensor& self,
24 exec_aten::ArrayRef<int64_t> kernel_size,
25 exec_aten::ArrayRef<int64_t> stride,
26 exec_aten::ArrayRef<int64_t> padding,
27 bool ceil_mode,
28 bool count_include_pad,
29 exec_aten::optional<int64_t> divisor_override,
30 exec_aten::Tensor& out) {
31 return torch::executor::aten::avg_pool2d_outf(
32 context_,
33 self,
34 kernel_size,
35 stride,
36 padding,
37 ceil_mode,
38 count_include_pad,
39 divisor_override,
40 out);
41 }
42 };
43
TEST_F(OpAvgPool2DOutTest,SanityCheck4D)44 TEST_F(OpAvgPool2DOutTest, SanityCheck4D) {
45 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
46
47 exec_aten::Tensor self = tfFloat.make(
48 {2, 3, 8, 8},
49 {-81.875, -37.25, -99.75, -60.5, -1.375, 25.625, -54.25, -95.875,
50 48.0, 26.125, -14.625, -5.0, 98.25, 1.75, 5.875, -71.25,
51 -26.5, 36.0, 32.375, -87.875, -43.125, -40.875, -49.0, -33.0,
52 16.0, -1.875, -48.0, 77.0, 87.0, 58.25, 20.0, 10.875,
53 -37.875, -88.25, 97.75, -98.25, -52.0, -92.75, -89.375, -31.5,
54 -1.625, -50.0, -6.625, -62.5, -86.5, -96.5, 85.0, -94.375,
55 -3.625, 94.125, 12.25, -33.875, -25.625, -20.625, -56.5, -78.5,
56 8.25, 38.875, 74.5, -22.75, 14.125, 46.5, -28.625, -52.0,
57 -36.75, -3.75, 48.875, 2.875, 76.125, -42.125, -71.125, 22.5,
58 9.25, 98.5, 41.5, 65.625, -82.25, -85.75, -20.75, 82.375,
59 3.25, -74.5, -14.25, 18.75, 60.25, -21.875, 7.0, -44.0,
60 -97.0, -71.625, 65.75, -89.5, 26.375, 14.125, -99.75, 42.625,
61 -11.25, 56.625, -63.0, -34.25, -76.125, 69.25, 22.875, 43.875,
62 -7.125, -16.875, 93.875, -48.25, -57.375, -30.625, 11.75, 73.0,
63 -96.5, 14.5, 50.375, -65.875, 73.625, 88.625, -11.375, 41.75,
64 7.5, 62.125, -1.25, 81.125, -78.375, 62.875, 75.125, -34.75,
65 50.5, -2.125, 64.125, -74.375, 22.875, 39.625, -91.125, 96.125,
66 -31.75, -70.875, 12.75, 6.5, 30.125, 12.75, -77.375, -56.5,
67 -24.5, -59.25, 88.125, -26.375, -46.0, 62.75, 39.25, -78.375,
68 67.875, -57.75, 15.5, 80.5, -70.5, 96.875, -18.625, -53.625,
69 70.75, 70.375, -68.125, -41.0, -59.875, 74.625, -82.125, 35.125,
70 -24.25, 33.25, 45.875, -74.75, 37.0, -0.875, 86.625, -14.375,
71 45.25, 41.0, -10.625, -49.875, 61.5, -10.375, -39.0, 75.5,
72 -72.625, -29.0, -53.875, -31.375, -57.375, 79.5, -25.0, -83.0,
73 -49.875, -74.875, 50.25, 8.25, -35.5, -0.875, 27.125, 85.875,
74 -95.625, 17.875, 45.25, -56.75, 60.5, -61.75, 47.625, 50.5,
75 83.0, -63.25, 40.375, -61.75, -89.0, -66.5, 6.0, 46.0,
76 -20.625, -15.625, -58.0, 88.25, 31.25, -37.5, -38.25, 68.5,
77 11.625, 34.125, -39.5, 84.25, 27.0, -85.625, 7.625, 86.375,
78 41.625, 28.0, -41.125, 44.625, 19.625, -43.125, -23.875, 54.5,
79 50.0, -88.25, 29.0, -77.5, -82.125, -84.0, 8.75, -32.875,
80 27.75, -88.75, -33.25, -58.25, 41.5, -84.0, -53.375, -85.5,
81 -15.625, 39.75, 29.375, -45.375, 96.5, 65.125, 34.875, 75.375,
82 -32.75, -9.75, -55.0, 38.0, 31.125, -35.0, -74.375, -61.0,
83 -63.75, 4.75, 88.75, -83.25, -19.75, 5.875, 88.375, 52.25,
84 70.125, -81.5, -56.375, -98.375, 97.625, 88.375, 22.625, -100.0,
85 -75.0, 10.0, 41.0, 40.375, 12.0, 72.125, 31.875, -22.25,
86 -63.875, 10.5, -81.25, 4.25, 43.5, -44.0, 71.5, -29.625,
87 -3.5, -91.5, 45.375, 88.875, -93.125, -50.25, -6.375, -88.875,
88 -2.375, -17.75, 49.25, -14.75, 14.75, -2.25, 51.25, -57.875,
89 43.875, 87.0, 86.25, -95.125, 11.75, -26.5, 29.875, 89.25,
90 -18.5, 55.375, 74.25, -64.0, 51.875, 78.0, 82.625, -34.0,
91 -0.875, -69.375, -90.875, -83.5, 13.625, 46.875, 8.375, 16.875,
92 96.75, 31.25, 45.625, -2.625, -71.0, -62.375, 31.25, -23.5,
93 66.0, 51.375, -45.25, 43.375, 49.5, 12.625, -73.875, 26.375,
94 29.0, -86.5, -55.375, 88.75, 20.0, 90.0, 28.75, -12.875,
95 -37.75, -1.875, -28.125, 96.75, -66.75, 48.375, -79.25, 8.0,
96 -14.25, -8.0, 51.75, 28.375, 32.0, -50.875, 53.0, -81.75});
97 ::std::vector<int64_t> kernel_size_vec = {2, 3};
98 exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
99 kernel_size_vec.data(), kernel_size_vec.size());
100 ::std::vector<int64_t> stride_vec = {3, 2};
101 exec_aten::ArrayRef<int64_t> stride =
102 exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
103 ::std::vector<int64_t> padding_vec = {1, 1};
104 exec_aten::ArrayRef<int64_t> padding =
105 exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
106 bool ceil_mode = false;
107 bool count_include_pad = true;
108 exec_aten::optional<int64_t> divisor_override;
109 exec_aten::Tensor out = tfFloat.zeros({2, 3, 3, 4});
110 exec_aten::Tensor out_expected = tfFloat.make(
111 {2, 3, 3, 4},
112 {-19.85416603088379,
113 -32.91666793823242,
114 -6.041666507720947,
115 -20.75,
116 3.9375,
117 1.2708333730697632,
118 8.395833015441895,
119 -5.625,
120 6.479166507720947,
121 -7.770833492279053,
122 -54.27083206176758,
123 -43.58333206176758,
124 -6.75,
125 8.0,
126 6.145833492279053,
127 -15.125,
128 -39.97916793823242,
129 -27.5625,
130 1.3541666269302368,
131 -16.97916603088379,
132 -17.66666603088379,
133 4.625,
134 -6.645833492279053,
135 28.85416603088379,
136 8.0625,
137 -2.0625,
138 -1.9791666269302368,
139 7.4375,
140 -12.270833015441895,
141 6.791666507720947,
142 16.20833396911621,
143 8.041666984558105,
144 15.875,
145 -2.5208332538604736,
146 -6.229166507720947,
147 16.25,
148 -20.79166603088379,
149 -2.7291667461395264,
150 -4.6875,
151 18.6875,
152 -2.75,
153 -11.666666984558105,
154 -22.54166603088379,
155 -3.625,
156 5.229166507720947,
157 -17.54166603088379,
158 -37.08333206176758,
159 -20.10416603088379,
160 4.020833492279053,
161 3.9583332538604736,
162 19.375,
163 29.22916603088379,
164 -11.729166984558105,
165 -37.66666793823242,
166 -1.5833333730697632,
167 26.25,
168 -24.72916603088379,
169 -3.9583332538604736,
170 -8.458333015441895,
171 -24.60416603088379,
172 21.8125,
173 13.020833015441895,
174 -18.3125,
175 15.4375,
176 9.625,
177 -28.25,
178 -26.5,
179 2.9166667461395264,
180 -16.1875,
181 2.2708332538604736,
182 46.1875,
183 13.833333015441895});
184 op_avg_pool2d_out(
185 self,
186 kernel_size,
187 stride,
188 padding,
189 ceil_mode,
190 count_include_pad,
191 divisor_override,
192 out);
193 EXPECT_TENSOR_CLOSE(out, out_expected);
194 }
195
TEST_F(OpAvgPool2DOutTest,SanityCheck4DDivisorOverride)196 TEST_F(OpAvgPool2DOutTest, SanityCheck4DDivisorOverride) {
197 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
198
199 exec_aten::Tensor self = tfFloat.make(
200 {2, 3, 8, 8},
201 {13.25, 87.125, -61.875, 52.875, -74.5, 37.5, -62.125, 25.375,
202 -34.375, 68.5, 87.875, 91.125, -22.75, 96.875, 79.25, 38.125,
203 6.5, 27.625, 58.875, 41.5, -82.875, -77.875, -3.25, 22.0,
204 -73.375, -56.875, 56.25, -6.375, -16.375, 40.0, 22.625, 61.0,
205 -78.375, -87.5, 23.25, -55.625, 92.75, -90.0, -55.625, 16.25,
206 -46.5, 91.75, 74.0, -45.0, -34.875, -46.375, 38.0, 63.375,
207 22.625, 16.25, 91.125, 68.375, 98.25, 34.0, 91.25, 89.375,
208 1.0, 22.875, -93.75, 7.5, -81.375, 1.125, -15.875, 16.75,
209 29.25, -14.0, 77.375, -64.625, -22.375, -97.25, -26.75, 9.125,
210 -3.5, -64.5, 35.625, 58.25, 4.125, 85.0, -39.125, -98.25,
211 -0.625, -29.5, -35.375, 90.375, 37.625, 57.875, -59.625, -81.875,
212 73.5, -37.375, 56.75, 88.25, -52.75, -4.75, 80.625, -38.375,
213 -55.375, -76.625, -59.875, -41.5, 76.875, 45.375, 24.625, 57.5,
214 58.25, 64.375, -38.125, -39.875, 71.125, 2.5, -14.75, 21.75,
215 -8.375, 5.0, 71.625, 29.875, 79.375, -13.125, -22.625, 96.875,
216 7.5, 64.5, 1.0, -40.75, -56.0, -94.75, -89.5, -71.25,
217 78.125, -5.875, 93.125, -20.125, -73.875, -34.625, 51.5, 46.5,
218 -88.25, 3.25, 26.0, -32.375, -77.625, -51.875, 0.75, 85.75,
219 -15.875, 84.625, -44.75, -87.625, 17.25, -59.25, 40.25, -0.5,
220 14.625, -80.125, -58.125, 17.25, -7.5, 68.875, -99.5, -82.75,
221 75.5, -95.25, -63.25, 55.25, -17.125, 64.875, -56.25, 39.75,
222 -35.25, 18.0, 21.125, -61.125, -35.5, 12.75, -94.125, -17.125,
223 -41.375, 22.625, -69.25, 51.375, 7.25, 14.625, 70.25, -63.5,
224 -10.0, 53.625, -46.75, 77.5, -78.5, -67.25, -53.375, -33.5,
225 23.375, 77.625, 64.125, -56.375, -59.25, -89.125, 16.25, -2.125,
226 -82.125, 72.75, 45.5, -36.375, -1.125, -7.25, -44.5, 27.125,
227 -1.75, 16.5, 71.0, 81.25, 4.125, 74.0, 45.0, -4.75,
228 68.125, 30.625, 11.625, -88.5, 48.125, 36.875, -32.0, 9.0,
229 -67.5, -17.125, -87.5, 29.875, 23.875, -33.0, -6.625, 0.5,
230 28.25, 54.75, 87.5, 28.0, -48.5, -98.0, -55.0, 62.25,
231 -93.125, -41.5, 94.25, -28.375, -23.75, -50.625, 90.375, -91.0,
232 -67.75, -11.625, 24.5, 49.5, 21.0, 70.25, 66.0, 69.75,
233 43.25, -14.875, -93.875, -15.75, 94.5, 69.125, -67.625, 15.875,
234 -26.875, 41.625, -64.375, 23.0, 13.375, 78.5, -88.875, 17.625,
235 -52.375, 55.25, 41.875, 62.625, 29.0, 53.875, -93.625, -93.375,
236 -70.375, -56.125, 9.0, 60.625, 4.375, 66.625, -46.25, 43.125,
237 93.625, 85.125, -91.375, 49.375, 40.25, -19.75, 13.0, 69.5,
238 -42.375, 47.25, 7.0, -93.625, -80.0, -57.875, -94.875, 9.375,
239 -2.5, -42.25, -8.0, 92.5, 6.875, -11.0, 80.625, 27.125,
240 -75.875, -57.375, -36.625, 17.375, 33.125, -49.0, -84.0, -45.5,
241 -9.625, -16.125, 8.75, -1.0, -16.625, 2.5, -56.0, -52.0,
242 85.0, 63.5, -50.625, 7.375, -90.0, -75.875, -82.0, 82.75,
243 16.125, -35.25, -42.875, -1.375, 39.875, -39.625, -66.875, 29.875,
244 -80.75, 66.5, -50.125, 13.5, -77.875, 6.125, -91.625, -44.125,
245 -50.125, 59.625, -32.5, -2.0, -83.375, 21.0, 49.875, 48.125,
246 -93.375, -54.875, -61.125, 96.375, 91.25, -2.375, -33.25, 48.125,
247 -58.125, -50.75, -50.875, 8.375, 35.625, -72.5, -76.125, -33.25,
248 -18.75, -71.0, 76.625, -11.25, -3.0, -38.625, -66.375, -25.0});
249 ::std::vector<int64_t> kernel_size_vec = {2, 3};
250 exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
251 kernel_size_vec.data(), kernel_size_vec.size());
252 ::std::vector<int64_t> stride_vec = {3, 2};
253 exec_aten::ArrayRef<int64_t> stride =
254 exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
255 ::std::vector<int64_t> padding_vec = {1, 1};
256 exec_aten::ArrayRef<int64_t> padding =
257 exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
258 bool ceil_mode = false;
259 bool count_include_pad = true;
260 exec_aten::optional<int64_t> divisor_override =
261 exec_aten::optional<int64_t>(10);
262 exec_aten::Tensor out = tfFloat.zeros({2, 3, 3, 4});
263 exec_aten::Tensor out_expected = tfFloat.make(
264 {2, 3, 3, 4},
265 {10.037500381469727,
266 7.8125,
267 1.587499976158142,
268 0.07500000298023224,
269 -9.612500190734863,
270 12.100000381469727,
271 -10.199999809265137,
272 6.449999809265137,
273 8.412500381469727,
274 29.649999618530273,
275 7.4375,
276 26.962499618530273,
277 1.524999976158142,
278 -0.125,
279 -18.424999237060547,
280 -11.487500190734863,
281 0.6000000238418579,
282 13.3125,
283 21.662500381469727,
284 -4.612500190734863,
285 11.925000190734863,
286 9.287500381469727,
287 12.987500190734863,
288 7.0625,
289 7.224999904632568,
290 6.712500095367432,
291 -12.862500190734863,
292 6.337500095367432,
293 0.32499998807907104,
294 -16.875,
295 -5.099999904632568,
296 -13.287500381469727,
297 -3.5999999046325684,
298 -1.725000023841858,
299 -1.0625,
300 -7.712500095367432,
301 10.100000381469727,
302 8.537500381469727,
303 -20.475000381469727,
304 -7.5,
305 11.350000381469727,
306 12.25,
307 15.587499618530273,
308 12.8125,
309 -5.162499904632568,
310 19.462499618530273,
311 -22.125,
312 -14.199999809265137,
313 2.8375000953674316,
314 -12.449999809265137,
315 14.787500381469727,
316 1.7374999523162842,
317 -12.362500190734863,
318 17.325000762939453,
319 27.712499618530273,
320 -6.962500095367432,
321 -3.987499952316284,
322 0.2874999940395355,
323 -14.3125,
324 -4.662499904632568,
325 -2.575000047683716,
326 -0.8374999761581421,
327 -1.5125000476837158,
328 -10.550000190734863,
329 -3.3375000953674316,
330 -4.962500095367432,
331 -5.9375,
332 -20.625,
333 -25.712499618530273,
334 -11.287500381469727,
335 15.675000190734863,
336 -16.9375});
337 op_avg_pool2d_out(
338 self,
339 kernel_size,
340 stride,
341 padding,
342 ceil_mode,
343 count_include_pad,
344 divisor_override,
345 out);
346 EXPECT_TENSOR_CLOSE(out, out_expected);
347 }
348
TEST_F(OpAvgPool2DOutTest,SanityCheck4DCeilModeNoIncludePadding)349 TEST_F(OpAvgPool2DOutTest, SanityCheck4DCeilModeNoIncludePadding) {
350 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
351
352 exec_aten::Tensor self = tfFloat.make(
353 {2, 3, 14, 12},
354 {26.375, -17.0, 63.5, 83.0, 21.375, -46.5, -69.125, 99.875,
355 -67.125, -76.0, -1.125, -2.625, -48.0, -1.5, -92.75, 95.5,
356 91.5, 32.625, -35.25, 12.125, 55.0, -64.875, -53.125, -20.0,
357 -5.875, -0.25, -3.5, 4.875, 81.625, 92.25, 98.25, -47.375,
358 78.875, -97.5, 10.375, 55.75, 65.375, 97.375, -93.625, 51.5,
359 62.75, -37.625, 1.875, -73.125, 78.75, -62.25, 20.625, 89.375,
360 16.75, 43.0, 98.375, 17.0, 40.125, -78.0, -24.375, 9.875,
361 2.0, -57.125, -65.375, -86.0, -98.375, 68.75, 70.375, 16.25,
362 -24.625, 12.0, -55.0, -17.0, 81.375, 12.0, 62.875, -79.0,
363 46.0, -12.625, -42.0, 3.125, -6.0, 88.5, -42.25, 82.625,
364 -96.875, 51.375, 41.875, 36.125, -27.375, 33.375, -81.75, 6.875,
365 -39.5, -57.75, 24.0, 35.375, -22.875, 96.375, -65.0, -67.0,
366 -64.875, 86.625, 70.625, 16.5, -45.0, 55.375, 88.5, -42.0,
367 -38.5, -71.5, -66.125, -30.5, 59.375, -8.25, 8.375, 85.5,
368 20.625, -58.875, -35.5, -80.125, 93.25, -15.5, 40.125, 54.25,
369 -40.25, 85.125, 61.375, 84.25, 58.5, -60.375, -58.375, -81.5,
370 -80.125, -84.0, 31.75, 55.125, 8.125, -56.625, -7.0, 12.625,
371 75.75, -88.125, -5.375, 60.0, 21.25, 62.125, 2.125, -3.375,
372 -26.75, 44.25, -96.5, -87.5, -36.625, -36.125, 1.125, -42.0,
373 -48.125, 24.0, 80.375, -38.5, 50.75, -0.625, 35.625, 31.875,
374 -32.75, 70.375, -41.0, -20.75, 27.5, 60.75, -10.5, 54.875,
375 68.25, 40.0, 0.875, 39.375, 67.375, 31.25, 79.25, 37.625,
376 33.0, 10.625, -47.5, -21.0, -2.625, 81.0, -42.25, 23.625,
377 10.375, 1.625, 11.0, 14.625, 19.5, -36.75, 99.875, 86.125,
378 -70.25, 26.375, -86.375, -21.0, -25.0, 51.75, -11.375, 58.875,
379 74.75, 77.625, 90.0, -40.5, -32.625, 40.5, 33.0, -80.625,
380 -33.875, 38.5, 30.625, -55.375, 53.875, -1.25, 75.0, -55.375,
381 71.125, -34.625, -29.75, -92.75, -34.125, -17.25, -72.0, 79.5,
382 81.5, -14.375, -35.625, -38.875, 97.625, -3.0, -31.125, 31.25,
383 13.5, -3.625, -64.0, 72.0, 64.5, 6.875, -4.875, 19.875,
384 50.25, 80.75, 76.875, 91.75, -26.375, 94.125, 19.0, 4.375,
385 -48.625, -69.375, 83.0, 6.75, -85.5, -64.75, -99.875, -88.375,
386 36.875, 24.375, 96.0, 24.25, -10.0, 39.75, -8.75, -76.25,
387 -20.75, -96.25, 75.5, 23.625, 50.875, 76.375, -74.25, 2.0,
388 57.125, 31.875, -67.375, 78.375, -25.125, -27.75, -82.375, 37.25,
389 84.875, 8.25, 53.75, -14.375, -68.875, -7.875, -5.25, 50.25,
390 -92.75, -90.625, 64.0, -32.0, -20.75, 47.625, 35.5, -17.875,
391 44.0, 35.625, -38.0, -80.5, 74.375, -11.125, 85.625, 47.25,
392 -92.375, 2.375, -5.375, 31.875, -83.5, -41.0, -72.25, 49.75,
393 72.125, -92.875, -39.125, -19.75, -79.25, 90.375, -67.625, 71.375,
394 -9.75, -80.125, -56.875, 27.0, 93.75, 55.75, 85.0, -60.5,
395 -11.625, 11.5, 66.0, 7.875, 72.125, -15.25, 60.25, 58.0,
396 -16.125, -41.875, -16.875, 32.25, -69.75, 81.5, -65.125, -28.0,
397 59.25, 2.375, -8.25, 60.0, -69.625, 3.0, 59.875, 93.25,
398 41.875, -38.625, -86.375, -18.25, 86.375, 29.875, 5.375, -63.0,
399 -40.875, 42.0, 51.125, -35.125, 3.125, 31.875, 29.75, 6.5,
400 -12.125, -86.875, 62.5, 6.75, -43.125, 37.0, 58.625, -89.5,
401 -62.125, -21.0, -95.875, -42.125, -42.125, -52.125, -22.875, 78.375,
402 -53.375, -85.625, 19.375, -28.0, 47.375, 63.875, 21.75, 85.75,
403 74.375, 27.75, 46.0, 41.375, 55.125, -4.375, -59.375, -93.375,
404 -95.375, 98.125, -30.875, 69.25, 69.75, -78.375, 71.125, -71.25,
405 -73.75, 69.0, -68.125, -49.0, 50.0, 51.25, -9.625, -4.625,
406 -85.75, -99.125, -99.875, -65.5, 28.75, 95.25, 79.625, 16.5,
407 37.125, -62.75, -62.0, 65.25, 95.5, -15.5, -42.375, 35.75,
408 18.5, 89.25, 35.5, -90.375, 3.625, -2.875, 55.25, 63.5,
409 1.25, 2.625, -18.375, -69.0, 9.5, -31.375, -44.375, 35.125,
410 -37.0, 20.875, 98.5, -43.0, -35.625, -8.125, 40.25, -89.125,
411 41.5, -100.0, -39.625, -76.5, 39.0, 6.125, 22.75, 83.75,
412 -62.375, 31.75, -86.625, -26.125, 87.25, -69.0, 47.0, 99.125,
413 73.5, 24.875, 84.875, 13.625, -57.875, 67.125, -10.375, -67.75,
414 -92.0, 52.125, 66.875, -62.25, 61.125, -46.875, -35.875, 35.625,
415 -91.5, -95.25, 11.625, 51.125, 38.875, -1.875, 1.75, 11.125,
416 61.125, -58.25, 60.125, -62.25, -83.75, -41.375, -14.75, 48.25,
417 -83.25, -91.75, 13.625, -12.75, 30.25, -54.125, 60.0, -25.875,
418 -45.0, -33.75, 85.875, 46.0, -74.75, -93.375, 64.25, -31.5,
419 -30.125, 85.25, 87.0, -1.75, -55.5, 32.0, -67.25, 3.125,
420 -66.0, 16.75, -62.25, -59.375, 91.375, -49.625, -74.125, 52.5,
421 -9.125, 70.5, 69.375, -86.25, -67.75, 82.625, 37.5, 1.5,
422 73.125, -68.875, -1.5, 42.875, 94.125, 98.125, 9.5, 99.375,
423 35.0, -17.625, -22.375, 75.0, 67.5, 47.25, -12.625, 30.375,
424 -16.625, 97.125, 86.5, 61.375, -94.875, 99.125, 62.25, -83.875,
425 -65.0, 69.125, 84.125, -84.75, 90.25, -77.0, 29.625, 58.375,
426 8.625, 76.875, 67.375, -89.875, 62.5, 51.0, 77.75, -0.375,
427 45.5, 86.875, -49.5, -71.375, -32.75, -76.375, -29.375, 69.25,
428 -13.375, 95.875, -58.375, 12.75, 95.375, 83.375, 51.0, 40.75,
429 58.125, -26.875, 29.375, 81.625, 37.25, -91.375, 1.25, 20.625,
430 -46.5, 24.875, 12.25, -73.25, -1.625, -69.625, -79.0, 60.125,
431 -0.125, 41.875, -24.375, 66.5, 9.25, -72.625, -66.375, -29.5,
432 -51.5, 28.75, -79.75, 16.125, -55.5, 80.25, -14.25, 37.875,
433 98.625, -6.875, 67.75, 90.875, -80.375, 71.625, 58.875, -63.0,
434 -71.625, 33.5, -30.125, -33.25, 94.75, -30.5, -45.75, 62.875,
435 67.75, -22.125, 22.0, 48.75, -35.75, 83.625, -53.375, 55.875,
436 53.875, 38.0, 0.375, 15.125, -29.25, 55.0, -79.875, 40.5,
437 65.25, -67.625, -43.125, -47.625, -81.375, 40.5, -94.375, 26.0,
438 70.0, 50.25, 53.0, 26.75, 0.875, 4.0, 17.625, -74.25,
439 -28.5, 56.5, 71.25, -67.0, -9.875, 15.25, -38.125, -61.375,
440 -22.25, 14.625, -47.25, -7.125, 34.75, 10.625, -10.25, 0.5,
441 77.5, 36.0, -51.875, -59.5, 13.375, 45.75, -71.625, -29.75,
442 -90.75, -38.25, 52.0, 77.875, 21.125, 56.625, 17.875, -32.625,
443 -39.0, -68.75, -82.875, -1.25, 96.75, -91.375, -22.5, 27.5,
444 -30.125, -23.875, -76.0, -93.5, 37.5, 78.375, -12.5, -84.875,
445 -21.5, 29.25, -4.375, -48.875, 90.75, -77.375, -23.625, -5.75,
446 -75.125, -92.125, 9.0, 96.5, 76.125, 47.5, 15.875, 21.0,
447 18.625, 94.25, 8.5, -67.75, 69.25, -72.0, -14.5, 90.5,
448 -31.875, -20.875, -93.625, 42.875, -54.375, -70.5, -75.375, -6.75,
449 -48.875, 21.75, 1.75, 45.5, 60.75, -9.75, 13.625, -87.125,
450 12.625, 71.125, -46.5, -31.25, -15.375, -42.75, 13.0, 26.25,
451 61.625, 81.0, 18.125, 84.5, 31.75, -14.125, -80.875, 18.875,
452 51.125, 0.625, 41.5, -81.625, 32.625, -19.375, 76.875, 54.0,
453 -40.375, -78.25, 48.0, 77.5, 21.0, 71.25, 34.875, 73.25,
454 71.75, 87.25, -89.375, 93.625, 64.25, -81.375, -15.0, -58.125,
455 64.75, -49.125, 72.125, -62.0, 0.125, -56.0, -6.875, 98.25,
456 -24.75, 48.375, -85.25, -62.25, 98.25, 12.75, -67.25, 72.125,
457 9.625, 8.5, -43.375, 15.25, -50.875, 7.875, -21.625, 97.25,
458 48.5, 33.25, -15.125, 41.125, 69.25, -35.875, 64.5, 98.75,
459 -78.125, 44.625, 75.75, 41.625, -39.375, -4.875, 87.5, -67.25,
460 -23.25, -52.875, -44.875, -70.75, -63.75, 86.25, -44.0, -92.375,
461 -22.75, -70.75, -84.875, -98.125, -86.75, 35.25, -77.5, -78.0,
462 -62.5, -88.5, 58.25, -65.125, 70.25, -5.875, 11.25, 47.125,
463 10.25, -91.75, -50.125, 45.125, 49.625, 57.25, 26.375, -74.75,
464 70.25, 38.625, 51.25, 12.375, -79.875, -76.25, 12.0, -15.625,
465 50.25, 72.125, 30.5, -68.625, -22.875, 5.125, -59.5, -82.75,
466 -64.5, 46.75, -87.75, 17.875, 50.875, 69.625, 24.125, -25.0,
467 -25.0, -5.0, 21.0, 51.875, 91.625, 50.0, 77.375, -51.0,
468 64.25, 41.0, -20.875, -89.75, -74.75, 48.375, -41.75, 99.25,
469 36.125, 10.75, -49.625, 57.375, 19.625, -43.75, 16.625, -48.375,
470 84.375, 31.125, 93.375, 76.125, -12.25, 98.5, 15.0, 71.625,
471 -87.75, 98.625, -74.5, -42.75, -73.5, -30.875, -44.625, 70.5,
472 96.0, 21.5, 48.375, 12.75, -65.625, 56.5, 97.375, 29.5,
473 -19.375, 94.625, 11.25, -0.375, -96.5, -48.75, -40.375, 98.0,
474 -43.0, 20.75, 50.375, 97.875, -72.875, 92.125, 98.5, -74.0,
475 -57.875, 56.375, -34.0, 6.125, -66.125, 71.75, -67.125, -79.0,
476 -85.875, -88.625, 97.375, 16.875, -18.5, -75.125, 22.625, -21.125,
477 3.875, 57.25, 97.75, -11.0, -46.875, 16.75, -69.25, 99.75,
478 -68.25, 99.75, -17.25, -3.125, 34.25, -54.125, -93.125, 65.0,
479 -76.375, -20.625, -77.875, -65.625, -79.875, 28.75, 58.25, -25.25});
480 ::std::vector<int64_t> kernel_size_vec = {4, 2};
481 exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
482 kernel_size_vec.data(), kernel_size_vec.size());
483 ::std::vector<int64_t> stride_vec = {1, 2};
484 exec_aten::ArrayRef<int64_t> stride =
485 exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
486 ::std::vector<int64_t> padding_vec = {1, 1};
487 exec_aten::ArrayRef<int64_t> padding =
488 exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
489 bool ceil_mode = true;
490 bool count_include_pad = false;
491 exec_aten::optional<int64_t> divisor_override;
492 exec_aten::Tensor out = tfFloat.zeros({2, 3, 13, 7});
493 exec_aten::Tensor out_expected = tfFloat.make(
494 {2, 3, 13, 7},
495 {-9.166666984558105,
496 -8.583333015441895,
497 62.97916793823242,
498 12.041666984558105,
499 21.89583396911621,
500 -47.04166793823242,
501 11.041666984558105,
502 9.46875,
503 -5.96875,
504 61.515625,
505 4.5625,
506 17.125,
507 -40.484375,
508 30.625,
509 7.0625,
510 5.890625,
511 55.609375,
512 6.21875,
513 14.515625,
514 -46.15625,
515 9.78125,
516 -5.53125,
517 35.0625,
518 31.1875,
519 1.171875,
520 14.171875,
521 -22.046875,
522 -4.96875,
523 7.4375,
524 28.703125,
525 20.015625,
526 -16.859375,
527 8.453125,
528 0.5,
529 -9.875,
530 -15.75,
531 22.1875,
532 1.65625,
533 -16.609375,
534 9.3125,
535 9.625,
536 -48.96875,
537 -36.15625,
538 24.171875,
539 -9.046875,
540 14.171875,
541 -2.234375,
542 7.734375,
543 -35.09375,
544 3.28125,
545 6.796875,
546 5.265625,
547 7.75,
548 -8.640625,
549 1.453125,
550 -1.78125,
551 -18.28125,
552 31.9375,
553 23.46875,
554 -12.875,
555 -27.0625,
556 -16.734375,
557 2.96875,
558 -9.40625,
559 30.03125,
560 38.59375,
561 -20.34375,
562 -18.46875,
563 -12.625,
564 18.875,
565 0.125,
566 3.84375,
567 26.640625,
568 -42.703125,
569 -19.671875,
570 17.625,
571 16.875,
572 -2.03125,
573 8.203125,
574 13.265625,
575 -27.234375,
576 -20.46875,
577 20.828125,
578 17.03125,
579 10.708333015441895,
580 -13.479166984558105,
581 -6.104166507720947,
582 -16.52083396911621,
583 -0.3541666567325592,
584 36.47916793823242,
585 4.333333492279053,
586 -1.5416666269302368,
587 3.2708332538604736,
588 15.791666984558105,
589 27.25,
590 39.72916793823242,
591 32.3125,
592 8.208333015441895,
593 -9.3125,
594 11.640625,
595 -2.46875,
596 29.078125,
597 29.609375,
598 33.453125,
599 -7.6875,
600 -8.59375,
601 -1.515625,
602 -31.671875,
603 4.109375,
604 40.90625,
605 31.8125,
606 -12.15625,
607 16.46875,
608 -10.625,
609 -30.328125,
610 -5.921875,
611 53.703125,
612 24.171875,
613 -28.71875,
614 46.59375,
615 16.578125,
616 -16.40625,
617 3.171875,
618 31.46875,
619 4.921875,
620 -16.90625,
621 33.375,
622 -13.1875,
623 -8.53125,
624 9.578125,
625 33.4375,
626 -0.421875,
627 -22.125,
628 10.40625,
629 -7.734375,
630 16.640625,
631 21.0,
632 20.703125,
633 1.390625,
634 7.1875,
635 -20.28125,
636 -17.234375,
637 26.3125,
638 37.203125,
639 -6.765625,
640 -0.5,
641 14.78125,
642 -56.03125,
643 -40.265625,
644 11.546875,
645 33.453125,
646 2.03125,
647 -2.5,
648 -7.03125,
649 -16.0625,
650 -10.375,
651 12.34375,
652 18.03125,
653 -6.203125,
654 -20.53125,
655 24.46875,
656 7.15625,
657 -24.28125,
658 -9.34375,
659 20.609375,
660 -5.890625,
661 -33.21875,
662 11.625,
663 36.875,
664 7.078125,
665 -33.625,
666 22.546875,
667 14.515625,
668 -25.953125,
669 13.5625,
670 80.08333587646484,
671 13.875,
672 -36.04166793823242,
673 16.20833396911621,
674 15.0,
675 -34.20833206176758,
676 44.91666793823242,
677 -42.20833206176758,
678 16.20833396911621,
679 10.9375,
680 -7.833333492279053,
681 15.625,
682 0.8333333134651184,
683 1.25,
684 -42.4375,
685 24.109375,
686 -10.75,
687 -20.484375,
688 1.1875,
689 -8.75,
690 20.53125,
691 -51.75,
692 23.171875,
693 -3.640625,
694 -11.828125,
695 17.296875,
696 1.203125,
697 15.875,
698 -20.5625,
699 7.34375,
700 -44.125,
701 12.203125,
702 26.15625,
703 -4.109375,
704 13.8125,
705 -28.78125,
706 -4.1875,
707 -40.0,
708 9.703125,
709 15.5625,
710 -25.9375,
711 -4.25,
712 -10.8125,
713 5.71875,
714 -14.34375,
715 8.71875,
716 46.1875,
717 -23.796875,
718 -14.90625,
719 7.15625,
720 29.59375,
721 -27.609375,
722 4.5625,
723 34.265625,
724 -34.984375,
725 -42.5,
726 -4.25,
727 28.09375,
728 -4.25,
729 11.078125,
730 7.0625,
731 -30.0625,
732 -46.96875,
733 24.5625,
734 10.53125,
735 -9.0625,
736 9.484375,
737 21.03125,
738 -12.046875,
739 -37.125,
740 39.1875,
741 -14.078125,
742 5.8125,
743 38.796875,
744 -4.59375,
745 2.28125,
746 -63.0,
747 11.5625,
748 -14.796875,
749 16.515625,
750 21.90625,
751 -19.671875,
752 -6.203125,
753 -32.96875,
754 18.90625,
755 -5.34375,
756 25.78125,
757 7.21875,
758 -28.09375,
759 -17.234375,
760 1.375,
761 11.375,
762 16.14583396911621,
763 40.625,
764 4.8125,
765 -41.02083206176758,
766 -13.833333015441895,
767 10.541666984558105,
768 -74.66666412353516,
769 -25.45833396911621,
770 -2.0208332538604736,
771 9.0625,
772 -14.125,
773 26.125,
774 -12.375,
775 -72.9375,
776 -4.078125,
777 7.8125,
778 -2.0,
779 6.53125,
780 33.046875,
781 15.5625,
782 -43.375,
783 0.6875,
784 23.4375,
785 1.59375,
786 17.109375,
787 49.484375,
788 19.40625,
789 -48.40625,
790 24.5,
791 12.53125,
792 -0.78125,
793 24.953125,
794 47.96875,
795 33.21875,
796 -29.75,
797 48.21875,
798 5.109375,
799 30.78125,
800 25.171875,
801 35.15625,
802 36.9375,
803 -21.0,
804 19.984375,
805 2.765625,
806 44.265625,
807 21.5625,
808 38.5,
809 22.28125,
810 -15.21875,
811 25.296875,
812 -0.1875,
813 28.671875,
814 16.609375,
815 20.1875,
816 -11.375,
817 8.09375,
818 -13.453125,
819 25.921875,
820 11.703125,
821 25.390625,
822 8.734375,
823 -33.34375,
824 -6.9375,
825 -37.859375,
826 24.421875,
827 3.859375,
828 36.8125,
829 11.671875,
830 7.21875,
831 -18.84375,
832 -8.328125,
833 0.609375,
834 -0.40625,
835 30.984375,
836 -14.65625,
837 12.75,
838 -16.4375,
839 -8.65625,
840 -12.625,
841 14.640625,
842 47.9375,
843 -14.5,
844 34.84375,
845 -23.34375,
846 6.8125,
847 -6.90625,
848 -1.390625,
849 22.34375,
850 -3.859375,
851 48.71875,
852 -13.958333015441895,
853 17.58333396911621,
854 -2.6458332538604736,
855 -12.854166984558105,
856 7.041666507720947,
857 -15.291666984558105,
858 34.66666793823242,
859 45.875,
860 10.75,
861 -17.02083396911621,
862 -6.145833492279053,
863 -32.60416793823242,
864 23.64583396911621,
865 3.7916667461395264,
866 39.6875,
867 17.375,
868 -21.71875,
869 -23.5625,
870 -12.515625,
871 3.5,
872 9.71875,
873 14.65625,
874 -8.015625,
875 -32.171875,
876 -18.03125,
877 -12.96875,
878 -9.359375,
879 14.25,
880 39.8125,
881 -17.78125,
882 -31.828125,
883 -24.34375,
884 5.15625,
885 -1.484375,
886 19.375,
887 25.09375,
888 -2.953125,
889 -25.875,
890 -31.921875,
891 27.546875,
892 -17.515625,
893 10.625,
894 6.21875,
895 -30.5,
896 -23.875,
897 -10.03125,
898 28.890625,
899 -2.796875,
900 -18.03125,
901 16.90625,
902 -14.9375,
903 -22.703125,
904 -21.984375,
905 53.171875,
906 6.484375,
907 15.3125,
908 2.15625,
909 -14.1875,
910 -3.84375,
911 -6.328125,
912 25.46875,
913 5.75,
914 23.5625,
915 -12.59375,
916 -30.8125,
917 8.28125,
918 17.75,
919 36.265625,
920 19.796875,
921 36.25,
922 17.0625,
923 -24.625,
924 16.0625,
925 17.6875,
926 15.25,
927 11.453125,
928 82.59375,
929 7.71875,
930 -32.3125,
931 26.390625,
932 14.59375,
933 14.484375,
934 -5.296875,
935 65.28125,
936 -12.9375,
937 -22.15625,
938 35.859375,
939 11.59375,
940 34.40625,
941 -8.90625,
942 76.46875,
943 -3.7916667461395264,
944 -24.5,
945 31.39583396911621,
946 -2.2291667461395264,
947 21.70833396911621,
948 -11.520833015441895,
949 70.75,
950 -68.125,
951 22.0625,
952 -17.95833396911621,
953 -11.270833015441895,
954 -36.33333206176758,
955 -46.97916793823242,
956 -34.54166793823242,
957 -38.6875,
958 27.0,
959 -14.03125,
960 2.78125,
961 -35.6875,
962 -43.265625,
963 -29.8125,
964 -6.59375,
965 24.78125,
966 -25.75,
967 -14.34375,
968 -42.78125,
969 -36.171875,
970 -7.65625,
971 22.0625,
972 31.21875,
973 -17.609375,
974 7.109375,
975 -1.734375,
976 -14.96875,
977 -0.90625,
978 53.75,
979 37.515625,
980 -38.8125,
981 7.265625,
982 8.015625,
983 -2.09375,
984 2.15625,
985 46.25,
986 23.671875,
987 -33.75,
988 11.59375,
989 24.4375,
990 20.125,
991 23.96875,
992 11.75,
993 13.859375,
994 -36.84375,
995 8.953125,
996 63.65625,
997 33.984375,
998 22.6875,
999 -17.375,
1000 21.375,
1001 -29.328125,
1002 20.1875,
1003 33.609375,
1004 6.921875,
1005 59.9375,
1006 -44.1875,
1007 27.75,
1008 -5.640625,
1009 43.1875,
1010 0.203125,
1011 14.578125,
1012 47.125,
1013 -65.625,
1014 31.71875,
1015 -30.75,
1016 28.71875,
1017 -7.984375,
1018 -6.171875,
1019 23.9375,
1020 -42.71875,
1021 48.078125,
1022 -23.453125,
1023 31.59375,
1024 -24.859375,
1025 -4.59375,
1026 19.96875,
1027 -17.75,
1028 10.4375,
1029 -26.140625,
1030 6.046875,
1031 -30.9375,
1032 17.421875,
1033 -10.84375,
1034 -9.333333015441895,
1035 2.0625,
1036 -39.02083206176758,
1037 -23.70833396911621,
1038 -19.27083396911621,
1039 19.5,
1040 -16.5});
1041 op_avg_pool2d_out(
1042 self,
1043 kernel_size,
1044 stride,
1045 padding,
1046 ceil_mode,
1047 count_include_pad,
1048 divisor_override,
1049 out);
1050 EXPECT_TENSOR_CLOSE(out, out_expected);
1051 }
1052