1// Copyright © 2022 Apple Inc. 2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 3#include <ATen/native/mps/OperationUtils.h> 4 5#ifndef AT_PER_OPERATOR_HEADERS 6#include <ATen/Functions.h> 7#include <ATen/NativeFunctions.h> 8#else 9#include <ATen/ops/constant_pad_nd_native.h> 10#include <ATen/ops/reflection_pad1d_backward_native.h> 11#include <ATen/ops/reflection_pad1d_native.h> 12#include <ATen/ops/reflection_pad2d_backward_native.h> 13#include <ATen/ops/reflection_pad2d_native.h> 14#include <ATen/ops/reflection_pad3d_backward_native.h> 15#include <ATen/ops/reflection_pad3d_native.h> 16#include <ATen/ops/replication_pad1d_backward_native.h> 17#include <ATen/ops/replication_pad1d_native.h> 18#include <ATen/ops/replication_pad2d_backward_native.h> 19#include <ATen/ops/replication_pad2d_native.h> 20#include <ATen/ops/replication_pad3d_backward_native.h> 21#include <ATen/ops/replication_pad3d_native.h> 22#endif 23 24namespace at::native { 25namespace mps { 26 27// Pad operations (1D/2D/3D forward and backward) 28static Tensor& pad_out_template(Tensor& output, 29 const Tensor& input_, 30 IntArrayRef padding, 31 const std::optional<Tensor>& grad_output_opt, 32 MPSGraphPaddingMode mode, 33 double constantValue, 34 const string op_name) { 35 using CachedGraph = MPSUnaryGradCachedGraph; 36 const int padding_size = (int)padding.size(); 37 int padding_dim = padding_size / 2; // either 1D, 2D, or 3D 38 39 TORCH_CHECK( 40 padding_size == 2 || padding_size == 4 || padding_size == 6, "invalid padding argument of size ", padding_size); 41 42 const Tensor& grad_output_ = *(at::borrow_from_optional_tensor(grad_output_opt)); 43 const bool is_backward_pass = grad_output_.defined(); 44 45 int64_t nbatch = 1; 46 int64_t ndims = input_.ndimension(); 47 48 TORCH_CHECK(ndims >= (int64_t)padding_dim, 49 "Length of pad should be no more than twice the number of " 50 "dimensions of the input. Pad length is ", 51 padding_size, 52 "while the input has ", 53 ndims, 54 "dimensions."); 55 56 // number of input dims with ConstantPad could be less than 2 57 int dim_w = padding_dim; 58 int dim_h = padding_dim - 1; 59 int dim_d = padding_dim - 2; 60 int dim_slices = 0; 61 62 if (!is_backward_pass && mode != MPSGraphPaddingModeConstant && ndims > padding_dim) { 63 bool valid_dims = input_.size(1) != 0 && input_.size(padding_dim) != 0; 64 TORCH_CHECK((ndims == 1 + padding_dim && valid_dims) || 65 (ndims == 2 + padding_dim && valid_dims && input_.size(1 + padding_dim) != 0), 66 "3D or 4D (batch mode) tensor expected for input, but got: ", 67 input_); 68 } 69 70 if (ndims == padding_dim) { 71 dim_w--; 72 dim_h--; 73 dim_d--; 74 } else if (ndims > padding_dim + 1) { 75 const int dim_diff = (int)ndims - padding_dim - 1; 76 // this virtually inflates the padding with zeros if ndims > padding_dim + 2 77 padding_dim += dim_diff - 1; 78 dim_w += dim_diff; 79 dim_h += dim_diff; 80 dim_d += dim_diff; 81 dim_slices++; 82 nbatch = input_.size(0); 83 } 84 85 int64_t pad_l = padding[0]; 86 int64_t pad_r = padding[1]; 87 int64_t pad_t = padding_size > 2 ? padding[2] : 0; 88 int64_t pad_b = padding_size > 2 ? padding[3] : 0; 89 int64_t pad_front = padding_size > 4 ? padding[4] : 0; 90 int64_t pad_back = padding_size > 4 ? padding[5] : 0; 91 92 int64_t nplane = input_.size(dim_slices); 93 int64_t input_w = input_.size(dim_w); 94 int64_t output_w = input_w + pad_l + pad_r; 95 int64_t input_h = padding_dim > 1 ? input_.size(dim_h) : 0; 96 int64_t output_h = padding_dim > 1 ? input_h + pad_t + pad_b : 0; 97 int64_t input_d = padding_dim > 2 ? input_.size(dim_d) : 0; 98 int64_t output_d = padding_dim > 2 ? input_d + pad_front + pad_back : 0; 99 100 Tensor grad_output, input = input_; 101 102 if (!is_backward_pass) { 103 TORCH_CHECK(output_w >= 1 || output_h >= padding_dim - 1, 104 "input (H: ", 105 input_h, 106 ", W: ", 107 input_w, 108 ") is too small. Calculated " 109 "output H: ", 110 output_h, 111 " W: ", 112 output_w); 113 114 std::vector<int64_t> outputSizes; 115 if (mode == MPSGraphPaddingModeConstant) { 116 // support arbitrary input dimensions for constant pad. 117 auto input_sizes = input_.sizes(); 118 auto ori_padding_dim = padding_size / 2; 119 auto l_diff = ndims - ori_padding_dim; 120 121 for (size_t i = 0; i < (size_t)l_diff; i++) { 122 outputSizes.emplace_back(input_sizes[i]); 123 } 124 for (const auto i : c10::irange((size_t)ori_padding_dim)) { 125 auto pad_idx = padding.size() - ((i + 1) * 2); 126 auto new_dim = input_sizes[l_diff + i] + padding[pad_idx] + padding[pad_idx + 1]; 127 outputSizes.emplace_back(new_dim); 128 } 129 } else { 130 // these checks are only relevant for reflection padding (code taken from ReflectionPad.cpp) 131 if (mode == MPSGraphPaddingModeReflect) { 132 TORCH_CHECK(pad_l < input_w && pad_r < input_w, 133 "Argument #4: Padding size should be less than the corresponding " 134 "input dimension, but got: padding (", 135 pad_l, 136 ", ", 137 pad_r, 138 ") at dimension ", 139 dim_w, 140 " of input ", 141 input_.sizes()); 142 143 if (padding_dim > 1) { 144 TORCH_CHECK(pad_t < input_h && pad_b < input_h, 145 "Argument #6: Padding size should be less than the corresponding " 146 "input dimension, but got: padding (", 147 pad_t, 148 ", ", 149 pad_b, 150 ") at dimension ", 151 dim_h, 152 " of input ", 153 input_.sizes()); 154 } 155 if (padding_dim > 2) { 156 TORCH_CHECK(pad_front < input_d && pad_back < input_d, 157 "Argument #8: Padding size should be less than the corresponding " 158 "input dimension, but got: padding (", 159 pad_front, 160 ", ", 161 pad_back, 162 ") at dimension ", 163 dim_d, 164 " of input ", 165 input_.sizes()); 166 } 167 } 168 outputSizes.insert(outputSizes.begin(), output_w); 169 if (padding_dim >= 2) 170 outputSizes.insert(outputSizes.begin(), output_h); 171 if (padding_dim >= 3) 172 outputSizes.insert(outputSizes.begin(), output_d); 173 if (ndims >= 1 + padding_dim) 174 outputSizes.insert(outputSizes.begin(), nplane); 175 if (ndims >= 2 + padding_dim) 176 outputSizes.insert(outputSizes.begin(), nbatch); 177 } 178 179 output.resize_(outputSizes); 180 181 if (output.numel() == 0) { 182 return output; 183 } 184 if (input_.numel() == 0) { 185 output.fill_(constantValue); 186 return output; 187 } 188 input = input_.contiguous(); 189 } else { 190 TORCH_CHECK(output_w == grad_output_.size(dim_w), 191 "gradOutput width unexpected. Expected: ", 192 output_w, 193 ", Got: ", 194 grad_output_.size(dim_w)); 195 if (padding_dim > 1) { 196 TORCH_CHECK(output_h == grad_output_.size(dim_h), 197 "gradOutput height unexpected. Expected: ", 198 output_h, 199 ", Got: ", 200 grad_output_.size(dim_h)); 201 } 202 output.resize_as_(input); 203 if (output.numel() == 0 || grad_output_.numel() == 0) 204 return output; 205 grad_output = grad_output_.contiguous(); 206 } 207 208 const uint32_t dims_mask = (1U << ndims) - 1; 209 uint32_t startMask = dims_mask, endMask = dims_mask; 210 std::vector<NSNumber*> leftPadVec(ndims, @(0)); 211 std::vector<NSNumber*> rightPadVec(ndims, @(0)); 212 std::vector<NSNumber*> startsVec(ndims, @(0)); 213 std::vector<NSNumber*> endsVec(ndims, @(0)); 214 std::vector<NSNumber*> stridesVec(ndims, @(1)); 215 216 for (int64_t pdim = 0; pdim < padding_size / 2; pdim++) { 217 const int64_t leftIdx = pdim * 2; 218 const int64_t rightIdx = pdim * 2 + 1; 219 const int64_t padIdx = ndims - pdim - 1; 220 221 leftPadVec[padIdx] = @(padding[leftIdx]); 222 rightPadVec[padIdx] = @(padding[rightIdx]); 223 // workaround for negative padding issue in backward pass 224 if (is_backward_pass) { 225 if (padding[leftIdx] < 0) { 226 leftPadVec[padIdx] = @(0); 227 startsVec[padIdx] = @(-padding[leftIdx]); 228 startMask &= ~(1U << padIdx); 229 } 230 if (padding[rightIdx] < 0) { 231 rightPadVec[padIdx] = @(0); 232 endsVec[padIdx] = @(input.size(padIdx) + padding[rightIdx]); 233 endMask &= ~(1U << padIdx); 234 } 235 } 236 } 237 MPSShape* leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims]; 238 MPSShape* rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims]; 239 240 MPSDataType dataType = getMPSScalarType(input.scalar_type()); 241 // workaround for Bool type assert with Constant padding 242 if (input.scalar_type() == kBool) { 243 dataType = MPSDataTypeInt8; 244 } 245 246 @autoreleasepool { 247 string key = op_name + getTensorsStringKey({input, grad_output, output}) + ":[" + getArrayRefString(padding) + 248 "]:" + std::to_string(constantValue); 249 250 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 251 newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(input)); 252 const bool needsSlice = startMask != dims_mask || endMask != dims_mask; 253 254 if (!is_backward_pass) { 255 MPSGraphTensor* padTensor = [mpsGraph padTensor:newCachedGraph->inputTensor_ 256 withPaddingMode:mode 257 leftPadding:leftPadding 258 rightPadding:rightPadding 259 constantValue:constantValue 260 name:nil]; 261 // workaround for the right padding bug in Monterey 262 if (needsSlice) { 263 newCachedGraph->gradInputTensor_ = 264 [mpsGraph sliceTensor:padTensor 265 starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] 266 ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] 267 strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] 268 startMask:startMask 269 endMask:endMask 270 squeezeMask:0 271 name:nil]; 272 } else { 273 newCachedGraph->gradInputTensor_ = padTensor; 274 } 275 } else { 276 newCachedGraph->gradOutputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, dataType, getMPSShape(grad_output)); 277 MPSGraphTensor* padGradTensor = 278 [mpsGraph padGradientWithIncomingGradientTensor:newCachedGraph->gradOutputTensor_ 279 sourceTensor:newCachedGraph->inputTensor_ 280 paddingMode:mode 281 leftPadding:leftPadding 282 rightPadding:rightPadding 283 name:nil]; 284 // workaround for negative padding issue with padGradientWithIncomingGradientTensor() 285 if (needsSlice) { 286 newCachedGraph->gradInputTensor_ = 287 [mpsGraph sliceGradientTensor:padGradTensor 288 fwdInShapeTensor:[mpsGraph shapeOfTensor:newCachedGraph->inputTensor_ name:nil] 289 starts:[NSArray arrayWithObjects:startsVec.data() count:ndims] 290 ends:[NSArray arrayWithObjects:endsVec.data() count:ndims] 291 strides:[NSArray arrayWithObjects:stridesVec.data() count:ndims] 292 startMask:startMask 293 endMask:endMask 294 squeezeMask:0 295 name:nil]; 296 } else { 297 newCachedGraph->gradInputTensor_ = padGradTensor; 298 } 299 } 300 }); 301 302 Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input, nullptr, true, dataType); 303 Placeholder outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, output, nullptr, true, dataType); 304 Placeholder gradOutputPlaceholder = !is_backward_pass 305 ? Placeholder() 306 : Placeholder(cachedGraph->gradOutputTensor_, grad_output, nullptr, true, dataType); 307 308 NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; 309 feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); 310 if (is_backward_pass) { 311 feeds[gradOutputPlaceholder.getMPSGraphTensor()] = gradOutputPlaceholder.getMPSGraphTensorData(); 312 } 313 runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, outputPlaceholder); 314 } 315 return output; 316} 317} // namespace mps 318 319// 1D Reflection and Replication Padding 320TORCH_IMPL_FUNC(reflection_pad1d_out_mps) 321(const Tensor& input, IntArrayRef padding, const Tensor& output) { 322 mps::pad_out_template(const_cast<Tensor&>(output), 323 input, 324 padding, 325 std::nullopt, 326 MPSGraphPaddingModeReflect, 327 0.0, 328 "reflection_pad1d_out_mps"); 329} 330 331TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps) 332(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { 333 grad_input.resize_as_(input).zero_(); 334 mps::pad_out_template(const_cast<Tensor&>(grad_input), 335 input, 336 padding, 337 grad_output, 338 MPSGraphPaddingModeReflect, 339 0.0, 340 "reflection_pad1d_backward_out_mps"); 341} 342 343TORCH_IMPL_FUNC(replication_pad1d_out_mps) 344(const Tensor& input, IntArrayRef padding, const Tensor& output) { 345 mps::pad_out_template(const_cast<Tensor&>(output), 346 input, 347 padding, 348 std::nullopt, 349 MPSGraphPaddingModeClampToEdge, 350 0.0, 351 "replication_pad1d_out_mps"); 352} 353 354TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps) 355(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { 356 grad_input.resize_as_(input).zero_(); 357 mps::pad_out_template(const_cast<Tensor&>(grad_input), 358 input, 359 padding, 360 grad_output, 361 MPSGraphPaddingModeClampToEdge, 362 0.0, 363 "replication_pad1d_backward_out_mps"); 364} 365 366// 2D Reflection and Replication Padding 367Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output) { 368 return mps::pad_out_template(output, input, padding, std::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__); 369} 370 371Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding) { 372 Tensor output = at::empty({0}, input.options()); 373 return mps::pad_out_template(output, input, padding, std::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__); 374} 375 376Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, 377 const Tensor& input, 378 IntArrayRef padding, 379 Tensor& grad_input) { 380 grad_input.resize_as_(input).zero_(); 381 return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__); 382} 383 384Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { 385 auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 386 return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__); 387} 388 389TORCH_IMPL_FUNC(replication_pad2d_out_mps) 390(const Tensor& input, IntArrayRef padding, const Tensor& output) { 391 mps::pad_out_template(const_cast<Tensor&>(output), 392 input, 393 padding, 394 std::nullopt, 395 MPSGraphPaddingModeClampToEdge, 396 0.0, 397 "replication_pad2d_out_mps"); 398} 399 400Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, 401 const Tensor& input, 402 IntArrayRef padding, 403 Tensor& grad_input) { 404 grad_input.resize_as_(input).zero_(); 405 return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); 406} 407 408Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { 409 auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 410 return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); 411} 412 413// 3D Reflection and Replication Padding 414TORCH_IMPL_FUNC(reflection_pad3d_out_mps) 415(const Tensor& input, IntArrayRef padding, const Tensor& output) { 416 mps::pad_out_template(const_cast<Tensor&>(output), 417 input, 418 padding, 419 std::nullopt, 420 MPSGraphPaddingModeReflect, 421 0.0, 422 "reflection_pad3d_out_mps"); 423} 424 425TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps) 426(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input) { 427 grad_input.resize_as_(input).zero_(); 428 mps::pad_out_template(const_cast<Tensor&>(grad_input), 429 input, 430 padding, 431 grad_output, 432 MPSGraphPaddingModeReflect, 433 0.0, 434 "reflection_pad3d_backward_out_mps"); 435} 436 437TORCH_IMPL_FUNC(replication_pad3d_out_mps) 438(const Tensor& input, IntArrayRef padding, const Tensor& output) { 439 mps::pad_out_template(const_cast<Tensor&>(output), 440 input, 441 padding, 442 std::nullopt, 443 MPSGraphPaddingModeClampToEdge, 444 0.0, 445 "replication_pad3d_out_mps"); 446} 447 448Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, 449 const Tensor& input, 450 IntArrayRef padding, 451 Tensor& grad_input) { 452 grad_input.resize_as_(input).zero_(); 453 return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); 454} 455 456Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding) { 457 auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 458 return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__); 459} 460 461// backward pass is explicitly handled in autograd by negating the "pad" argument 462Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value) { 463 if (pad.size() > 6) { 464 TORCH_WARN_ONCE("MPS: The constant padding of more than 3 dimensions is not currently supported natively. ", 465 "It uses View Ops default implementation to run. This may have performance implications."); 466 return at::native::constant_pad_nd(self, pad, value); 467 } 468 Tensor output = at::empty({0}, self.options()); 469 return mps::pad_out_template( 470 output, self, pad, std::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__); 471} 472 473} // namespace at::native 474