• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//  Copyright © 2022 Apple Inc.
2
3#include <ATen/mps/MPSProfiler.h>
4#include <ATen/native/CPUFallback.h>
5
6namespace at {
7
8static void mps_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
9  TORCH_WARN_ONCE("The operator '",
10                  op.schema().operator_name(),
11                  "' is not currently supported ",
12                  "on the MPS backend and will fall back to run on the CPU.",
13                  " This may have performance implications.");
14
15  auto& profiler = mps::getMPSProfiler();
16  const bool isCPUFallbackProfilingEnabled = profiler.isCPUFallbackProfilingEnabled();
17
18  // only do profiling if CPU Fallback op execution tracing or logging is enabled
19  if (isCPUFallbackProfilingEnabled) {
20    // we create a Tensors list to compute the size of copies required to convert
21    // the input MPS tensors to CPU, and the CPU results back to MPS
22    std::vector<at::Tensor> tensor_args;
23    for (const auto& ivalue : torch::jit::last(stack, op.schema().arguments().size())) {
24      if (ivalue.isTensor()) {
25        tensor_args.push_back(ivalue.toTensor());
26      }
27    }
28    // TODO: check if any returns exist at this stage
29    for (const auto& ivalue : torch::jit::last(stack, op.schema().returns().size())) {
30      if (ivalue.isTensor()) {
31        tensor_args.push_back(ivalue.toTensor());
32      }
33    }
34    profiler.beginProfileCPUFallback(op.schema().name(), tensor_args);
35  }
36
37  native::cpu_fallback(op, stack);
38
39  if (isCPUFallbackProfilingEnabled) {
40    profiler.endProfileCPUFallback(op.schema().name());
41  }
42}
43
44static void mps_error_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
45  TORCH_CHECK_NOT_IMPLEMENTED(
46      false,
47      "The operator '",
48      op.schema().operator_name(),
49      "' is not currently implemented ",
50      "for the MPS device. If you want this op to be added in priority during the prototype ",
51      "phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. ",
52      "As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ",
53      "to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ",
54      "on MPS.");
55}
56
57// This dispatch should never be called for tensor on MPS but is frequently called
58// If one of them are on CPU
59static Tensor slow_conv2d_forward_mps(const Tensor& self,
60                                      const Tensor& weight,
61                                      IntArrayRef kernel_size,
62                                      const std::optional<Tensor>& bias,
63                                      IntArrayRef stride,
64                                      IntArrayRef padding) {
65  TORCH_CHECK(self.device() == weight.device(),
66              __func__,
67              ": input(device='",
68              self.device(),
69              "') and weight(device=",
70              weight.device(),
71              "')  must be on the same device");
72  TORCH_INTERNAL_ASSERT(false, __func__, " should not be called for both tensors on MPS device");
73}
74
75TORCH_LIBRARY_IMPL(_, MPS, m) {
76  static const char* enable_mps_fallback = getenv("PYTORCH_ENABLE_MPS_FALLBACK");
77  if (!enable_mps_fallback || std::stoi(enable_mps_fallback) == 0) {
78    m.fallback(torch::CppFunction::makeFromBoxedFunction<&mps_error_fallback>());
79  } else {
80    m.fallback(torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
81  }
82}
83
84TORCH_LIBRARY_IMPL(aten, MPS, m) {
85  // These ops are not supported via MPS backend currently, and we fallback to run on CPU.
86  // For the rest of unsupported ops the user needs to pass 'PYTORCH_ENABLE_MPS_FALLBACK=1'
87  // to fallback on CPU, otherwise we will error out.
88  m.impl("embedding_renorm_", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
89  m.impl("linalg_svd", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
90  m.impl("linalg_svd.U", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
91  m.impl("im2col", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>()); // Used in  preprocessing by nn.Unfold
92  m.impl("col2im", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
93  m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
94  m.impl("upsample_nearest3d.vec", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
95}
96
97} // namespace at
98