1// Copyright © 2022 Apple Inc. 2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 3#include <ATen/native/mps/OperationUtils.h> 4 5#ifndef AT_PER_OPERATOR_HEADERS 6#include <ATen/Functions.h> 7#include <ATen/NativeFunctions.h> 8#else 9#include <ATen/ops/fill_native.h> 10#include <ATen/ops/view_as_real.h> 11#include <ATen/ops/zero_native.h> 12#endif 13 14namespace at::native { 15 16static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) { 17 using namespace mps; 18 19 if (self.numel() == 0) { 20 return self; 21 } 22 Tensor output = self; 23 bool needsCopyToOutput = false; 24 if (needsGather(self)) { 25 output = at::empty(self.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); 26 needsCopyToOutput = true; 27 } 28 29 struct CachedGraph : public MPSCachedGraph { 30 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 31 MPSGraphTensor* inputTensor_ = nil; 32 MPSGraphTensor* outputTensor_ = nil; 33 }; 34 35 @autoreleasepool { 36 string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble()); 37 38 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 39 MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type())); 40 MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil]; 41 newCachedGraph->inputTensor_ = inputTensor; 42 newCachedGraph->outputTensor_ = outputTensor; 43 }); 44 45 auto mpsScalar = getMPSScalar(value, self.scalar_type()); 46 auto mpsScalarData = getMPSGraphTensorFromScalar(getCurrentMPSStream(), mpsScalar); 47 NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_ : mpsScalarData}; 48 49 Placeholder outputPlaceholder = 50 Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput); 51 52 NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = 53 @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()}; 54 55 runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); 56 57 if (needsCopyToOutput) { 58 self.copy_(output); 59 } 60 } 61 62 return self; 63} 64 65// returns false if tensor cannot be filled with fillBuffer() 66static bool fill_mps_tensor_(Tensor& self, uint8_t value) { 67 if (self.is_contiguous()) { 68 MPSStream* stream = getCurrentMPSStream(); 69 auto storage_byte_offset = self.storage_offset() * self.itemsize(); 70 stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset); 71 return true; 72 } 73 return false; 74} 75 76Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) { 77 if (isComplexType(self.scalar_type())) { 78 auto self_as_real = at::view_as_real(self); 79 auto self_as_real_real = self_as_real.select(self.dim(), 0); 80 auto self_as_real_imag = self_as_real.select(self.dim(), 1); 81 if (value.isComplex()) { 82 auto value_cdouble = value.to<c10::complex<double>>(); 83 fill_scalar_mps_impl(self_as_real_real, value_cdouble.real()); 84 fill_scalar_mps_impl(self_as_real_imag, value_cdouble.imag()); 85 return self; 86 } 87 fill_scalar_mps_impl(self_as_real_real, value); 88 fill_scalar_mps_impl(self_as_real_imag, 0.0f); 89 return self; 90 } 91 // check if it's possible to use fillBuffer() to fill the Tensor's storage 92 if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) 93 return self; 94 95 return fill_scalar_mps_impl(self, value); 96} 97 98Tensor& fill_tensor_mps_(Tensor& self, const Tensor& value) { 99 TORCH_CHECK(value.dim() == 0, 100 "fill_ only supports 0-dimension value tensor but got tensor with ", 101 value.dim(), 102 " dimensions."); 103 Scalar scalar_value = value.item(); 104 if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true) 105 return self; 106 return fill_scalar_mps(self, scalar_value); 107} 108 109Tensor& zero_mps_(Tensor& self) { 110 return fill_scalar_mps(self, 0.0f); 111} 112 113} // namespace at::native 114