1 // Copyright © 2022 Apple Inc.
2
3 #include <ATen/ATen.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/Utils.h>
6 #include <torch/library.h>
7 #include <ATen/mps/EmptyTensor.h>
8 #include <ATen/mps/MPSDevice.h>
9 #include <ATen/native/Resize.h>
10 #include <ATen/native/ResizeCommon.h>
11 #include <ATen/native/mps/Copy.h>
12 #include <ATen/native/mps/TensorFactory.h>
13 #include <ATen/Dispatch.h>
14
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #endif
19 #include <ATen/ops/_efficientzerotensor_native.h>
20
21 #include <utility>
22
23 namespace at::native {
24
maybe_resize_storage_mps(TensorImpl * self,uint64_t new_size)25 static inline void maybe_resize_storage_mps(TensorImpl* self, uint64_t new_size) {
26 if (new_size == 0) {
27 return;
28 }
29
30 auto storage = self->storage().unsafeGetStorageImpl();
31 if (!storage) {
32 TORCH_CHECK(false, "Tensor: invalid null storage");
33 }
34 uint64_t new_size_bytes = (new_size + self->storage_offset()) * self->dtype().itemsize();
35 if (new_size_bytes > self->storage().nbytes()) {
36 if (new_size_bytes == 0) {
37 storage->set_data_ptr_noswap(at::DataPtr(nullptr, at::Device(at::DeviceType::MPS, 0)));
38 storage->set_nbytes(0);
39 } else {
40 at::DataPtr new_data = storage->allocator()->allocate(new_size_bytes);
41 size_t copy_capacity = std::min<size_t>(new_size_bytes, storage->nbytes());
42 if (storage->data() && copy_capacity > 0) {
43 at::native::mps::copy_blit_mps(new_data.get(), storage->data(), copy_capacity);
44 }
45 // Destructively overwrite data_ptr
46 storage->set_data_ptr_noswap(std::move(new_data));
47 storage->set_nbytes(new_size_bytes);
48 }
49 }
50 }
51
resize_impl_mps_(TensorImpl * self,IntArrayRef size,std::optional<IntArrayRef> stride,bool device_guard=true)52 inline TensorImpl* resize_impl_mps_(
53 TensorImpl* self,
54 IntArrayRef size,
55 std::optional<IntArrayRef> stride,
56 bool device_guard = true) {
57 if (self->sizes() == size && (!stride || self->strides() == stride)) {
58 return self;
59 }
60
61 int64_t storage_size = 1;
62 if (stride) {
63 self->set_sizes_and_strides(size, *stride);
64 // NB: storage size can be different from numel.
65 storage_size = storage_size_for(size, *stride);
66 } else {
67 self->set_sizes_contiguous(size);
68 storage_size = self->numel();
69 }
70 maybe_resize_storage_mps(self, storage_size);
71
72 return self;
73 }
74
empty_mps(IntArrayRef size,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)75 Tensor empty_mps(
76 IntArrayRef size,
77 std::optional<ScalarType> dtype_opt,
78 std::optional<Layout> layout_opt,
79 std::optional<Device> device_opt,
80 std::optional<bool> pin_memory_opt,
81 std::optional<c10::MemoryFormat> memory_format_opt) {
82
83 return at::detail::empty_mps(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
84 }
85
empty_strided_mps(IntArrayRef size,IntArrayRef stride,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt)86 Tensor empty_strided_mps(
87 IntArrayRef size,
88 IntArrayRef stride,
89 std::optional<ScalarType> dtype_opt,
90 std::optional<Layout> layout_opt,
91 std::optional<Device> device_opt,
92 std::optional<bool> pin_memory_opt) {
93 check_size_nonnegative(size);
94 // empty memory formatempty
95 auto t = at::native::empty_mps(
96 {0},
97 dtype_opt,
98 layout_opt,
99 device_opt,
100 pin_memory_opt);
101 resize_impl_mps_(t.unsafeGetTensorImpl(), size, stride);
102 return t;
103 }
104
resize_mps_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)105 const Tensor& resize_mps_(
106 const Tensor& self,
107 IntArrayRef size,
108 std::optional<MemoryFormat> optional_memory_format) {
109 if (self.has_names()) {
110 return resize_named_tensor_(self, size, optional_memory_format);
111 }
112 auto* self_ = self.unsafeGetTensorImpl();
113 int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().nbytes() : 0;
114 resize_impl_mps_(self_, size, /*stride=*/std::nullopt);
115 if (optional_memory_format.has_value()) {
116 auto memory_format =
117 optional_memory_format.value();
118 TORCH_CHECK(
119 memory_format != MemoryFormat::Preserve,
120 "Unsupported memory format",
121 memory_format);
122 self_->empty_tensor_restride(memory_format);
123 }
124 // See Note [Enabling Deterministic Operations]
125 if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
126 at::native::fill_resize_deterministic_(self, old_storage_nbytes);
127 }
128 return self;
129 }
130
set_mps_(Tensor & result)131 Tensor& set_mps_(Tensor& result) {
132 caffe2::TypeMeta dtype = result.dtype();
133 Storage storage(
134 Storage::use_byte_size_t(),
135 0,
136 at::mps::GetMPSAllocator(),
137 true);
138 result.set_(storage, 0, {0}, {});
139 TORCH_INTERNAL_ASSERT(dtype == result.dtype());
140 return result;
141 }
142
set_storage_mps_(Tensor & result,Storage storage,int64_t storage_offset,IntArrayRef size,IntArrayRef stride)143 Tensor& set_storage_mps_(Tensor& result, Storage storage, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) {
144 checkSetStorage(result, std::move(storage), storage_offset, size, stride);
145 //std::cout << "set storage_mps " << storage_offset << " stride " << stride << std::endl;
146 result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
147 std::optional<IntArrayRef> stride_opt = stride.data() != nullptr ?
148 std::optional<IntArrayRef>(stride) : std::nullopt;
149 at::native::resize_impl_mps_(result.unsafeGetTensorImpl(), size, stride_opt);
150 return result;
151 }
152
_efficientzerotensor_mps(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)153 Tensor _efficientzerotensor_mps(IntArrayRef size,
154 std::optional<ScalarType> dtype,
155 std::optional<Layout> layout,
156 std::optional<Device> device,
157 std::optional<bool> pin_memory) {
158 auto device_ = device_or_default(device);
159 auto allocator = at::native::ZeroTensorAllocator(device_);
160 auto dtype_ = dtype_or_default(dtype);
161 auto zero_ks = at::DispatchKeySet(c10::DispatchKey::MPS) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
162 auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, std::nullopt);
163 return out;
164 }
165
166 } // namespace at::native
167