1#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2#include <ATen/native/mps/MPSGraphVenturaOps.h> 3#include <ATen/native/mps/OperationUtils.h> 4 5#ifndef AT_PER_OPERATOR_HEADERS 6#include <ATen/Functions.h> 7#include <ATen/NativeFunctions.h> 8#else 9#include <ATen/ops/linalg_inv_ex.h> 10#include <ATen/ops/linalg_inv_ex_native.h> 11#endif 12 13namespace at::native { 14 15TORCH_IMPL_FUNC(linalg_inv_ex_out_mps)(const Tensor& A, bool check_errors, const Tensor& result, const Tensor& info) { 16 TORCH_CHECK(result.is_mps(), "Output tensor is not MPS"); 17 TORCH_CHECK(!A.is_complex(), "linalg_inv: not supported for complex types yet!"); 18 if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS)) { 19 TORCH_WARN_ONCE( 20 "torch.linalg_inv_ex.inverse is supported by MPS on MacOS 13+, please upgrade. Falling back to CPU."); 21 auto cpu_info = at::empty({0}, kInt, std::nullopt, kCPU, std::nullopt, std::nullopt); 22 auto cpu_result = result.clone().to("cpu"); 23 at::linalg_inv_ex_out(cpu_result, cpu_info, A.to("cpu")); 24 info.copy_(cpu_info); 25 result.copy_(cpu_result); 26 return; 27 } 28 29 using namespace mps; 30 using CachedGraph = MPSUnaryCachedGraph; 31 32 MPSStream* stream = getCurrentMPSStream(); 33 info.zero_(); 34 35 if (A.numel() == 0) { 36 return; 37 } 38 39 if (!result.is_contiguous()) { 40 result.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); 41 } 42 43 @autoreleasepool { 44 string key = "inv_out_mps" + getTensorsStringKey({A}); 45 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 46 MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, A); 47 MPSGraphTensor* outputTensor = [mpsGraph inverseOfTensor:inputTensor name:nil]; 48 49 newCachedGraph->inputTensor_ = inputTensor; 50 newCachedGraph->outputTensor_ = outputTensor; 51 }); 52 53 Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor_, A); 54 Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result); 55 56 auto feeds = dictionaryFromPlaceholders(inputPlaceholder); 57 runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); 58 } 59} 60 61} // namespace at::native 62