1// Copyright © 2022 Apple Inc. 2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 3#include <ATen/native/Pool.h> 4#include <ATen/native/mps/OperationUtils.h> 5 6#ifndef AT_PER_OPERATOR_HEADERS 7#include <ATen/Functions.h> 8#include <ATen/NativeFunctions.h> 9#else 10#include <ATen/ops/_adaptive_avg_pool2d_backward_native.h> 11#include <ATen/ops/_adaptive_avg_pool2d_native.h> 12#include <ATen/ops/adaptive_avg_pool2d.h> 13#include <ATen/ops/adaptive_avg_pool2d_native.h> 14#include <ATen/ops/adaptive_max_pool2d_backward_native.h> 15#include <ATen/ops/adaptive_max_pool2d_native.h> 16#include <ATen/ops/avg_pool2d.h> 17#include <ATen/ops/avg_pool2d_backward.h> 18#include <ATen/ops/max_pool2d_with_indices.h> 19#include <ATen/ops/max_pool2d_with_indices_backward.h> 20#include <ATen/ops/mul.h> 21#include <ATen/ops/ones_like.h> 22#endif 23namespace at::native { 24namespace mps { 25static void set_kernel_params(int64_t isizeH, 26 int64_t isizeW, 27 int64_t osizeH, 28 int64_t osizeW, 29 int64_t& strideH, 30 int64_t& strideW, 31 int64_t& kernel_sizeH, 32 int64_t& kernel_sizeW, 33 bool check_avg_pooling = false) { 34 TORCH_CHECK((isizeH >= osizeH && isizeW >= osizeW) || (isizeH <= osizeH && isizeW <= osizeW), 35 "Adaptive pool MPS: Input height and width must both be greater than, " 36 "or equal to, or lesser than output height and width") 37 38 if (isizeH >= osizeH) { 39 if (check_avg_pooling) { 40 TORCH_CHECK( 41 (isizeH % osizeH == 0 && isizeW % osizeW == 0), 42 "Adaptive pool MPS: input sizes must be divisible by output sizes. Non-divisible input sizes are not implemented on MPS device yet. For now, you can manually transfer tensor to cpu in this case. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/96056)"); 43 } 44 strideH = (int64_t)(isizeH / osizeH); 45 strideW = (int64_t)(isizeW / osizeW); 46 kernel_sizeH = isizeH - (osizeH - 1) * strideH; 47 kernel_sizeW = isizeW - (osizeW - 1) * strideW; 48 } else { 49 if (check_avg_pooling) { 50 TORCH_CHECK( 51 (osizeH % isizeH == 0 && osizeW % isizeW == 0), 52 "Adaptive pool MPS: output sizes must be divisible by input sizes. Non-divisible input sizes are not implemented on MPS device yet. For now, you can manually transfer tensor to cpu in this case. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/96056)"); 53 } 54 strideH = (int64_t)(osizeH / isizeH); 55 strideW = (int64_t)(osizeW / isizeW); 56 kernel_sizeH = osizeH - (isizeH - 1) * strideH; 57 kernel_sizeW = osizeW - (isizeW - 1) * strideW; 58 } 59} 60} // namespace mps 61 62// Adaptive average pooling 63Tensor& adaptive_avg_pool2d_out_mps(const Tensor& input, IntArrayRef output_size, Tensor& output) { 64 for (int64_t i = 1; i < input.ndimension(); i++) { 65 TORCH_CHECK(input.size(i) > 0, 66 "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, " 67 "but input has sizes ", 68 input.sizes(), 69 " with dimension ", 70 i, 71 " being empty"); 72 } 73 74 int64_t isizeH = input.size(-2); 75 int64_t isizeW = input.size(-1); 76 int64_t osizeH = output_size[0]; 77 int64_t osizeW = output_size[1]; 78 79 int64_t strideH = 0, strideW = 0; 80 int64_t kernel_sizeH = 0, kernel_sizeW = 0; 81 82 mps::set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true); 83 84 if (isizeH >= osizeH) { 85 output = at::avg_pool2d(input, 86 IntArrayRef({kernel_sizeH, kernel_sizeW}), 87 IntArrayRef({strideH, strideW}), 88 IntArrayRef({0, 0}), 89 false, 90 true, 91 std::nullopt); 92 } else { 93 Tensor phony_grad = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 94 auto input_sizes = input.sizes(); 95 std::vector<int64_t> phony_shape{input_sizes.begin(), input_sizes.end() - 2}; 96 phony_shape.push_back(output_size[0]); 97 phony_shape.push_back(output_size[1]); 98 phony_grad.resize_(IntArrayRef(phony_shape)); 99 output = at::avg_pool2d_backward(input, 100 phony_grad, 101 IntArrayRef({kernel_sizeH, kernel_sizeW}), 102 IntArrayRef({strideH, strideW}), 103 IntArrayRef({0, 0}), 104 false, 105 true, 106 std::nullopt); 107 // Multiply output by kernel size 108 output = at::mul(output, kernel_sizeH * kernel_sizeW); 109 } 110 111 return output; 112} 113 114Tensor adaptive_avg_pool2d_mps(at::Tensor const& input, IntArrayRef output_size) { 115 IntArrayRef output_shape; 116 117 auto osizeH = output_size[0]; 118 auto osizeW = output_size[1]; 119 120 std::vector<long long> out_dims = {}; 121 122 if (input.ndimension() == 4) { 123 auto sizeB = input.size(0); 124 auto sizeD = input.size(1); 125 126 out_dims.push_back(sizeB); 127 out_dims.push_back(sizeD); 128 out_dims.push_back(osizeH); 129 out_dims.push_back(osizeW); 130 output_shape = IntArrayRef(out_dims); 131 } else { 132 auto sizeD = input.size(0); 133 out_dims.push_back(sizeD); 134 out_dims.push_back(osizeH); 135 out_dims.push_back(osizeW); 136 output_shape = IntArrayRef(out_dims); 137 } 138 139 const auto memory_format = input.suggest_memory_format(); 140 Tensor output = at::empty(output_shape, input.scalar_type(), std::nullopt, kMPS, std::nullopt, memory_format); 141 return adaptive_avg_pool2d_out_mps(input, output_size, output); 142} 143 144Tensor adaptive_avg_pool2d_backward_mps(const Tensor& gradOutput, const Tensor& input) { 145 int64_t isizeH = input.size(-2); 146 int64_t isizeW = input.size(-1); 147 int64_t osizeH = gradOutput.size(-2); 148 int64_t osizeW = gradOutput.size(-1); 149 150 int64_t strideH = 0, strideW = 0; 151 int64_t kernel_sizeH = 0, kernel_sizeW = 0; 152 153 mps::set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true); 154 155 auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 156 if (gradInput.numel() != 0) { 157 if (isizeH >= osizeH) { 158 gradInput = at::avg_pool2d_backward(gradOutput, 159 input, 160 IntArrayRef({kernel_sizeH, kernel_sizeW}), 161 IntArrayRef({strideH, strideW}), 162 IntArrayRef({0, 0}), 163 false, 164 true, 165 std::nullopt); 166 } else { 167 gradInput = at::avg_pool2d(gradOutput, 168 IntArrayRef({kernel_sizeH, kernel_sizeW}), 169 IntArrayRef({strideH, strideW}), 170 IntArrayRef({0, 0}), 171 false, 172 true, 173 std::nullopt); 174 gradInput = at::mul(gradInput, kernel_sizeH * kernel_sizeW); 175 } 176 } 177 178 return gradInput; 179} 180 181// Adaptive max pooling 182TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps) 183(const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) { 184 for (int64_t i = 1; i < input.ndimension(); i++) { 185 TORCH_CHECK(input.size(i) > 0, 186 "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " 187 "but input has sizes ", 188 input.sizes(), 189 " with dimension ", 190 i, 191 " being " 192 "empty"); 193 } 194 195 int64_t isizeH = input.size(-2); 196 int64_t isizeW = input.size(-1); 197 int64_t osizeH = output_size[0]; 198 int64_t osizeW = output_size[1]; 199 200 int64_t strideH = 0, strideW = 0; 201 int64_t kernel_sizeH = 0, kernel_sizeW = 0; 202 203 mps::set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW); 204 205 at::max_pool2d_with_indices_out(const_cast<Tensor&>(output), 206 const_cast<Tensor&>(indices), 207 input, 208 IntArrayRef({kernel_sizeH, kernel_sizeW}), 209 IntArrayRef({strideH, strideW}), 210 IntArrayRef({0, 0}), 211 IntArrayRef({1, 1}), 212 false); 213} 214 215TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps) 216(const Tensor& gradOutput, const Tensor& input, const Tensor& indices, const Tensor& gradInput) { 217 int64_t isizeH = input.size(-2); 218 int64_t isizeW = input.size(-1); 219 int64_t osizeH = gradOutput.size(-2); 220 int64_t osizeW = gradOutput.size(-1); 221 222 int64_t strideH = 0, strideW = 0; 223 int64_t kernel_sizeH = 0, kernel_sizeW = 0; 224 225 mps::set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW); 226 227 at::max_pool2d_with_indices_backward_out(const_cast<Tensor&>(gradInput), 228 gradOutput, 229 input, 230 IntArrayRef({kernel_sizeH, kernel_sizeW}), 231 IntArrayRef({strideH, strideW}), 232 IntArrayRef({0, 0}), 233 IntArrayRef({1, 1}), 234 false, 235 indices); 236} 237 238} // namespace at::native 239