• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2#include <ATen/native/mps/MPSGraphVenturaOps.h>
3#include <ATen/native/mps/OperationUtils.h>
4
5#ifndef AT_PER_OPERATOR_HEADERS
6#include <ATen/Functions.h>
7#include <ATen/NativeFunctions.h>
8#else
9#include <ATen/ops/linalg_inv_ex.h>
10#include <ATen/ops/linalg_inv_ex_native.h>
11#endif
12
13namespace at::native {
14
15TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) {
16  TORCH_CHECK(result.is_mps(), "Output tensor is not MPS");
17  TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!");
18  if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) {
19    TORCH_WARN_ONCE(
20        "torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU.");
21    auto cpu_info = at::empty({0}, kInt, std::nullopt, kCPU, std::nullopt, std::nullopt);
22    auto cpu_result = result.clone().to("cpu");
23    at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu"));
24    info.copy_(cpu_info);
25    result.copy_(cpu_result);
26    return;
27  }
28
29  using namespace mps;
30  using CachedGraph = MPSUnaryCachedGraph;
31
32  MPSStream* stream = getCurrentMPSStream();
33  info.zero_();
34
35  if (A.numel() == 0) {
36    return;
37  }
38
39  if (!result.is_contiguous()) {
40    result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
41  }
42
43  @autoreleasepool {
44    string key = "inv_out_mps" + getTensorsStringKey({A});
45    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
46      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A);
47      MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil];
48
49      newCachedGraph->inputTensor_ = inputTensor;
50      newCachedGraph->outputTensor_ = outputTensor;
51    });
52
53    Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A);
54    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
55
56    auto feeds = dictionaryFromPlaceholders(inputPlaceholder);
57    runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
58  }
59}
60
61} // namespace at::native
62