/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include // Declares the operator #include #include #include #include #include #include using namespace ::testing; class OpMaxPool2DWithIndicesOutTest : public OperatorTest { protected: ::std::tuple op_max_pool2d_with_indices_out( const exec_aten::Tensor& self, exec_aten::ArrayRef kernel_size, exec_aten::ArrayRef stride, exec_aten::ArrayRef padding, exec_aten::ArrayRef dilation, bool ceil_mode, exec_aten::Tensor& out, exec_aten::Tensor& indices) { return torch::executor::aten::max_pool2d_with_indices_outf( context_, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); } }; TEST_F(OpMaxPool2DWithIndicesOutTest, SanityTest4D) { torch::executor::testing::TensorFactory tfFloat; torch::executor::testing::TensorFactory tfLong; exec_aten::Tensor self = tfFloat.make( {2, 3, 5, 5}, {28.75, -38.875, -7.0, -13.5, 70.75, 53.75, 69.625, 97.375, 25.375, 99.5, -72.125, -87.25, 79.25, 42.0, -24.75, -15.5, 12.5, -86.0, 85.5, -0.25, 67.125, 77.0, 53.375, -61.125, 50.0, 3.875, 42.25, -37.375, 51.0, -60.875, 87.0, 32.25, 73.5, 68.875, -84.375, -98.75, -30.125, 94.25, 1.625, -86.25, -56.5, -68.0, 74.25, -51.25, 8.125, 71.375, -53.125, 4.875, 77.5, -89.875, 4.5, -46.5, -46.375, -92.625, -85.5, -23.0, -8.875, -12.0, -46.625, -88.625, 66.75, 87.75, 90.25, -45.0, -78.125, 63.25, 28.75, 28.125, -30.375, 17.75, -16.0, 5.0, 11.125, 88.625, -47.625, 72.25, 32.0, -7.625, 61.625, -63.125, -22.75, 83.125, -40.375, -78.25, 49.5, -39.125, -89.625, 47.875, -61.375, 7.75, 16.875, -96.375, -22.5, 8.5, 74.25, 12.75, 90.125, 73.875, -71.75, -10.0, 41.25, 1.125, 10.375, -34.625, 29.75, -27.5, 26.625, 81.0, -8.875, 17.625, 84.375, -23.625, -53.875, -26.0, -67.375, -90.75, 16.375, 45.625, 99.5, 56.25, -87.625, -65.5, -79.75, 31.875, 79.75, 6.375, 44.625, -55.25, -5.5, -68.875, -38.625, 54.125, -3.125, 5.75, 29.25, -39.5, 26.75, 68.25, -24.625, -53.0, 51.0, 90.625, 65.375, 43.875, 90.875, -41.625, 99.875, 6.375, -31.25, -94.0}); ::std::vector kernel_size_vec = {2, 2}; exec_aten::ArrayRef kernel_size = exec_aten::ArrayRef( kernel_size_vec.data(), kernel_size_vec.size()); ::std::vector stride_vec = {1, 1}; exec_aten::ArrayRef stride = exec_aten::ArrayRef(stride_vec.data(), stride_vec.size()); ::std::vector padding_vec = {0, 0}; exec_aten::ArrayRef padding = exec_aten::ArrayRef(padding_vec.data(), padding_vec.size()); ::std::vector dilation_vec = {1, 1}; exec_aten::ArrayRef dilation = exec_aten::ArrayRef(dilation_vec.data(), dilation_vec.size()); bool ceil_mode = false; exec_aten::Tensor out = tfFloat.zeros({2, 3, 4, 4}); exec_aten::Tensor indices = tfLong.zeros({2, 3, 4, 4}); exec_aten::Tensor out_expected = tfFloat.make( {2, 3, 4, 4}, {69.625, 97.375, 97.375, 99.5, 69.625, 97.375, 97.375, 99.5, 12.5, 79.25, 85.5, 85.5, 77.0, 77.0, 85.5, 85.5, 87.0, 73.5, 73.5, 68.875, 87.0, 94.25, 94.25, 68.875, -30.125, 94.25, 94.25, 8.125, 71.375, 74.25, 77.5, 77.5, 4.5, -8.875, -12.0, -46.625, 87.75, 90.25, 90.25, -45.0, 87.75, 90.25, 90.25, 17.75, 63.25, 28.75, 88.625, 88.625, 83.125, 83.125, 61.625, 61.625, 83.125, 83.125, 47.875, 49.5, 16.875, 47.875, 47.875, 74.25, 90.125, 90.125, 73.875, 74.25, 41.25, 81.0, 81.0, 29.75, 84.375, 81.0, 81.0, 17.625, 84.375, 45.625, 99.5, 99.5, 16.375, 45.625, 99.5, 99.5, 54.125, 54.125, 5.75, 29.25, 54.125, 68.25, 68.25, 29.25, 90.625, 90.625, 68.25, 90.875, 99.875, 99.875, 65.375, 90.875}); exec_aten::Tensor indices_expected = tfLong.make( {2, 3, 4, 4}, {6, 7, 7, 9, 6, 7, 7, 9, 16, 12, 18, 18, 21, 21, 18, 18, 5, 7, 7, 8, 5, 12, 12, 8, 11, 12, 12, 19, 20, 17, 23, 23, 0, 6, 7, 8, 11, 12, 12, 13, 11, 12, 12, 19, 15, 16, 23, 23, 6, 6, 3, 3, 6, 6, 12, 9, 15, 12, 12, 19, 21, 21, 22, 19, 0, 7, 7, 4, 10, 7, 7, 9, 10, 17, 18, 18, 16, 17, 18, 18, 6, 6, 8, 9, 6, 12, 12, 9, 16, 16, 12, 19, 21, 21, 17, 19}); op_max_pool2d_with_indices_out( self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); EXPECT_TENSOR_CLOSE(out, out_expected); EXPECT_TENSOR_CLOSE(indices, indices_expected); } TEST_F(OpMaxPool2DWithIndicesOutTest, SanityTest4D_2) { torch::executor::testing::TensorFactory tfFloat; torch::executor::testing::TensorFactory tfLong; exec_aten::Tensor self = tfFloat.make( {2, 3, 8, 8}, {47.375, -18.625, 12.0, 5.375, -40.375, -75.875, 51.0, -48.25, -5.0, -50.625, -96.875, -53.25, 82.25, 0.125, -13.125, 89.75, 53.0, 67.125, 79.625, 55.0, 77.75, -85.5, 35.25, -44.625, 25.625, -44.625, -27.75, 75.875, -96.375, 37.375, 67.0, -73.75, -77.25, 86.5, -27.5, 93.25, -40.0, -66.125, 57.5, 16.0, -94.125, 18.125, -22.625, -94.25, 41.5, 52.875, 36.875, 0.875, 92.75, -94.375, -80.375, -55.625, -11.0, -23.75, -8.375, 33.625, -27.25, -2.25, 70.25, -60.375, -96.125, -33.25, 43.375, -38.75, -58.5, -32.375, -83.75, -15.375, -4.375, 7.625, 58.875, 59.25, -12.5, 85.75, 67.125, -32.0, 32.125, 41.0, 47.625, 72.25, -36.5, 5.125, -30.75, -37.625, -76.25, 96.25, 55.125, -88.0, -78.0, -18.125, -23.75, -8.75, -3.25, 82.0, 10.5, 48.625, -64.0, 11.5, -66.75, 40.75, 38.25, 44.875, 21.0, 7.0, 30.125, -59.125, -31.5, -56.0, 75.0, -31.375, -25.0, -23.75, 48.625, -20.5, 52.375, -27.25, 84.75, -39.125, -56.5, -80.875, -60.5, -26.125, 42.375, 82.625, 43.375, -5.5, 73.375, 71.25, -12.25, 58.125, -1.375, -97.625, -43.375, 64.125, -71.125, -55.25, 5.375, 56.0, 20.0, -70.125, 78.625, 20.625, 62.5, 71.375, 93.0, -50.25, 26.125, -16.875, 50.375, -96.25, 32.75, 53.375, -61.5, -48.375, 72.375, -15.5, 2.875, -66.0, -31.75, -1.0, 65.625, -46.75, 69.875, 74.5, 53.25, 13.25, -41.625, 96.625, 75.875, 97.625, -98.25, -32.875, -75.0, -31.75, 72.875, -89.625, 56.625, 10.375, -40.5, -53.625, 5.75, -52.75, 19.5, 71.125, 80.875, 16.25, -32.625, 88.25, -98.5, 92.5, -68.25, -83.25, 43.375, 20.25, 1.875, -69.25, -92.75, 80.0, -17.125, -52.125, -82.0, 64.5, -7.875, -0.125, -74.5, 46.375, -68.875, 63.125, 88.0, 12.0, 87.75, 58.125, -83.375, 21.125, -34.0, -60.125, 37.75, 17.125, -44.875, -22.875, 72.5, -99.0, -30.25, 81.125, -73.875, -90.25, -30.0, 64.75, -57.875, -4.0, -33.125, -22.875, -74.0, -46.0, 4.125, -33.375, -72.625, 83.25, 25.125, 54.25, 43.0, -22.5, -22.25, -72.375, 80.25, 5.125, -7.0, 73.5, -28.625, -33.75, -91.125, -76.75, -67.625, 79.25, -56.625, 85.0, -53.125, 11.0, -44.0, -75.5, -23.375, 54.5, 21.5, 36.5, 33.0, 83.375, 81.625, 3.5, -7.5, 7.125, -87.5, 12.125, 9.625, -90.125, -3.75, 60.375, -80.75, -8.375, 59.25, 30.25, -75.75, 78.25, 13.5, -69.125, 48.0, -56.5, -10.5, -65.5, -80.875, 7.0, 64.375, -81.0, 48.625, -33.25, -53.75, -68.75, -12.0, 3.0, -56.625, 17.0, 36.0, 28.25, 36.75, -39.5, -23.5, -53.875, -17.375, -84.375, 0.625, 12.125, -45.5, 14.25, 75.5, -92.875, 9.125, 38.25, 20.375, -46.625, 59.875, -79.375, 52.625, 6.75, 11.875, -57.5, 91.75, -55.5, -86.875, 11.375, -41.0, 1.25, 41.375, 70.5, 43.625, -65.625, 7.625, -90.875, 84.5, 26.0, -52.25, 4.125, 27.75, 17.875, -95.75, 66.5, -88.875, -67.125, 21.5, -0.875, 35.875, 53.0, 62.0, -78.5, -65.125, -57.375, 5.625, 54.75, 70.0, -94.25, 50.625, -0.25, 89.5, -32.0, -83.125, -48.625, -23.0, 75.5, 44.0, 75.0, 58.0, 18.125, -94.75, -69.375, 70.375, -51.75, -86.75, 81.5, 75.75, 61.625, -14.5, -60.75, -58.125, -3.25, 36.25, -95.125}); ::std::vector kernel_size_vec = {2, 3}; exec_aten::ArrayRef kernel_size = exec_aten::ArrayRef( kernel_size_vec.data(), kernel_size_vec.size()); ::std::vector stride_vec = {2, 1}; exec_aten::ArrayRef stride = exec_aten::ArrayRef(stride_vec.data(), stride_vec.size()); ::std::vector padding_vec = {0, 1}; exec_aten::ArrayRef padding = exec_aten::ArrayRef(padding_vec.data(), padding_vec.size()); ::std::vector dilation_vec = {2, 1}; exec_aten::ArrayRef dilation = exec_aten::ArrayRef(dilation_vec.data(), dilation_vec.size()); bool ceil_mode = false; exec_aten::Tensor out = tfFloat.zeros({2, 3, 3, 8}); exec_aten::Tensor indices = tfLong.zeros({2, 3, 3, 8}); exec_aten::Tensor out_expected = tfFloat.make( {2, 3, 3, 8}, {67.125, 79.625, 79.625, 79.625, 77.75, 77.75, 51.0, 51.0, 86.5, 86.5, 93.25, 93.25, 93.25, 77.75, 57.5, 57.5, 92.75, 92.75, 93.25, 93.25, 93.25, 57.5, 57.5, 57.5, 5.125, 5.125, 5.125, -4.375, 96.25, 96.25, 96.25, 59.25, 11.5, 11.5, 40.75, 40.75, 96.25, 96.25, 96.25, 55.125, 48.625, 52.375, 52.375, 84.75, 84.75, 84.75, 44.875, 21.0, 93.0, 93.0, 58.125, 50.375, 64.125, 64.125, 64.125, 53.375, 93.0, 93.0, 74.5, 74.5, 74.5, 53.25, 96.625, 96.625, 65.625, 69.875, 74.5, 74.5, 74.5, 53.25, 96.625, 96.625, 88.0, 88.0, 87.75, 87.75, 80.0, 80.0, 80.0, -17.125, 88.0, 88.0, 87.75, 87.75, 64.75, 21.125, 21.125, -22.875, 43.0, 43.0, 64.75, 80.25, 80.25, 80.25, 73.5, 73.5, 11.0, 11.0, 60.375, 60.375, 60.375, 59.25, 59.25, 59.25, 9.625, 64.375, 64.375, 64.375, 60.375, 59.25, 59.25, 59.25, 7.0, 64.375, 64.375, 64.375, 48.625, 48.625, 14.25, 14.25, 84.5, 84.5, 26.0, 91.75, 91.75, 91.75, 66.5, 66.5, 84.5, 84.5, 54.75, 70.0, 70.0, 70.0, 66.5, 66.5, 58.0, 58.0, 54.75, 70.375, 70.375, 70.375, 81.5, 81.5}); exec_aten::Tensor indices_expected = tfLong.make( {2, 3, 3, 8}, {17, 18, 18, 18, 20, 20, 6, 6, 33, 33, 35, 35, 35, 20, 38, 38, 48, 48, 35, 35, 35, 38, 38, 38, 17, 17, 17, 4, 21, 21, 21, 7, 33, 33, 35, 35, 21, 21, 21, 22, 48, 50, 50, 52, 52, 52, 37, 38, 16, 16, 1, 20, 5, 5, 5, 23, 16, 16, 35, 35, 35, 36, 39, 39, 32, 34, 35, 35, 35, 36, 39, 39, 16, 16, 18, 18, 5, 5, 5, 6, 16, 16, 18, 18, 35, 21, 21, 39, 48, 48, 35, 52, 52, 52, 55, 55, 1, 1, 19, 19, 19, 22, 22, 22, 16, 34, 34, 34, 19, 22, 22, 22, 33, 34, 34, 34, 36, 36, 55, 55, 16, 16, 17, 4, 4, 4, 23, 23, 16, 16, 35, 36, 36, 36, 23, 23, 48, 48, 35, 52, 52, 52, 55, 55}); op_max_pool2d_with_indices_out( self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); EXPECT_TENSOR_CLOSE(out, out_expected); EXPECT_TENSOR_CLOSE(indices, indices_expected); } TEST_F(OpMaxPool2DWithIndicesOutTest, SanityTest3D) { torch::executor::testing::TensorFactory tfFloat; torch::executor::testing::TensorFactory tfLong; exec_aten::Tensor self = tfFloat.make( {2, 12, 12}, {73.625, 15.5, 30.875, 89.25, -55.625, -62.875, 25.0, -50.75, -47.125, 12.125, -73.125, -89.875, 53.625, -63.125, -44.375, 86.0, 53.625, -84.125, -6.75, 20.125, -24.25, -43.5, -11.125, -34.625, -7.5, -13.0, 74.375, 33.75, -44.875, 49.125, -59.5, -88.5, -46.5, -33.0, 48.125, 80.875, 38.875, -58.875, 0.875, -48.625, -46.125, -87.25, -66.625, 14.375, -68.25, -77.0, -50.5, -15.625, 86.875, 89.875, -37.25, 7.5, -16.75, -6.625, 55.875, 40.5, -83.875, -77.625, -55.375, 32.25, -17.5, -83.125, 43.375, 17.5, 2.75, -51.25, 20.25, -77.375, -68.0, -72.625, -47.5, -78.875, -49.375, -52.125, -7.125, -25.125, -77.5, -8.625, -3.125, 99.375, 71.875, 19.625, 21.125, -47.0, 44.5, -90.625, -75.75, -87.25, 79.75, -42.125, -90.0, 22.5, 2.5, 73.5, -65.125, -50.375, -71.625, 19.25, -60.125, -91.75, -43.375, -60.875, 16.375, 86.875, -93.25, -78.375, 82.5, 14.75, 20.125, 19.625, 33.875, 84.875, 60.625, 41.5, 2.0, -4.875, -52.5, 74.375, -40.125, -60.125, 88.5, 51.875, -59.75, 49.5, -81.0, -93.5, 43.0, -99.625, 40.375, -84.0, 76.5, 27.5, 59.125, -19.5, -55.25, -50.5, 81.875, 86.0, -19.75, 51.5, 70.875, -90.5, 74.375, 62.5, -0.625, -31.375, -71.25, 42.75, 42.5, 67.125, 26.125, 85.375, -11.75, -34.375, -97.125, 5.875, -45.25, -50.125, 74.5, -62.0, -81.5, -84.875, 70.75, 33.375, -27.5, -54.25, 94.25, 74.625, -30.0, 16.875, 39.875, -0.5, 0.25, -80.125, 85.375, 42.5, 13.125, -82.375, -30.75, -95.75, 34.75, -60.125, -51.625, -10.375, -30.75, -65.5, -96.0, -95.25, 60.125, -33.125, 67.125, -88.0, -26.125, 75.875, 29.5, -27.75, -28.875, 21.375, -2.0, -29.125, 11.0, -68.5, -36.75, -85.375, -4.625, 9.0, -31.75, -63.5, 98.75, -1.375, 17.125, 61.25, -50.5, 41.375, -18.375, -92.25, -50.375, -40.625, 14.0, 18.5, 22.625, 10.375, 58.875, -86.0, -9.625, 5.125, -69.625, -50.25, -26.875, 26.0, 57.25, -94.5, -53.125, 98.0, 37.375, 35.0, -20.125, -9.375, -13.375, -41.125, 41.75, 95.125, 82.75, -71.5, -43.125, -37.75, -91.25, -14.0, -55.5, 52.5, -30.125, 93.875, -26.75, 83.125, 2.625, -63.75, 52.875, 31.25, 57.625, 42.75, -2.875, -45.5, 99.0, -18.625, 38.375, 88.125, 36.625, -36.875, -35.25, 13.125, -31.875, -50.875, 10.5, -38.75, 1.625, 67.125, 3.0, -87.0, 42.0, -31.25, -77.875, -7.125, -94.0, -99.0, 24.75, -21.625, -98.375, 15.875}); ::std::vector kernel_size_vec = {4, 3}; exec_aten::ArrayRef kernel_size = exec_aten::ArrayRef( kernel_size_vec.data(), kernel_size_vec.size()); ::std::vector stride_vec = {3, 2}; exec_aten::ArrayRef stride = exec_aten::ArrayRef(stride_vec.data(), stride_vec.size()); ::std::vector padding_vec = {2, 1}; exec_aten::ArrayRef padding = exec_aten::ArrayRef(padding_vec.data(), padding_vec.size()); ::std::vector dilation_vec = {1, 2}; exec_aten::ArrayRef dilation = exec_aten::ArrayRef(dilation_vec.data(), dilation_vec.size()); bool ceil_mode = false; exec_aten::Tensor out = tfFloat.zeros({2, 5, 5}); exec_aten::Tensor indices = tfLong.zeros({2, 5, 5}); exec_aten::Tensor out_expected = tfFloat.make( {2, 5, 5}, {89.25, 89.25, 89.25, 20.125, 20.125, 89.875, 89.875, 86.0, 49.125, 80.875, 89.875, 89.875, 99.375, 99.375, 99.375, 84.875, 84.875, 86.875, 86.875, 86.875, 51.875, 86.0, 86.0, 86.0, 62.5, 42.75, 67.125, 85.375, 85.375, 85.375, 75.875, 75.875, 42.5, 42.5, 74.625, 75.875, 98.0, 98.0, 98.0, 61.25, 95.125, 98.0, 98.0, 98.0, 93.875, 88.125, 88.125, 13.125, 13.125, 67.125}); exec_aten::Tensor indices_expected = tfLong.make( {2, 5, 5}, {3, 3, 3, 19, 19, 49, 49, 15, 29, 35, 49, 49, 79, 79, 79, 111, 111, 103, 103, 103, 121, 137, 137, 137, 143, 3, 5, 7, 7, 7, 49, 49, 31, 31, 23, 49, 89, 89, 89, 67, 97, 89, 89, 89, 107, 121, 121, 125, 125, 131}); op_max_pool2d_with_indices_out( self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); EXPECT_TENSOR_CLOSE(out, out_expected); EXPECT_TENSOR_CLOSE(indices, indices_expected); }