• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//  Copyright © 2022 Apple Inc.
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3#include <ATen/native/mps/OperationUtils.h>
4
5#ifndef AT_PER_OPERATOR_HEADERS
6#include <ATen/Functions.h>
7#include <ATen/NativeFunctions.h>
8#else
9#include <ATen/ops/constant_pad_nd_native.h>
10#include <ATen/ops/reflection_pad1d_backward_native.h>
11#include <ATen/ops/reflection_pad1d_native.h>
12#include <ATen/ops/reflection_pad2d_backward_native.h>
13#include <ATen/ops/reflection_pad2d_native.h>
14#include <ATen/ops/reflection_pad3d_backward_native.h>
15#include <ATen/ops/reflection_pad3d_native.h>
16#include <ATen/ops/replication_pad1d_backward_native.h>
17#include <ATen/ops/replication_pad1d_native.h>
18#include <ATen/ops/replication_pad2d_backward_native.h>
19#include <ATen/ops/replication_pad2d_native.h>
20#include <ATen/ops/replication_pad3d_backward_native.h>
21#include <ATen/ops/replication_pad3d_native.h>
22#endif
23
24namespace at::native {
25namespace mps {
26
27// Pad operations (1D/2D/3D forward and backward)
28static Tensor& pad_out_template(Tensor& output,
29                                const Tensor& input_,
30                                IntArrayRef padding,
31                                const std::optional<Tensor>& grad_output_opt,
32                                MPSGraphPaddingMode mode,
33                                double constantValue,
34                                const string op_name) {
35  using CachedGraph = MPSUnaryGradCachedGraph;
36  const int padding_size = (int)padding.size();
37  int padding_dim = padding_size / 2; // either 1D, 2D, or 3D
38
39  TORCH_CHECK(
40      padding_size == 2 || padding_size == 4 || padding_size == 6, "invalid padding argument of size ", padding_size);
41
42  const Tensor& grad_output_ = *(at::borrow_from_optional_tensor(grad_output_opt));
43  const bool is_backward_pass = grad_output_.defined();
44
45  int64_t nbatch = 1;
46  int64_t ndims = input_.ndimension();
47
48  TORCH_CHECK(ndims >= (int64_t)padding_dim,
49              "Length of pad should be no more than twice the number of "
50              "dimensions of the input. Pad length is ",
51              padding_size,
52              "while the input has ",
53              ndims,
54              "dimensions.");
55
56  // number of input dims with ConstantPad could be less than 2
57  int dim_w = padding_dim;
58  int dim_h = padding_dim - 1;
59  int dim_d = padding_dim - 2;
60  int dim_slices = 0;
61
62  if (!is_backward_pass && mode != MPSGraphPaddingModeConstant && ndims > padding_dim) {
63    bool valid_dims = input_.size(1) != 0 && input_.size(padding_dim) != 0;
64    TORCH_CHECK((ndims == 1 + padding_dim && valid_dims) ||
65                    (ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0),
66                "3D or 4D (batch mode) tensor expected for input, but got: ",
67                input_);
68  }
69
70  if (ndims == padding_dim) {
71    dim_w--;
72    dim_h--;
73    dim_d--;
74  } else if (ndims > padding_dim + 1) {
75    const int dim_diff = (int)ndims - padding_dim - 1;
76    // this virtually inflates the padding with zeros if ndims > padding_dim + 2
77    padding_dim += dim_diff - 1;
78    dim_w += dim_diff;
79    dim_h += dim_diff;
80    dim_d += dim_diff;
81    dim_slices++;
82    nbatch = input_.size(0);
83  }
84
85  int64_t pad_l = padding[0];
86  int64_t pad_r = padding[1];
87  int64_t pad_t = padding_size > 2 ? padding[2] : 0;
88  int64_t pad_b = padding_size > 2 ? padding[3] : 0;
89  int64_t pad_front = padding_size > 4 ? padding[4] : 0;
90  int64_t pad_back = padding_size > 4 ? padding[5] : 0;
91
92  int64_t nplane = input_.size(dim_slices);
93  int64_t input_w = input_.size(dim_w);
94  int64_t output_w = input_w + pad_l + pad_r;
95  int64_t input_h = padding_dim > 1 ? input_.size(dim_h) : 0;
96  int64_t output_h = padding_dim > 1 ? input_h + pad_t + pad_b : 0;
97  int64_t input_d = padding_dim > 2 ? input_.size(dim_d) : 0;
98  int64_t output_d = padding_dim > 2 ? input_d + pad_front + pad_back : 0;
99
100  Tensor grad_output, input = input_;
101
102  if (!is_backward_pass) {
103    TORCH_CHECK(output_w >= 1 || output_h >= padding_dim - 1,
104                "input (H: ",
105                input_h,
106                ", W: ",
107                input_w,
108                ") is too small. Calculated "
109                "output H: ",
110                output_h,
111                " W: ",
112                output_w);
113
114    std::vector<int64_t> outputSizes;
115    if (mode == MPSGraphPaddingModeConstant) {
116      // support arbitrary input dimensions for constant pad.
117      auto input_sizes = input_.sizes();
118      auto ori_padding_dim = padding_size / 2;
119      auto l_diff = ndims - ori_padding_dim;
120
121      for (size_t i = 0; i < (size_t)l_diff; i++) {
122        outputSizes.emplace_back(input_sizes[i]);
123      }
124      for (const auto i : c10::irange((size_t)ori_padding_dim)) {
125        auto pad_idx = padding.size() - ((i + 1) * 2);
126        auto new_dim = input_sizes[l_diff + i] + padding[pad_idx] + padding[pad_idx + 1];
127        outputSizes.emplace_back(new_dim);
128      }
129    } else {
130      // these checks are only relevant for reflection padding (code taken from ReflectionPad.cpp)
131      if (mode == MPSGraphPaddingModeReflect) {
132        TORCH_CHECK(pad_l < input_w && pad_r < input_w,
133                    "Argument #4: Padding size should be less than the corresponding "
134                    "input dimension, but got: padding (",
135                    pad_l,
136                    ", ",
137                    pad_r,
138                    ") at dimension ",
139                    dim_w,
140                    " of input ",
141                    input_.sizes());
142
143        if (padding_dim > 1) {
144          TORCH_CHECK(pad_t < input_h && pad_b < input_h,
145                      "Argument #6: Padding size should be less than the corresponding "
146                      "input dimension, but got: padding (",
147                      pad_t,
148                      ", ",
149                      pad_b,
150                      ") at dimension ",
151                      dim_h,
152                      " of input ",
153                      input_.sizes());
154        }
155        if (padding_dim > 2) {
156          TORCH_CHECK(pad_front < input_d && pad_back < input_d,
157                      "Argument #8: Padding size should be less than the corresponding "
158                      "input dimension, but got: padding (",
159                      pad_front,
160                      ", ",
161                      pad_back,
162                      ") at dimension ",
163                      dim_d,
164                      " of input ",
165                      input_.sizes());
166        }
167      }
168      outputSizes.insert(outputSizes.begin(), output_w);
169      if (padding_dim >= 2)
170        outputSizes.insert(outputSizes.begin(), output_h);
171      if (padding_dim >= 3)
172        outputSizes.insert(outputSizes.begin(), output_d);
173      if (ndims >= 1 + padding_dim)
174        outputSizes.insert(outputSizes.begin(), nplane);
175      if (ndims >= 2 + padding_dim)
176        outputSizes.insert(outputSizes.begin(), nbatch);
177    }
178
179    output.resize_(outputSizes);
180
181    if (output.numel() == 0) {
182      return output;
183    }
184    if (input_.numel() == 0) {
185      output.fill_(constantValue);
186      return output;
187    }
188    input = input_.contiguous();
189  } else {
190    TORCH_CHECK(output_w == grad_output_.size(dim_w),
191                "gradOutput width unexpected. Expected: ",
192                output_w,
193                ", Got: ",
194                grad_output_.size(dim_w));
195    if (padding_dim > 1) {
196      TORCH_CHECK(output_h == grad_output_.size(dim_h),
197                  "gradOutput height unexpected. Expected: ",
198                  output_h,
199                  ", Got: ",
200                  grad_output_.size(dim_h));
201    }
202    output.resize_as_(input);
203    if (output.numel() == 0 || grad_output_.numel() == 0)
204      return output;
205    grad_output = grad_output_.contiguous();
206  }
207
208  const uint32_t dims_mask = (1U << ndims) - 1;
209  uint32_t startMask = dims_mask, endMask = dims_mask;
210  std::vector<NSNumber*> leftPadVec(ndims, @(0));
211  std::vector<NSNumber*> rightPadVec(ndims, @(0));
212  std::vector<NSNumber*> startsVec(ndims, @(0));
213  std::vector<NSNumber*> endsVec(ndims, @(0));
214  std::vector<NSNumber*> stridesVec(ndims, @(1));
215
216  for (int64_t pdim = 0; pdim < padding_size / 2; pdim++) {
217    const int64_t leftIdx = pdim * 2;
218    const int64_t rightIdx = pdim * 2 + 1;
219    const int64_t padIdx = ndims - pdim - 1;
220
221    leftPadVec[padIdx] = @(padding[leftIdx]);
222    rightPadVec[padIdx] = @(padding[rightIdx]);
223    // workaround for negative padding issue in backward pass
224    if (is_backward_pass) {
225      if (padding[leftIdx] < 0) {
226        leftPadVec[padIdx] = @(0);
227        startsVec[padIdx] = @(-padding[leftIdx]);
228        startMask &= ~(1U << padIdx);
229      }
230      if (padding[rightIdx] < 0) {
231        rightPadVec[padIdx] = @(0);
232        endsVec[padIdx] = @(input.size(padIdx) + padding[rightIdx]);
233        endMask &= ~(1U << padIdx);
234      }
235    }
236  }
237  MPSShape* leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims];
238  MPSShape* rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims];
239
240  MPSDataType dataType = getMPSScalarType(input.scalar_type());
241  // workaround for Bool type assert with Constant padding
242  if (input.scalar_type() == kBool) {
243    dataType = MPSDataTypeInt8;
244  }
245
246  @autoreleasepool {
247    string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) +
248        "]:" + std::to_string(constantValue);
249
250    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
251      newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(input));
252      const bool needsSlice = startMask != dims_mask || endMask != dims_mask;
253
254      if (!is_backward_pass) {
255        MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_
256                                        withPaddingMode:mode
257                                            leftPadding:leftPadding
258                                           rightPadding:rightPadding
259                                          constantValue:constantValue
260                                                   name:nil];
261        // workaround for the right padding bug in Monterey
262        if (needsSlice) {
263          newCachedGraph->gradInputTensor_ =
264              [mpsGraph sliceTensor:padTensor
265                             starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
266                               ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
267                            strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
268                          startMask:startMask
269                            endMask:endMask
270                        squeezeMask:0
271                               name:nil];
272        } else {
273          newCachedGraph->gradInputTensor_ = padTensor;
274        }
275      } else {
276        newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output));
277        MPSGraphTensor* padGradTensor =
278            [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_
279                                               sourceTensor:newCachedGraph->inputTensor_
280                                                paddingMode:mode
281                                                leftPadding:leftPadding
282                                               rightPadding:rightPadding
283                                                       name:nil];
284        // workaround for negative padding issue with padGradientWithIncomingGradientTensor()
285        if (needsSlice) {
286          newCachedGraph->gradInputTensor_ =
287              [mpsGraph sliceGradientTensor:padGradTensor
288                           fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil]
289                                     starts:[NSArray arrayWithObjects:startsVec.data() count:ndims]
290                                       ends:[NSArray arrayWithObjects:endsVec.data() count:ndims]
291                                    strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims]
292                                  startMask:startMask
293                                    endMask:endMask
294                                squeezeMask:0
295                                       name:nil];
296        } else {
297          newCachedGraph->gradInputTensor_ = padGradTensor;
298        }
299      }
300    });
301
302    Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, nullptr, true, dataType);
303    Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output, nullptr, true, dataType);
304    Placeholder gradOutputPlaceholder = !is_backward_pass
305        ? Placeholder()
306        : Placeholder(cachedGraph->gradOutputTensor_, grad_output, nullptr, true, dataType);
307
308    NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
309    feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
310    if (is_backward_pass) {
311      feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData();
312    }
313    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder);
314  }
315  return output;
316}
317} // namespace mps
318
319// 1D Reflection and Replication Padding
320TORCH_IMPL_FUNC(reflection_pad1d_out_mps)
321(const Tensor& input, IntArrayRef padding, const Tensor& output) {
322  mps::pad_out_template(const_cast<Tensor&>(output),
323                        input,
324                        padding,
325                        std::nullopt,
326                        MPSGraphPaddingModeReflect,
327                        0.0,
328                        "reflection_pad1d_out_mps");
329}
330
331TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps)
332(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
333  grad_input.resize_as_(input).zero_();
334  mps::pad_out_template(const_cast<Tensor&>(grad_input),
335                        input,
336                        padding,
337                        grad_output,
338                        MPSGraphPaddingModeReflect,
339                        0.0,
340                        "reflection_pad1d_backward_out_mps");
341}
342
343TORCH_IMPL_FUNC(replication_pad1d_out_mps)
344(const Tensor& input, IntArrayRef padding, const Tensor& output) {
345  mps::pad_out_template(const_cast<Tensor&>(output),
346                        input,
347                        padding,
348                        std::nullopt,
349                        MPSGraphPaddingModeClampToEdge,
350                        0.0,
351                        "replication_pad1d_out_mps");
352}
353
354TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps)
355(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
356  grad_input.resize_as_(input).zero_();
357  mps::pad_out_template(const_cast<Tensor&>(grad_input),
358                        input,
359                        padding,
360                        grad_output,
361                        MPSGraphPaddingModeClampToEdge,
362                        0.0,
363                        "replication_pad1d_backward_out_mps");
364}
365
366// 2D Reflection and Replication Padding
367Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output) {
368  return mps::pad_out_template(output, input, padding, std::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
369}
370
371Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding) {
372  Tensor output = at::empty({0}, input.options());
373  return mps::pad_out_template(output, input, padding, std::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
374}
375
376Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output,
377                                          const Tensor& input,
378                                          IntArrayRef padding,
379                                          Tensor& grad_input) {
380  grad_input.resize_as_(input).zero_();
381  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
382}
383
384Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
385  auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
386  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
387}
388
389TORCH_IMPL_FUNC(replication_pad2d_out_mps)
390(const Tensor& input, IntArrayRef padding, const Tensor& output) {
391  mps::pad_out_template(const_cast<Tensor&>(output),
392                        input,
393                        padding,
394                        std::nullopt,
395                        MPSGraphPaddingModeClampToEdge,
396                        0.0,
397                        "replication_pad2d_out_mps");
398}
399
400Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output,
401                                           const Tensor& input,
402                                           IntArrayRef padding,
403                                           Tensor& grad_input) {
404  grad_input.resize_as_(input).zero_();
405  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
406}
407
408Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
409  auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
410  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
411}
412
413// 3D Reflection and Replication Padding
414TORCH_IMPL_FUNC(reflection_pad3d_out_mps)
415(const Tensor& input, IntArrayRef padding, const Tensor& output) {
416  mps::pad_out_template(const_cast<Tensor&>(output),
417                        input,
418                        padding,
419                        std::nullopt,
420                        MPSGraphPaddingModeReflect,
421                        0.0,
422                        "reflection_pad3d_out_mps");
423}
424
425TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps)
426(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) {
427  grad_input.resize_as_(input).zero_();
428  mps::pad_out_template(const_cast<Tensor&>(grad_input),
429                        input,
430                        padding,
431                        grad_output,
432                        MPSGraphPaddingModeReflect,
433                        0.0,
434                        "reflection_pad3d_backward_out_mps");
435}
436
437TORCH_IMPL_FUNC(replication_pad3d_out_mps)
438(const Tensor& input, IntArrayRef padding, const Tensor& output) {
439  mps::pad_out_template(const_cast<Tensor&>(output),
440                        input,
441                        padding,
442                        std::nullopt,
443                        MPSGraphPaddingModeClampToEdge,
444                        0.0,
445                        "replication_pad3d_out_mps");
446}
447
448Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output,
449                                           const Tensor& input,
450                                           IntArrayRef padding,
451                                           Tensor& grad_input) {
452  grad_input.resize_as_(input).zero_();
453  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
454}
455
456Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) {
457  auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
458  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
459}
460
461// backward pass is explicitly handled in autograd by negating the "pad" argument
462Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) {
463  if (pad.size() > 6) {
464    TORCH_WARN_ONCE("MPS: The constant padding of more than 3 dimensions is not currently supported natively. ",
465                    "It uses View Ops default implementation to run. This may have performance implications.");
466    return at::native::constant_pad_nd(self, pad, value);
467  }
468  Tensor output = at::empty({0}, self.options());
469  return mps::pad_out_template(
470      output, self, pad, std::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
471}
472
473} // namespace at::native
474