• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 
20 class OpAvgPool2DOutTest : public OperatorTest {
21  protected:
op_avg_pool2d_out(const exec_aten::Tensor & self,exec_aten::ArrayRef<int64_t> kernel_size,exec_aten::ArrayRef<int64_t> stride,exec_aten::ArrayRef<int64_t> padding,bool ceil_mode,bool count_include_pad,exec_aten::optional<int64_t> divisor_override,exec_aten::Tensor & out)22   exec_aten::Tensor& op_avg_pool2d_out(
23       const exec_aten::Tensor& self,
24       exec_aten::ArrayRef<int64_t> kernel_size,
25       exec_aten::ArrayRef<int64_t> stride,
26       exec_aten::ArrayRef<int64_t> padding,
27       bool ceil_mode,
28       bool count_include_pad,
29       exec_aten::optional<int64_t> divisor_override,
30       exec_aten::Tensor& out) {
31     return torch::executor::aten::avg_pool2d_outf(
32         context_,
33         self,
34         kernel_size,
35         stride,
36         padding,
37         ceil_mode,
38         count_include_pad,
39         divisor_override,
40         out);
41   }
42 };
43 
TEST_F(OpAvgPool2DOutTest,SanityCheck4D)44 TEST_F(OpAvgPool2DOutTest, SanityCheck4D) {
45   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
46 
47   exec_aten::Tensor self = tfFloat.make(
48       {2, 3, 8, 8},
49       {-81.875, -37.25,  -99.75,  -60.5,   -1.375,  25.625,  -54.25,  -95.875,
50        48.0,    26.125,  -14.625, -5.0,    98.25,   1.75,    5.875,   -71.25,
51        -26.5,   36.0,    32.375,  -87.875, -43.125, -40.875, -49.0,   -33.0,
52        16.0,    -1.875,  -48.0,   77.0,    87.0,    58.25,   20.0,    10.875,
53        -37.875, -88.25,  97.75,   -98.25,  -52.0,   -92.75,  -89.375, -31.5,
54        -1.625,  -50.0,   -6.625,  -62.5,   -86.5,   -96.5,   85.0,    -94.375,
55        -3.625,  94.125,  12.25,   -33.875, -25.625, -20.625, -56.5,   -78.5,
56        8.25,    38.875,  74.5,    -22.75,  14.125,  46.5,    -28.625, -52.0,
57        -36.75,  -3.75,   48.875,  2.875,   76.125,  -42.125, -71.125, 22.5,
58        9.25,    98.5,    41.5,    65.625,  -82.25,  -85.75,  -20.75,  82.375,
59        3.25,    -74.5,   -14.25,  18.75,   60.25,   -21.875, 7.0,     -44.0,
60        -97.0,   -71.625, 65.75,   -89.5,   26.375,  14.125,  -99.75,  42.625,
61        -11.25,  56.625,  -63.0,   -34.25,  -76.125, 69.25,   22.875,  43.875,
62        -7.125,  -16.875, 93.875,  -48.25,  -57.375, -30.625, 11.75,   73.0,
63        -96.5,   14.5,    50.375,  -65.875, 73.625,  88.625,  -11.375, 41.75,
64        7.5,     62.125,  -1.25,   81.125,  -78.375, 62.875,  75.125,  -34.75,
65        50.5,    -2.125,  64.125,  -74.375, 22.875,  39.625,  -91.125, 96.125,
66        -31.75,  -70.875, 12.75,   6.5,     30.125,  12.75,   -77.375, -56.5,
67        -24.5,   -59.25,  88.125,  -26.375, -46.0,   62.75,   39.25,   -78.375,
68        67.875,  -57.75,  15.5,    80.5,    -70.5,   96.875,  -18.625, -53.625,
69        70.75,   70.375,  -68.125, -41.0,   -59.875, 74.625,  -82.125, 35.125,
70        -24.25,  33.25,   45.875,  -74.75,  37.0,    -0.875,  86.625,  -14.375,
71        45.25,   41.0,    -10.625, -49.875, 61.5,    -10.375, -39.0,   75.5,
72        -72.625, -29.0,   -53.875, -31.375, -57.375, 79.5,    -25.0,   -83.0,
73        -49.875, -74.875, 50.25,   8.25,    -35.5,   -0.875,  27.125,  85.875,
74        -95.625, 17.875,  45.25,   -56.75,  60.5,    -61.75,  47.625,  50.5,
75        83.0,    -63.25,  40.375,  -61.75,  -89.0,   -66.5,   6.0,     46.0,
76        -20.625, -15.625, -58.0,   88.25,   31.25,   -37.5,   -38.25,  68.5,
77        11.625,  34.125,  -39.5,   84.25,   27.0,    -85.625, 7.625,   86.375,
78        41.625,  28.0,    -41.125, 44.625,  19.625,  -43.125, -23.875, 54.5,
79        50.0,    -88.25,  29.0,    -77.5,   -82.125, -84.0,   8.75,    -32.875,
80        27.75,   -88.75,  -33.25,  -58.25,  41.5,    -84.0,   -53.375, -85.5,
81        -15.625, 39.75,   29.375,  -45.375, 96.5,    65.125,  34.875,  75.375,
82        -32.75,  -9.75,   -55.0,   38.0,    31.125,  -35.0,   -74.375, -61.0,
83        -63.75,  4.75,    88.75,   -83.25,  -19.75,  5.875,   88.375,  52.25,
84        70.125,  -81.5,   -56.375, -98.375, 97.625,  88.375,  22.625,  -100.0,
85        -75.0,   10.0,    41.0,    40.375,  12.0,    72.125,  31.875,  -22.25,
86        -63.875, 10.5,    -81.25,  4.25,    43.5,    -44.0,   71.5,    -29.625,
87        -3.5,    -91.5,   45.375,  88.875,  -93.125, -50.25,  -6.375,  -88.875,
88        -2.375,  -17.75,  49.25,   -14.75,  14.75,   -2.25,   51.25,   -57.875,
89        43.875,  87.0,    86.25,   -95.125, 11.75,   -26.5,   29.875,  89.25,
90        -18.5,   55.375,  74.25,   -64.0,   51.875,  78.0,    82.625,  -34.0,
91        -0.875,  -69.375, -90.875, -83.5,   13.625,  46.875,  8.375,   16.875,
92        96.75,   31.25,   45.625,  -2.625,  -71.0,   -62.375, 31.25,   -23.5,
93        66.0,    51.375,  -45.25,  43.375,  49.5,    12.625,  -73.875, 26.375,
94        29.0,    -86.5,   -55.375, 88.75,   20.0,    90.0,    28.75,   -12.875,
95        -37.75,  -1.875,  -28.125, 96.75,   -66.75,  48.375,  -79.25,  8.0,
96        -14.25,  -8.0,    51.75,   28.375,  32.0,    -50.875, 53.0,    -81.75});
97   ::std::vector<int64_t> kernel_size_vec = {2, 3};
98   exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
99       kernel_size_vec.data(), kernel_size_vec.size());
100   ::std::vector<int64_t> stride_vec = {3, 2};
101   exec_aten::ArrayRef<int64_t> stride =
102       exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
103   ::std::vector<int64_t> padding_vec = {1, 1};
104   exec_aten::ArrayRef<int64_t> padding =
105       exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
106   bool ceil_mode = false;
107   bool count_include_pad = true;
108   exec_aten::optional<int64_t> divisor_override;
109   exec_aten::Tensor out = tfFloat.zeros({2, 3, 3, 4});
110   exec_aten::Tensor out_expected = tfFloat.make(
111       {2, 3, 3, 4},
112       {-19.85416603088379,
113        -32.91666793823242,
114        -6.041666507720947,
115        -20.75,
116        3.9375,
117        1.2708333730697632,
118        8.395833015441895,
119        -5.625,
120        6.479166507720947,
121        -7.770833492279053,
122        -54.27083206176758,
123        -43.58333206176758,
124        -6.75,
125        8.0,
126        6.145833492279053,
127        -15.125,
128        -39.97916793823242,
129        -27.5625,
130        1.3541666269302368,
131        -16.97916603088379,
132        -17.66666603088379,
133        4.625,
134        -6.645833492279053,
135        28.85416603088379,
136        8.0625,
137        -2.0625,
138        -1.9791666269302368,
139        7.4375,
140        -12.270833015441895,
141        6.791666507720947,
142        16.20833396911621,
143        8.041666984558105,
144        15.875,
145        -2.5208332538604736,
146        -6.229166507720947,
147        16.25,
148        -20.79166603088379,
149        -2.7291667461395264,
150        -4.6875,
151        18.6875,
152        -2.75,
153        -11.666666984558105,
154        -22.54166603088379,
155        -3.625,
156        5.229166507720947,
157        -17.54166603088379,
158        -37.08333206176758,
159        -20.10416603088379,
160        4.020833492279053,
161        3.9583332538604736,
162        19.375,
163        29.22916603088379,
164        -11.729166984558105,
165        -37.66666793823242,
166        -1.5833333730697632,
167        26.25,
168        -24.72916603088379,
169        -3.9583332538604736,
170        -8.458333015441895,
171        -24.60416603088379,
172        21.8125,
173        13.020833015441895,
174        -18.3125,
175        15.4375,
176        9.625,
177        -28.25,
178        -26.5,
179        2.9166667461395264,
180        -16.1875,
181        2.2708332538604736,
182        46.1875,
183        13.833333015441895});
184   op_avg_pool2d_out(
185       self,
186       kernel_size,
187       stride,
188       padding,
189       ceil_mode,
190       count_include_pad,
191       divisor_override,
192       out);
193   EXPECT_TENSOR_CLOSE(out, out_expected);
194 }
195 
TEST_F(OpAvgPool2DOutTest,SanityCheck4DDivisorOverride)196 TEST_F(OpAvgPool2DOutTest, SanityCheck4DDivisorOverride) {
197   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
198 
199   exec_aten::Tensor self = tfFloat.make(
200       {2, 3, 8, 8},
201       {13.25,   87.125,  -61.875, 52.875,  -74.5,   37.5,    -62.125, 25.375,
202        -34.375, 68.5,    87.875,  91.125,  -22.75,  96.875,  79.25,   38.125,
203        6.5,     27.625,  58.875,  41.5,    -82.875, -77.875, -3.25,   22.0,
204        -73.375, -56.875, 56.25,   -6.375,  -16.375, 40.0,    22.625,  61.0,
205        -78.375, -87.5,   23.25,   -55.625, 92.75,   -90.0,   -55.625, 16.25,
206        -46.5,   91.75,   74.0,    -45.0,   -34.875, -46.375, 38.0,    63.375,
207        22.625,  16.25,   91.125,  68.375,  98.25,   34.0,    91.25,   89.375,
208        1.0,     22.875,  -93.75,  7.5,     -81.375, 1.125,   -15.875, 16.75,
209        29.25,   -14.0,   77.375,  -64.625, -22.375, -97.25,  -26.75,  9.125,
210        -3.5,    -64.5,   35.625,  58.25,   4.125,   85.0,    -39.125, -98.25,
211        -0.625,  -29.5,   -35.375, 90.375,  37.625,  57.875,  -59.625, -81.875,
212        73.5,    -37.375, 56.75,   88.25,   -52.75,  -4.75,   80.625,  -38.375,
213        -55.375, -76.625, -59.875, -41.5,   76.875,  45.375,  24.625,  57.5,
214        58.25,   64.375,  -38.125, -39.875, 71.125,  2.5,     -14.75,  21.75,
215        -8.375,  5.0,     71.625,  29.875,  79.375,  -13.125, -22.625, 96.875,
216        7.5,     64.5,    1.0,     -40.75,  -56.0,   -94.75,  -89.5,   -71.25,
217        78.125,  -5.875,  93.125,  -20.125, -73.875, -34.625, 51.5,    46.5,
218        -88.25,  3.25,    26.0,    -32.375, -77.625, -51.875, 0.75,    85.75,
219        -15.875, 84.625,  -44.75,  -87.625, 17.25,   -59.25,  40.25,   -0.5,
220        14.625,  -80.125, -58.125, 17.25,   -7.5,    68.875,  -99.5,   -82.75,
221        75.5,    -95.25,  -63.25,  55.25,   -17.125, 64.875,  -56.25,  39.75,
222        -35.25,  18.0,    21.125,  -61.125, -35.5,   12.75,   -94.125, -17.125,
223        -41.375, 22.625,  -69.25,  51.375,  7.25,    14.625,  70.25,   -63.5,
224        -10.0,   53.625,  -46.75,  77.5,    -78.5,   -67.25,  -53.375, -33.5,
225        23.375,  77.625,  64.125,  -56.375, -59.25,  -89.125, 16.25,   -2.125,
226        -82.125, 72.75,   45.5,    -36.375, -1.125,  -7.25,   -44.5,   27.125,
227        -1.75,   16.5,    71.0,    81.25,   4.125,   74.0,    45.0,    -4.75,
228        68.125,  30.625,  11.625,  -88.5,   48.125,  36.875,  -32.0,   9.0,
229        -67.5,   -17.125, -87.5,   29.875,  23.875,  -33.0,   -6.625,  0.5,
230        28.25,   54.75,   87.5,    28.0,    -48.5,   -98.0,   -55.0,   62.25,
231        -93.125, -41.5,   94.25,   -28.375, -23.75,  -50.625, 90.375,  -91.0,
232        -67.75,  -11.625, 24.5,    49.5,    21.0,    70.25,   66.0,    69.75,
233        43.25,   -14.875, -93.875, -15.75,  94.5,    69.125,  -67.625, 15.875,
234        -26.875, 41.625,  -64.375, 23.0,    13.375,  78.5,    -88.875, 17.625,
235        -52.375, 55.25,   41.875,  62.625,  29.0,    53.875,  -93.625, -93.375,
236        -70.375, -56.125, 9.0,     60.625,  4.375,   66.625,  -46.25,  43.125,
237        93.625,  85.125,  -91.375, 49.375,  40.25,   -19.75,  13.0,    69.5,
238        -42.375, 47.25,   7.0,     -93.625, -80.0,   -57.875, -94.875, 9.375,
239        -2.5,    -42.25,  -8.0,    92.5,    6.875,   -11.0,   80.625,  27.125,
240        -75.875, -57.375, -36.625, 17.375,  33.125,  -49.0,   -84.0,   -45.5,
241        -9.625,  -16.125, 8.75,    -1.0,    -16.625, 2.5,     -56.0,   -52.0,
242        85.0,    63.5,    -50.625, 7.375,   -90.0,   -75.875, -82.0,   82.75,
243        16.125,  -35.25,  -42.875, -1.375,  39.875,  -39.625, -66.875, 29.875,
244        -80.75,  66.5,    -50.125, 13.5,    -77.875, 6.125,   -91.625, -44.125,
245        -50.125, 59.625,  -32.5,   -2.0,    -83.375, 21.0,    49.875,  48.125,
246        -93.375, -54.875, -61.125, 96.375,  91.25,   -2.375,  -33.25,  48.125,
247        -58.125, -50.75,  -50.875, 8.375,   35.625,  -72.5,   -76.125, -33.25,
248        -18.75,  -71.0,   76.625,  -11.25,  -3.0,    -38.625, -66.375, -25.0});
249   ::std::vector<int64_t> kernel_size_vec = {2, 3};
250   exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
251       kernel_size_vec.data(), kernel_size_vec.size());
252   ::std::vector<int64_t> stride_vec = {3, 2};
253   exec_aten::ArrayRef<int64_t> stride =
254       exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
255   ::std::vector<int64_t> padding_vec = {1, 1};
256   exec_aten::ArrayRef<int64_t> padding =
257       exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
258   bool ceil_mode = false;
259   bool count_include_pad = true;
260   exec_aten::optional<int64_t> divisor_override =
261       exec_aten::optional<int64_t>(10);
262   exec_aten::Tensor out = tfFloat.zeros({2, 3, 3, 4});
263   exec_aten::Tensor out_expected = tfFloat.make(
264       {2, 3, 3, 4},
265       {10.037500381469727,
266        7.8125,
267        1.587499976158142,
268        0.07500000298023224,
269        -9.612500190734863,
270        12.100000381469727,
271        -10.199999809265137,
272        6.449999809265137,
273        8.412500381469727,
274        29.649999618530273,
275        7.4375,
276        26.962499618530273,
277        1.524999976158142,
278        -0.125,
279        -18.424999237060547,
280        -11.487500190734863,
281        0.6000000238418579,
282        13.3125,
283        21.662500381469727,
284        -4.612500190734863,
285        11.925000190734863,
286        9.287500381469727,
287        12.987500190734863,
288        7.0625,
289        7.224999904632568,
290        6.712500095367432,
291        -12.862500190734863,
292        6.337500095367432,
293        0.32499998807907104,
294        -16.875,
295        -5.099999904632568,
296        -13.287500381469727,
297        -3.5999999046325684,
298        -1.725000023841858,
299        -1.0625,
300        -7.712500095367432,
301        10.100000381469727,
302        8.537500381469727,
303        -20.475000381469727,
304        -7.5,
305        11.350000381469727,
306        12.25,
307        15.587499618530273,
308        12.8125,
309        -5.162499904632568,
310        19.462499618530273,
311        -22.125,
312        -14.199999809265137,
313        2.8375000953674316,
314        -12.449999809265137,
315        14.787500381469727,
316        1.7374999523162842,
317        -12.362500190734863,
318        17.325000762939453,
319        27.712499618530273,
320        -6.962500095367432,
321        -3.987499952316284,
322        0.2874999940395355,
323        -14.3125,
324        -4.662499904632568,
325        -2.575000047683716,
326        -0.8374999761581421,
327        -1.5125000476837158,
328        -10.550000190734863,
329        -3.3375000953674316,
330        -4.962500095367432,
331        -5.9375,
332        -20.625,
333        -25.712499618530273,
334        -11.287500381469727,
335        15.675000190734863,
336        -16.9375});
337   op_avg_pool2d_out(
338       self,
339       kernel_size,
340       stride,
341       padding,
342       ceil_mode,
343       count_include_pad,
344       divisor_override,
345       out);
346   EXPECT_TENSOR_CLOSE(out, out_expected);
347 }
348 
TEST_F(OpAvgPool2DOutTest,SanityCheck4DCeilModeNoIncludePadding)349 TEST_F(OpAvgPool2DOutTest, SanityCheck4DCeilModeNoIncludePadding) {
350   torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
351 
352   exec_aten::Tensor self = tfFloat.make(
353       {2, 3, 14, 12},
354       {26.375,  -17.0,   63.5,    83.0,    21.375,  -46.5,   -69.125, 99.875,
355        -67.125, -76.0,   -1.125,  -2.625,  -48.0,   -1.5,    -92.75,  95.5,
356        91.5,    32.625,  -35.25,  12.125,  55.0,    -64.875, -53.125, -20.0,
357        -5.875,  -0.25,   -3.5,    4.875,   81.625,  92.25,   98.25,   -47.375,
358        78.875,  -97.5,   10.375,  55.75,   65.375,  97.375,  -93.625, 51.5,
359        62.75,   -37.625, 1.875,   -73.125, 78.75,   -62.25,  20.625,  89.375,
360        16.75,   43.0,    98.375,  17.0,    40.125,  -78.0,   -24.375, 9.875,
361        2.0,     -57.125, -65.375, -86.0,   -98.375, 68.75,   70.375,  16.25,
362        -24.625, 12.0,    -55.0,   -17.0,   81.375,  12.0,    62.875,  -79.0,
363        46.0,    -12.625, -42.0,   3.125,   -6.0,    88.5,    -42.25,  82.625,
364        -96.875, 51.375,  41.875,  36.125,  -27.375, 33.375,  -81.75,  6.875,
365        -39.5,   -57.75,  24.0,    35.375,  -22.875, 96.375,  -65.0,   -67.0,
366        -64.875, 86.625,  70.625,  16.5,    -45.0,   55.375,  88.5,    -42.0,
367        -38.5,   -71.5,   -66.125, -30.5,   59.375,  -8.25,   8.375,   85.5,
368        20.625,  -58.875, -35.5,   -80.125, 93.25,   -15.5,   40.125,  54.25,
369        -40.25,  85.125,  61.375,  84.25,   58.5,    -60.375, -58.375, -81.5,
370        -80.125, -84.0,   31.75,   55.125,  8.125,   -56.625, -7.0,    12.625,
371        75.75,   -88.125, -5.375,  60.0,    21.25,   62.125,  2.125,   -3.375,
372        -26.75,  44.25,   -96.5,   -87.5,   -36.625, -36.125, 1.125,   -42.0,
373        -48.125, 24.0,    80.375,  -38.5,   50.75,   -0.625,  35.625,  31.875,
374        -32.75,  70.375,  -41.0,   -20.75,  27.5,    60.75,   -10.5,   54.875,
375        68.25,   40.0,    0.875,   39.375,  67.375,  31.25,   79.25,   37.625,
376        33.0,    10.625,  -47.5,   -21.0,   -2.625,  81.0,    -42.25,  23.625,
377        10.375,  1.625,   11.0,    14.625,  19.5,    -36.75,  99.875,  86.125,
378        -70.25,  26.375,  -86.375, -21.0,   -25.0,   51.75,   -11.375, 58.875,
379        74.75,   77.625,  90.0,    -40.5,   -32.625, 40.5,    33.0,    -80.625,
380        -33.875, 38.5,    30.625,  -55.375, 53.875,  -1.25,   75.0,    -55.375,
381        71.125,  -34.625, -29.75,  -92.75,  -34.125, -17.25,  -72.0,   79.5,
382        81.5,    -14.375, -35.625, -38.875, 97.625,  -3.0,    -31.125, 31.25,
383        13.5,    -3.625,  -64.0,   72.0,    64.5,    6.875,   -4.875,  19.875,
384        50.25,   80.75,   76.875,  91.75,   -26.375, 94.125,  19.0,    4.375,
385        -48.625, -69.375, 83.0,    6.75,    -85.5,   -64.75,  -99.875, -88.375,
386        36.875,  24.375,  96.0,    24.25,   -10.0,   39.75,   -8.75,   -76.25,
387        -20.75,  -96.25,  75.5,    23.625,  50.875,  76.375,  -74.25,  2.0,
388        57.125,  31.875,  -67.375, 78.375,  -25.125, -27.75,  -82.375, 37.25,
389        84.875,  8.25,    53.75,   -14.375, -68.875, -7.875,  -5.25,   50.25,
390        -92.75,  -90.625, 64.0,    -32.0,   -20.75,  47.625,  35.5,    -17.875,
391        44.0,    35.625,  -38.0,   -80.5,   74.375,  -11.125, 85.625,  47.25,
392        -92.375, 2.375,   -5.375,  31.875,  -83.5,   -41.0,   -72.25,  49.75,
393        72.125,  -92.875, -39.125, -19.75,  -79.25,  90.375,  -67.625, 71.375,
394        -9.75,   -80.125, -56.875, 27.0,    93.75,   55.75,   85.0,    -60.5,
395        -11.625, 11.5,    66.0,    7.875,   72.125,  -15.25,  60.25,   58.0,
396        -16.125, -41.875, -16.875, 32.25,   -69.75,  81.5,    -65.125, -28.0,
397        59.25,   2.375,   -8.25,   60.0,    -69.625, 3.0,     59.875,  93.25,
398        41.875,  -38.625, -86.375, -18.25,  86.375,  29.875,  5.375,   -63.0,
399        -40.875, 42.0,    51.125,  -35.125, 3.125,   31.875,  29.75,   6.5,
400        -12.125, -86.875, 62.5,    6.75,    -43.125, 37.0,    58.625,  -89.5,
401        -62.125, -21.0,   -95.875, -42.125, -42.125, -52.125, -22.875, 78.375,
402        -53.375, -85.625, 19.375,  -28.0,   47.375,  63.875,  21.75,   85.75,
403        74.375,  27.75,   46.0,    41.375,  55.125,  -4.375,  -59.375, -93.375,
404        -95.375, 98.125,  -30.875, 69.25,   69.75,   -78.375, 71.125,  -71.25,
405        -73.75,  69.0,    -68.125, -49.0,   50.0,    51.25,   -9.625,  -4.625,
406        -85.75,  -99.125, -99.875, -65.5,   28.75,   95.25,   79.625,  16.5,
407        37.125,  -62.75,  -62.0,   65.25,   95.5,    -15.5,   -42.375, 35.75,
408        18.5,    89.25,   35.5,    -90.375, 3.625,   -2.875,  55.25,   63.5,
409        1.25,    2.625,   -18.375, -69.0,   9.5,     -31.375, -44.375, 35.125,
410        -37.0,   20.875,  98.5,    -43.0,   -35.625, -8.125,  40.25,   -89.125,
411        41.5,    -100.0,  -39.625, -76.5,   39.0,    6.125,   22.75,   83.75,
412        -62.375, 31.75,   -86.625, -26.125, 87.25,   -69.0,   47.0,    99.125,
413        73.5,    24.875,  84.875,  13.625,  -57.875, 67.125,  -10.375, -67.75,
414        -92.0,   52.125,  66.875,  -62.25,  61.125,  -46.875, -35.875, 35.625,
415        -91.5,   -95.25,  11.625,  51.125,  38.875,  -1.875,  1.75,    11.125,
416        61.125,  -58.25,  60.125,  -62.25,  -83.75,  -41.375, -14.75,  48.25,
417        -83.25,  -91.75,  13.625,  -12.75,  30.25,   -54.125, 60.0,    -25.875,
418        -45.0,   -33.75,  85.875,  46.0,    -74.75,  -93.375, 64.25,   -31.5,
419        -30.125, 85.25,   87.0,    -1.75,   -55.5,   32.0,    -67.25,  3.125,
420        -66.0,   16.75,   -62.25,  -59.375, 91.375,  -49.625, -74.125, 52.5,
421        -9.125,  70.5,    69.375,  -86.25,  -67.75,  82.625,  37.5,    1.5,
422        73.125,  -68.875, -1.5,    42.875,  94.125,  98.125,  9.5,     99.375,
423        35.0,    -17.625, -22.375, 75.0,    67.5,    47.25,   -12.625, 30.375,
424        -16.625, 97.125,  86.5,    61.375,  -94.875, 99.125,  62.25,   -83.875,
425        -65.0,   69.125,  84.125,  -84.75,  90.25,   -77.0,   29.625,  58.375,
426        8.625,   76.875,  67.375,  -89.875, 62.5,    51.0,    77.75,   -0.375,
427        45.5,    86.875,  -49.5,   -71.375, -32.75,  -76.375, -29.375, 69.25,
428        -13.375, 95.875,  -58.375, 12.75,   95.375,  83.375,  51.0,    40.75,
429        58.125,  -26.875, 29.375,  81.625,  37.25,   -91.375, 1.25,    20.625,
430        -46.5,   24.875,  12.25,   -73.25,  -1.625,  -69.625, -79.0,   60.125,
431        -0.125,  41.875,  -24.375, 66.5,    9.25,    -72.625, -66.375, -29.5,
432        -51.5,   28.75,   -79.75,  16.125,  -55.5,   80.25,   -14.25,  37.875,
433        98.625,  -6.875,  67.75,   90.875,  -80.375, 71.625,  58.875,  -63.0,
434        -71.625, 33.5,    -30.125, -33.25,  94.75,   -30.5,   -45.75,  62.875,
435        67.75,   -22.125, 22.0,    48.75,   -35.75,  83.625,  -53.375, 55.875,
436        53.875,  38.0,    0.375,   15.125,  -29.25,  55.0,    -79.875, 40.5,
437        65.25,   -67.625, -43.125, -47.625, -81.375, 40.5,    -94.375, 26.0,
438        70.0,    50.25,   53.0,    26.75,   0.875,   4.0,     17.625,  -74.25,
439        -28.5,   56.5,    71.25,   -67.0,   -9.875,  15.25,   -38.125, -61.375,
440        -22.25,  14.625,  -47.25,  -7.125,  34.75,   10.625,  -10.25,  0.5,
441        77.5,    36.0,    -51.875, -59.5,   13.375,  45.75,   -71.625, -29.75,
442        -90.75,  -38.25,  52.0,    77.875,  21.125,  56.625,  17.875,  -32.625,
443        -39.0,   -68.75,  -82.875, -1.25,   96.75,   -91.375, -22.5,   27.5,
444        -30.125, -23.875, -76.0,   -93.5,   37.5,    78.375,  -12.5,   -84.875,
445        -21.5,   29.25,   -4.375,  -48.875, 90.75,   -77.375, -23.625, -5.75,
446        -75.125, -92.125, 9.0,     96.5,    76.125,  47.5,    15.875,  21.0,
447        18.625,  94.25,   8.5,     -67.75,  69.25,   -72.0,   -14.5,   90.5,
448        -31.875, -20.875, -93.625, 42.875,  -54.375, -70.5,   -75.375, -6.75,
449        -48.875, 21.75,   1.75,    45.5,    60.75,   -9.75,   13.625,  -87.125,
450        12.625,  71.125,  -46.5,   -31.25,  -15.375, -42.75,  13.0,    26.25,
451        61.625,  81.0,    18.125,  84.5,    31.75,   -14.125, -80.875, 18.875,
452        51.125,  0.625,   41.5,    -81.625, 32.625,  -19.375, 76.875,  54.0,
453        -40.375, -78.25,  48.0,    77.5,    21.0,    71.25,   34.875,  73.25,
454        71.75,   87.25,   -89.375, 93.625,  64.25,   -81.375, -15.0,   -58.125,
455        64.75,   -49.125, 72.125,  -62.0,   0.125,   -56.0,   -6.875,  98.25,
456        -24.75,  48.375,  -85.25,  -62.25,  98.25,   12.75,   -67.25,  72.125,
457        9.625,   8.5,     -43.375, 15.25,   -50.875, 7.875,   -21.625, 97.25,
458        48.5,    33.25,   -15.125, 41.125,  69.25,   -35.875, 64.5,    98.75,
459        -78.125, 44.625,  75.75,   41.625,  -39.375, -4.875,  87.5,    -67.25,
460        -23.25,  -52.875, -44.875, -70.75,  -63.75,  86.25,   -44.0,   -92.375,
461        -22.75,  -70.75,  -84.875, -98.125, -86.75,  35.25,   -77.5,   -78.0,
462        -62.5,   -88.5,   58.25,   -65.125, 70.25,   -5.875,  11.25,   47.125,
463        10.25,   -91.75,  -50.125, 45.125,  49.625,  57.25,   26.375,  -74.75,
464        70.25,   38.625,  51.25,   12.375,  -79.875, -76.25,  12.0,    -15.625,
465        50.25,   72.125,  30.5,    -68.625, -22.875, 5.125,   -59.5,   -82.75,
466        -64.5,   46.75,   -87.75,  17.875,  50.875,  69.625,  24.125,  -25.0,
467        -25.0,   -5.0,    21.0,    51.875,  91.625,  50.0,    77.375,  -51.0,
468        64.25,   41.0,    -20.875, -89.75,  -74.75,  48.375,  -41.75,  99.25,
469        36.125,  10.75,   -49.625, 57.375,  19.625,  -43.75,  16.625,  -48.375,
470        84.375,  31.125,  93.375,  76.125,  -12.25,  98.5,    15.0,    71.625,
471        -87.75,  98.625,  -74.5,   -42.75,  -73.5,   -30.875, -44.625, 70.5,
472        96.0,    21.5,    48.375,  12.75,   -65.625, 56.5,    97.375,  29.5,
473        -19.375, 94.625,  11.25,   -0.375,  -96.5,   -48.75,  -40.375, 98.0,
474        -43.0,   20.75,   50.375,  97.875,  -72.875, 92.125,  98.5,    -74.0,
475        -57.875, 56.375,  -34.0,   6.125,   -66.125, 71.75,   -67.125, -79.0,
476        -85.875, -88.625, 97.375,  16.875,  -18.5,   -75.125, 22.625,  -21.125,
477        3.875,   57.25,   97.75,   -11.0,   -46.875, 16.75,   -69.25,  99.75,
478        -68.25,  99.75,   -17.25,  -3.125,  34.25,   -54.125, -93.125, 65.0,
479        -76.375, -20.625, -77.875, -65.625, -79.875, 28.75,   58.25,   -25.25});
480   ::std::vector<int64_t> kernel_size_vec = {4, 2};
481   exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
482       kernel_size_vec.data(), kernel_size_vec.size());
483   ::std::vector<int64_t> stride_vec = {1, 2};
484   exec_aten::ArrayRef<int64_t> stride =
485       exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
486   ::std::vector<int64_t> padding_vec = {1, 1};
487   exec_aten::ArrayRef<int64_t> padding =
488       exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
489   bool ceil_mode = true;
490   bool count_include_pad = false;
491   exec_aten::optional<int64_t> divisor_override;
492   exec_aten::Tensor out = tfFloat.zeros({2, 3, 13, 7});
493   exec_aten::Tensor out_expected = tfFloat.make(
494       {2, 3, 13, 7},
495       {-9.166666984558105,
496        -8.583333015441895,
497        62.97916793823242,
498        12.041666984558105,
499        21.89583396911621,
500        -47.04166793823242,
501        11.041666984558105,
502        9.46875,
503        -5.96875,
504        61.515625,
505        4.5625,
506        17.125,
507        -40.484375,
508        30.625,
509        7.0625,
510        5.890625,
511        55.609375,
512        6.21875,
513        14.515625,
514        -46.15625,
515        9.78125,
516        -5.53125,
517        35.0625,
518        31.1875,
519        1.171875,
520        14.171875,
521        -22.046875,
522        -4.96875,
523        7.4375,
524        28.703125,
525        20.015625,
526        -16.859375,
527        8.453125,
528        0.5,
529        -9.875,
530        -15.75,
531        22.1875,
532        1.65625,
533        -16.609375,
534        9.3125,
535        9.625,
536        -48.96875,
537        -36.15625,
538        24.171875,
539        -9.046875,
540        14.171875,
541        -2.234375,
542        7.734375,
543        -35.09375,
544        3.28125,
545        6.796875,
546        5.265625,
547        7.75,
548        -8.640625,
549        1.453125,
550        -1.78125,
551        -18.28125,
552        31.9375,
553        23.46875,
554        -12.875,
555        -27.0625,
556        -16.734375,
557        2.96875,
558        -9.40625,
559        30.03125,
560        38.59375,
561        -20.34375,
562        -18.46875,
563        -12.625,
564        18.875,
565        0.125,
566        3.84375,
567        26.640625,
568        -42.703125,
569        -19.671875,
570        17.625,
571        16.875,
572        -2.03125,
573        8.203125,
574        13.265625,
575        -27.234375,
576        -20.46875,
577        20.828125,
578        17.03125,
579        10.708333015441895,
580        -13.479166984558105,
581        -6.104166507720947,
582        -16.52083396911621,
583        -0.3541666567325592,
584        36.47916793823242,
585        4.333333492279053,
586        -1.5416666269302368,
587        3.2708332538604736,
588        15.791666984558105,
589        27.25,
590        39.72916793823242,
591        32.3125,
592        8.208333015441895,
593        -9.3125,
594        11.640625,
595        -2.46875,
596        29.078125,
597        29.609375,
598        33.453125,
599        -7.6875,
600        -8.59375,
601        -1.515625,
602        -31.671875,
603        4.109375,
604        40.90625,
605        31.8125,
606        -12.15625,
607        16.46875,
608        -10.625,
609        -30.328125,
610        -5.921875,
611        53.703125,
612        24.171875,
613        -28.71875,
614        46.59375,
615        16.578125,
616        -16.40625,
617        3.171875,
618        31.46875,
619        4.921875,
620        -16.90625,
621        33.375,
622        -13.1875,
623        -8.53125,
624        9.578125,
625        33.4375,
626        -0.421875,
627        -22.125,
628        10.40625,
629        -7.734375,
630        16.640625,
631        21.0,
632        20.703125,
633        1.390625,
634        7.1875,
635        -20.28125,
636        -17.234375,
637        26.3125,
638        37.203125,
639        -6.765625,
640        -0.5,
641        14.78125,
642        -56.03125,
643        -40.265625,
644        11.546875,
645        33.453125,
646        2.03125,
647        -2.5,
648        -7.03125,
649        -16.0625,
650        -10.375,
651        12.34375,
652        18.03125,
653        -6.203125,
654        -20.53125,
655        24.46875,
656        7.15625,
657        -24.28125,
658        -9.34375,
659        20.609375,
660        -5.890625,
661        -33.21875,
662        11.625,
663        36.875,
664        7.078125,
665        -33.625,
666        22.546875,
667        14.515625,
668        -25.953125,
669        13.5625,
670        80.08333587646484,
671        13.875,
672        -36.04166793823242,
673        16.20833396911621,
674        15.0,
675        -34.20833206176758,
676        44.91666793823242,
677        -42.20833206176758,
678        16.20833396911621,
679        10.9375,
680        -7.833333492279053,
681        15.625,
682        0.8333333134651184,
683        1.25,
684        -42.4375,
685        24.109375,
686        -10.75,
687        -20.484375,
688        1.1875,
689        -8.75,
690        20.53125,
691        -51.75,
692        23.171875,
693        -3.640625,
694        -11.828125,
695        17.296875,
696        1.203125,
697        15.875,
698        -20.5625,
699        7.34375,
700        -44.125,
701        12.203125,
702        26.15625,
703        -4.109375,
704        13.8125,
705        -28.78125,
706        -4.1875,
707        -40.0,
708        9.703125,
709        15.5625,
710        -25.9375,
711        -4.25,
712        -10.8125,
713        5.71875,
714        -14.34375,
715        8.71875,
716        46.1875,
717        -23.796875,
718        -14.90625,
719        7.15625,
720        29.59375,
721        -27.609375,
722        4.5625,
723        34.265625,
724        -34.984375,
725        -42.5,
726        -4.25,
727        28.09375,
728        -4.25,
729        11.078125,
730        7.0625,
731        -30.0625,
732        -46.96875,
733        24.5625,
734        10.53125,
735        -9.0625,
736        9.484375,
737        21.03125,
738        -12.046875,
739        -37.125,
740        39.1875,
741        -14.078125,
742        5.8125,
743        38.796875,
744        -4.59375,
745        2.28125,
746        -63.0,
747        11.5625,
748        -14.796875,
749        16.515625,
750        21.90625,
751        -19.671875,
752        -6.203125,
753        -32.96875,
754        18.90625,
755        -5.34375,
756        25.78125,
757        7.21875,
758        -28.09375,
759        -17.234375,
760        1.375,
761        11.375,
762        16.14583396911621,
763        40.625,
764        4.8125,
765        -41.02083206176758,
766        -13.833333015441895,
767        10.541666984558105,
768        -74.66666412353516,
769        -25.45833396911621,
770        -2.0208332538604736,
771        9.0625,
772        -14.125,
773        26.125,
774        -12.375,
775        -72.9375,
776        -4.078125,
777        7.8125,
778        -2.0,
779        6.53125,
780        33.046875,
781        15.5625,
782        -43.375,
783        0.6875,
784        23.4375,
785        1.59375,
786        17.109375,
787        49.484375,
788        19.40625,
789        -48.40625,
790        24.5,
791        12.53125,
792        -0.78125,
793        24.953125,
794        47.96875,
795        33.21875,
796        -29.75,
797        48.21875,
798        5.109375,
799        30.78125,
800        25.171875,
801        35.15625,
802        36.9375,
803        -21.0,
804        19.984375,
805        2.765625,
806        44.265625,
807        21.5625,
808        38.5,
809        22.28125,
810        -15.21875,
811        25.296875,
812        -0.1875,
813        28.671875,
814        16.609375,
815        20.1875,
816        -11.375,
817        8.09375,
818        -13.453125,
819        25.921875,
820        11.703125,
821        25.390625,
822        8.734375,
823        -33.34375,
824        -6.9375,
825        -37.859375,
826        24.421875,
827        3.859375,
828        36.8125,
829        11.671875,
830        7.21875,
831        -18.84375,
832        -8.328125,
833        0.609375,
834        -0.40625,
835        30.984375,
836        -14.65625,
837        12.75,
838        -16.4375,
839        -8.65625,
840        -12.625,
841        14.640625,
842        47.9375,
843        -14.5,
844        34.84375,
845        -23.34375,
846        6.8125,
847        -6.90625,
848        -1.390625,
849        22.34375,
850        -3.859375,
851        48.71875,
852        -13.958333015441895,
853        17.58333396911621,
854        -2.6458332538604736,
855        -12.854166984558105,
856        7.041666507720947,
857        -15.291666984558105,
858        34.66666793823242,
859        45.875,
860        10.75,
861        -17.02083396911621,
862        -6.145833492279053,
863        -32.60416793823242,
864        23.64583396911621,
865        3.7916667461395264,
866        39.6875,
867        17.375,
868        -21.71875,
869        -23.5625,
870        -12.515625,
871        3.5,
872        9.71875,
873        14.65625,
874        -8.015625,
875        -32.171875,
876        -18.03125,
877        -12.96875,
878        -9.359375,
879        14.25,
880        39.8125,
881        -17.78125,
882        -31.828125,
883        -24.34375,
884        5.15625,
885        -1.484375,
886        19.375,
887        25.09375,
888        -2.953125,
889        -25.875,
890        -31.921875,
891        27.546875,
892        -17.515625,
893        10.625,
894        6.21875,
895        -30.5,
896        -23.875,
897        -10.03125,
898        28.890625,
899        -2.796875,
900        -18.03125,
901        16.90625,
902        -14.9375,
903        -22.703125,
904        -21.984375,
905        53.171875,
906        6.484375,
907        15.3125,
908        2.15625,
909        -14.1875,
910        -3.84375,
911        -6.328125,
912        25.46875,
913        5.75,
914        23.5625,
915        -12.59375,
916        -30.8125,
917        8.28125,
918        17.75,
919        36.265625,
920        19.796875,
921        36.25,
922        17.0625,
923        -24.625,
924        16.0625,
925        17.6875,
926        15.25,
927        11.453125,
928        82.59375,
929        7.71875,
930        -32.3125,
931        26.390625,
932        14.59375,
933        14.484375,
934        -5.296875,
935        65.28125,
936        -12.9375,
937        -22.15625,
938        35.859375,
939        11.59375,
940        34.40625,
941        -8.90625,
942        76.46875,
943        -3.7916667461395264,
944        -24.5,
945        31.39583396911621,
946        -2.2291667461395264,
947        21.70833396911621,
948        -11.520833015441895,
949        70.75,
950        -68.125,
951        22.0625,
952        -17.95833396911621,
953        -11.270833015441895,
954        -36.33333206176758,
955        -46.97916793823242,
956        -34.54166793823242,
957        -38.6875,
958        27.0,
959        -14.03125,
960        2.78125,
961        -35.6875,
962        -43.265625,
963        -29.8125,
964        -6.59375,
965        24.78125,
966        -25.75,
967        -14.34375,
968        -42.78125,
969        -36.171875,
970        -7.65625,
971        22.0625,
972        31.21875,
973        -17.609375,
974        7.109375,
975        -1.734375,
976        -14.96875,
977        -0.90625,
978        53.75,
979        37.515625,
980        -38.8125,
981        7.265625,
982        8.015625,
983        -2.09375,
984        2.15625,
985        46.25,
986        23.671875,
987        -33.75,
988        11.59375,
989        24.4375,
990        20.125,
991        23.96875,
992        11.75,
993        13.859375,
994        -36.84375,
995        8.953125,
996        63.65625,
997        33.984375,
998        22.6875,
999        -17.375,
1000        21.375,
1001        -29.328125,
1002        20.1875,
1003        33.609375,
1004        6.921875,
1005        59.9375,
1006        -44.1875,
1007        27.75,
1008        -5.640625,
1009        43.1875,
1010        0.203125,
1011        14.578125,
1012        47.125,
1013        -65.625,
1014        31.71875,
1015        -30.75,
1016        28.71875,
1017        -7.984375,
1018        -6.171875,
1019        23.9375,
1020        -42.71875,
1021        48.078125,
1022        -23.453125,
1023        31.59375,
1024        -24.859375,
1025        -4.59375,
1026        19.96875,
1027        -17.75,
1028        10.4375,
1029        -26.140625,
1030        6.046875,
1031        -30.9375,
1032        17.421875,
1033        -10.84375,
1034        -9.333333015441895,
1035        2.0625,
1036        -39.02083206176758,
1037        -23.70833396911621,
1038        -19.27083396911621,
1039        19.5,
1040        -16.5});
1041   op_avg_pool2d_out(
1042       self,
1043       kernel_size,
1044       stride,
1045       padding,
1046       ceil_mode,
1047       count_include_pad,
1048       divisor_override,
1049       out);
1050   EXPECT_TENSOR_CLOSE(out, out_expected);
1051 }
1052