• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//  Copyright © 2022 Apple Inc.
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3#include <ATen/native/Pool.h>
4#include <ATen/native/mps/OperationUtils.h>
5
6#ifndef AT_PER_OPERATOR_HEADERS
7#include <ATen/Functions.h>
8#include <ATen/NativeFunctions.h>
9#else
10#include <ATen/ops/_adaptive_avg_pool2d_backward_native.h>
11#include <ATen/ops/_adaptive_avg_pool2d_native.h>
12#include <ATen/ops/adaptive_avg_pool2d.h>
13#include <ATen/ops/adaptive_avg_pool2d_native.h>
14#include <ATen/ops/adaptive_max_pool2d_backward_native.h>
15#include <ATen/ops/adaptive_max_pool2d_native.h>
16#include <ATen/ops/avg_pool2d.h>
17#include <ATen/ops/avg_pool2d_backward.h>
18#include <ATen/ops/max_pool2d_with_indices.h>
19#include <ATen/ops/max_pool2d_with_indices_backward.h>
20#include <ATen/ops/mul.h>
21#include <ATen/ops/ones_like.h>
22#endif
23namespace at::native {
24namespace mps {
25static void set_kernel_params(int64_t isizeH,
26                              int64_t isizeW,
27                              int64_t osizeH,
28                              int64_t osizeW,
29                              int64_t& strideH,
30                              int64_t& strideW,
31                              int64_t& kernel_sizeH,
32                              int64_t& kernel_sizeW,
33                              bool check_avg_pooling = false) {
34  TORCH_CHECK((isizeH >= osizeH && isizeW >= osizeW) || (isizeH <= osizeH && isizeW <= osizeW),
35              "Adaptive pool MPS: Input height and width must both be greater than, "
36              "or equal to, or lesser than output height and width")
37
38  if (isizeH >= osizeH) {
39    if (check_avg_pooling) {
40      TORCH_CHECK(
41          (isizeH % osizeH == 0 && isizeW % osizeW == 0),
42          "Adaptive pool MPS: input sizes must be divisible by output sizes. Non-divisible input sizes are not implemented on MPS device yet. For now, you can manually transfer tensor to cpu in this case. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/96056)");
43    }
44    strideH = (int64_t)(isizeH / osizeH);
45    strideW = (int64_t)(isizeW / osizeW);
46    kernel_sizeH = isizeH - (osizeH - 1) * strideH;
47    kernel_sizeW = isizeW - (osizeW - 1) * strideW;
48  } else {
49    if (check_avg_pooling) {
50      TORCH_CHECK(
51          (osizeH % isizeH == 0 && osizeW % isizeW == 0),
52          "Adaptive pool MPS: output sizes must be divisible by input sizes. Non-divisible input sizes are not implemented on MPS device yet. For now, you can manually transfer tensor to cpu in this case. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/96056)");
53    }
54    strideH = (int64_t)(osizeH / isizeH);
55    strideW = (int64_t)(osizeW / isizeW);
56    kernel_sizeH = osizeH - (isizeH - 1) * strideH;
57    kernel_sizeW = osizeW - (isizeW - 1) * strideW;
58  }
59}
60} // namespace mps
61
62// Adaptive average pooling
63Tensor& adaptive_avg_pool2d_out_mps(const Tensor& input, IntArrayRef output_size, Tensor& output) {
64  for (int64_t i = 1; i < input.ndimension(); i++) {
65    TORCH_CHECK(input.size(i) > 0,
66                "adaptive_avg_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
67                "but input has sizes ",
68                input.sizes(),
69                " with dimension ",
70                i,
71                " being empty");
72  }
73
74  int64_t isizeH = input.size(-2);
75  int64_t isizeW = input.size(-1);
76  int64_t osizeH = output_size[0];
77  int64_t osizeW = output_size[1];
78
79  int64_t strideH = 0, strideW = 0;
80  int64_t kernel_sizeH = 0, kernel_sizeW = 0;
81
82  mps::set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true);
83
84  if (isizeH >= osizeH) {
85    output = at::avg_pool2d(input,
86                            IntArrayRef({kernel_sizeH, kernel_sizeW}),
87                            IntArrayRef({strideH, strideW}),
88                            IntArrayRef({0, 0}),
89                            false,
90                            true,
91                            std::nullopt);
92  } else {
93    Tensor phony_grad = at::ones_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
94    auto input_sizes = input.sizes();
95    std::vector<int64_t> phony_shape{input_sizes.begin(), input_sizes.end() - 2};
96    phony_shape.push_back(output_size[0]);
97    phony_shape.push_back(output_size[1]);
98    phony_grad.resize_(IntArrayRef(phony_shape));
99    output = at::avg_pool2d_backward(input,
100                                     phony_grad,
101                                     IntArrayRef({kernel_sizeH, kernel_sizeW}),
102                                     IntArrayRef({strideH, strideW}),
103                                     IntArrayRef({0, 0}),
104                                     false,
105                                     true,
106                                     std::nullopt);
107    // Multiply output by kernel size
108    output = at::mul(output, kernel_sizeH * kernel_sizeW);
109  }
110
111  return output;
112}
113
114Tensor adaptive_avg_pool2d_mps(at::Tensor const& input, IntArrayRef output_size) {
115  IntArrayRef output_shape;
116
117  auto osizeH = output_size[0];
118  auto osizeW = output_size[1];
119
120  std::vector<long long> out_dims = {};
121
122  if (input.ndimension() == 4) {
123    auto sizeB = input.size(0);
124    auto sizeD = input.size(1);
125
126    out_dims.push_back(sizeB);
127    out_dims.push_back(sizeD);
128    out_dims.push_back(osizeH);
129    out_dims.push_back(osizeW);
130    output_shape = IntArrayRef(out_dims);
131  } else {
132    auto sizeD = input.size(0);
133    out_dims.push_back(sizeD);
134    out_dims.push_back(osizeH);
135    out_dims.push_back(osizeW);
136    output_shape = IntArrayRef(out_dims);
137  }
138
139  const auto memory_format = input.suggest_memory_format();
140  Tensor output = at::empty(output_shape, input.scalar_type(), std::nullopt, kMPS, std::nullopt, memory_format);
141  return adaptive_avg_pool2d_out_mps(input, output_size, output);
142}
143
144Tensor adaptive_avg_pool2d_backward_mps(const Tensor& gradOutput, const Tensor& input) {
145  int64_t isizeH = input.size(-2);
146  int64_t isizeW = input.size(-1);
147  int64_t osizeH = gradOutput.size(-2);
148  int64_t osizeW = gradOutput.size(-1);
149
150  int64_t strideH = 0, strideW = 0;
151  int64_t kernel_sizeH = 0, kernel_sizeW = 0;
152
153  mps::set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW, true);
154
155  auto gradInput = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
156  if (gradInput.numel() != 0) {
157    if (isizeH >= osizeH) {
158      gradInput = at::avg_pool2d_backward(gradOutput,
159                                          input,
160                                          IntArrayRef({kernel_sizeH, kernel_sizeW}),
161                                          IntArrayRef({strideH, strideW}),
162                                          IntArrayRef({0, 0}),
163                                          false,
164                                          true,
165                                          std::nullopt);
166    } else {
167      gradInput = at::avg_pool2d(gradOutput,
168                                 IntArrayRef({kernel_sizeH, kernel_sizeW}),
169                                 IntArrayRef({strideH, strideW}),
170                                 IntArrayRef({0, 0}),
171                                 false,
172                                 true,
173                                 std::nullopt);
174      gradInput = at::mul(gradInput, kernel_sizeH * kernel_sizeW);
175    }
176  }
177
178  return gradInput;
179}
180
181// Adaptive max pooling
182TORCH_IMPL_FUNC(adaptive_max_pool2d_out_mps)
183(const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) {
184  for (int64_t i = 1; i < input.ndimension(); i++) {
185    TORCH_CHECK(input.size(i) > 0,
186                "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
187                "but input has sizes ",
188                input.sizes(),
189                " with dimension ",
190                i,
191                " being "
192                "empty");
193  }
194
195  int64_t isizeH = input.size(-2);
196  int64_t isizeW = input.size(-1);
197  int64_t osizeH = output_size[0];
198  int64_t osizeW = output_size[1];
199
200  int64_t strideH = 0, strideW = 0;
201  int64_t kernel_sizeH = 0, kernel_sizeW = 0;
202
203  mps::set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW);
204
205  at::max_pool2d_with_indices_out(const_cast<Tensor&>(output),
206                                  const_cast<Tensor&>(indices),
207                                  input,
208                                  IntArrayRef({kernel_sizeH, kernel_sizeW}),
209                                  IntArrayRef({strideH, strideW}),
210                                  IntArrayRef({0, 0}),
211                                  IntArrayRef({1, 1}),
212                                  false);
213}
214
215TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_mps)
216(const Tensor& gradOutput, const Tensor& input, const Tensor& indices, const Tensor& gradInput) {
217  int64_t isizeH = input.size(-2);
218  int64_t isizeW = input.size(-1);
219  int64_t osizeH = gradOutput.size(-2);
220  int64_t osizeW = gradOutput.size(-1);
221
222  int64_t strideH = 0, strideW = 0;
223  int64_t kernel_sizeH = 0, kernel_sizeW = 0;
224
225  mps::set_kernel_params(isizeH, isizeW, osizeH, osizeW, strideH, strideW, kernel_sizeH, kernel_sizeW);
226
227  at::max_pool2d_with_indices_backward_out(const_cast<Tensor&>(gradInput),
228                                           gradOutput,
229                                           input,
230                                           IntArrayRef({kernel_sizeH, kernel_sizeW}),
231                                           IntArrayRef({strideH, strideW}),
232                                           IntArrayRef({0, 0}),
233                                           IntArrayRef({1, 1}),
234                                           false,
235                                           indices);
236}
237
238} // namespace at::native
239