• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1//  Copyright © 2022 Apple Inc.
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3#include <ATen/native/mps/OperationUtils.h>
4
5#ifndef AT_PER_OPERATOR_HEADERS
6#include <ATen/Functions.h>
7#include <ATen/NativeFunctions.h>
8#else
9#include <ATen/ops/fill_native.h>
10#include <ATen/ops/view_as_real.h>
11#include <ATen/ops/zero_native.h>
12#endif
13
14namespace at::native {
15
16static Tensor& fill_scalar_mps_impl(Tensor& self, const Scalar& value) {
17  using namespace mps;
18
19  if (self.numel() == 0) {
20    return self;
21  }
22  Tensor output = self;
23  bool needsCopyToOutput = false;
24  if (needsGather(self)) {
25    output = at::empty(self.sizes(), self.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
26    needsCopyToOutput = true;
27  }
28
29  struct CachedGraph : public MPSCachedGraph {
30    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
31    MPSGraphTensor* inputTensor_ = nil;
32    MPSGraphTensor* outputTensor_ = nil;
33  };
34
35  @autoreleasepool {
36    string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + std::to_string(value.toDouble());
37
38    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
39      MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()));
40      MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
41      newCachedGraph->inputTensor_ = inputTensor;
42      newCachedGraph->outputTensor_ = outputTensor;
43    });
44
45    auto mpsScalar = getMPSScalar(value, self.scalar_type());
46    auto mpsScalarData = getMPSGraphTensorFromScalar(getCurrentMPSStream(), mpsScalar);
47    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{cachedGraph->inputTensor_ : mpsScalarData};
48
49    Placeholder outputPlaceholder =
50        Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);
51
52    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
53        @{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};
54
55    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
56
57    if (needsCopyToOutput) {
58      self.copy_(output);
59    }
60  }
61
62  return self;
63}
64
65// returns false if tensor cannot be filled with fillBuffer()
66static bool fill_mps_tensor_(Tensor& self, uint8_t value) {
67  if (self.is_contiguous()) {
68    MPSStream* stream = getCurrentMPSStream();
69    auto storage_byte_offset = self.storage_offset() * self.itemsize();
70    stream->fill(mps::getMTLBufferStorage(self), value, self.nbytes(), storage_byte_offset);
71    return true;
72  }
73  return false;
74}
75
76Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) {
77  if (isComplexType(self.scalar_type())) {
78    auto self_as_real = at::view_as_real(self);
79    auto self_as_real_real = self_as_real.select(self.dim(), 0);
80    auto self_as_real_imag = self_as_real.select(self.dim(), 1);
81    if (value.isComplex()) {
82      auto value_cdouble = value.to<c10::complex<double>>();
83      fill_scalar_mps_impl(self_as_real_real, value_cdouble.real());
84      fill_scalar_mps_impl(self_as_real_imag, value_cdouble.imag());
85      return self;
86    }
87    fill_scalar_mps_impl(self_as_real_real, value);
88    fill_scalar_mps_impl(self_as_real_imag, 0.0f);
89    return self;
90  }
91  // check if it's possible to use fillBuffer() to fill the Tensor's storage
92  if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
93    return self;
94
95  return fill_scalar_mps_impl(self, value);
96}
97
98Tensor& fill_tensor_mps_(Tensor& self, const Tensor& value) {
99  TORCH_CHECK(value.dim() == 0,
100              "fill_ only supports 0-dimension value tensor but got tensor with ",
101              value.dim(),
102              " dimensions.");
103  Scalar scalar_value = value.item();
104  if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
105    return self;
106  return fill_scalar_mps(self, scalar_value);
107}
108
109Tensor& zero_mps_(Tensor& self) {
110  return fill_scalar_mps(self, 0.0f);
111}
112
113} // namespace at::native
114