• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //  Copyright © 2022 Apple Inc.
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/Utils.h>
6 #include <torch/library.h>
7 #include <ATen/mps/EmptyTensor.h>
8 #include <ATen/mps/MPSDevice.h>
9 #include <ATen/native/Resize.h>
10 #include <ATen/native/ResizeCommon.h>
11 #include <ATen/native/mps/Copy.h>
12 #include <ATen/native/mps/TensorFactory.h>
13 #include <ATen/Dispatch.h>
14 
15 #ifndef AT_PER_OPERATOR_HEADERS
16 #include <ATen/Functions.h>
17 #include <ATen/NativeFunctions.h>
18 #endif
19 #include <ATen/ops/_efficientzerotensor_native.h>
20 
21 #include <utility>
22 
23 namespace at::native {
24 
maybe_resize_storage_mps(TensorImpl * self,uint64_t new_size)25 static inline void maybe_resize_storage_mps(TensorImpl* self, uint64_t new_size) {
26   if (new_size == 0) {
27     return;
28   }
29 
30   auto storage = self->storage().unsafeGetStorageImpl();
31   if (!storage) {
32     TORCH_CHECK(false, "Tensor: invalid null storage");
33   }
34   uint64_t new_size_bytes = (new_size + self->storage_offset()) * self->dtype().itemsize();
35   if (new_size_bytes > self->storage().nbytes()) {
36     if (new_size_bytes == 0) {
37       storage->set_data_ptr_noswap(at::DataPtr(nullptr, at::Device(at::DeviceType::MPS, 0)));
38       storage->set_nbytes(0);
39     } else {
40       at::DataPtr new_data = storage->allocator()->allocate(new_size_bytes);
41       size_t copy_capacity = std::min<size_t>(new_size_bytes, storage->nbytes());
42       if (storage->data() && copy_capacity > 0) {
43         at::native::mps::copy_blit_mps(new_data.get(), storage->data(), copy_capacity);
44       }
45       // Destructively overwrite data_ptr
46       storage->set_data_ptr_noswap(std::move(new_data));
47       storage->set_nbytes(new_size_bytes);
48     }
49   }
50 }
51 
resize_impl_mps_(TensorImpl * self,IntArrayRef size,std::optional<IntArrayRef> stride,bool device_guard=true)52 inline TensorImpl* resize_impl_mps_(
53     TensorImpl* self,
54     IntArrayRef size,
55     std::optional<IntArrayRef> stride,
56     bool device_guard = true) {
57   if (self->sizes() == size && (!stride || self->strides() == stride)) {
58     return self;
59   }
60 
61   int64_t storage_size = 1;
62   if (stride) {
63     self->set_sizes_and_strides(size, *stride);
64     // NB: storage size can be different from numel.
65     storage_size = storage_size_for(size, *stride);
66   } else {
67     self->set_sizes_contiguous(size);
68     storage_size = self->numel();
69   }
70   maybe_resize_storage_mps(self, storage_size);
71 
72   return self;
73 }
74 
empty_mps(IntArrayRef size,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt,std::optional<c10::MemoryFormat> memory_format_opt)75 Tensor empty_mps(
76     IntArrayRef size,
77     std::optional<ScalarType> dtype_opt,
78     std::optional<Layout> layout_opt,
79     std::optional<Device> device_opt,
80     std::optional<bool> pin_memory_opt,
81     std::optional<c10::MemoryFormat> memory_format_opt) {
82 
83   return at::detail::empty_mps(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
84 }
85 
empty_strided_mps(IntArrayRef size,IntArrayRef stride,std::optional<ScalarType> dtype_opt,std::optional<Layout> layout_opt,std::optional<Device> device_opt,std::optional<bool> pin_memory_opt)86 Tensor empty_strided_mps(
87     IntArrayRef size,
88     IntArrayRef stride,
89     std::optional<ScalarType> dtype_opt,
90     std::optional<Layout> layout_opt,
91     std::optional<Device> device_opt,
92     std::optional<bool> pin_memory_opt) {
93   check_size_nonnegative(size);
94   // empty memory formatempty
95   auto t = at::native::empty_mps(
96       {0},
97       dtype_opt,
98       layout_opt,
99       device_opt,
100       pin_memory_opt);
101   resize_impl_mps_(t.unsafeGetTensorImpl(), size, stride);
102   return t;
103 }
104 
resize_mps_(const Tensor & self,IntArrayRef size,std::optional<MemoryFormat> optional_memory_format)105 const Tensor& resize_mps_(
106     const Tensor& self,
107     IntArrayRef size,
108     std::optional<MemoryFormat> optional_memory_format) {
109   if (self.has_names()) {
110     return resize_named_tensor_(self, size, optional_memory_format);
111   }
112   auto* self_ = self.unsafeGetTensorImpl();
113   int64_t old_storage_nbytes = self_->unsafe_storage() ? self_->unsafe_storage().nbytes() : 0;
114   resize_impl_mps_(self_, size, /*stride=*/std::nullopt);
115   if (optional_memory_format.has_value()) {
116     auto memory_format =
117         optional_memory_format.value();
118     TORCH_CHECK(
119         memory_format != MemoryFormat::Preserve,
120         "Unsupported memory format",
121         memory_format);
122     self_->empty_tensor_restride(memory_format);
123   }
124   // See Note [Enabling Deterministic Operations]
125   if (C10_UNLIKELY(at::globalContext().deterministicAlgorithms() && at::globalContext().deterministicFillUninitializedMemory())) {
126     at::native::fill_resize_deterministic_(self, old_storage_nbytes);
127   }
128   return self;
129 }
130 
set_mps_(Tensor & result)131 Tensor& set_mps_(Tensor& result) {
132   caffe2::TypeMeta dtype = result.dtype();
133   Storage storage(
134       Storage::use_byte_size_t(),
135       0,
136       at::mps::GetMPSAllocator(),
137       true);
138   result.set_(storage, 0, {0}, {});
139   TORCH_INTERNAL_ASSERT(dtype == result.dtype());
140   return result;
141 }
142 
set_storage_mps_(Tensor & result,Storage storage,int64_t storage_offset,IntArrayRef size,IntArrayRef stride)143 Tensor& set_storage_mps_(Tensor& result, Storage storage, int64_t storage_offset, IntArrayRef size, IntArrayRef stride) {
144   checkSetStorage(result, std::move(storage), storage_offset, size, stride);
145   //std::cout << "set storage_mps " << storage_offset << " stride " << stride << std::endl;
146   result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
147   std::optional<IntArrayRef> stride_opt = stride.data() != nullptr ?
148                                           std::optional<IntArrayRef>(stride) : std::nullopt;
149   at::native::resize_impl_mps_(result.unsafeGetTensorImpl(), size, stride_opt);
150   return result;
151 }
152 
_efficientzerotensor_mps(IntArrayRef size,std::optional<ScalarType> dtype,std::optional<Layout> layout,std::optional<Device> device,std::optional<bool> pin_memory)153 Tensor _efficientzerotensor_mps(IntArrayRef size,
154     std::optional<ScalarType> dtype,
155     std::optional<Layout> layout,
156     std::optional<Device> device,
157     std::optional<bool> pin_memory) {
158     auto device_ = device_or_default(device);
159     auto allocator = at::native::ZeroTensorAllocator(device_);
160     auto dtype_ = dtype_or_default(dtype);
161     auto zero_ks = at::DispatchKeySet(c10::DispatchKey::MPS) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
162     auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, std::nullopt);
163     return out;
164 }
165 
166 } // namespace at::native
167