• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "plugin/device/ascend/hal/device/ascend_device_address.h"
17 #include <memory>
18 #include <vector>
19 #include <unordered_map>
20 #include <utility>
21 #include <set>
22 #include "graph/types.h"
23 #include "pybind_api/gil_scoped_long_running.h"
24 #include "runtime/device/kernel_runtime_manager.h"
25 #include "runtime/device/memory_manager.h"
26 #include "runtime/device/convert_tensor_utils.h"
27 #include "plugin/device/ascend/hal/device/ascend_event.h"
28 #include "plugin/device/ascend/hal/device/ascend_stream_manager.h"
29 #include "ir/dtype/type.h"
30 #include "ir/tensor.h"
31 #include "abstract/utils.h"
32 #include "include/common/utils/utils.h"
33 #include "runtime/device/ms_device_shape_transfer.h"
34 #include "plugin/device/ascend/hal/device/ascend_device_synchronizer.h"
35 #ifndef ENABLE_SECURITY
36 #include "include/backend/debug/data_dump/dump_json_parser.h"
37 #endif
38 #ifdef ENABLE_DEBUGGER
39 #include "debug/tensor_load.h"
40 #endif
41 #include "transform/symbol/acl_rt_symbol.h"
42 #include "transform/symbol/symbol_utils.h"
43 
44 namespace py = pybind11;
45 namespace mindspore {
46 namespace device {
47 namespace ascend {
48 const auto kFloat16Bytes = 2;
49 const auto kFloatBytes = sizeof(float);
50 const auto kFloat64Bytes = 8;
51 static std::recursive_mutex transdata_mutx;
52 
53 #if defined(RT_MEMORY_P2PDMA)
54 static std::mutex dma_lock;
55 #endif
56 
IsUseTransDataTypeFormat(const std::pair<std::string,std::string> & type_format)57 bool IsUseTransDataTypeFormat(const std::pair<std::string, std::string> &type_format) {
58   static const std::set<std::pair<std::string, std::string>> use_trans_data = {
59     std::make_pair("float16", mindspore::kOpFormat_NC1HWC0), std::make_pair("float32", mindspore::kOpFormat_NC1HWC0),
60     std::make_pair("bool", mindspore::kOpFormat_NC1HWC0),    std::make_pair("float32", mindspore::kOpFormat_FRAC_Z),
61     std::make_pair("float16", mindspore::kOpFormat_FRAC_Z),  std::make_pair("float16", mindspore::kOpFormat_FRAC_NZ),
62     std::make_pair("float32", mindspore::kOpFormat_FRAC_NZ), std::make_pair("int32", mindspore::kOpFormat_FRAC_NZ),
63     std::make_pair("float16", mindspore::kOpFormat_NHWC),    std::make_pair("float32", mindspore::kOpFormat_NHWC),
64     std::make_pair("int8", mindspore::kOpFormat_NHWC),       std::make_pair("int16", mindspore::kOpFormat_NHWC),
65     std::make_pair("int32", mindspore::kOpFormat_NHWC),      std::make_pair("int64", mindspore::kOpFormat_NHWC),
66     std::make_pair("uint8", mindspore::kOpFormat_NHWC),      std::make_pair("uint16", mindspore::kOpFormat_NHWC),
67     std::make_pair("uint32", mindspore::kOpFormat_NHWC),     std::make_pair("uint64", mindspore::kOpFormat_NHWC),
68     std::make_pair("float16", mindspore::kOpFormat_HWCN),    std::make_pair("float32", mindspore::kOpFormat_HWCN),
69     std::make_pair("int8", mindspore::kOpFormat_HWCN),       std::make_pair("int16", mindspore::kOpFormat_HWCN),
70     std::make_pair("int32", mindspore::kOpFormat_HWCN),      std::make_pair("int64", mindspore::kOpFormat_HWCN),
71     std::make_pair("uint8", mindspore::kOpFormat_HWCN),      std::make_pair("uint16", mindspore::kOpFormat_HWCN),
72     std::make_pair("uint32", mindspore::kOpFormat_HWCN),     std::make_pair("uint64", mindspore::kOpFormat_HWCN)};
73   return use_trans_data.find(type_format) != use_trans_data.end();
74 }
75 
76 static const std::set<std::string> basic_format = {kOpFormat_NCHW, kOpFormat_DEFAULT, kOpFormat_NCDHW, kOpFormat_ND};
77 
IsOpNeedTransFormat(const std::string & format)78 bool IsOpNeedTransFormat(const std::string &format) {
79   static const std::set<std::string> op_need_trans_format = {
80     kOpFormat_NHWC,    kOpFormat_HWCN,        kOpFormat_NC1HWC0,       kOpFormat_FRAC_Z,   kOpFormat_C1HWNCoC0,
81     kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D};
82   return op_need_trans_format.find(format) != op_need_trans_format.end();
83 }
84 
DeviceSynchronizerInit()85 void AscendDeviceAddress::DeviceSynchronizerInit() {
86   set_device_synchronizer(std::make_shared<AscendDeviceSynchronizer>());
87 }
88 
SyncHostMemoryToDeviceWithCopySrc(void * dst,const void * src,uint64_t size,aclrtMemcpyKind kind,KernelRuntime * runtime_instance) const89 void AscendDeviceAddress::SyncHostMemoryToDeviceWithCopySrc(void *dst, const void *src, uint64_t size,
90                                                             aclrtMemcpyKind kind,
91                                                             KernelRuntime *runtime_instance) const {
92   MS_EXCEPTION_IF_NULL(runtime_instance);
93 
94   MS_LOG(DEBUG) << "Begin, size:" << size;
95   std::shared_ptr<uint8_t[]> buffer(new (std::nothrow) uint8_t[size]);
96   MS_EXCEPTION_IF_NULL(buffer);
97   auto ret_code = memcpy_s(buffer.get(), size, src, size);
98   // Return ERANGE when the copy size is larger than SECUREC_MEM_MAX_LEN.
99   if (ret_code == ERANGE) {
100     ConvertSameType(buffer.get(), src, size, type_id());
101   }
102 
103   const auto stream = AscendStreamMng::GetInstance().GetStream(this->stream_id());
104   auto ret = runtime_instance->MemcpyAsync(dst, buffer.get(), size, static_cast<int32_t>(kind), stream);
105   if (!ret) {
106     MS_LOG(EXCEPTION) << "MemcpyAsync failed!";
107   }
108 
109   device::CallbackFunc callback_func = [buffer]() {
110     // Clear buffer automatically.
111     MS_LOG(DEBUG) << "callback_func exec, buffer cnt:" << buffer.use_count();
112   };
113   auto device_context = GetDeviceContext();
114   MS_EXCEPTION_IF_NULL(device_context);
115   auto callback_ret = device_context->GetKernelExecutor(false)->LaunchCallback(callback_func, this->stream_id());
116   if (!callback_ret) {
117     MS_LOG(EXCEPTION) << "LaunchCallback failed";
118   }
119 }
120 
SyncHostMemoryToDeviceForTensorFromNumpy(void * dst,const void * src,uint64_t size,aclrtMemcpyKind kind,KernelRuntime * runtime_instance) const121 void AscendDeviceAddress::SyncHostMemoryToDeviceForTensorFromNumpy(void *dst, const void *src, uint64_t size,
122                                                                    aclrtMemcpyKind kind,
123                                                                    KernelRuntime *runtime_instance) const {
124   MS_EXCEPTION_IF_NULL(runtime_instance);
125   MS_LOG(DEBUG) << "Begin, size:" << size;
126 
127   runtime_instance->SetContextForce();
128   // Memcpy needs to be synchronized firstm, if tensor data is from numpy.
129   const auto stream = AscendStreamMng::GetInstance().GetStream(this->stream_id());
130   // cppcheck-suppress unreadVariable
131   auto lock = device::KernelRuntime::LockRuntime(stream);
132   if (!AscendStreamMng::GetInstance().SyncStream(stream)) {
133     MS_EXCEPTION(DeviceProcessError) << "Sync stream error!";
134   }
135 
136   auto ret_rt_memcpy = CALL_ASCEND_API(aclrtMemcpy, dst, size, src, size, kind);
137   MS_LOG(DEBUG) << "tensor is_from_numpy, sync it first";
138   if (ret_rt_memcpy != ACL_ERROR_NONE) {
139     MS_EXCEPTION(DeviceProcessError) << "aclrtMemcpy failed";
140   }
141 }
142 
SyncHostMemoryToDeviceWithTensorData(void * dst,const void * src,uint64_t size,aclrtMemcpyKind kind,const tensor::TensorDataPtr & tensor_data,KernelRuntime * runtime_instance) const143 void AscendDeviceAddress::SyncHostMemoryToDeviceWithTensorData(void *dst, const void *src, uint64_t size,
144                                                                aclrtMemcpyKind kind,
145                                                                const tensor::TensorDataPtr &tensor_data,
146                                                                KernelRuntime *runtime_instance) const {
147   MS_EXCEPTION_IF_NULL(runtime_instance);
148 
149   MS_LOG(DEBUG) << "Begin, size:" << size;
150   const auto stream = AscendStreamMng::GetInstance().GetStream(this->stream_id());
151   auto ret = runtime_instance->MemcpyAsync(dst, src, size, static_cast<int32_t>(kind), stream);
152   if (!ret) {
153     MS_LOG(EXCEPTION) << "MemcpyAsync failed!";
154   }
155   device::CallbackFunc callback_func = [tensor_data]() {
156     // Clear tensor_data automatically.
157     MS_LOG(DEBUG) << "callback_func exec, tensor_data cnt:" << tensor_data.use_count();
158   };
159   auto device_context = GetDeviceContext();
160   MS_EXCEPTION_IF_NULL(device_context);
161   auto callback_ret = device_context->GetKernelExecutor(false)->LaunchCallback(callback_func, this->stream_id());
162   if (!callback_ret) {
163     MS_LOG(EXCEPTION) << "LaunchCallback failed";
164   }
165 }
166 
SyncMemory(void * dst,const void * src,uint64_t size,aclrtMemcpyKind kind,const tensor::TensorDataPtr & tensor_data) const167 void AscendDeviceAddress::SyncMemory(void *dst, const void *src, uint64_t size, aclrtMemcpyKind kind,
168                                      const tensor::TensorDataPtr &tensor_data) const {
169   if (size == 0) {
170     return;
171   }
172   if (dst == nullptr) {
173     MS_LOG(EXCEPTION) << "dst ptr is null, please check the address is set correctly.";
174   }
175   if (src == nullptr) {
176     MS_LOG(EXCEPTION) << "src ptr is null, please check the address is set correctly.";
177   }
178   auto ms_context = MsContext::GetInstance();
179   MS_EXCEPTION_IF_NULL(ms_context);
180   auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
181   auto execution_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE);
182   auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
183   MS_EXCEPTION_IF_NULL(runtime_instance);
184   runtime_instance->SetContext();
185 
186   // Only apply asynchronous copy in Pynative && ACL_MEMCPY_HOST_TO_DEVICE mode
187   if (execution_mode != kPynativeMode || kind != ACL_MEMCPY_HOST_TO_DEVICE) {
188     auto ret = runtime_instance->SyncStream();
189     if (!ret) {
190       MS_LOG(EXCEPTION) << "Sync stream error!";
191     }
192     if (!common::IsNeedProfileMemory()) {
193       auto ret_rt_memcpy = CALL_ASCEND_API(aclrtMemcpy, dst, size, src, size, kind);
194       if (ret_rt_memcpy != ACL_ERROR_NONE) {
195         MS_EXCEPTION(DeviceProcessError) << "aclrtMemcpy failed";
196       }
197     }
198   } else {
199     if (tensor_data == nullptr) {
200       // tensor_data is nullptr. Need to copy host first, then dispatch callbacks.
201       SyncHostMemoryToDeviceWithCopySrc(dst, src, size, kind, runtime_instance);
202       return;
203     }
204     if (tensor_data->is_from_numpy()) {
205       SyncHostMemoryToDeviceForTensorFromNumpy(dst, src, size, kind, runtime_instance);
206     } else {
207       SyncHostMemoryToDeviceWithTensorData(dst, src, size, kind, tensor_data, runtime_instance);
208     }
209   }
210 }
211 
Float64ToFloatAndSyncHostToDevice(void * dst,size_t dst_size,const void * src,size_t src_size,const tensor::TensorDataPtr & tensor_data) const212 bool AscendDeviceAddress::Float64ToFloatAndSyncHostToDevice(void *dst, size_t dst_size, const void *src,
213                                                             size_t src_size,
214                                                             const tensor::TensorDataPtr &tensor_data) const {
215   if (src_size / kFloat64Bytes != dst_size / kFloatBytes) {
216     MS_INTERNAL_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]";
217   }
218   size_t elem_num = dst_size / sizeof(float);
219   auto host_tmp = std::vector<float>(elem_num);
220   DoubleToFloat(host_tmp.data(), src, elem_num);
221   SyncMemory(dst, host_tmp.data(), dst_size, ACL_MEMCPY_HOST_TO_DEVICE, tensor_data);
222   return true;
223 }
224 
SyncDeviceToHostAndFloatToFloat64(void * dst,size_t dst_size,const void * src,size_t src_size) const225 bool AscendDeviceAddress::SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *src,
226                                                             size_t src_size) const {
227   if (src_size / kFloatBytes != dst_size / kFloat64Bytes) {
228     MS_INTERNAL_EXCEPTION(ArgumentError) << "src_size[" << src_size << "], dst_size[" << dst_size << "]";
229   }
230   size_t elem_num = src_size / sizeof(float);
231   auto host_tmp = std::vector<float>(elem_num);
232   SyncMemory(host_tmp.data(), src, src_size, ACL_MEMCPY_DEVICE_TO_HOST);
233   FloatToDouble(dst, host_tmp.data(), elem_num);
234   return true;
235 }
236 
SetDevicePtrDeleter()237 void AscendDeviceAddress::SetDevicePtrDeleter() {
238   if (!address_common_) {
239     return;
240   }
241 
242   address_common_->pointer_ref_count_->set_deleter(
243     [communication_ptr = this->communication_ptr_](void *ptr, bool from_mem_pool) {
244       if (ptr == nullptr || !from_mem_pool) {
245         return;
246       }
247 
248       if (communication_ptr != nullptr) {
249         AscendMemoryPool::GetInstance().FreeTensorMem(communication_ptr);
250       } else {
251         AscendMemoryPool::GetInstance().FreeTensorMem(ptr);
252       }
253     });
254 }
255 
BindDevice() const256 void AscendDeviceAddress::BindDevice() const {
257   auto ms_context = MsContext::GetInstance();
258   MS_EXCEPTION_IF_NULL(ms_context);
259   if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
260     return;
261   }
262 
263   // Bind device by device name and device id on the current thread.
264   if (!device_name().empty()) {
265     auto ascend_device_context = GetDeviceContext();
266     MS_EXCEPTION_IF_NULL(ascend_device_context);
267     if (!ascend_device_context->device_res_manager_->BindDeviceToCurrentThread(false)) {
268       MS_LOG(WARNING) << "Bind device to current thread failed.";
269     }
270   } else {
271     MS_LOG(DEBUG) << "Device name is null.";
272   }
273 }
274 
SyncStream() const275 void AscendDeviceAddress::SyncStream() const {
276   MS_LOG(DEBUG) << "SyncStream Start!";
277   auto ms_context = MsContext::GetInstance();
278   MS_EXCEPTION_IF_NULL(ms_context);
279   auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
280   auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
281   MS_EXCEPTION_IF_NULL(runtime_instance);
282   auto ret = runtime_instance->SyncStream();
283   if (!ret) {
284     MS_LOG(EXCEPTION) << "Sync stream error!";
285   }
286   MS_LOG(DEBUG) << "SyncStream Finish!";
287 }
288 
SyncStream(size_t stream_id) const289 bool AscendDeviceAddress::SyncStream(size_t stream_id) const {
290   const auto stream = AscendStreamMng::GetInstance().GetStream(stream_id);
291   MS_EXCEPTION_IF_NULL(stream);
292   BindDevice();
293   if (!AscendStreamMng::GetInstance().SyncStream(stream)) {
294     MS_LOG(ERROR) << "Sync default stream failed.";
295     return false;
296   }
297   return true;
298 }
299 
CopyDeviceToHost(void * dst,const void * src,size_t size,bool async,size_t stream_id) const300 bool AscendDeviceAddress::CopyDeviceToHost(void *dst, const void *src, size_t size, bool async,
301                                            size_t stream_id) const {
302   return CopyBetweenHostDevice(dst, src, size, async, stream_id, false);
303 }
304 
CopyHostToDevice(void * dst,const void * src,size_t size,bool async,size_t stream_id) const305 bool AscendDeviceAddress::CopyHostToDevice(void *dst, const void *src, size_t size, bool async,
306                                            size_t stream_id) const {
307   return CopyBetweenHostDevice(dst, src, size, async, stream_id, true);
308 }
309 
DeviceToFileDirectly(void * ptr,size_t size,const std::string & file_name,size_t stream_id) const310 bool AscendDeviceAddress::DeviceToFileDirectly(void *ptr, size_t size, const std::string &file_name,
311                                                size_t stream_id) const {
312   return CopyBetweenFileDeviceDirectly(ptr, file_name, size, stream_id, false);
313 }
314 
FileToDeviceDirectly(void * ptr,size_t size,const std::string & file_name,size_t stream_id) const315 bool AscendDeviceAddress::FileToDeviceDirectly(void *ptr, size_t size, const std::string &file_name,
316                                                size_t stream_id) const {
317   return CopyBetweenFileDeviceDirectly(ptr, file_name, size, stream_id, true);
318 }
319 
CopyBetweenFileDeviceDirectly(void * ptr,const std::string & file_name,size_t size,size_t stream_id,bool file_to_device) const320 bool AscendDeviceAddress::CopyBetweenFileDeviceDirectly(void *ptr, const std::string &file_name, size_t size,
321                                                         size_t stream_id, bool file_to_device) const {
322 #if defined(RT_MEMORY_P2PDMA)
323   void *dargs = AscendDmaHandle::GetInstance().GetDargs();
324   void *buf = AscendDmaHandle::GetInstance().GetBuf();
325   if (dargs == nullptr || buf == nullptr) {
326     return false;
327   }
328   std::lock_guard<std::mutex> lock(dma_lock);
329   auto open_flag = file_to_device ? (O_RDWR | O_DIRECT) : (O_RDWR | O_CREAT | O_DIRECT);
330   auto nvme_fd = open(file_name.c_str(), open_flag, S_IRUSR | S_IWUSR);
331   if (nvme_fd < 0) {
332     MS_LOG(ERROR) << "Open file failed, file name:" << file_name;
333     return false;
334   }
335   size_t buf_size = AscendDmaHandle::GetInstance().GetSize();
336   size_t count = (size + buf_size - 1) / buf_size;
337   for (size_t i = 0; i < count; i++) {
338     size_t ptr_offset = i * buf_size;
339     size_t cur_size = (i == count - 1) ? (size - ptr_offset) : buf_size;
340     if (file_to_device) {
341       size_t ret_size = read(nvme_fd, buf, cur_size);
342       if (ret_size != cur_size || !SyncStream(stream_id)) {
343         MS_LOG(ERROR) << "Read file failed, file name:" << file_name << ", size:" << size;
344         close(nvme_fd);
345         return false;
346       }
347       DeviceToDevice(static_cast<uint8_t *>(ptr) + ptr_offset, dargs, cur_size, stream_id);
348     } else {
349       DeviceToDevice(dargs, static_cast<uint8_t *>(ptr) + ptr_offset, cur_size, stream_id);
350       size_t ret_size = write(nvme_fd, buf, cur_size);
351       if (ret_size != cur_size || !SyncStream(stream_id)) {
352         MS_LOG(ERROR) << "Write file failed, file name:" << file_name << ", size:" << size;
353         close(nvme_fd);
354         return false;
355       }
356     }
357   }
358   close(nvme_fd);
359   return true;
360 #else
361   return false;
362 #endif
363 }
364 
DeviceToDevice(void * dst,void * src,size_t size,size_t stream_id) const365 void AscendDeviceAddress::DeviceToDevice(void *dst, void *src, size_t size, size_t stream_id) const {
366   MS_EXCEPTION_IF_NULL(dst);
367   MS_EXCEPTION_IF_NULL(src);
368   const auto stream = AscendStreamMng::GetInstance().GetStream(stream_id);
369   MS_EXCEPTION_IF_NULL(stream);
370   BindDevice();
371   auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, dst, size, src, size, ACL_MEMCPY_DEVICE_TO_DEVICE, stream);
372   if (ret != ACL_ERROR_NONE) {
373     MS_LOG(EXCEPTION) << "Call aclrtMemcpyAsync device to device failed, the error num[" << ret << "].";
374   }
375   if (!AscendStreamMng::GetInstance().SyncStream(stream_id)) {
376     MS_LOG(EXCEPTION) << "Sync default failed.";
377   }
378 }
379 
SyncDeviceToHost(size_t size,void * const host_ptr) const380 bool AscendDeviceAddress::SyncDeviceToHost(size_t size, void *const host_ptr) const {
381   MS_EXCEPTION_IF_NULL(host_ptr);
382   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
383   BindDevice();
384   SyncStream();
385   if (!MoveToDevice(false)) {
386     MS_LOG(WARNING) << "Move data to device failed, check previous log for details.";
387   }
388   CopyDeviceToHost(host_ptr, size);
389   return true;
390 }
391 
SyncHostToDevice(size_t size,const void * host_ptr) const392 bool AscendDeviceAddress::SyncHostToDevice(size_t size, const void *host_ptr) const {
393   MS_EXCEPTION_IF_NULL(host_ptr);
394   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
395   BindDevice();
396   if (!MoveToDevice(false)) {
397     MS_LOG(WARNING) << "Move data to device failed, check previous log for details.";
398   }
399   CopyHostToDevice(host_ptr, size, nullptr);
400   return true;
401 }
402 
SyncDeviceToHost(const ShapeVector & shape,size_t size,mindspore::TypeId type,void * host_ptr) const403 bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size, mindspore::TypeId type,
404                                            void *host_ptr) const {
405   MS_LOG(DEBUG) << "SyncDeviceToHost, Device(format:" << format() << ", type_id:" << TypeIdLabel(type_id())
406                 << ", size:" << GetSize() << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
407   if (type_id() > kMonadTypeBegin && type_id() < kMonadTypeEnd) {
408     return true;
409   }
410   BindDevice();
411   SyncStream();
412   if (!MoveToDevice(false)) {
413     MS_LOG(WARNING) << "Move data to device failed, check previous log for details.";
414   }
415   bool sync_ok = false;
416   ShapeVector host_shape = shape;
417   if (host_shape.empty()) {
418     (void)host_shape.emplace_back(1);
419   }
420   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
421   if (basic_format.find(format()) != basic_format.end()) {
422     if (type_id() == type) {
423       CopyDeviceToHost(host_ptr, size);
424       sync_ok = true;
425     } else if (type_id() == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
426       if (mem_offloaded()) {
427         FloatToDouble(host_ptr, loadable_mem_->offload_ptr_, GetSize() / sizeof(float));
428         sync_ok = true;
429       } else {
430         sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, GetDevicePtr(), GetSize());
431       }
432     } else {
433       auto shape_size = abstract::ShapeSize(host_shape);
434       auto host = std::vector<uint8_t>(GetSize());
435       CopyDeviceToHost(host.data(), GetSize());
436       const trans::TypeIdArgs type_args{host.data(), shape_size, type_id(), type, GetSize()};
437       sync_ok = trans::TransDataType(type_args, host_ptr);
438       if (!sync_ok) {
439         MS_LOG(ERROR) << "Trans data type failed.";
440         return false;
441       }
442     }
443   } else {
444     if (IsOpNeedTransFormat(format())) {
445       sync_ok = SyncDeviceToHostAndConvertFormat(shape, size, type, host_ptr);
446     } else {
447       MS_LOG(INFO) << "Can not find format transfer function for :" << format();
448     }
449   }
450   if (!sync_ok) {
451     MS_LOG(ERROR) << "Unsupported to trans, dev_format:" << format() << ", dev_type:" << TypeIdLabel(type_id())
452                   << ", host_type:" << TypeIdLabel(type);
453     return false;
454   }
455   return sync_ok;
456 }
457 
GetDeviceShape(ShapeVector * host_shape) const458 ShapeVector AscendDeviceAddress::GetDeviceShape(ShapeVector *host_shape) const {
459   MS_EXCEPTION_IF_NULL(host_shape);
460   ShapeVector device_shape;
461   auto node_index = GetNodeIndex();
462   if (format() == kOpFormat_FRAC_NZ || format() == kOpFormat_NCDHW) {
463     device_shape = trans::TransShapeToDevice(*host_shape, format(), node_index.first, node_index.second, type_id());
464   } else {
465     if (!DeviceAddress::host_shape().empty()) {
466       host_shape->clear();
467       *host_shape = DeviceAddress::host_shape();
468     }
469     *host_shape = trans::PaddingShape(*host_shape, format());
470     device_shape = trans::TransShapeToDevice(*host_shape, format(), node_index.first, node_index.second, type_id());
471   }
472   return device_shape;
473 }
474 
CreateLaunchTransData(const ShapeVector & host_shape,const std::string & ori_format,const std::string & dst_format) const475 std::shared_ptr<LaunchTransData> AscendDeviceAddress::CreateLaunchTransData(const ShapeVector &host_shape,
476                                                                             const std::string &ori_format,
477                                                                             const std::string &dst_format) const {
478   int64_t groups = 1;
479   if (format() == kOpFormat_FRAC_Z) {
480     groups = GetGroupsWithCache();
481   }
482   auto launch_trans_data = std::make_shared<LaunchTransData>(this->stream_id(), type_id(), GetSize(), ori_format,
483                                                              dst_format, host_shape, groups);
484   MS_EXCEPTION_IF_NULL(launch_trans_data);
485   return launch_trans_data;
486 }
487 
SyncDeviceToHostAndConvertFormatBasedOnTransData(const ShapeVector & host_shape,size_t size,mindspore::TypeId type,void * host_ptr) const488 bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormatBasedOnTransData(const ShapeVector &host_shape, size_t size,
489                                                                            mindspore::TypeId type,
490                                                                            void *host_ptr) const {
491   bool sync_ok = true;
492   const std::string dst_format = kOpFormat_NCHW;
493   if (launch_transdata_ == nullptr) {
494     launch_transdata_ = CreateLaunchTransData(host_shape, format(), dst_format);
495     MS_EXCEPTION_IF_NULL(launch_transdata_);
496   }
497   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
498   // launch transdata
499   GilReleaseWithCheck release_gil;
500   launch_transdata_->SetInputAddr(GetMutablePtr());
501   {
502     std::lock_guard<std::recursive_mutex> lock_launch(transdata_mutx);
503     launch_transdata_->LaunchOpKernel();
504   }
505 
506   SyncStream();
507   auto output_addr_vec = launch_transdata_->GetKernelOutputAddr();
508   if (output_addr_vec.size() != 1) {
509     launch_transdata_->FreeDeviceMem();
510     MS_LOG(EXCEPTION) << "Launch transdata outputs should have only one output, actual output size: "
511                       << output_addr_vec.size();
512   }
513   if (type_id() == type) {
514     SyncMemory(host_ptr, output_addr_vec[0], size, ACL_MEMCPY_DEVICE_TO_HOST);
515   } else {
516     auto host = std::vector<uint8_t>(size);
517     SyncMemory(host.data(), output_addr_vec[0], size, ACL_MEMCPY_DEVICE_TO_HOST);
518     auto shape_size = abstract::ShapeSize(host_shape);
519     const trans::TypeIdArgs type_args{host.data(), shape_size, type_id(), type, size};
520     sync_ok = trans::TransDataType(type_args, host_ptr);
521     if (!sync_ok) {
522       MS_LOG(ERROR) << "Trans data type failed.";
523       launch_transdata_->FreeDeviceMem();
524       return false;
525     }
526   }
527   launch_transdata_->FreeDeviceMem();
528   return sync_ok;
529 }
530 
SyncDeviceToHostAndConvertFormat(const ShapeVector & shape,size_t size,mindspore::TypeId type,void * host_ptr) const531 bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &shape, size_t size,
532                                                            mindspore::TypeId type, void *host_ptr) const {
533   MS_LOG(DEBUG) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format()
534                 << ", type_id:" << TypeIdLabel(type_id()) << ", size:" << GetSize()
535                 << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
536   static const std::unordered_map<mindspore::TypeId, std::string> type_id_name_map = {
537     {mindspore::kNumberTypeBool, "bool"},       {mindspore::kNumberTypeInt8, "int8"},
538     {mindspore::kNumberTypeInt16, "int16"},     {mindspore::kNumberTypeInt32, "int32"},
539     {mindspore::kNumberTypeInt64, "int64"},     {mindspore::kNumberTypeFloat16, "float16"},
540     {mindspore::kNumberTypeFloat32, "float32"}, {mindspore::kNumberTypeUInt8, "uint8"},
541     {mindspore::kNumberTypeUInt16, "uint16"},   {mindspore::kNumberTypeUInt32, "uint32"},
542     {mindspore::kNumberTypeUInt64, "uint64"}};
543   bool sync_ok = false;
544   ShapeVector host_shape = shape;
545   if (host_shape.empty()) {
546     (void)host_shape.emplace_back(1);
547   }
548   auto device_shape = GetDeviceShape(&host_shape);
549   auto ms_context = MsContext::GetInstance();
550   MS_EXCEPTION_IF_NULL(ms_context);
551   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
552       type_id_name_map.find(type_id()) != type_id_name_map.end() && !mem_offloaded()) {
553     std::pair<std::string, std::string> type_format = std::make_pair(type_id_name_map.at(type_id()), format());
554     if (IsUseTransDataTypeFormat(type_format)) {
555       sync_ok = SyncDeviceToHostAndConvertFormatBasedOnTransData(host_shape, size, type, host_ptr);
556       return sync_ok;
557     }
558   }
559   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
560   auto host_tmp = std::vector<uint8_t>(GetSize());
561   CopyDeviceToHost(host_tmp.data(), GetSize());
562   auto node_index = GetNodeIndex();
563   if (type_id() != type) {
564     const trans::FormatArgs format_args{host_tmp.data(), GetSize(),    kOpFormat_NCHW, format(),
565                                         host_shape,      device_shape, type_id()};
566     auto host = std::vector<uint8_t>(GetSize());
567     sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data(), node_index.first, node_index.second);
568     if (!sync_ok) {
569       MS_LOG(ERROR) << "Trans format failed.";
570       return false;
571     }
572     auto shape_size = abstract::ShapeSize(host_shape);
573     const trans::TypeIdArgs type_args{host.data(), shape_size, type_id(), type, size};
574     sync_ok = trans::TransDataType(type_args, host_ptr);
575     if (!sync_ok) {
576       MS_LOG(ERROR) << "Trans data type failed.";
577       return false;
578     }
579   } else {
580     const trans::FormatArgs format_args{host_tmp.data(), GetSize(),    kOpFormat_NCHW, format(),
581                                         host_shape,      device_shape, type_id()};
582     sync_ok = trans::TransFormatFromDeviceToHost(format_args, host_ptr, node_index.first, node_index.second);
583     if (!sync_ok) {
584       MS_LOG(ERROR) << "Trans format failed.";
585       return false;
586     }
587   }
588   return sync_ok;
589 }
590 
SyncHostToDeviceImpl(const ShapeVector & shape,size_t size,mindspore::TypeId type,const void * host_ptr,const std::string & format,const tensor::TensorDataPtr & tensor_data) const591 bool AscendDeviceAddress::SyncHostToDeviceImpl(const ShapeVector &shape, size_t size, mindspore::TypeId type,
592                                                const void *host_ptr, const std::string &format,
593                                                const tensor::TensorDataPtr &tensor_data) const {
594   MS_LOG(DEBUG) << "SyncHostToDevice, Device(format:" << DeviceAddress::format()
595                 << ", type_id:" << TypeIdLabel(type_id()) << ", size:" << GetSize() << "), Host(format:" << format
596                 << ", type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
597   if (type_id() > kMonadTypeBegin && type_id() < kMonadTypeEnd) {
598     return true;
599   }
600   BindDevice();
601   if (!MoveToDevice(false)) {
602     MS_LOG(WARNING) << "Move data to device failed, check previous log for details.";
603   }
604   bool sync_ok = false;
605   ShapeVector host_shape = shape;
606   if (host_shape.empty()) {
607     (void)host_shape.emplace_back(1);
608   }
609   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
610   if (DeviceAddress::format() == format || basic_format.find(DeviceAddress::format()) != basic_format.end()) {
611     if (type_id() == type) {
612       CopyHostToDevice(host_ptr, size, tensor_data);
613       sync_ok = true;
614     } else if (type_id() == kNumberTypeFloat32 && type == kNumberTypeFloat64) {
615       sync_ok = Float64ToFloatAndSyncHostToDevice(GetDevicePtr(), GetSize(), host_ptr, size, tensor_data);
616     } else {
617       auto shape_size = abstract::ShapeSize(host_shape);
618       const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id(), size};
619       auto host_tmp = std::vector<uint8_t>(GetSize());
620       sync_ok = trans::TransDataType(type_args, host_tmp.data());
621       if (!sync_ok) {
622         MS_LOG(ERROR) << "Trans data type failed.";
623         return false;
624       }
625       CopyHostToDevice(host_tmp.data(), GetSize(), tensor_data);
626     }
627   } else {
628     if (IsOpNeedTransFormat(DeviceAddress::format())) {
629       sync_ok = ConvertFormatAndSyncHostToDevice(shape, size, type, host_ptr, tensor_data);
630     } else {
631       MS_LOG(INFO) << "Can not find format transfer function for :" << DeviceAddress::format();
632     }
633   }
634   if (!sync_ok) {
635     MS_LOG(ERROR) << "Unsupported trans, dev_format:" << DeviceAddress::format()
636                   << ", dev_type:" << TypeIdLabel(type_id()) << ", host_type:" << TypeIdLabel(type);
637     return false;
638   }
639   return sync_ok;
640 }
641 
SyncHostToDevice(const ShapeVector & shape,size_t size,mindspore::TypeId type,const void * host_ptr,const std::string & format) const642 bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size, mindspore::TypeId type,
643                                            const void *host_ptr, const std::string &format) const {
644   return SyncHostToDeviceImpl(shape, size, type, host_ptr, format);
645 }
646 
SyncHostToDevice(const ShapeVector & shape,size_t size,TypeId type,const std::string & format,const tensor::TensorDataPtr & tensor_data) const647 bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type,
648                                            const std::string &format, const tensor::TensorDataPtr &tensor_data) const {
649   MS_EXCEPTION_IF_NULL(tensor_data);
650   return SyncHostToDeviceImpl(shape, size, type, tensor_data->data(), format, tensor_data);
651 }
652 
SyncDeviceToDeviceWithDiffFormatType(const DeviceSync * src_device_addr) const653 bool AscendDeviceAddress::SyncDeviceToDeviceWithDiffFormatType(const DeviceSync *src_device_addr) const {
654   MS_EXCEPTION_IF_NULL(src_device_addr);
655   if (type_id() > kMonadTypeBegin && type_id() < kMonadTypeEnd) {
656     return true;
657   }
658 
659   auto src_device_address = dynamic_cast<const AscendDeviceAddress *>(src_device_addr);
660   MS_EXCEPTION_IF_NULL(src_device_address);
661   BindDevice();
662   auto host_shape = src_device_address->host_shape();
663   if (host_shape.empty()) {
664     MS_LOG(WARNING) << "Host shape of source device address is empty, emplace back shape [1],  device address size: "
665                     << src_device_address->GetSize()
666                     << ", device address type: " << TypeIdLabel(src_device_address->type_id());
667     (void)host_shape.emplace_back(1);
668   }
669   auto host_tensor = std::make_shared<tensor::Tensor>(src_device_address->type_id(), host_shape);
670   MS_EXCEPTION_IF_NULL(host_tensor);
671   auto host_tensor_size = LongToSize(host_tensor->data().nbytes());
672   auto host_tensor_type = host_tensor->data_type();
673   if (!src_device_address->SyncDeviceToHost(host_shape, host_tensor_size, host_tensor_type, host_tensor->data_c())) {
674     MS_LOG(ERROR) << "Sync device to device failed at the stage of sync device to intermediate Tensor.";
675     return false;
676   }
677   if (!SyncHostToDevice(host_shape, host_tensor_size, host_tensor_type, host_tensor->data_c(),
678                         host_tensor->device_info().host_format_)) {
679     MS_LOG(ERROR) << "Sync device to device failed at the stage of sync intermediate tensor to device.";
680     return false;
681   }
682   return true;
683 }
684 
SyncDeviceToDevice(const DeviceSync * src_device_addr) const685 bool AscendDeviceAddress::SyncDeviceToDevice(const DeviceSync *src_device_addr) const {
686   MS_EXCEPTION_IF_NULL(src_device_addr);
687   auto src_device_address = dynamic_cast<const AscendDeviceAddress *>(src_device_addr);
688   MS_EXCEPTION_IF_NULL(src_device_address);
689   if (!src_device_address->MoveToDevice(false)) {
690     MS_LOG(WARNING) << "Move data to device failed, check previous log for details.";
691   }
692   if (format() == src_device_address->format() && type_id() == src_device_address->type_id()) {
693     if (src_device_address->mem_offloaded()) {
694       auto device_context = GetDeviceContext();
695       MS_EXCEPTION_IF_NULL(device_context);
696       void *temp_device_ptr = device_context->device_res_manager_->AllocateMemory(src_device_address->GetSize());
697       MS_EXCEPTION_IF_NULL(temp_device_ptr);
698       SyncMemory(temp_device_ptr, src_device_address->GetOffloadPtr(), src_device_address->GetSize(),
699                  ACL_MEMCPY_HOST_TO_DEVICE);
700       const auto ret = SyncDeviceToDevice(ShapeVector(), src_device_address->GetSize(), src_device_address->type_id(),
701                                           temp_device_ptr, src_device_address->format());
702       device_context->device_res_manager_->FreeMemory(temp_device_ptr);
703       return ret;
704     }
705     return SyncDeviceToDevice(ShapeVector(), src_device_address->GetSize(), src_device_address->type_id(),
706                               src_device_address->GetPtr(), src_device_address->format());
707   } else {
708     MS_LOG(INFO) << "Can not copy from device to device directly, format or type is different, src(format:"
709                  << src_device_address->format() << ", type_id:" << TypeIdLabel(src_device_address->type_id())
710                  << "), dst(format:" << format() << ", type_id:" << TypeIdLabel(type_id())
711                  << ", use the intermediate Tensor copy instead.";
712     return SyncDeviceToDeviceWithDiffFormatType(src_device_addr);
713   }
714 }
715 
SyncDeviceToDevice(const ShapeVector & shape,size_t size,TypeId type,const void * src_ptr,const std::string & format) const716 bool AscendDeviceAddress::SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
717                                              const std::string &format) const {
718   bool ret = AsyncDeviceToDevice(shape, size, type, src_ptr, format);
719   if (!ret) {
720     return ret;
721   }
722   SyncStream();
723   return true;
724 }
725 
AsyncDeviceToDevice(const ShapeVector &,size_t size,TypeId type,const void * src_ptr,const std::string & format) const726 bool AscendDeviceAddress::AsyncDeviceToDevice(const ShapeVector & /* shape */, size_t size, TypeId type,
727                                               const void *src_ptr, const std::string &format) const {
728   MS_LOG(DEBUG) << "AsyncDeviceToDevice, dst(format:" << DeviceAddress::format()
729                 << ", type_id:" << TypeIdLabel(type_id()) << ", size:" << GetSize() << "), src(format:" << format
730                 << ", type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
731   if (GetDevicePtr() == src_ptr) {
732     MS_LOG(INFO) << "Dst addr is same with src addr, no need memcpy data.";
733     return true;
734   }
735   if (type_id() > kMonadTypeBegin && type_id() < kMonadTypeEnd) {
736     return true;
737   }
738   if (GetSize() < size) {
739     MS_LOG(ERROR) << "Src size is greater than det size, src size is: " << size << ", dst size is: " << GetSize();
740     return false;
741   }
742   if (DeviceAddress::format() != format || type_id() != type) {
743     MS_LOG(ERROR) << "Format or type is different, src(format:" << format << ", type_id:" << TypeIdLabel(type)
744                   << "), dst(format:" << DeviceAddress::format() << "), type_id:" << TypeIdLabel(type_id());
745     return false;
746   }
747 
748   BindDevice();
749   if (!MoveToDevice(false)) {
750     MS_LOG(WARNING) << "Move data to device failed, check previous log for details.";
751   }
752   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
753   auto ms_context = MsContext::GetInstance();
754   MS_EXCEPTION_IF_NULL(ms_context);
755   auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
756   auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
757   MS_EXCEPTION_IF_NULL(runtime_instance);
758   bool ret;
759   if (mem_offloaded()) {
760     ret = runtime_instance->MemcpyAsync(loadable_mem_->offload_ptr_, src_ptr, size,
761                                         static_cast<int32_t>(ACL_MEMCPY_DEVICE_TO_HOST),
762                                         runtime_instance->compute_stream());
763   } else {
764     ret =
765       runtime_instance->MemcpyAsync(GetDevicePtr(), src_ptr, size, static_cast<int32_t>(ACL_MEMCPY_DEVICE_TO_DEVICE),
766                                     runtime_instance->compute_stream());
767   }
768   if (!ret) {
769     MS_LOG(ERROR) << "MemcpyAsync failed!";
770   }
771   return ret;
772 }
773 
AsyncHostToDevice(size_t size,TypeId,const void * host_ptr) const774 bool AscendDeviceAddress::AsyncHostToDevice(size_t size, TypeId /* type */, const void *host_ptr) const {
775   MS_ERROR_IF_NULL(host_ptr);
776   BindDevice();
777   if (!MoveToDevice(false)) {
778     MS_LOG(WARNING) << "Move data to device failed, check previous log for details.";
779   }
780   MS_ERROR_IF_NULL(GetDevicePtr());
781 
782   auto ms_context = MsContext::GetInstance();
783   MS_EXCEPTION_IF_NULL(ms_context);
784   auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
785   auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
786   MS_EXCEPTION_IF_NULL(runtime_instance);
787 
788   auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, GetDevicePtr(), size, host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE,
789                              runtime_instance->compute_stream());
790   if (ret != ACL_ERROR_NONE) {
791     MS_LOG(ERROR) << "Call aclrtMemcpyAsync host to device failed, the error num[" << ret << "]";
792     return false;
793   }
794   return true;
795 }
796 
AsyncHostToDevice(const ShapeVector &,size_t size,TypeId,const void * host_ptr,size_t stream_id) const797 bool AscendDeviceAddress::AsyncHostToDevice(const ShapeVector & /* shape */, size_t size, TypeId /* type */,
798                                             const void *host_ptr, size_t stream_id) const {
799   MS_ERROR_IF_NULL(host_ptr);
800   BindDevice();
801   if (!MoveToDevice(false)) {
802     MS_LOG(WARNING) << "Move data to device failed, check previous log for details.";
803   }
804   MS_ERROR_IF_NULL(GetDevicePtr());
805   auto stream = AscendStreamMng::GetInstance().GetStream(stream_id);
806   if (stream == nullptr) {
807     stream = AscendStreamMng::GetInstance().GetStream(kDefaultStreamIndex);
808   }
809   MS_ERROR_IF_NULL(stream);
810 
811   auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, GetDevicePtr(), size, host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE, stream);
812   if (ret != ACL_ERROR_NONE) {
813     MS_LOG(ERROR) << "Call aclrtMemcpyAsync host to device failed, the error num[" << ret << "]";
814     return false;
815   }
816   return true;
817 }
818 
AsyncDeviceToHost(const ShapeVector &,size_t size,TypeId,void * host_ptr,size_t stream_id) const819 bool AscendDeviceAddress::AsyncDeviceToHost(const ShapeVector & /* shape */, size_t size, TypeId /* type */,
820                                             void *host_ptr, size_t stream_id) const {
821   MS_ERROR_IF_NULL(host_ptr);
822   BindDevice();
823   if (!MoveToDevice(false)) {
824     MS_LOG(ERROR) << "Move data to device failed, check previous log for details.";
825     return false;
826   }
827   MS_ERROR_IF_NULL(GetDevicePtr());
828   const auto stream = AscendStreamMng::GetInstance().GetStream(stream_id);
829   MS_ERROR_IF_NULL(stream);
830   auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, host_ptr, size, GetDevicePtr(), size, ACL_MEMCPY_DEVICE_TO_HOST, stream);
831   if (ret != ACL_ERROR_NONE) {
832     MS_LOG(ERROR) << "Call aclrtMemcpyAsync device to host failed, the error num[" << ret << "]";
833     return false;
834   }
835   return true;
836 }
837 
ConvertFormatAndSyncHostToDevice(const ShapeVector & shape,size_t size,mindspore::TypeId type,const void * host_ptr,const tensor::TensorDataPtr & tensor_data) const838 bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &shape, size_t size,
839                                                            mindspore::TypeId type, const void *host_ptr,
840                                                            const tensor::TensorDataPtr &tensor_data) const {
841   bool sync_ok = false;
842   MS_LOG(DEBUG) << "ConvertFormatAndSyncHostToDevice, Device(format:" << format()
843                 << ", type_id:" << TypeIdLabel(type_id()) << ", size:" << GetSize()
844                 << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")";
845   ShapeVector host_shape = shape;
846   if (host_shape.empty()) {
847     (void)host_shape.emplace_back(1);
848   }
849   auto node_index = GetNodeIndex();
850   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
851   (void)GetGroupsWithCache();
852   std::vector<int64_t> device_shape;
853   if (format() == kOpFormat_FRAC_NZ) {
854     device_shape = trans::TransShapeToDevice(host_shape, format(), node_index.first, node_index.second, type_id());
855   } else {
856     host_shape = trans::PaddingShape(host_shape, format());
857     device_shape = trans::TransShapeToDevice(host_shape, format(), node_index.first, node_index.second, type_id());
858   }
859   if (type_id() != type) {
860     auto shape_size = abstract::ShapeSize(host_shape);
861     const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id(), size};
862     auto host_tmp = std::vector<uint8_t>(GetSize());
863     sync_ok = trans::TransDataType(type_args, host_tmp.data());
864     if (!sync_ok) {
865       MS_LOG(ERROR) << "Trans data type failed.";
866       return false;
867     }
868     const trans::FormatArgs format_args{host_tmp.data(), GetSize(),    kOpFormat_NCHW, format(),
869                                         host_shape,      device_shape, type_id()};
870     auto dst_tmp = std::vector<uint8_t>(GetSize());
871     sync_ok = trans::TransFormat(format_args, dst_tmp.data(), node_index.first, node_index.second);
872     if (!sync_ok) {
873       MS_LOG(ERROR) << "Trans format failed.";
874       return false;
875     }
876     CopyHostToDevice(dst_tmp.data(), GetSize(), tensor_data);
877   } else {
878     const trans::FormatArgs format_args{host_ptr,   GetSize(),    kOpFormat_NCHW, format(),
879                                         host_shape, device_shape, type_id()};
880     auto host_tmp = std::vector<uint8_t>(GetSize());
881     sync_ok = trans::TransFormat(format_args, host_tmp.data(), node_index.first, node_index.second);
882     if (!sync_ok) {
883       MS_LOG(ERROR) << "Trans format failed.";
884       return false;
885     }
886     CopyHostToDevice(host_tmp.data(), GetSize(), tensor_data);
887   }
888   return sync_ok;
889 }
890 
ClearDeviceMemory()891 void AscendDeviceAddress::ClearDeviceMemory() {
892   std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
893   (void)Wait();
894   if (loadable_mem_ != nullptr && loadable_mem_->offload_ptr_ != nullptr) {
895     auto device_context = GetDeviceContext();
896     MS_EXCEPTION_IF_NULL(device_context);
897     device_context->device_res_manager_->FreeOffloadMemory(loadable_mem_->offload_ptr_);
898     loadable_mem_->offload_ptr_ = nullptr;
899   }
900   if (GetDevicePtr() != nullptr && from_mem_pool()) {
901     if (communication_ptr_ != nullptr) {
902       AscendMemoryPool::GetInstance().FreeTensorMem(communication_ptr_);
903       communication_ptr_ = nullptr;
904     } else {
905       AscendMemoryPool::GetInstance().FreeTensorMem(GetDevicePtr());
906     }
907     SetDevicePtr(nullptr);
908   }
909 }
910 
CopyDeviceToHost(void * dst,uint64_t size) const911 void AscendDeviceAddress::CopyDeviceToHost(void *dst, uint64_t size) const {
912   MS_EXCEPTION_IF_NULL(dst);
913   if (mem_offloaded()) {
914     MS_EXCEPTION_IF_NULL(loadable_mem_->offload_ptr_);
915     SyncMemory(dst, loadable_mem_->offload_ptr_, size, ACL_MEMCPY_HOST_TO_HOST);
916   } else {
917     MS_EXCEPTION_IF_NULL(GetDevicePtr());
918     SyncMemory(dst, GetDevicePtr(), size, ACL_MEMCPY_DEVICE_TO_HOST);
919   }
920 }
921 
CopyHostToDevice(const void * src,uint64_t size,const tensor::TensorDataPtr & tensor_data) const922 void AscendDeviceAddress::CopyHostToDevice(const void *src, uint64_t size,
923                                            const tensor::TensorDataPtr &tensor_data) const {
924   MS_EXCEPTION_IF_NULL(src);
925 
926   if (mem_offloaded()) {
927     MS_EXCEPTION_IF_NULL(loadable_mem_->offload_ptr_);
928     SyncMemory(loadable_mem_->offload_ptr_, src, size, ACL_MEMCPY_HOST_TO_HOST, tensor_data);
929   } else {
930     MS_EXCEPTION_IF_NULL(GetDevicePtr());
931     if (type_id() == kObjectTypeString) {
932       // NOTE: For string type, ge::StringHead.len does not include '\0', since kernel_tensor allocated size including
933       // '\0', see method `CreateDeviceAddressForScalarAndString` defined in `device_address_utils.cc`, and method
934       // `PrepareDataForStringValue` defined in `data_prepare_actor.cc`, so here pass `size - 1` to `head.len`.
935       // NOTE: method `CopyHostToDevice` can be triggered from the two scenarios as below:
936       // 1. method `CopyNoneTensorDataToDevice` in `device_address_utils.cc` passes a kernel tensor, the parameter
937       // `size` include `ge::StringHead`
938       // 2. method `PrepareDataForStringValue` in `data_prepare_actor.cc` passes a raw string, the parameter `size` does
939       // not include `ge::StringHead`
940       if (size == GetSize() && size >= sizeof(ge::StringHead)) {
941         size -= sizeof(ge::StringHead);
942       }
943       ge::StringHead head{.addr = sizeof(ge::StringHead), .len = static_cast<int64_t>(size) - 1};
944       // sync string head info from device to host
945       SyncMemory(GetDevicePtr(), &head, sizeof(ge::StringHead), ACL_MEMCPY_HOST_TO_DEVICE, nullptr);
946       // sync string body (real contents) from device to host
947       SyncMemory(static_cast<void *>(static_cast<char *>(GetDevicePtr()) + sizeof(ge::StringHead)), src, size,
948                  ACL_MEMCPY_HOST_TO_DEVICE, tensor_data);
949       MS_LOG(DEBUG) << "Copy string info to device, ge::StringHead.len=" << head.len
950                     << ", text=" << std::string(static_cast<const char *>(src), head.len)
951                     << ", device_addr=" << GetDevicePtr();
952     } else {
953       SyncMemory(GetDevicePtr(), src, size, ACL_MEMCPY_HOST_TO_DEVICE, tensor_data);
954     }
955   }
956 }
957 
CopyBetweenHostDevice(void * dst,const void * src,size_t size,bool async,size_t stream_id,bool host_to_device) const958 bool AscendDeviceAddress::CopyBetweenHostDevice(void *dst, const void *src, size_t size, bool async, size_t stream_id,
959                                                 bool host_to_device) const {
960   MS_EXCEPTION_IF_NULL(dst);
961   MS_EXCEPTION_IF_NULL(src);
962   auto copy_kind = host_to_device ? ACL_MEMCPY_HOST_TO_DEVICE : ACL_MEMCPY_DEVICE_TO_HOST;
963   const auto stream = AscendStreamMng::GetInstance().GetStream(stream_id);
964   MS_EXCEPTION_IF_NULL(stream);
965   BindDevice();
966   auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, dst, size, src, size, copy_kind, stream);
967   if (ret != ACL_ERROR_NONE) {
968     MS_LOG(ERROR) << "Call aclrtMemcpyAsync device to host failed, the error num[" << ret << "]";
969     return false;
970   }
971   if (async) {
972     auto record_event = std::make_shared<AscendEvent>();
973     record_event->set_record_stream(stream);
974     record_event->RecordEvent();
975     if (loadable_mem_ == nullptr) {
976       loadable_mem_ = std::make_unique<LoadableMember>();
977     }
978     loadable_mem_->swap_event_.device_event_ = record_event;
979   } else {
980     if (!AscendStreamMng::GetInstance().SyncStream(stream)) {
981       MS_LOG(ERROR) << "Sync default stream failed.";
982       return false;
983     }
984   }
985   return true;
986 }
987 
CopyDeviceToHost(void * dst,const void * src,const size_t & size) const988 bool AscendDeviceAddress::CopyDeviceToHost(void *dst, const void *src, const size_t &size) const {
989   SyncMemory(dst, src, size, ACL_MEMCPY_DEVICE_TO_HOST);
990   return true;
991 }
992 
CopyHostToDevice(void * dst,const void * src,const size_t & size) const993 bool AscendDeviceAddress::CopyHostToDevice(void *dst, const void *src, const size_t &size) const {
994   SyncMemory(dst, src, size, ACL_MEMCPY_HOST_TO_DEVICE);
995   return true;
996 }
997 
AsyncDeviceToHost(size_t size,void * host_ptr) const998 bool AscendDeviceAddress::AsyncDeviceToHost(size_t size, void *host_ptr) const {
999   MS_EXCEPTION_IF_NULL(host_ptr);
1000   if (GetDevicePtr() == host_ptr) {
1001     MS_LOG(INFO) << "Dst addr is same with src addr, no need copy data.";
1002     return true;
1003   }
1004   BindDevice();
1005   MS_EXCEPTION_IF_NULL(GetDevicePtr());
1006   auto device_context = GetDeviceContext();
1007   MS_EXCEPTION_IF_NULL(device_context);
1008   auto stream_id = device_context->device_res_manager_->GetCurrentStreamId();
1009   auto stream = device_context->device_res_manager_->GetStream(stream_id);
1010   if (stream == nullptr) {
1011     stream = device_context->device_res_manager_->GetStream(kDefaultStreamIndex);
1012   }
1013   MS_ERROR_IF_NULL(stream);
1014   auto ret = CALL_ASCEND_API(aclrtMemcpyAsync, host_ptr, size, GetDevicePtr(), size, ACL_MEMCPY_DEVICE_TO_HOST, stream);
1015   if (ret != ACL_ERROR_NONE) {
1016     MS_LOG(ERROR) << "Call aclrtMemcpyAsync host to device failed, the error num[" << ret << "]";
1017     return false;
1018   }
1019   return true;
1020 }
1021 
AsyncHostToDevice(size_t size,const void * host_ptr) const1022 bool AscendDeviceAddress::AsyncHostToDevice(size_t size, const void *host_ptr) const {
1023   MS_EXCEPTION_IF_NULL(host_ptr);
1024   if (GetDevicePtr() == host_ptr) {
1025     MS_LOG(INFO) << "Dst addr is same with src addr, no need copy data.";
1026     return true;
1027   }
1028   BindDevice();
1029   auto device_context = GetDeviceContext();
1030   MS_EXCEPTION_IF_NULL(device_context);
1031   auto stream_id = device_context->device_res_manager_->GetCurrentStreamId();
1032   auto stream = device_context->device_res_manager_->GetStream(stream_id);
1033   if (stream == nullptr) {
1034     stream = device_context->device_res_manager_->GetStream(kDefaultStreamIndex);
1035     stream_id = kDefaultStreamIndex;
1036   }
1037   MS_ERROR_IF_NULL(stream);
1038   if (GetDevicePtr() == nullptr) {
1039     auto ptr = device_context->device_res_manager_->AllocateMemory(size, stream_id);
1040     MS_EXCEPTION_IF_NULL(ptr);
1041     SetDevicePtr(ptr);
1042   }
1043   auto device_id = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1044   auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id);
1045   MS_EXCEPTION_IF_NULL(runtime_instance);
1046   runtime_instance->SetContext();
1047   SyncHostMemoryToDeviceWithCopySrc(GetDevicePtr(), host_ptr, size, ACL_MEMCPY_HOST_TO_DEVICE, runtime_instance);
1048   return true;
1049 }
1050 
~AscendDeviceAddress()1051 AscendDeviceAddress::~AscendDeviceAddress() {
1052   try {
1053     // Only release offload memory, release device memory when `kernel_tensor_` in base class destroyed, because maybe
1054     // multi GPUDeviceAddress objects use same device pointer in ref case.
1055     std::lock_guard<std::recursive_mutex> lock(ptr_mutex_);
1056     (void)Wait();
1057     if (loadable_mem_ != nullptr && loadable_mem_->offload_ptr_ != nullptr) {
1058       auto device_context = GetDeviceContext();
1059       MS_EXCEPTION_IF_NULL(device_context);
1060       device_context->device_res_manager_->FreeOffloadMemory(loadable_mem_->offload_ptr_);
1061       loadable_mem_->offload_ptr_ = nullptr;
1062     }
1063     LoadableDeviceAddress::ReleaseResource();
1064   } catch (const std::exception &e) {
1065     MS_LOG(ERROR) << "AscendDeviceAddress destructor failed: " << e.what();
1066   } catch (...) {
1067     MS_LOG(ERROR) << "AscendDeviceAddress destructor failed.";
1068   }
1069 }
1070 
1071 #ifndef ENABLE_SECURITY
1072 /*
1073  * Feature group: Dump.
1074  * Target device group: Ascend.
1075  * Runtime category: Old runtime, MindRT.
1076  * Description: Dump tensor data to file for e2e dump.
1077  */
DumpMemToFile(const std::string & filepath,const std::string & host_fmt,const ShapeVector & host_shape,TypeId host_type,bool trans_flag) const1078 bool AscendDeviceAddress::DumpMemToFile(const std::string &filepath, const std::string &host_fmt,
1079                                         const ShapeVector &host_shape, TypeId host_type, bool trans_flag) const {
1080   if (GetSize() == 0) {
1081     MS_LOG(INFO) << "the operator in filepath: " << filepath << ", size == 0";
1082     return true;
1083   }
1084   bool ret = false;
1085   if (filepath.empty()) {
1086     MS_LOG(ERROR) << "Dump file path is null!";
1087     return ret;
1088   }
1089   if (trans_flag) {
1090     std::string path = filepath + '.' + host_fmt;
1091     MS_LOG(INFO) << "E2E Dump path is " << path;
1092     if (host_type > TypeId::kNumberTypeEnd || host_type < TypeId::kNumberTypeBegin ||
1093         host_type == kNumberTypeComplex64) {
1094       MS_LOG(INFO) << "Cannot create tensor with type: " << TypeIdLabel(host_type);
1095       return false;
1096     }
1097     mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape);
1098     MS_EXCEPTION_IF_NULL(out_tensor);
1099     size_t host_size = LongToSize(out_tensor->data().nbytes());
1100     ret = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c());
1101     if (!ret) {
1102       MS_LOG(ERROR) << "Copy device mem to host failed";
1103       return ret;
1104     }
1105     ret = DumpJsonParser::DumpToFile(path, out_tensor->data_c(), host_size, host_shape, host_type);
1106   } else {
1107     auto host_tmp = std::vector<uint8_t>(GetSize());
1108     BindDevice();
1109     SyncStream();
1110     auto ret_rt_memcpy =
1111       CALL_ASCEND_API(aclrtMemcpy, host_tmp.data(), GetSize(), GetDevicePtr(), GetSize(), ACL_MEMCPY_DEVICE_TO_HOST);
1112     if (ret_rt_memcpy != ACL_ERROR_NONE) {
1113       MS_LOG(ERROR) << "SyncDeviceToHost: aclrtMemcpy mem size[" << GetSize() << "] fail, ret[" << ret_rt_memcpy << "]";
1114       return false;
1115     }
1116     std::string path = filepath + '.' + format();
1117     MS_LOG(INFO) << "E2E Dump path is " << path;
1118     ret = DumpJsonParser::DumpToFile(path, host_tmp.data(), GetSize(), host_shape, type_id());
1119   }
1120 
1121   return ret;
1122 }
1123 #endif
1124 
GetGroupsWithCache() const1125 int64_t AscendDeviceAddress::GetGroupsWithCache() const {
1126   auto node = GetNodeIndex();
1127   if (node.first != nullptr) {
1128     groups_ = common::AnfAlgo::GetAttrGroups(node.first, node.second);
1129   }
1130   return groups_;
1131 }
1132 
1133 #ifdef ENABLE_DEBUGGER
1134 /*
1135  * Feature group: Dump, Online debugger.
1136  * Target device group: Ascend.
1137  * Runtime category: Old runtime, MindRT.
1138  * Description: Load tensor to host and create tensor_data object for the loaded tensor.
1139  */
LoadMemToHost(const std::string & tensor_name,int execution_order,const std::string & host_fmt,const ShapeVector & host_shape,TypeId host_type,size_t slot,bool keep_prev,uint32_t root_graph_id,bool force_update,bool trans_flag,bool async_copy) const1140 bool AscendDeviceAddress::LoadMemToHost(const std::string &tensor_name, int execution_order,
1141                                         const std::string &host_fmt, const ShapeVector &host_shape, TypeId host_type,
1142                                         size_t slot, bool keep_prev, uint32_t root_graph_id, bool force_update,
1143                                         bool trans_flag, bool async_copy) const {
1144   bool ret = false;
1145   auto debugger = Debugger::GetInstance();
1146   MS_EXCEPTION_IF_NULL(debugger);
1147   if (debugger->TensorExistsInCurrent(tensor_name) && !force_update) {
1148     MS_LOG(INFO) << tensor_name << " already loaded for this step so not loading it again.";
1149     return true;
1150   }
1151   // TensorData is freed up in AscendSession class
1152   auto tensor_data = std::make_shared<mindspore::TensorData>();
1153   MS_EXCEPTION_IF_NULL(tensor_data);
1154   tensor_data->SetName(tensor_name);
1155   tensor_data->SetExecutionOrder(execution_order);
1156   tensor_data->SetSlot(slot);
1157 
1158   if (host_type > TypeId::kNumberTypeEnd || host_type < TypeId::kNumberTypeBegin || host_type == kNumberTypeComplex64) {
1159     MS_LOG(INFO) << "Cannot create tensor with type: " << TypeIdLabel(host_type);
1160     return false;
1161   }
1162   mindspore::tensor::TensorPtr out_tensor = std::make_shared<tensor::Tensor>(host_type, host_shape);
1163   MS_EXCEPTION_IF_NULL(out_tensor);
1164   size_t host_size = LongToSize(out_tensor->data().nbytes());
1165   if (host_size == 0) {
1166     MS_LOG(INFO) << "Tensor size is 0 for tensor: " << tensor_name;
1167     return true;
1168   }
1169   bool ret_sync = false;
1170   if (async_copy) {
1171     if (trans_flag) {
1172       ret_sync = SyncDeviceToHost(host_shape, host_size, host_type, out_tensor->data_c());
1173     } else {
1174       ret_sync = SyncDeviceToHost(host_size, out_tensor->data_c());
1175     }
1176   } else {
1177     // copy device to host using sync mode
1178     auto ret_rt_memcpy = CALL_ASCEND_API(aclrtMemcpy, out_tensor->data_c(), host_size, GetDevicePtr(), GetSize(),
1179                                          ACL_MEMCPY_DEVICE_TO_HOST);
1180     if (ret_rt_memcpy != ACL_ERROR_NONE) {
1181       MS_LOG(ERROR) << "SyncDeviceToHost: aclrtMemcpy mem size[" << GetSize() << "] fail, ret[" << ret_rt_memcpy << "]";
1182       return false;
1183     } else {
1184       ret_sync = true;
1185     }
1186   }
1187   if (!ret_sync) {
1188     MS_LOG(ERROR) << "Convert format or Copy device mem to host failed";
1189     return ret;
1190   }
1191   MS_LOG(INFO) << "E2E tensor name is " << tensor_name;
1192   tensor_data->SetTensor(out_tensor);
1193   tensor_data->SetDataPtr(static_cast<char *>(out_tensor->data_c()));
1194   tensor_data->SetByteSize(LongToSize(out_tensor->data().nbytes()));
1195   tensor_data->SetType(host_type);
1196   tensor_data->SetShape(out_tensor->shape());
1197   tensor_data->SetRootGraphId(root_graph_id);
1198   std::string tensor_format = trans_flag ? host_fmt : format();
1199   tensor_data->SetFormat(tensor_format);
1200   ret = debugger->LoadNewTensor(tensor_data, keep_prev);
1201   MS_LOG(INFO) << "Load tensor '" << tensor_name << "' into debugger tensor loader successfully: format("
1202                << tensor_format << ")";
1203   return ret;
1204 }
1205 #endif
1206 }  // namespace ascend
1207 }  // namespace device
1208 }  // namespace mindspore
1209