1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_DEVICE_TENSOR_H 18 #define MINDSPORE_DEVICE_TENSOR_H 19 20 #include <string> 21 #include <vector> 22 #include <memory> 23 #include <map> 24 #include <unordered_map> 25 #include <utility> 26 #include <mutex> 27 #include "ir/tensor.h" 28 #include "ir/dtype.h" 29 #include "ir/device_sync.h" 30 #include "utils/shape_utils.h" 31 #include "utils/check_convert_utils.h" 32 #include "include/common/utils/utils.h" 33 #include "include/backend/device_type.h" 34 #include "kernel/kernel.h" 35 36 namespace mindspore { 37 namespace device { 38 namespace cpu { 39 class CPUSimpleMemPlan; 40 class CPUMemoryManager; 41 class CPUKernelRuntime; 42 class CPUDeviceContext; 43 } // namespace cpu 44 namespace ascend { 45 class AscendKernelRuntime; 46 class AscendRuntimeCore; 47 class AscendMemoryManager; 48 class AscendDeviceContext; 49 #ifndef ENABLE_SECURITY 50 class DataDumper; 51 #endif 52 namespace tasksink { 53 class TaskGenerator; 54 } // namespace tasksink 55 } // namespace ascend 56 namespace gpu { 57 class GPUKernelRuntime; 58 class GPUMemoryManager; 59 class GPUDeviceContext; 60 } // namespace gpu 61 } // namespace device 62 class SingleOpInferSession; 63 class RuntimeUtils; 64 } // namespace mindspore 65 66 namespace mindspore { 67 namespace device { 68 using KernelWithIndex = std::pair<AnfNodePtr, size_t>; 69 using kernel::AddressCommon; 70 using kernel::AddressCommonPtr; 71 using kernel::KernelTensor; 72 using kernel::KernelTensorPtr; 73 74 struct StorageInfo { 75 void *host_ptr_{nullptr}; 76 std::string file_name_{""}; 77 bool host_ptr_mutable_{true}; 78 bool file_name_mutable_{true}; 79 }; 80 81 enum class StorageType { kDevice, kHost, kFile }; 82 83 enum class DeviceAddressStatus { 84 kInDevice, 85 kInHost, 86 kInFile, 87 kInDeviceToHost, 88 kInHostToDevice, 89 kInHostToFile, 90 kInFileToHost 91 }; 92 93 // The flag of device address. 94 constexpr size_t kDeviceAddressFlagInit = 0; 95 // Indicates that it is the device address of ref node. 96 constexpr size_t kDeviceAddressFlagRefNode = 1; 97 // Indicates that it is the device address of node which has no user. 98 constexpr size_t kDeviceAddressFlagNotUsed = 2; 99 // Indicates that it is the device address of node has init arg and do not need device address. 100 constexpr size_t kDeviceAddressFlagIgnoreDevicePtr = 4; 101 // Indicates that it is the ptr of device address is nullptr. 102 constexpr size_t kDeviceAddressFlagNullptr = 8; 103 104 class DeviceAddress : public mindspore::DeviceSync { 105 public: DeviceAddress(const KernelTensorPtr & kernel_tensor)106 explicit DeviceAddress(const KernelTensorPtr &kernel_tensor) 107 : kernel_tensor_(kernel_tensor), address_common_(kernel_tensor_->address_common()) {} 108 DeviceAddress(void * ptr,size_t size)109 explicit DeviceAddress(void *ptr, size_t size) { 110 address_common_ = std::make_shared<AddressCommon>(ptr, size); 111 kernel_tensor_ = std::make_shared<KernelTensor>(); 112 } DeviceAddress(void * ptr,size_t size,const string & format,TypeId type_id)113 explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) { 114 kernel_tensor_ = std::make_shared<KernelTensor>(); 115 address_common_ = kernel_tensor_->address_common(); 116 address_common_->pointer_ref_count_->set_ptr(ptr); 117 address_common_->size_ = size; 118 address_common_->dtype_id_ = type_id; 119 kernel_tensor_->SetStringFormat(format); 120 } DeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index)121 explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, 122 const KernelWithIndex &node_index) 123 : node_index_(node_index) { 124 kernel_tensor_ = std::make_shared<KernelTensor>(); 125 address_common_ = kernel_tensor_->address_common(); 126 address_common_->pointer_ref_count_->set_ptr(ptr); 127 address_common_->size_ = size; 128 address_common_->dtype_id_ = type_id; 129 kernel_tensor_->SetStringFormat(format); 130 } 131 DeviceAddress(void * ptr,size_t size,const std::string & device_name,uint32_t device_id)132 explicit DeviceAddress(void *ptr, size_t size, const std::string &device_name, uint32_t device_id) { 133 kernel_tensor_ = std::make_shared<KernelTensor>(); 134 address_common_ = kernel_tensor_->address_common(); 135 address_common_->pointer_ref_count_->set_ptr(ptr); 136 address_common_->size_ = size; 137 address_common_->device_name_ = device_name; 138 kernel_tensor_->set_device_id(device_id); 139 } DeviceAddress(void * ptr,size_t size,const string & format,TypeId type_id,const std::string & device_name,uint32_t device_id)140 explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id, const std::string &device_name, 141 uint32_t device_id) { 142 kernel_tensor_ = std::make_shared<KernelTensor>(); 143 address_common_ = kernel_tensor_->address_common(); 144 address_common_->pointer_ref_count_->set_ptr(ptr); 145 address_common_->size_ = size; 146 address_common_->device_name_ = device_name; 147 address_common_->dtype_id_ = type_id; 148 kernel_tensor_->SetStringFormat(format); 149 kernel_tensor_->set_device_id(device_id); 150 } DeviceAddress(void * ptr,size_t size,const ShapeVector & shape_vector,const Format & format,TypeId type_id,const std::string & device_name,uint32_t device_id,uint32_t stream_id)151 explicit DeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector, const Format &format, TypeId type_id, 152 const std::string &device_name, uint32_t device_id, uint32_t stream_id) { 153 address_common_ = 154 std::make_shared<AddressCommon>(ptr, size, shape_vector, format, type_id, device_name, device_id, stream_id); 155 } DeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index,const std::string & device_name,uint32_t device_id)156 explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, 157 const KernelWithIndex &node_index, const std::string &device_name, uint32_t device_id) 158 : node_index_(node_index) { 159 kernel_tensor_ = std::make_shared<KernelTensor>(); 160 address_common_ = kernel_tensor_->address_common(); 161 address_common_->pointer_ref_count_->set_ptr(ptr); 162 address_common_->size_ = size; 163 address_common_->device_name_ = device_name; 164 address_common_->dtype_id_ = type_id; 165 kernel_tensor_->SetStringFormat(format); 166 kernel_tensor_->set_device_id(device_id); 167 } 168 ~DeviceAddress()169 virtual ~DeviceAddress() { 170 if (!from_mem_pool() && deleter_ && GetDevicePtr() != nullptr) { 171 deleter_(static_cast<uint8_t *>(GetDevicePtr())); 172 SetDevicePtr(nullptr); 173 } else { 174 address_common_->pointer_ref_count_ = nullptr; 175 } 176 } AsyncHostToDevice(size_t size,TypeId,const void * host_ptr)177 virtual bool AsyncHostToDevice(size_t size, TypeId /* type */, const void *host_ptr) const { return true; } 178 AsyncHostToDevice(size_t size,const void * host_ptr)179 virtual bool AsyncHostToDevice(size_t size, const void *host_ptr) const { return true; } AsyncDeviceToHost(size_t size,void * host_ptr)180 virtual bool AsyncDeviceToHost(size_t size, void *host_ptr) const { return true; } 181 182 // Asynchronously copy host memory to device side. AsyncHostToDevice(const ShapeVector &,size_t,TypeId,const void *,size_t)183 virtual bool AsyncHostToDevice(const ShapeVector &, size_t, TypeId, const void *, size_t) const { return true; } 184 // Asynchronously copy device memory to host side. AsyncDeviceToHost(const ShapeVector &,size_t,TypeId,void *,size_t)185 virtual bool AsyncDeviceToHost(const ShapeVector &, size_t, TypeId, void *, size_t) const { return true; } 186 // Synchronously copy device memory to device side. SyncDeviceToDevice(const DeviceSync *)187 virtual bool SyncDeviceToDevice(const DeviceSync *) const { return true; } SyncDeviceToDevice(const ShapeVector &,size_t,TypeId,const void *,const std::string &)188 virtual bool SyncDeviceToDevice(const ShapeVector &, size_t, TypeId, const void *, const std::string &) const { 189 return true; 190 } 191 // Asynchronously copy device memory to device side. AsyncDeviceToDevice(const ShapeVector &,size_t,TypeId,const void *,const std::string &)192 virtual bool AsyncDeviceToDevice(const ShapeVector &, size_t, TypeId, const void *, const std::string &) const { 193 return true; 194 } CopyDeviceToHost(void * dst,const void * src,const size_t & size)195 virtual bool CopyDeviceToHost(void *dst, const void *src, const size_t &size) const { return true; } CopyHostToDevice(void * dst,const void * src,const size_t & size)196 virtual bool CopyHostToDevice(void *dst, const void *src, const size_t &size) const { return true; } DeviceSynchronizerInit()197 virtual void DeviceSynchronizerInit() { MS_LOG(EXCEPTION) << "Not implemented."; } 198 199 // Get kernel tensor pointer. kernel_tensor()200 const KernelTensorPtr &kernel_tensor() const { return kernel_tensor_; } set_kernel_tensor(const KernelTensorPtr & kernel_tensor)201 void set_kernel_tensor(const KernelTensorPtr &kernel_tensor) { 202 kernel_tensor_ = kernel_tensor; 203 address_common_ = kernel_tensor_->address_common(); 204 } 205 set_device_synchronizer(const DeviceSynchronizerPtr & device_synchronizer)206 void set_device_synchronizer(const DeviceSynchronizerPtr &device_synchronizer) { 207 MS_EXCEPTION_IF_NULL(kernel_tensor_); 208 kernel_tensor_->set_device_synchronizer(device_synchronizer); 209 } 210 GetPtr()211 const void *GetPtr() const { 212 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_); 213 return GetDevicePtr(); 214 } set_ptr(void * ptr)215 void set_ptr(void *ptr) { 216 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_); 217 address_common_->pointer_ref_count_->set_ptr(ptr); 218 if (ptr != nullptr) { 219 const auto &storage_info = GetStorageInfo(); 220 if (storage_info.host_ptr_ == nullptr && storage_info.file_name_.empty()) { 221 status_ = DeviceAddressStatus::kInDevice; 222 } 223 } 224 } GetSize()225 size_t GetSize() const { return size(); } SetSize(size_t size)226 void SetSize(size_t size) { address_common_->size_ = size; } 227 format()228 std::string format() const { return kernel::GetFormatFromEnumToStr(address_common_->format_); } set_format(const std::string & format)229 void set_format(const std::string &format) { address_common_->format_ = kernel::GetFormatFromStrToEnum(format); } padding_type()230 const std::string &padding_type() const { return padding_type_; } set_padding_type(const std::string & padding_type)231 void set_padding_type(const std::string &padding_type) { padding_type_ = padding_type; } type_id()232 TypeId type_id() const { return address_common_->dtype_id_; } set_type_id(TypeId type_id)233 void set_type_id(TypeId type_id) { address_common_->dtype_id_ = type_id; } from_mem_pool()234 bool from_mem_pool() const { return address_common_->pointer_ref_count_->from_mem_pool(); } set_from_mem_pool(bool from_mem_pool)235 void set_from_mem_pool(bool from_mem_pool) const { 236 address_common_->pointer_ref_count_->set_from_mem_pool(from_mem_pool); 237 } set_communication_ptr(uint8_t * communication_ptr)238 virtual void set_communication_ptr(uint8_t *communication_ptr) { MS_LOG(EXCEPTION) << "Not implemented error."; } is_ptr_persisted()239 bool is_ptr_persisted() const { return is_ptr_persisted_; } set_is_ptr_persisted(bool is_ptr_persisted)240 void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; } set_host_shape(const ShapeVector & shape)241 void set_host_shape(const ShapeVector &shape) { kernel_tensor_->set_host_shape(shape); } host_shape()242 const ShapeVector &host_shape() const { return kernel_tensor_->host_shape(); } set_device_shape(const ShapeVector & shape)243 void set_device_shape(const ShapeVector &shape) { device_shape_ = shape; } device_shape()244 const ShapeVector &device_shape() const { return device_shape_; } from_persistent_mem()245 bool from_persistent_mem() const { return from_persistent_mem_; } set_from_persistent_mem(bool from_persistent_mem)246 void set_from_persistent_mem(bool from_persistent_mem) { from_persistent_mem_ = from_persistent_mem; } need_recycle()247 bool need_recycle() const { return need_recycle_; } set_need_recycle(bool need_recycle)248 void set_need_recycle(bool need_recycle) { need_recycle_ = need_recycle; } mem_offloaded()249 virtual bool mem_offloaded() const { return false; } set_status(DeviceAddressStatus status)250 void set_status(DeviceAddressStatus status) { status_ = status; } status()251 DeviceAddressStatus status() const { return status_; } GetDeviceType()252 virtual DeviceType GetDeviceType() const { return DeviceType::kUnknown; } GetMutablePtr()253 void *GetMutablePtr() const override { 254 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_); 255 return GetDevicePtr(); 256 } 257 // Get the shape vector for Tensor/Sequence/Scalar. GetShapeVector()258 const ShapeVector &GetShapeVector() const { return address_common_->shape_vector_; } 259 GetTensorStorageInfo()260 const TensorStorageInfoPtr GetTensorStorageInfo() const override { 261 if (address_common_ == nullptr) { 262 return nullptr; 263 } 264 265 return address_common_->tensor_storage_info_; 266 } set_tensor_storage_info(const TensorStorageInfoPtr & tensor_storage_info)267 void set_tensor_storage_info(const TensorStorageInfoPtr &tensor_storage_info) { 268 address_common_->tensor_storage_info_ = tensor_storage_info; 269 } 270 device_name()271 const std::string &device_name() const { return address_common_->device_name_; } device_id()272 uint32_t device_id() const { return address_common_->device_id_; } 273 set_stream_id(uint32_t stream_id)274 void set_stream_id(uint32_t stream_id) { address_common_->stream_id_ = stream_id; } stream_id()275 const uint32_t stream_id() const { return address_common_->stream_id_; } 276 AddHeldByNode(const std::weak_ptr<ValueNode> & value_node)277 void AddHeldByNode(const std::weak_ptr<ValueNode> &value_node) { (void)held_by_nodes_.emplace_back(value_node); } held_by_nodes()278 std::vector<std::weak_ptr<ValueNode>> held_by_nodes() const { return held_by_nodes_; } ClearHeldByNodes()279 void ClearHeldByNodes() { held_by_nodes_.clear(); } 280 SetNodeIndex(const AnfNodePtr & node,size_t out_index)281 virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; } GetNodeIndex()282 KernelWithIndex GetNodeIndex() const { 283 return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second} 284 : KernelWithIndex{node_index_.first.lock(), node_index_.second}; 285 } 286 IncreaseCounter()287 size_t IncreaseCounter() { return address_common_->pointer_ref_count_->IncreaseCounter(); } DecreaseCounter()288 size_t DecreaseCounter() { return address_common_->pointer_ref_count_->DecreaseCounter(); } 289 290 // The related interface of reference count operation. set_original_ref_count(size_t original_ref_count)291 void set_original_ref_count(size_t original_ref_count) const override { 292 address_common_->pointer_ref_count_->set_original_ref_count(original_ref_count); 293 } original_ref_count()294 size_t original_ref_count() const override { return address_common_->pointer_ref_count_->original_ref_count(); } set_ref_count(size_t ref_count)295 void set_ref_count(size_t ref_count) const override { address_common_->pointer_ref_count_->set_ref_count(ref_count); } ref_count()296 size_t ref_count() const override { return address_common_->pointer_ref_count_->ref_count(); } ResetRefCount()297 void ResetRefCount() override { address_common_->pointer_ref_count_->ResetRefCount(); } 298 IncreaseOriginalRefCount()299 void IncreaseOriginalRefCount() { 300 if (original_ref_count() < SIZE_MAX) { 301 address_common_->pointer_ref_count_->IncreaseOriginalRefCount(); 302 } 303 } DecreaseOriginalRefCount()304 void DecreaseOriginalRefCount() { 305 if ((original_ref_count() < SIZE_MAX) && (original_ref_count() > 0)) { 306 address_common_->pointer_ref_count_->DecreaseOriginalRefCount(); 307 } 308 } DecreaseRefCount()309 size_t DecreaseRefCount() { return address_common_->pointer_ref_count_->DecreaseRefCount(); } 310 311 // The related interface of dynamic reference count operation. set_dynamic_ref_count(int32_t dynamic_ref_count)312 void set_dynamic_ref_count(int32_t dynamic_ref_count) { 313 address_common_->pointer_ref_count_->set_dynamic_ref_count(dynamic_ref_count); 314 } 315 dynamic_ref_count()316 int32_t dynamic_ref_count() const { return address_common_->pointer_ref_count_->dynamic_ref_count(); } IncreaseDynamicRefCount(const std::string & op_object)317 void IncreaseDynamicRefCount(const std::string &op_object) { 318 address_common_->pointer_ref_count_->IncreaseDynamicRefCount(op_object); 319 } DecreaseDynamicRefCount(const std::string & op_object)320 int32_t DecreaseDynamicRefCount(const std::string &op_object) { 321 return address_common_->pointer_ref_count_->DecreaseDynamicRefCount(op_object); 322 } 323 DumpMemToFile(const std::string & filepath,const std::string & host_fmt,const ShapeVector & host_shape,TypeId host_type,bool trans_flag)324 virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape, 325 TypeId host_type, bool trans_flag) const { 326 return true; 327 } 328 #ifdef ENABLE_DEBUGGER 329 virtual bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, 330 const ShapeVector &host_shape, TypeId host_type, size_t slot, bool keep_prev, 331 uint32_t root_graph_id, bool force_update, bool trans_flag, bool async_copy = true) const { 332 return true; 333 } 334 #endif 335 336 // Return whether DeviceAddress has a valid ptr. IsPtrValid()337 virtual bool IsPtrValid() const { 338 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_); 339 return GetDevicePtr() != nullptr; 340 } 341 IsNotNeedAlloc()342 bool IsNotNeedAlloc() const { return IsPtrValid() || TEST_FLAG(flag(), device::kDeviceAddressFlagNotUsed); } 343 344 using SyncUserDataHandler = void (*)(DeviceAddress *const device_address); 345 // Return the valid device ptr. GetValidPtr(size_t)346 virtual void *GetValidPtr(size_t) { 347 if (user_data() == nullptr || (!need_sync_user_data_)) { 348 return GetDevicePtr(); 349 } 350 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_); 351 if (!need_sync_user_data_) { 352 return GetDevicePtr(); 353 } 354 auto sync_handler = user_data()->get<SyncUserDataHandler>(kSyncUserDataHandler); 355 if (sync_handler == nullptr) { 356 MS_LOG(WARNING) << "For device address:" << this << ", the sync user data handler is null."; 357 return GetDevicePtr(); 358 } 359 (*sync_handler)(this); 360 need_sync_user_data_ = false; 361 return GetDevicePtr(); 362 } 363 TouchSyncHandler()364 inline void TouchSyncHandler() { 365 if (!need_sync_user_data_ || kernel_tensor_->user_data() == nullptr) { 366 return; 367 } 368 std::lock_guard<std::recursive_mutex> lock(ptr_mutex_); 369 auto sync_handler = user_data()->get<SyncUserDataHandler>(kSyncUserDataHandler); 370 if (sync_handler == nullptr) { 371 MS_LOG(WARNING) << "For device address:" << this << ", the sync user data handler is null."; 372 return; 373 } 374 (*sync_handler)(this); 375 need_sync_user_data_ = false; 376 } 377 378 // Offload data from device to host and free device memory Offload(size_t)379 virtual bool Offload(size_t) { MS_LOG(EXCEPTION) << "Not implemented."; } 380 381 // Load data from host to device and free host memory Load(size_t)382 virtual bool Load(size_t) { MS_LOG(EXCEPTION) << "Not implemented."; } 383 384 // Move data to destination hardware and free resource on source hardware MoveTo(StorageType,bool,size_t)385 virtual bool MoveTo(StorageType, bool, size_t) { MS_LOG(EXCEPTION) << "Not implemented."; } 386 Wait()387 virtual bool Wait() const { MS_LOG(EXCEPTION) << "Not implemented."; } 388 389 // Set host ptr data offloaded to SetOffloadPtr(void *)390 virtual void SetOffloadPtr(void *) {} 391 392 // Get offloaded host ptr GetOffloadPtr()393 virtual void *GetOffloadPtr() const { return nullptr; } 394 SetStorageInfo(const StorageInfo &)395 virtual void SetStorageInfo(const StorageInfo &) {} GetStorageInfo()396 virtual StorageInfo GetStorageInfo() const { return StorageInfo(); } 397 Swap(DeviceAddress * other)398 virtual void Swap(DeviceAddress *other) { 399 MS_EXCEPTION_IF_NULL(other); 400 if (other == this) { 401 return; 402 } 403 other->SetDevicePtr(GetDevicePtr()); 404 405 other->set_from_mem_pool(this->from_mem_pool()); 406 other->set_deleter(deleter()); 407 other->set_need_sync_user_data(need_sync_user_data_); 408 SetDevicePtr(nullptr); 409 this->set_from_mem_pool(false); 410 deleter_ = nullptr; 411 kernel_tensor()->set_task_id_on_stream(other->kernel_tensor()->task_id_on_stream()); 412 kernel_tensor()->set_managed_by_somas(other->kernel_tensor()->managed_by_somas()); 413 } 414 set_swappable(bool)415 virtual void set_swappable(bool) {} swappable()416 virtual bool swappable() { return false; } 417 418 // Get user data maintained by the DeviceAddress. user_data()419 const UserDataPtr &user_data() const override { return kernel_tensor_->user_data(); } 420 421 // Set user data to the DeviceAddress. set_user_data(const UserDataPtr & user_data)422 void set_user_data(const UserDataPtr &user_data) override { kernel_tensor_->set_user_data(user_data); } 423 424 // Free the ptr in user data when the ref count is 0. ClearUserData()425 virtual void ClearUserData() {} 426 427 // The interface of flag. flag()428 size_t flag() const { return flag_; } set_flag(size_t flag)429 void set_flag(size_t flag) { flag_ = flag; } UpdateFlag(size_t flag)430 void UpdateFlag(size_t flag) { SET_FLAG(flag_, flag); } ClearFlag(size_t flag)431 void ClearFlag(size_t flag) { CLEAR_FLAG(flag_, flag); } 432 node_index()433 std::pair<AnfNodeWeakPtr, size_t> node_index() const { return node_index_; } set_deleter(const std::function<void (uint8_t *)> & deleter)434 void set_deleter(const std::function<void(uint8_t *)> &deleter) { deleter_ = deleter; } deleter()435 std::function<void(uint8_t *)> deleter() const { return deleter_; } 436 437 // For output of pyexecute kernel, the input data is stored in user data and the handler is used to sync data from 438 // user data to device ptr. need_sync_user_data()439 bool need_sync_user_data() { return need_sync_user_data_; } set_need_sync_user_data(bool need_sync_user_data)440 void set_need_sync_user_data(bool need_sync_user_data) { need_sync_user_data_ = need_sync_user_data; } 441 pointer_ref_count()442 const PointerRefCountPtr &pointer_ref_count() const { return address_common_->pointer_ref_count_; } set_pointer_ref_count(const PointerRefCountPtr & ptr_ref_cnt)443 void set_pointer_ref_count(const PointerRefCountPtr &ptr_ref_cnt) { 444 MS_EXCEPTION_IF_NULL(ptr_ref_cnt); 445 address_common_->pointer_ref_count_ = ptr_ref_cnt; 446 } 447 set_is_view(bool is_view)448 void set_is_view(bool is_view) { is_view_ = is_view; } is_view()449 bool is_view() const { return is_view_; } address_common()450 AddressCommonPtr address_common() const { return address_common_; } 451 452 protected: 453 KernelTensorPtr kernel_tensor_{nullptr}; 454 // address basic info 455 AddressCommonPtr address_common_{nullptr}; size()456 size_t size() const { return address_common_->size_; } 457 GetDevicePtr()458 void *GetDevicePtr() const { return address_common_->pointer_ref_count_->ptr(); } SetDevicePtr(void * ptr)459 void SetDevicePtr(void *ptr) const { address_common_->pointer_ref_count_->set_ptr(ptr); } 460 SetTypeId(TypeId type)461 void SetTypeId(TypeId type) const { address_common_->dtype_id_ = type; } 462 463 ShapeVector device_shape_{}; 464 // {node, out_index} 465 std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0}; 466 // The DeviceAddress is held by ValueNodes. These ValueNodes are outputs of forward network. 467 // We need to release the device memory when the reference count of the device address in bprop graph is 0. 468 std::vector<std::weak_ptr<ValueNode>> held_by_nodes_; 469 // The device address of the node that owns the device address cannot be updated and replaced. 470 // Application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during 471 // execution. 472 bool is_ptr_persisted_{false}; 473 // Thread lock for ptr_. 474 mutable std::recursive_mutex ptr_mutex_; 475 476 bool from_persistent_mem_{false}; 477 bool need_recycle_{false}; 478 479 // The padding type corresponds to data format. 480 std::string padding_type_; 481 482 // The device address flag. 483 size_t flag_{0}; 484 485 // Indicating whether the address is the input of view op. 486 // If yes, the device address cannot be reused with the host address in CPU. 487 bool is_view_{false}; 488 489 // The flag identify where data is stored 490 mutable DeviceAddressStatus status_{DeviceAddressStatus::kInDevice}; 491 // Handler for sync data from user data. 492 bool need_sync_user_data_{false}; 493 // The specified deleter to release memory 494 std::function<void(uint8_t *)> deleter_; 495 friend class KernelRuntime; 496 friend class MemoryManager; 497 friend class mindspore::device::ascend::tasksink::TaskGenerator; 498 friend class mindspore::device::cpu::CPUSimpleMemPlan; 499 friend class mindspore::device::cpu::CPUMemoryManager; 500 friend class mindspore::device::cpu::CPUKernelRuntime; 501 friend class mindspore::device::cpu::CPUDeviceContext; 502 friend class mindspore::device::gpu::GPUKernelRuntime; 503 friend class mindspore::device::gpu::GPUMemoryManager; 504 friend class mindspore::device::gpu::GPUDeviceContext; 505 friend class mindspore::device::ascend::AscendKernelRuntime; 506 friend class mindspore::device::ascend::AscendRuntimeCore; 507 friend class mindspore::device::ascend::AscendMemoryManager; 508 friend class mindspore::device::ascend::AscendDeviceContext; 509 #ifndef ENABLE_SECURITY 510 friend class mindspore::device::ascend::DataDumper; 511 #endif 512 friend class mindspore::SingleOpInferSession; 513 friend class mindspore::RuntimeUtils; 514 }; 515 516 using DeviceAddressPtr = std::shared_ptr<DeviceAddress>; 517 using DeviceAddressPtrList = std::vector<DeviceAddressPtr>; 518 } // namespace device 519 } // namespace mindspore 520 #endif // MINDSPORE_DEVICE_TENSOR_H 521