• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_DEVICE_TENSOR_H
18 #define MINDSPORE_DEVICE_TENSOR_H
19 
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include <map>
24 #include <unordered_map>
25 #include <utility>
26 #include <mutex>
27 #include "ir/tensor.h"
28 #include "ir/dtype.h"
29 #include "ir/device_sync.h"
30 #include "utils/shape_utils.h"
31 #include "utils/check_convert_utils.h"
32 #include "include/common/utils/utils.h"
33 #include "include/backend/device_type.h"
34 #include "kernel/kernel.h"
35 
36 namespace mindspore {
37 namespace device {
38 namespace cpu {
39 class CPUSimpleMemPlan;
40 class CPUMemoryManager;
41 class CPUKernelRuntime;
42 class CPUDeviceContext;
43 }  // namespace cpu
44 namespace ascend {
45 class AscendKernelRuntime;
46 class AscendRuntimeCore;
47 class AscendMemoryManager;
48 class AscendDeviceContext;
49 #ifndef ENABLE_SECURITY
50 class DataDumper;
51 #endif
52 namespace tasksink {
53 class TaskGenerator;
54 }  // namespace tasksink
55 }  // namespace ascend
56 namespace gpu {
57 class GPUKernelRuntime;
58 class GPUMemoryManager;
59 class GPUDeviceContext;
60 }  // namespace gpu
61 }  // namespace device
62 class SingleOpInferSession;
63 class RuntimeUtils;
64 }  // namespace mindspore
65 
66 namespace mindspore {
67 namespace device {
68 using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
69 using kernel::AddressCommon;
70 using kernel::AddressCommonPtr;
71 using kernel::KernelTensor;
72 using kernel::KernelTensorPtr;
73 
74 struct StorageInfo {
75   void *host_ptr_{nullptr};
76   std::string file_name_{""};
77   bool host_ptr_mutable_{true};
78   bool file_name_mutable_{true};
79 };
80 
81 enum class StorageType { kDevice, kHost, kFile };
82 
83 enum class DeviceAddressStatus {
84   kInDevice,
85   kInHost,
86   kInFile,
87   kInDeviceToHost,
88   kInHostToDevice,
89   kInHostToFile,
90   kInFileToHost
91 };
92 
93 // The flag of device address.
94 constexpr size_t kDeviceAddressFlagInit = 0;
95 // Indicates that it is the device address of ref node.
96 constexpr size_t kDeviceAddressFlagRefNode = 1;
97 // Indicates that it is the device address of node which has no user.
98 constexpr size_t kDeviceAddressFlagNotUsed = 2;
99 // Indicates that it is the device address of node has init arg and do not need device address.
100 constexpr size_t kDeviceAddressFlagIgnoreDevicePtr = 4;
101 // Indicates that it is the ptr of device address is nullptr.
102 constexpr size_t kDeviceAddressFlagNullptr = 8;
103 
104 class DeviceAddress : public mindspore::DeviceSync {
105  public:
DeviceAddress(const KernelTensorPtr & kernel_tensor)106   explicit DeviceAddress(const KernelTensorPtr &kernel_tensor)
107       : kernel_tensor_(kernel_tensor), address_common_(kernel_tensor_->address_common()) {}
108 
DeviceAddress(void * ptr,size_t size)109   explicit DeviceAddress(void *ptr, size_t size) {
110     address_common_ = std::make_shared<AddressCommon>(ptr, size);
111     kernel_tensor_ = std::make_shared<KernelTensor>();
112   }
DeviceAddress(void * ptr,size_t size,const string & format,TypeId type_id)113   explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) {
114     kernel_tensor_ = std::make_shared<KernelTensor>();
115     address_common_ = kernel_tensor_->address_common();
116     address_common_->pointer_ref_count_->set_ptr(ptr);
117     address_common_->size_ = size;
118     address_common_->dtype_id_ = type_id;
119     kernel_tensor_->SetStringFormat(format);
120   }
DeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index)121   explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
122                          const KernelWithIndex &node_index)
123       : node_index_(node_index) {
124     kernel_tensor_ = std::make_shared<KernelTensor>();
125     address_common_ = kernel_tensor_->address_common();
126     address_common_->pointer_ref_count_->set_ptr(ptr);
127     address_common_->size_ = size;
128     address_common_->dtype_id_ = type_id;
129     kernel_tensor_->SetStringFormat(format);
130   }
131 
DeviceAddress(void * ptr,size_t size,const std::string & device_name,uint32_t device_id)132   explicit DeviceAddress(void *ptr, size_t size, const std::string &device_name, uint32_t device_id) {
133     kernel_tensor_ = std::make_shared<KernelTensor>();
134     address_common_ = kernel_tensor_->address_common();
135     address_common_->pointer_ref_count_->set_ptr(ptr);
136     address_common_->size_ = size;
137     address_common_->device_name_ = device_name;
138     kernel_tensor_->set_device_id(device_id);
139   }
DeviceAddress(void * ptr,size_t size,const string & format,TypeId type_id,const std::string & device_name,uint32_t device_id)140   explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id, const std::string &device_name,
141                          uint32_t device_id) {
142     kernel_tensor_ = std::make_shared<KernelTensor>();
143     address_common_ = kernel_tensor_->address_common();
144     address_common_->pointer_ref_count_->set_ptr(ptr);
145     address_common_->size_ = size;
146     address_common_->device_name_ = device_name;
147     address_common_->dtype_id_ = type_id;
148     kernel_tensor_->SetStringFormat(format);
149     kernel_tensor_->set_device_id(device_id);
150   }
DeviceAddress(void * ptr,size_t size,const ShapeVector & shape_vector,const Format & format,TypeId type_id,const std::string & device_name,uint32_t device_id,uint32_t stream_id)151   explicit DeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector, const Format &format, TypeId type_id,
152                          const std::string &device_name, uint32_t device_id, uint32_t stream_id) {
153     address_common_ =
154       std::make_shared<AddressCommon>(ptr, size, shape_vector, format, type_id, device_name, device_id, stream_id);
155   }
DeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index,const std::string & device_name,uint32_t device_id)156   explicit DeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
157                          const KernelWithIndex &node_index, const std::string &device_name, uint32_t device_id)
158       : node_index_(node_index) {
159     kernel_tensor_ = std::make_shared<KernelTensor>();
160     address_common_ = kernel_tensor_->address_common();
161     address_common_->pointer_ref_count_->set_ptr(ptr);
162     address_common_->size_ = size;
163     address_common_->device_name_ = device_name;
164     address_common_->dtype_id_ = type_id;
165     kernel_tensor_->SetStringFormat(format);
166     kernel_tensor_->set_device_id(device_id);
167   }
168 
~DeviceAddress()169   virtual ~DeviceAddress() {
170     if (!from_mem_pool() && deleter_ && GetDevicePtr() != nullptr) {
171       deleter_(static_cast<uint8_t *>(GetDevicePtr()));
172       SetDevicePtr(nullptr);
173     } else {
174       address_common_->pointer_ref_count_ = nullptr;
175     }
176   }
AsyncHostToDevice(size_t size,TypeId,const void * host_ptr)177   virtual bool AsyncHostToDevice(size_t size, TypeId /* type */, const void *host_ptr) const { return true; }
178 
AsyncHostToDevice(size_t size,const void * host_ptr)179   virtual bool AsyncHostToDevice(size_t size, const void *host_ptr) const { return true; }
AsyncDeviceToHost(size_t size,void * host_ptr)180   virtual bool AsyncDeviceToHost(size_t size, void *host_ptr) const { return true; }
181 
182   // Asynchronously copy host memory to device side.
AsyncHostToDevice(const ShapeVector &,size_t,TypeId,const void *,size_t)183   virtual bool AsyncHostToDevice(const ShapeVector &, size_t, TypeId, const void *, size_t) const { return true; }
184   // Asynchronously copy device memory to host side.
AsyncDeviceToHost(const ShapeVector &,size_t,TypeId,void *,size_t)185   virtual bool AsyncDeviceToHost(const ShapeVector &, size_t, TypeId, void *, size_t) const { return true; }
186   // Synchronously copy device memory to device side.
SyncDeviceToDevice(const DeviceSync *)187   virtual bool SyncDeviceToDevice(const DeviceSync *) const { return true; }
SyncDeviceToDevice(const ShapeVector &,size_t,TypeId,const void *,const std::string &)188   virtual bool SyncDeviceToDevice(const ShapeVector &, size_t, TypeId, const void *, const std::string &) const {
189     return true;
190   }
191   // Asynchronously copy device memory to device side.
AsyncDeviceToDevice(const ShapeVector &,size_t,TypeId,const void *,const std::string &)192   virtual bool AsyncDeviceToDevice(const ShapeVector &, size_t, TypeId, const void *, const std::string &) const {
193     return true;
194   }
CopyDeviceToHost(void * dst,const void * src,const size_t & size)195   virtual bool CopyDeviceToHost(void *dst, const void *src, const size_t &size) const { return true; }
CopyHostToDevice(void * dst,const void * src,const size_t & size)196   virtual bool CopyHostToDevice(void *dst, const void *src, const size_t &size) const { return true; }
DeviceSynchronizerInit()197   virtual void DeviceSynchronizerInit() { MS_LOG(EXCEPTION) << "Not implemented."; }
198 
199   // Get kernel tensor pointer.
kernel_tensor()200   const KernelTensorPtr &kernel_tensor() const { return kernel_tensor_; }
set_kernel_tensor(const KernelTensorPtr & kernel_tensor)201   void set_kernel_tensor(const KernelTensorPtr &kernel_tensor) {
202     kernel_tensor_ = kernel_tensor;
203     address_common_ = kernel_tensor_->address_common();
204   }
205 
set_device_synchronizer(const DeviceSynchronizerPtr & device_synchronizer)206   void set_device_synchronizer(const DeviceSynchronizerPtr &device_synchronizer) {
207     MS_EXCEPTION_IF_NULL(kernel_tensor_);
208     kernel_tensor_->set_device_synchronizer(device_synchronizer);
209   }
210 
GetPtr()211   const void *GetPtr() const {
212     std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
213     return GetDevicePtr();
214   }
set_ptr(void * ptr)215   void set_ptr(void *ptr) {
216     std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
217     address_common_->pointer_ref_count_->set_ptr(ptr);
218     if (ptr != nullptr) {
219       const auto &storage_info = GetStorageInfo();
220       if (storage_info.host_ptr_ == nullptr && storage_info.file_name_.empty()) {
221         status_ = DeviceAddressStatus::kInDevice;
222       }
223     }
224   }
GetSize()225   size_t GetSize() const { return size(); }
SetSize(size_t size)226   void SetSize(size_t size) { address_common_->size_ = size; }
227 
format()228   std::string format() const { return kernel::GetFormatFromEnumToStr(address_common_->format_); }
set_format(const std::string & format)229   void set_format(const std::string &format) { address_common_->format_ = kernel::GetFormatFromStrToEnum(format); }
padding_type()230   const std::string &padding_type() const { return padding_type_; }
set_padding_type(const std::string & padding_type)231   void set_padding_type(const std::string &padding_type) { padding_type_ = padding_type; }
type_id()232   TypeId type_id() const { return address_common_->dtype_id_; }
set_type_id(TypeId type_id)233   void set_type_id(TypeId type_id) { address_common_->dtype_id_ = type_id; }
from_mem_pool()234   bool from_mem_pool() const { return address_common_->pointer_ref_count_->from_mem_pool(); }
set_from_mem_pool(bool from_mem_pool)235   void set_from_mem_pool(bool from_mem_pool) const {
236     address_common_->pointer_ref_count_->set_from_mem_pool(from_mem_pool);
237   }
set_communication_ptr(uint8_t * communication_ptr)238   virtual void set_communication_ptr(uint8_t *communication_ptr) { MS_LOG(EXCEPTION) << "Not implemented error."; }
is_ptr_persisted()239   bool is_ptr_persisted() const { return is_ptr_persisted_; }
set_is_ptr_persisted(bool is_ptr_persisted)240   void set_is_ptr_persisted(bool is_ptr_persisted) { is_ptr_persisted_ = is_ptr_persisted; }
set_host_shape(const ShapeVector & shape)241   void set_host_shape(const ShapeVector &shape) { kernel_tensor_->set_host_shape(shape); }
host_shape()242   const ShapeVector &host_shape() const { return kernel_tensor_->host_shape(); }
set_device_shape(const ShapeVector & shape)243   void set_device_shape(const ShapeVector &shape) { device_shape_ = shape; }
device_shape()244   const ShapeVector &device_shape() const { return device_shape_; }
from_persistent_mem()245   bool from_persistent_mem() const { return from_persistent_mem_; }
set_from_persistent_mem(bool from_persistent_mem)246   void set_from_persistent_mem(bool from_persistent_mem) { from_persistent_mem_ = from_persistent_mem; }
need_recycle()247   bool need_recycle() const { return need_recycle_; }
set_need_recycle(bool need_recycle)248   void set_need_recycle(bool need_recycle) { need_recycle_ = need_recycle; }
mem_offloaded()249   virtual bool mem_offloaded() const { return false; }
set_status(DeviceAddressStatus status)250   void set_status(DeviceAddressStatus status) { status_ = status; }
status()251   DeviceAddressStatus status() const { return status_; }
GetDeviceType()252   virtual DeviceType GetDeviceType() const { return DeviceType::kUnknown; }
GetMutablePtr()253   void *GetMutablePtr() const override {
254     std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
255     return GetDevicePtr();
256   }
257   // Get the shape vector for Tensor/Sequence/Scalar.
GetShapeVector()258   const ShapeVector &GetShapeVector() const { return address_common_->shape_vector_; }
259 
GetTensorStorageInfo()260   const TensorStorageInfoPtr GetTensorStorageInfo() const override {
261     if (address_common_ == nullptr) {
262       return nullptr;
263     }
264 
265     return address_common_->tensor_storage_info_;
266   }
set_tensor_storage_info(const TensorStorageInfoPtr & tensor_storage_info)267   void set_tensor_storage_info(const TensorStorageInfoPtr &tensor_storage_info) {
268     address_common_->tensor_storage_info_ = tensor_storage_info;
269   }
270 
device_name()271   const std::string &device_name() const { return address_common_->device_name_; }
device_id()272   uint32_t device_id() const { return address_common_->device_id_; }
273 
set_stream_id(uint32_t stream_id)274   void set_stream_id(uint32_t stream_id) { address_common_->stream_id_ = stream_id; }
stream_id()275   const uint32_t stream_id() const { return address_common_->stream_id_; }
276 
AddHeldByNode(const std::weak_ptr<ValueNode> & value_node)277   void AddHeldByNode(const std::weak_ptr<ValueNode> &value_node) { (void)held_by_nodes_.emplace_back(value_node); }
held_by_nodes()278   std::vector<std::weak_ptr<ValueNode>> held_by_nodes() const { return held_by_nodes_; }
ClearHeldByNodes()279   void ClearHeldByNodes() { held_by_nodes_.clear(); }
280 
SetNodeIndex(const AnfNodePtr & node,size_t out_index)281   virtual void SetNodeIndex(const AnfNodePtr &node, size_t out_index) { node_index_ = {node, out_index}; }
GetNodeIndex()282   KernelWithIndex GetNodeIndex() const {
283     return node_index_.first.expired() ? KernelWithIndex{nullptr, node_index_.second}
284                                        : KernelWithIndex{node_index_.first.lock(), node_index_.second};
285   }
286 
IncreaseCounter()287   size_t IncreaseCounter() { return address_common_->pointer_ref_count_->IncreaseCounter(); }
DecreaseCounter()288   size_t DecreaseCounter() { return address_common_->pointer_ref_count_->DecreaseCounter(); }
289 
290   // The related interface of reference count operation.
set_original_ref_count(size_t original_ref_count)291   void set_original_ref_count(size_t original_ref_count) const override {
292     address_common_->pointer_ref_count_->set_original_ref_count(original_ref_count);
293   }
original_ref_count()294   size_t original_ref_count() const override { return address_common_->pointer_ref_count_->original_ref_count(); }
set_ref_count(size_t ref_count)295   void set_ref_count(size_t ref_count) const override { address_common_->pointer_ref_count_->set_ref_count(ref_count); }
ref_count()296   size_t ref_count() const override { return address_common_->pointer_ref_count_->ref_count(); }
ResetRefCount()297   void ResetRefCount() override { address_common_->pointer_ref_count_->ResetRefCount(); }
298 
IncreaseOriginalRefCount()299   void IncreaseOriginalRefCount() {
300     if (original_ref_count() < SIZE_MAX) {
301       address_common_->pointer_ref_count_->IncreaseOriginalRefCount();
302     }
303   }
DecreaseOriginalRefCount()304   void DecreaseOriginalRefCount() {
305     if ((original_ref_count() < SIZE_MAX) && (original_ref_count() > 0)) {
306       address_common_->pointer_ref_count_->DecreaseOriginalRefCount();
307     }
308   }
DecreaseRefCount()309   size_t DecreaseRefCount() { return address_common_->pointer_ref_count_->DecreaseRefCount(); }
310 
311   // The related interface of dynamic reference count operation.
set_dynamic_ref_count(int32_t dynamic_ref_count)312   void set_dynamic_ref_count(int32_t dynamic_ref_count) {
313     address_common_->pointer_ref_count_->set_dynamic_ref_count(dynamic_ref_count);
314   }
315 
dynamic_ref_count()316   int32_t dynamic_ref_count() const { return address_common_->pointer_ref_count_->dynamic_ref_count(); }
IncreaseDynamicRefCount(const std::string & op_object)317   void IncreaseDynamicRefCount(const std::string &op_object) {
318     address_common_->pointer_ref_count_->IncreaseDynamicRefCount(op_object);
319   }
DecreaseDynamicRefCount(const std::string & op_object)320   int32_t DecreaseDynamicRefCount(const std::string &op_object) {
321     return address_common_->pointer_ref_count_->DecreaseDynamicRefCount(op_object);
322   }
323 
DumpMemToFile(const std::string & filepath,const std::string & host_fmt,const ShapeVector & host_shape,TypeId host_type,bool trans_flag)324   virtual bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
325                              TypeId host_type, bool trans_flag) const {
326     return true;
327   }
328 #ifdef ENABLE_DEBUGGER
329   virtual bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt,
330                              const ShapeVector &host_shape, TypeId host_type, size_t slot, bool keep_prev,
331                              uint32_t root_graph_id, bool force_update, bool trans_flag, bool async_copy = true) const {
332     return true;
333   }
334 #endif
335 
336   // Return whether DeviceAddress has a valid ptr.
IsPtrValid()337   virtual bool IsPtrValid() const {
338     std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
339     return GetDevicePtr() != nullptr;
340   }
341 
IsNotNeedAlloc()342   bool IsNotNeedAlloc() const { return IsPtrValid() || TEST_FLAG(flag(), device::kDeviceAddressFlagNotUsed); }
343 
344   using SyncUserDataHandler = void (*)(DeviceAddress *const device_address);
345   // Return the valid device ptr.
GetValidPtr(size_t)346   virtual void *GetValidPtr(size_t) {
347     if (user_data() == nullptr || (!need_sync_user_data_)) {
348       return GetDevicePtr();
349     }
350     std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
351     if (!need_sync_user_data_) {
352       return GetDevicePtr();
353     }
354     auto sync_handler = user_data()->get<SyncUserDataHandler>(kSyncUserDataHandler);
355     if (sync_handler == nullptr) {
356       MS_LOG(WARNING) << "For device address:" << this << ", the sync user data handler is null.";
357       return GetDevicePtr();
358     }
359     (*sync_handler)(this);
360     need_sync_user_data_ = false;
361     return GetDevicePtr();
362   }
363 
TouchSyncHandler()364   inline void TouchSyncHandler() {
365     if (!need_sync_user_data_ || kernel_tensor_->user_data() == nullptr) {
366       return;
367     }
368     std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
369     auto sync_handler = user_data()->get<SyncUserDataHandler>(kSyncUserDataHandler);
370     if (sync_handler == nullptr) {
371       MS_LOG(WARNING) << "For device address:" << this << ", the sync user data handler is null.";
372       return;
373     }
374     (*sync_handler)(this);
375     need_sync_user_data_ = false;
376   }
377 
378   // Offload data from device to host and free device memory
Offload(size_t)379   virtual bool Offload(size_t) { MS_LOG(EXCEPTION) << "Not implemented."; }
380 
381   // Load data from host to device and free host memory
Load(size_t)382   virtual bool Load(size_t) { MS_LOG(EXCEPTION) << "Not implemented."; }
383 
384   // Move data to destination hardware and free resource on source hardware
MoveTo(StorageType,bool,size_t)385   virtual bool MoveTo(StorageType, bool, size_t) { MS_LOG(EXCEPTION) << "Not implemented."; }
386 
Wait()387   virtual bool Wait() const { MS_LOG(EXCEPTION) << "Not implemented."; }
388 
389   // Set host ptr data offloaded to
SetOffloadPtr(void *)390   virtual void SetOffloadPtr(void *) {}
391 
392   // Get offloaded host ptr
GetOffloadPtr()393   virtual void *GetOffloadPtr() const { return nullptr; }
394 
SetStorageInfo(const StorageInfo &)395   virtual void SetStorageInfo(const StorageInfo &) {}
GetStorageInfo()396   virtual StorageInfo GetStorageInfo() const { return StorageInfo(); }
397 
Swap(DeviceAddress * other)398   virtual void Swap(DeviceAddress *other) {
399     MS_EXCEPTION_IF_NULL(other);
400     if (other == this) {
401       return;
402     }
403     other->SetDevicePtr(GetDevicePtr());
404 
405     other->set_from_mem_pool(this->from_mem_pool());
406     other->set_deleter(deleter());
407     other->set_need_sync_user_data(need_sync_user_data_);
408     SetDevicePtr(nullptr);
409     this->set_from_mem_pool(false);
410     deleter_ = nullptr;
411     kernel_tensor()->set_task_id_on_stream(other->kernel_tensor()->task_id_on_stream());
412     kernel_tensor()->set_managed_by_somas(other->kernel_tensor()->managed_by_somas());
413   }
414 
set_swappable(bool)415   virtual void set_swappable(bool) {}
swappable()416   virtual bool swappable() { return false; }
417 
418   // Get user data maintained by the DeviceAddress.
user_data()419   const UserDataPtr &user_data() const override { return kernel_tensor_->user_data(); }
420 
421   // Set user data to the DeviceAddress.
set_user_data(const UserDataPtr & user_data)422   void set_user_data(const UserDataPtr &user_data) override { kernel_tensor_->set_user_data(user_data); }
423 
424   // Free the ptr in user data when the ref count is 0.
ClearUserData()425   virtual void ClearUserData() {}
426 
427   // The interface of flag.
flag()428   size_t flag() const { return flag_; }
set_flag(size_t flag)429   void set_flag(size_t flag) { flag_ = flag; }
UpdateFlag(size_t flag)430   void UpdateFlag(size_t flag) { SET_FLAG(flag_, flag); }
ClearFlag(size_t flag)431   void ClearFlag(size_t flag) { CLEAR_FLAG(flag_, flag); }
432 
node_index()433   std::pair<AnfNodeWeakPtr, size_t> node_index() const { return node_index_; }
set_deleter(const std::function<void (uint8_t *)> & deleter)434   void set_deleter(const std::function<void(uint8_t *)> &deleter) { deleter_ = deleter; }
deleter()435   std::function<void(uint8_t *)> deleter() const { return deleter_; }
436 
437   // For output of pyexecute kernel, the input data is stored in user data and the handler is used to sync data from
438   // user data to device ptr.
need_sync_user_data()439   bool need_sync_user_data() { return need_sync_user_data_; }
set_need_sync_user_data(bool need_sync_user_data)440   void set_need_sync_user_data(bool need_sync_user_data) { need_sync_user_data_ = need_sync_user_data; }
441 
pointer_ref_count()442   const PointerRefCountPtr &pointer_ref_count() const { return address_common_->pointer_ref_count_; }
set_pointer_ref_count(const PointerRefCountPtr & ptr_ref_cnt)443   void set_pointer_ref_count(const PointerRefCountPtr &ptr_ref_cnt) {
444     MS_EXCEPTION_IF_NULL(ptr_ref_cnt);
445     address_common_->pointer_ref_count_ = ptr_ref_cnt;
446   }
447 
set_is_view(bool is_view)448   void set_is_view(bool is_view) { is_view_ = is_view; }
is_view()449   bool is_view() const { return is_view_; }
address_common()450   AddressCommonPtr address_common() const { return address_common_; }
451 
452  protected:
453   KernelTensorPtr kernel_tensor_{nullptr};
454   // address basic info
455   AddressCommonPtr address_common_{nullptr};
size()456   size_t size() const { return address_common_->size_; }
457 
GetDevicePtr()458   void *GetDevicePtr() const { return address_common_->pointer_ref_count_->ptr(); }
SetDevicePtr(void * ptr)459   void SetDevicePtr(void *ptr) const { address_common_->pointer_ref_count_->set_ptr(ptr); }
460 
SetTypeId(TypeId type)461   void SetTypeId(TypeId type) const { address_common_->dtype_id_ = type; }
462 
463   ShapeVector device_shape_{};
464   // {node, out_index}
465   std::pair<AnfNodeWeakPtr, size_t> node_index_{AnfNodePtr(nullptr), 0};
466   // The DeviceAddress is held by ValueNodes. These ValueNodes are outputs of forward network.
467   // We need to release the device memory when the reference count of the device address in bprop graph is 0.
468   std::vector<std::weak_ptr<ValueNode>> held_by_nodes_;
469   // The device address of the node that owns the device address cannot be updated and replaced.
470   // Application scenario: set to true when the hardware execution mode requires that ptr cannot be changed during
471   // execution.
472   bool is_ptr_persisted_{false};
473   // Thread lock for ptr_.
474   mutable std::recursive_mutex ptr_mutex_;
475 
476   bool from_persistent_mem_{false};
477   bool need_recycle_{false};
478 
479   // The padding type corresponds to data format.
480   std::string padding_type_;
481 
482   // The device address flag.
483   size_t flag_{0};
484 
485   // Indicating whether the address is the input of view op.
486   // If yes, the device address cannot be reused with the host address in CPU.
487   bool is_view_{false};
488 
489   // The flag identify where data is stored
490   mutable DeviceAddressStatus status_{DeviceAddressStatus::kInDevice};
491   // Handler for sync data from user data.
492   bool need_sync_user_data_{false};
493   // The specified deleter to release memory
494   std::function<void(uint8_t *)> deleter_;
495   friend class KernelRuntime;
496   friend class MemoryManager;
497   friend class mindspore::device::ascend::tasksink::TaskGenerator;
498   friend class mindspore::device::cpu::CPUSimpleMemPlan;
499   friend class mindspore::device::cpu::CPUMemoryManager;
500   friend class mindspore::device::cpu::CPUKernelRuntime;
501   friend class mindspore::device::cpu::CPUDeviceContext;
502   friend class mindspore::device::gpu::GPUKernelRuntime;
503   friend class mindspore::device::gpu::GPUMemoryManager;
504   friend class mindspore::device::gpu::GPUDeviceContext;
505   friend class mindspore::device::ascend::AscendKernelRuntime;
506   friend class mindspore::device::ascend::AscendRuntimeCore;
507   friend class mindspore::device::ascend::AscendMemoryManager;
508   friend class mindspore::device::ascend::AscendDeviceContext;
509 #ifndef ENABLE_SECURITY
510   friend class mindspore::device::ascend::DataDumper;
511 #endif
512   friend class mindspore::SingleOpInferSession;
513   friend class mindspore::RuntimeUtils;
514 };
515 
516 using DeviceAddressPtr = std::shared_ptr<DeviceAddress>;
517 using DeviceAddressPtrList = std::vector<DeviceAddressPtr>;
518 }  // namespace device
519 }  // namespace mindspore
520 #endif  // MINDSPORE_DEVICE_TENSOR_H
521