• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <algorithm>
18 #include <set>
19 #include "runtime/graph_scheduler/actor/data_prepare_actor.h"
20 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
21 #include "runtime/graph_scheduler/actor/kernel_actor.h"
22 #include "runtime/graph_scheduler/actor/loop_count_actor.h"
23 #include "runtime/graph_scheduler/actor/debug_actor.h"
24 #include "runtime/graph_scheduler/actor/profiler_actor.h"
25 #include "runtime/hardware/device_context_manager.h"
26 #include "runtime/device/auto_mem_offload.h"
27 #include "runtime/device/device_address_utils.h"
28 #include "mindrt/include/async/async.h"
29 #include "utils/log_adapter.h"
30 #include "utils/phase.h"
31 #include "include/common/utils/convert_utils.h"
32 #include "include/backend/distributed/recovery/recovery_context.h"
33 #include "include/backend/mem_reuse/mem_tracker.h"
34 #if defined(__linux__) && defined(WITH_BACKEND)
35 #include "runtime/graph_scheduler/rpc_node_scheduler.h"
36 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
37 #endif
38 
39 namespace mindspore {
40 namespace runtime {
41 using distributed::recovery::RecoveryContext;
42 namespace {
43 constexpr size_t kNormalTensorNum = 1;
44 constexpr size_t kMapTensorNum = 3;
45 constexpr size_t kMapTensorKeyIndex = 0;
46 constexpr size_t kMapTensorValueIndex = 1;
47 constexpr size_t kMapTensorStatusIndex = 2;
48 constexpr size_t kPinMemThreshold = 1024 << 10;
49 
IsEmptySequenceTensor(const TensorPtr & tensor)50 bool IsEmptySequenceTensor(const TensorPtr &tensor) {
51   MS_EXCEPTION_IF_NULL(tensor);
52   if (tensor->base_shape_ptr() == nullptr || (!tensor->base_shape_ptr()->isa<abstract::SequenceShape>())) {
53     return false;
54   }
55   const auto &sequence_shape = tensor->base_shape_ptr()->cast<abstract::SequenceShapePtr>();
56   MS_EXCEPTION_IF_NULL(sequence_shape);
57   return sequence_shape->size() == 0;
58 }
59 
IsDataTakenOverByMemOffload(const DeviceContext * device_context)60 bool IsDataTakenOverByMemOffload(const DeviceContext *device_context) {
61   MS_EXCEPTION_IF_NULL(device_context);
62   if (device_context->GetDeviceType() == device::DeviceType::kCPU) {
63     return false;
64   }
65   auto ms_context = MsContext::GetInstance();
66   MS_EXCEPTION_IF_NULL(ms_context);
67   return ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD);
68 }
69 
GetStorageInfo(const TensorPtr & host_tensor,const DeviceTensorPtr & device_tensor,const DeviceContext * device_context)70 device::StorageInfo GetStorageInfo(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor,
71                                    const DeviceContext *device_context) {
72   MS_EXCEPTION_IF_NULL(host_tensor);
73   MS_EXCEPTION_IF_NULL(device_tensor);
74   MS_EXCEPTION_IF_NULL(device_context);
75   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
76   auto swap_manager = device_context->device_res_manager_->swap_manager();
77   MS_EXCEPTION_IF_NULL(swap_manager);
78   if (host_tensor->data_type() == device_tensor->type_id()) {
79     const auto &offload_file = host_tensor->GetOffloadFilePath();
80     if (!offload_file.empty()) {
81       return {nullptr, offload_file};
82     } else if (host_tensor->Size() > kPinMemThreshold) {
83       host_tensor->PinMemory(swap_manager->GetPinMemPool());
84     }
85     return {host_tensor->data_c(), ""};
86   }
87   const auto shape_size = abstract::ShapeSize(host_tensor->shape());
88   const auto data_size = host_tensor->Size();
89   const trans::TypeIdArgs type_args{host_tensor->data_c(), shape_size, host_tensor->data_type(),
90                                     device_tensor->type_id(), data_size};
91   auto offload_ptr = swap_manager->AllocHostMemory(device_tensor->GetSize());
92   MS_EXCEPTION_IF_NULL(offload_ptr);
93   bool trans_ret = trans::TransDataType(type_args, offload_ptr);
94   if (!trans_ret) {
95     MS_LOG(EXCEPTION) << "Trans data type for offload ptr failed, src type: "
96                       << TypeIdToString(host_tensor->data_type())
97                       << ", dst type: " << TypeIdToString(device_tensor->type_id());
98   }
99   return {offload_ptr, ""};
100 }
101 
UpdateTracker(const std::string & task_name,const AnfNodePtr & node,const std::string & graph_str,device::tracker::MemType mem_type,const DeviceTensorPtr & device_tensor)102 void UpdateTracker(const std::string &task_name, const AnfNodePtr &node, const std::string &graph_str,
103                    device::tracker::MemType mem_type, const DeviceTensorPtr &device_tensor) {
104   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, task_name, node->fullname_with_scope(), graph_str);
105   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, task_name, mem_type, device_tensor->GetSize(),
106                                                  device_tensor.get());
107 }
108 
SyncTensorData(const TensorPtr & host_tensor,const DeviceTensorPtr & device_tensor,const AnfNodePtr & node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context,GraphExecutionStrategy strategy)109 void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
110                     const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
111                     GraphExecutionStrategy strategy) {
112   MS_EXCEPTION_IF_NULL(host_tensor);
113   MS_EXCEPTION_IF_NULL(device_tensor);
114   MS_EXCEPTION_IF_NULL(node);
115   MS_EXCEPTION_IF_NULL(device_context);
116   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
117   MS_EXCEPTION_IF_NULL(context);
118   const bool taken_over_by_swap_manager = IsDataTakenOverByMemOffload(device_context);
119   auto allocator_type = node->isa<ValueNode>() ? device::AllocatorType::kConstantValue : device::AllocatorType::kWeight;
120   device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), allocator_type, 0);
121   bool need_alloc_memory = !taken_over_by_swap_manager && (device_tensor->GetPtr() == nullptr);
122   auto graph_str = (node->func_graph() == nullptr) ? "" : node->func_graph()->ToString();
123   auto mem_type = node->isa<ValueNode>() ? device::tracker::MemType::kConstantValue : device::tracker::MemType::kWeight;
124   if (need_alloc_memory) {
125     UpdateTracker("SyncTensorData", node, graph_str, mem_type, device_tensor);
126     if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex)) {
127       SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy, *context, *device_context, node->fullname_with_scope(),
128                                                   device_tensor->GetSize());
129     }
130     if (common::IsNeedProfileMemory()) {
131       auto output_address = reinterpret_cast<std::uintptr_t>(device_tensor.get());
132       MS_LOG(WARNING) << "Need Profile Memory, alloc type: SyncTensorData, device address class ptr: " << output_address
133                       << ", node: " << node->fullname_with_scope() << ", graph: " << graph_str
134                       << ", device address size: " << device_tensor->GetSize()
135                       << ", device address addr: " << device_tensor->GetPtr();
136     }
137   }
138 
139   auto get_tensor_by_index = [&host_tensor](size_t index) {
140     if (!host_tensor->isa<tensor::MapTensor>()) {
141       return host_tensor;
142     }
143     const auto &map_tensor = host_tensor->cast<tensor::MapTensorPtr>();
144     MS_EXCEPTION_IF_NULL(map_tensor);
145     switch (index) {
146       case kMapTensorKeyIndex:
147         return map_tensor->key_tensor();
148       case kMapTensorValueIndex:
149         return map_tensor->value_tensor();
150       case kMapTensorStatusIndex:
151         return map_tensor->status_tensor();
152       default:
153         MS_LOG(EXCEPTION) << "Invalid index:" << index << " for map tensor:" << host_tensor->ToString();
154     }
155   };
156 
157   ShapeVector host_shape = {};
158   // GetRuntimePaddingShape doesn't support the value tuple node.
159   if (!node->isa<ValueNode>()) {
160     host_shape = trans::GetRuntimePaddingShape(node, 0);
161   }
162   auto get_tensor_num = (host_tensor->isa<tensor::MapTensor>() ? kMapTensorNum : kNormalTensorNum);
163   for (size_t i = 0; i < get_tensor_num; ++i) {
164     const auto &real_host_tensor = get_tensor_by_index(i);
165     MS_EXCEPTION_IF_NULL(real_host_tensor);
166     // Copy data from host tensor to device.
167     auto host_tensor_size = LongToSize(real_host_tensor->data().nbytes());
168     auto host_tensor_type = real_host_tensor->data_type();
169     if (node->isa<ValueNode>()) {
170       host_shape = real_host_tensor->shape();
171     }
172     if (taken_over_by_swap_manager) {
173       device_tensor->SetStorageInfo(GetStorageInfo(real_host_tensor, device_tensor, device_context));
174     } else if (!device_tensor->SyncHostToDevice(host_shape, host_tensor_size, host_tensor_type,
175                                                 real_host_tensor->device_info().host_format_,
176                                                 real_host_tensor->data_ptr())) {
177       std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope() +
178                                ", host tensor size: " + std::to_string(host_tensor_size) +
179                                ", host tensor type: " + std::to_string(static_cast<int>(host_tensor_type)) +
180                                ", device tensor size: " + std::to_string(device_tensor->GetSize());
181       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy, (*context), error_info);
182     }
183   }
184 }
185 
FetchContinuousMemoryInfo(const CNodePtr & node,std::vector<DeviceTensorPtr> * const addr_list,std::vector<size_t> * const size_list,size_t * const total_size,bool is_input)186 void FetchContinuousMemoryInfo(const CNodePtr &node, std::vector<DeviceTensorPtr> *const addr_list,
187                                std::vector<size_t> *const size_list, size_t *const total_size, bool is_input) {
188   MS_EXCEPTION_IF_NULL(node);
189   MS_EXCEPTION_IF_NULL(addr_list);
190   MS_EXCEPTION_IF_NULL(size_list);
191   MS_EXCEPTION_IF_NULL(total_size);
192 
193   (*addr_list).clear();
194   (*size_list).clear();
195   *total_size = 0;
196 
197   if (is_input) {
198     const auto &intput_sizes = AnfAlgo::GetNodeInputSizeList(node);
199     for (size_t i = 0; i < intput_sizes.size(); ++i) {
200       const auto &device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, false);
201       MS_EXCEPTION_IF_NULL(device_tensor);
202       *total_size += intput_sizes[i];
203       (void)size_list->emplace_back(intput_sizes[i]);
204       (void)addr_list->emplace_back(device_tensor);
205     }
206   } else {
207     const auto &kernel_mod = AnfAlgo::GetKernelMod(node);
208     MS_EXCEPTION_IF_NULL(kernel_mod);
209     const auto &output_sizes = kernel_mod->GetOutputSizeList();
210     for (size_t i = 0; i < output_sizes.size(); ++i) {
211       const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
212       MS_EXCEPTION_IF_NULL(device_tensor);
213       *total_size += output_sizes[i];
214       (void)size_list->emplace_back(output_sizes[i]);
215       (void)addr_list->emplace_back(device_tensor);
216     }
217   }
218 }
219 
ValueTupleToValue(const ValuePtr & value,std::vector<ValuePtr> * const values)220 void ValueTupleToValue(const ValuePtr &value, std::vector<ValuePtr> *const values) {
221   MS_EXCEPTION_IF_NULL(value);
222   MS_EXCEPTION_IF_NULL(values);
223   if (value->isa<ValueSequence>()) {
224     auto value_tuple = value->cast<ValueSequencePtr>();
225     MS_EXCEPTION_IF_NULL(value_tuple);
226     for (size_t i = 0; i < value_tuple->size(); ++i) {
227       ValuePtr element = value_tuple->value()[i];
228       MS_EXCEPTION_IF_NULL(element);
229 
230       if (element->isa<ValueSequence>()) {
231         ValueTupleToValue(element, values);
232       } else {
233         (void)values->emplace_back(element);
234       }
235     }
236   } else if (value->isa<tensor::CSRTensor>()) {
237     auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
238     MS_EXCEPTION_IF_NULL(csr_tensor);
239     MS_EXCEPTION_IF_NULL(csr_tensor->GetIndptr());
240     MS_EXCEPTION_IF_NULL(csr_tensor->GetIndices());
241     MS_EXCEPTION_IF_NULL(csr_tensor->GetValues());
242     (void)values->emplace_back(csr_tensor->GetIndptr());
243     (void)values->emplace_back(csr_tensor->GetIndices());
244     (void)values->emplace_back(csr_tensor->GetValues());
245     (void)std::transform(csr_tensor->shape().begin(), csr_tensor->shape().end(), std::back_inserter(*values),
246                          [](int64_t n) { return std::make_shared<Int64Imm>(n); });
247   } else if (value->isa<tensor::COOTensor>()) {
248     auto coo_tensor = value->cast<tensor::COOTensorPtr>();
249     MS_EXCEPTION_IF_NULL(coo_tensor);
250     MS_EXCEPTION_IF_NULL(coo_tensor->GetIndices());
251     MS_EXCEPTION_IF_NULL(coo_tensor->GetValues());
252     (void)values->emplace_back(coo_tensor->GetIndices());
253     (void)values->emplace_back(coo_tensor->GetValues());
254     (void)std::transform(coo_tensor->shape().begin(), coo_tensor->shape().end(), std::back_inserter(*values),
255                          [](int64_t n) { return std::make_shared<Int64Imm>(n); });
256   } else {
257     (void)values->emplace_back(value);
258   }
259 }
260 
261 // The device address of input ref node may be modified by input tensor, so need update the device address of ref node.
UpdateDeviceAddressByRefInputNode(const std::vector<KernelGraphPtr> & graphs,const std::set<AnfNode * > & modified_input_nodes)262 void UpdateDeviceAddressByRefInputNode(const std::vector<KernelGraphPtr> &graphs,
263                                        const std::set<AnfNode *> &modified_input_nodes) {
264   for (const auto &graph : graphs) {
265     MS_EXCEPTION_IF_NULL(graph);
266     // The DeviceAddress of the graph parameter has been updated.
267     if (graph->is_graph_run_mode()) {
268       continue;
269     }
270 
271     for (auto &iter : graph->GetRefMap()) {
272       auto &output_pair = iter.first;
273       auto &input_pair = iter.second;
274       MS_EXCEPTION_IF_NULL(output_pair.first);
275       MS_EXCEPTION_IF_NULL(input_pair.first);
276       if (modified_input_nodes.count(input_pair.first.get()) == 0) {
277         continue;
278       }
279       // The output device tensor of ref node actor can't be changed in the running, and only the ptr of output device
280       // address can be modified. And need set `ref_count` to `SIZE_MAX` for avoiding clean. So only support the
281       // persistent device tensor.
282       if (!IsPersistentDeviceTensor(input_pair.first)) {
283         MS_LOG(INFO) << "The input parameter: " << input_pair.first->fullname_with_scope()
284                      << " isn't the ref parameter which used by the ref node: "
285                      << output_pair.first->fullname_with_scope();
286         continue;
287       }
288 
289       MS_LOG(INFO) << "Update the ptr of ref node: " << output_pair.first->fullname_with_scope()
290                    << " by the modified ref input parameter: " << input_pair.first->fullname_with_scope();
291       auto ref_node_output_addr = AnfAlgo::GetMutableOutputAddr(output_pair.first, output_pair.second, false);
292       MS_EXCEPTION_IF_NULL(ref_node_output_addr);
293       const auto &front_input_node = AnfAlgo::FetchFrontNodeByBackendNode(input_pair.first, *graph);
294       auto input_addr =
295         DeviceTensorStore::GetInstance().Fetch(front_input_node.get(), ref_node_output_addr->GetDeviceType());
296       // Maybe subgraphs share the same backend input parameter, so fetch device tensor store by front node of this
297       // subgraph maybe nullptr and use the output addr of input parameter directly.
298       if (input_addr == nullptr) {
299         input_addr = AnfAlgo::GetMutableOutputAddr(input_pair.first, input_pair.second, false);
300       }
301       MS_EXCEPTION_IF_NULL(input_addr);
302       MS_EXCEPTION_IF_CHECK_FAIL((ref_node_output_addr->GetDeviceType() == input_addr->GetDeviceType()),
303                                  "The device type of ref node is not equal.");
304       ref_node_output_addr->set_ptr(input_addr->GetMutablePtr());
305       ref_node_output_addr->set_original_ref_count(SIZE_MAX);
306       ref_node_output_addr->ResetRefCount();
307     }
308   }
309 }
310 
IsNeedSync(const TensorPtr & tensor)311 bool IsNeedSync(const TensorPtr &tensor) {
312   if (RecoveryContext::GetInstance()->enable_recovery() &&
313       RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
314     return true;
315   }
316 
317   if (tensor == nullptr) {
318     return false;
319   }
320   // Sub data need sync each step
321   auto data_ptr = tensor->data_ptr();
322   return data_ptr != nullptr && data_ptr->is_sub_data();
323 }
324 
SyncTensorTrunk(const std::vector<std::vector<TensorPtr>> & input_tensors)325 void SyncTensorTrunk(const std::vector<std::vector<TensorPtr>> &input_tensors) {
326   for (auto &tensors : input_tensors) {
327     for (auto &tensor : tensors) {
328       if (tensor == nullptr) {
329         continue;
330       }
331       auto data_ptr = tensor->data_ptr();
332       if (data_ptr != nullptr && data_ptr->has_sub_data()) {
333         tensor->data_sync();
334       }
335     }
336   }
337 }
338 
UpdateDataNodeDeviceAddressSize(const AnfNodePtr & input_node,const TensorPtr & input_tensor,const device::DeviceAddressPtr & device_address)339 void UpdateDataNodeDeviceAddressSize(const AnfNodePtr &input_node, const TensorPtr &input_tensor,
340                                      const device::DeviceAddressPtr &device_address) {
341   MS_EXCEPTION_IF_NULL(input_node);
342   MS_EXCEPTION_IF_NULL(input_tensor);
343   MS_EXCEPTION_IF_NULL(device_address);
344   TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
345   if (output_type_id == kTypeUnknown) {
346     output_type_id = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
347   }
348   auto device_shape =
349     trans::TransShapeToDevice(input_tensor->shape(), device_address->format(), input_node, 0, output_type_id);
350   size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
351   auto device_address_size = type_size * SizeOf(device_shape);
352   MS_LOG(INFO) << "Size of device_address is updated from " << device_address->GetSize() << " to "
353                << device_address_size;
354   device_address->SetSize(device_address_size);
355 }
356 }  // namespace
357 
358 mindspore::HashSet<const tensor::Tensor *> DataPrepareActor::tensors_need_reprepare_ = {};
359 
360 std::atomic<size_t> DataPrepareActor::execution_count_ = 0;
361 
Init()362 void DataPrepareActor::Init() {
363   MS_EXCEPTION_IF_NULL(graph_compiler_info_);
364   strategy_ = graph_compiler_info_->strategy_;
365   if (graph_compiler_info_->graphs_.size() != graph_compiler_info_->device_contexts_.size()) {
366     MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
367   }
368 
369   size_t host_data_size = 0;
370   if (host_data_source_actor_ != nullptr) {
371     host_data_size = host_data_source_actor_->data_nodes().size();
372   }
373   has_parameter_input_ = graph_compiler_info_->inputs_num_ > host_data_size;
374   MS_LOG(INFO) << graph_compiler_info_->name_
375                << " has the parameter input num: " << graph_compiler_info_->inputs_num_ - host_data_size;
376 
377   for (const auto &graph : graph_compiler_info_->graphs_) {
378     MS_EXCEPTION_IF_NULL(graph);
379     if (graph->is_dynamic_shape()) {
380       has_dynamic_shape_ = true;
381       break;
382     }
383   }
384 
385   for (auto &iter : continuous_memory_nodes_) {
386     size_t total_size = 0;
387     std::vector<size_t> size_list;
388     std::vector<DeviceTensorPtr> addr_list;
389     // Inputs need continuous memory.
390     if (iter.second.first) {
391       const auto &cnode = iter.first.first;
392       FetchContinuousMemoryInfo(cnode, &addr_list, &size_list, &total_size, true);
393       (void)continuous_memory_alloc_list_list_.emplace_back(addr_list);
394       (void)size_list_list_.emplace_back(size_list);
395       (void)stream_id_list_.emplace_back(kDefaultStreamIndex);
396       (void)total_size_list_.emplace_back(total_size);
397       (void)continuous_memory_device_contexts_.emplace_back(iter.first.second);
398     }
399 
400     // Outputs need continuous memory.
401     if (iter.second.second) {
402       const auto &cnode = iter.first.first;
403       FetchContinuousMemoryInfo(cnode, &addr_list, &size_list, &total_size, false);
404       (void)continuous_memory_alloc_list_list_.emplace_back(addr_list);
405       (void)size_list_list_.emplace_back(size_list);
406       (void)stream_id_list_.emplace_back(kDefaultStreamIndex);
407       (void)total_size_list_.emplace_back(total_size);
408       (void)continuous_memory_device_contexts_.emplace_back(iter.first.second);
409     }
410   }
411 }
412 
UpdateDynamicShapeAndSize(const AnfNodePtr & input_node,const TensorPtr & input_tensor) const413 void DataPrepareActor::UpdateDynamicShapeAndSize(const AnfNodePtr &input_node, const TensorPtr &input_tensor) const {
414   MS_EXCEPTION_IF_NULL(input_node);
415   if (input_tensor == nullptr || IsEmptySequenceTensor(input_tensor)) {
416     return;
417   }
418   if (!input_node->isa<Parameter>()) {
419     return;
420   }
421   auto input_param = input_node->cast<ParameterPtr>();
422   MS_EXCEPTION_IF_NULL(input_param);
423   auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
424   MS_EXCEPTION_IF_NULL(device_address);
425   if (!input_param->has_dynamic_shape() && !IsDynamic(device_address->host_shape())) {
426     return;
427   }
428 
429   // Update shape.
430   MS_LOG(DEBUG) << "Update dynamic shape for parameter:" << input_param->DebugString();
431   const auto &output_kernel_tensor = AnfAlgo::GetOutputKernelTensor(input_node, 0);
432   MS_EXCEPTION_IF_NULL(output_kernel_tensor);
433   if (input_tensor->base_shape_ptr() == nullptr || (!input_tensor->base_shape_ptr()->isa<abstract::SequenceShape>())) {
434     output_kernel_tensor->SetShape(input_tensor->ToAbstract()->GetShape());
435     return;
436   }
437   output_kernel_tensor->SetShape(input_tensor->base_shape_ptr());
438 
439   // Update size.
440   auto device_format = device_address->format();
441   static const std::set<std::string> kNormalFormat = {
442     kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
443   };
444   if (kNormalFormat.find(device_format) != kNormalFormat.end()) {
445     auto tensor_data_size = input_tensor->data().nbytes();
446     MS_LOG(DEBUG) << "Set device address:" << device_address << " size from:" << device_address->GetSize()
447                   << " to:" << tensor_data_size;
448     device_address->SetSize(tensor_data_size);
449   } else {
450     MS_LOG(DEBUG) << "Update data node device address size";
451     // Size of 5D format device_address is larger than tensor_data_size.
452     UpdateDataNodeDeviceAddressSize(input_node, input_tensor, device_address);
453   }
454 }
455 
UpdateDeviceAddressForDataNode(const AnfNodePtr & input_node,const TensorPtr & input_tensor)456 void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_node, const TensorPtr &input_tensor) {
457   MS_EXCEPTION_IF_NULL(input_tensor);
458   MS_EXCEPTION_IF_NULL(input_node);
459 
460   auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
461   if (tensor_address == nullptr) {
462     return;
463   }
464 
465   auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
466   MS_EXCEPTION_IF_NULL(device_address);
467   if (tensor_address == device_address) {
468     tensor_address->SetNodeIndex(input_node, 0);
469     tensor_address->set_original_ref_count(SIZE_MAX);
470     tensor_address->ResetRefCount();
471     return;
472   }
473 
474   // If tensor address and device address are different (heterogeneous scenarios), or device address is persisted
475   // Update device address data in data source actor process.
476   if (device_address->is_ptr_persisted() || (tensor_address->GetDeviceType() != device_address->GetDeviceType()) ||
477       (!AnfAlgo::IsEquivalentFormat(tensor_address->format(), device_address->format())) ||
478       (tensor_address->type_id() != device_address->type_id())) {
479     MS_LOG(DEBUG) << "Cannot update address of " << input_node->DebugString();
480     return;
481   }
482 
483   // Assign tensor address to input data node and set `ref_count` to `SIZE_MAX` for avoiding clean.
484   (void)address_modified_input_nodes_.insert(input_node.get());
485   tensor_address->set_flag(device_address->flag());
486   DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(tensor_address, input_node, 0);
487   AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
488   MS_LOG(DEBUG) << "Update device address of " << input_node->DebugString() << " to " << tensor_address.get()
489                 << ", kernel tensor addr:" << tensor_address->kernel_tensor().get()
490                 << " ptr:" << tensor_address->GetPtr();
491   tensor_address->SetNodeIndex(input_node, 0);
492   tensor_address->set_original_ref_count(SIZE_MAX);
493   tensor_address->ResetRefCount();
494 }
495 
SetInitTensorsIfNeeded(const std::vector<std::vector<TensorPtr>> & input_tensors)496 void DataPrepareActor::SetInitTensorsIfNeeded(const std::vector<std::vector<TensorPtr>> &input_tensors) {
497   if (!init_tensors_.empty()) {
498     return;
499   }
500   bool need_save = std::any_of(input_tensors.begin(), input_tensors.end(), [](const std::vector<TensorPtr> &tensors) {
501     return std::any_of(tensors.begin(), tensors.end(), [](const TensorPtr &tensor) {
502       if (tensor == nullptr) {
503         return false;
504       }
505       auto data_ptr = tensor->data_ptr();
506       return data_ptr != nullptr && data_ptr->is_sub_data();
507     });
508   });
509   if (need_save) {
510     init_tensors_ = input_tensors;
511   }
512 }
513 
PrepareData(const std::vector<std::vector<TensorPtr>> & input_tensors,const VectorRef & args,OpContext<DeviceTensor> * const context,GraphExecutionStrategy real_strategy)514 void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &input_tensors, const VectorRef &args,
515                                    OpContext<DeviceTensor> *const context, GraphExecutionStrategy real_strategy) {
516   MS_EXCEPTION_IF_NULL(context);
517   uint64_t start_time = 0;
518   PROFILER_START(start_time);
519 
520 #if defined(__linux__) && defined(WITH_BACKEND)
521   // Update rpc actors' status.
522   RpcActorStatusUpdater::GetInstance().UpdateRpcActorStatus(graph_compiler_info_->name_);
523 #endif
524 
525   try {
526     // Preprocess before prepare data for data prepare actor.
527     PreprocessBeforePrepareData();
528   } catch (const std::exception &e) {
529     MsException::Instance().SetException();
530     std::string error_info = e.what();
531     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
532   }
533 
534   MS_LOG(DEBUG) << "Data prepare actor(" << GetAID().Name() << ") prepares data.";
535   real_strategy_ = real_strategy;
536   // Convert actor running data from input tensors.
537   if (!input_tensors.empty()) {
538     SyncTensorTrunk(input_tensors);
539     SetInitTensorsIfNeeded(input_tensors);
540   }
541   try {
542     auto ms_context = MsContext::GetInstance();
543     MS_EXCEPTION_IF_NULL(ms_context);
544     static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
545     if (first_step_ || !tensors_need_reprepare_.empty() || (has_parameter_input_ && !enable_infer_boost)) {
546       PrepareDataForDeviceTensorStore(input_tensors, args, context);
547     }
548     PrepareDataForHostTensorQueue(input_tensors, args, context);
549   } catch (const std::exception &e) {
550     std::string error_info = e.what();
551     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
552   }
553 
554   first_step_ = false;
555   if (IsRunningFailed(context)) {
556     return;
557   }
558   MS_EXCEPTION_IF_NULL(graph_compiler_info_);
559   if (!address_modified_input_nodes_.empty()) {
560     UpdateDeviceAddressByRefInputNode(graph_compiler_info_->graphs_, address_modified_input_nodes_);
561     address_modified_input_nodes_.clear();
562   }
563 
564   // Debug actor is blocked, must wait debug actor callback message to process continue.
565   if (debug_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
566     SendDebugReq(context);
567     return;
568   }
569 
570   if (profiler_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
571     SendProfilerReq(context);
572     return;
573   }
574 
575   PROFILER_END(start_time, runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kPreLaunch, GetAID().Name(),
576                false);
577 
578   // Allocate continuous memory and send output to trigger the step running.
579   if (continuous_memory_alloc_list_list_.size() > 0) {
580     SendMemoryAllocReq(context);
581   } else {
582     PostRun(context);
583   }
584 }
585 
SendDebugReq(OpContext<DeviceTensor> * const context)586 void DataPrepareActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
587   MS_EXCEPTION_IF_NULL(graph_compiler_info_);
588   ActorDispatcher::SendSync(*debug_aid_, &DebugActor::DebugOnStepBegin, graph_compiler_info_->graphs_,
589                             graph_compiler_info_->origin_parameters_order_, graph_compiler_info_->device_contexts_,
590                             context, &GetAID());
591   OnDebugFinish(context);
592 }
593 
SendProfilerReq(OpContext<DeviceTensor> * const context)594 void DataPrepareActor::SendProfilerReq(OpContext<DeviceTensor> *const context) {
595   MS_EXCEPTION_IF_NULL(graph_compiler_info_);
596   ActorDispatcher::SendSync(*profiler_aid_, &ProfilerActor::ProfilerOnStepBegin, graph_compiler_info_->graphs_,
597                             graph_compiler_info_->origin_parameters_order_, graph_compiler_info_->device_contexts_,
598                             context, &GetAID());
599   OnDebugFinish(context);
600 }
601 
OnDebugFinish(OpContext<DeviceTensor> * const context)602 void DataPrepareActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
603   MS_EXCEPTION_IF_NULL(context);
604   if (continuous_memory_alloc_list_list_.size() > 0) {
605     SendMemoryAllocReq(context);
606   } else {
607     PostRun(context);
608   }
609 }
610 
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)611 void DataPrepareActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
612   // Allocate continuous memory in the begin of the step running.
613   if (ActorDispatcher::is_memory_allocation_sync()) {
614     if (!ActorDispatcher::enable_use_trace_memory()) {
615       ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateContinuousMemory,
616                                 &continuous_memory_alloc_list_list_, &size_list_list_, &stream_id_list_,
617                                 &total_size_list_, &continuous_memory_device_contexts_, context, GetAID());
618     }
619     OnMemoryAllocFinish(context);
620   } else {
621     ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateContinuousMemory,
622                           &continuous_memory_alloc_list_list_, &size_list_list_, &stream_id_list_, &total_size_list_,
623                           &continuous_memory_device_contexts_, context, GetAID());
624   }
625 }
626 
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)627 void DataPrepareActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
628   MS_EXCEPTION_IF_NULL(context);
629   if (IsRunningFailed(context)) {
630     return;
631   }
632 
633   PostRun(context);
634 }
635 
FetchInputTensor(const std::vector<TensorPtr> & tensors,size_t tensor_index,const VectorRef & args,const KernelWithIndex & front_node) const636 TensorPtr DataPrepareActor::FetchInputTensor(const std::vector<TensorPtr> &tensors, size_t tensor_index,
637                                              const VectorRef &args, const KernelWithIndex &front_node) const {
638   if (!tensors.empty()) {
639     MS_EXCEPTION_IF_CHECK_FAIL((tensor_index < tensors.size()), "The tensor index is out of range.");
640     auto tensor = tensors[tensor_index];
641     // The tensor needs to be converted to contiguous before being given to the actors.
642     // After the view feature is supported in the graph mode, the following code will be deleted.
643     DeviceAddressUtils::ConvertContiguousTensorSync(tensor);
644     return tensor;
645   }
646 
647   MS_EXCEPTION_IF_NULL(front_node.first);
648   const auto &iter = std::find(graph_compiler_info_->origin_parameters_order_.begin(),
649                                graph_compiler_info_->origin_parameters_order_.end(), front_node.first);
650   if (iter == graph_compiler_info_->origin_parameters_order_.end()) {
651     MS_LOG(INFO) << "Not origin parameter:  " << front_node.first->fullname_with_scope();
652     return nullptr;
653   }
654   auto arg_index = iter - graph_compiler_info_->origin_parameters_order_.begin();
655   auto tensor = FetchInputTensorByArg(args, arg_index, front_node);
656   // The tensor needs to be converted to contiguous before being given to the actors.
657   // After the view feature is supported in the graph mode, the following code will be deleted.
658   DeviceAddressUtils::ConvertContiguousTensorSync(tensor);
659   return tensor;
660 }
661 
FetchInputTensorByArg(const VectorRef & args,size_t arg_index,const KernelWithIndex & front_node) const662 TensorPtr DataPrepareActor::FetchInputTensorByArg(const VectorRef &args, size_t arg_index,
663                                                   const KernelWithIndex &front_node) const {
664   if (arg_index >= args.size()) {
665     MS_LOG(INFO) << "Arg index out of args range, index is " << arg_index << " and args size is " << args.size();
666     return nullptr;
667   }
668 
669   std::vector<tensor::TensorPtr> flatten_tensors;
670   AnfAlgo::FlattenInputArg(args[arg_index], front_node.first, &flatten_tensors);
671   auto input_tensor_index = FetchInputTensorIndex(front_node);
672   if (input_tensor_index >= flatten_tensors.size()) {
673     MS_LOG(INFO) << "Input tensor index out of args range, index is " << input_tensor_index << " and tensors size is "
674                  << flatten_tensors.size();
675     return nullptr;
676   }
677 
678   auto tensor = flatten_tensors[input_tensor_index];
679   // The tensor needs to be converted to contiguous before being given to the actors.
680   // After the view feature is supported in the graph mode, the following code will be deleted.
681   DeviceAddressUtils::ConvertContiguousTensorSync(tensor);
682 
683   if (tensor != nullptr && tensor->update_value_callback() == nullptr && tensor->is_parameter()) {
684     static auto callback = [](const tensor::Tensor *tensor) { tensors_need_reprepare_.insert(tensor); };
685     tensor->set_update_value_callback(callback);
686   }
687 
688   if (tensor != nullptr && !tensors_need_reprepare_.empty() && tensor->is_parameter()) {
689     auto erased_num = tensors_need_reprepare_.erase(tensor.get());
690     MS_LOG(DEBUG) << "Erase " << erased_num << " tensor which is reprepared.";
691   }
692 
693   return tensor;
694 }
695 
FetchInputTensorIndex(const KernelWithIndex & front_node) const696 size_t DataPrepareActor::FetchInputTensorIndex(const KernelWithIndex &front_node) const {
697   MS_EXCEPTION_IF_NULL(front_node.first);
698   if (common::AnfAlgo::IsDynamicSequence(front_node.first)) {
699     return 0;
700   }
701 
702   const auto &abs = front_node.first->abstract();
703   MS_EXCEPTION_IF_NULL(abs);
704   if (abs->isa<abstract::AbstractSequence>()) {
705     return front_node.second;
706   }
707 
708   return 0;
709 }
710 
PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> & input_tensors,const VectorRef & args,OpContext<DeviceTensor> * const context)711 void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
712                                                        const VectorRef &args, OpContext<DeviceTensor> *const context) {
713   MS_LOG(INFO) << "Prepare store data, input tensor size: " << input_tensors.size() << ", arg size: " << args.size();
714   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, "PrepareStoreData", true);
715   MS_EXCEPTION_IF_NULL(graph_compiler_info_);
716   const auto &parser = graph_compiler_info_->control_node_parser_;
717   MS_EXCEPTION_IF_NULL(parser);
718   for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
719     const auto &graph = graph_compiler_info_->graphs_[i];
720     const auto &device_context = graph_compiler_info_->device_contexts_[i];
721     MS_EXCEPTION_IF_NULL(graph);
722     MS_LOG(DEBUG) << "prepare data for graph:" << graph->ToString();
723     // Prepare the data of device tensor store(value nodes of graph).
724     for (const auto &value_node : graph->graph_value_nodes()) {
725       MS_EXCEPTION_IF_NULL(value_node);
726       if (AnfAlgo::OutputAddrExist(value_node, 0)) {
727         const auto &front_node = AnfAlgo::FetchFrontNodeByBackendNode(value_node, *graph);
728         MS_EXCEPTION_IF_NULL(front_node);
729         MS_LOG(DEBUG) << "Prepare data for value node:" << value_node->fullname_with_scope()
730                       << ", debug name:" << value_node->DebugString() << ", front node:" << front_node->DebugString()
731                       << " for graph:" << graph->ToString();
732         const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
733         const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
734         MS_EXCEPTION_IF_NULL(device_tensor);
735         // If front_node has more than one device tensor, it means the node may used in multi graphs.
736         // so we will clear the deviceaddress flag of ignore.
737         if (TEST_FLAG(device_tensor->flag(), device::kDeviceAddressFlagIgnoreDevicePtr) && device_tensors.size() > 1) {
738           device_tensor->ClearFlag(device::kDeviceAddressFlagIgnoreDevicePtr);
739         }
740         // If node address has flag ignore, we will not prepare device data for it.
741         if (!TEST_FLAG(device_tensor->flag(), device::kDeviceAddressFlagIgnoreDevicePtr)) {
742           PrepareDataForValueNode(value_node, front_node, device_context, context);
743         }
744       }
745     }
746 
747     // Prepare the data of device tensor store(weights of graph).
748     const auto &input_nodes = graph->input_nodes();
749     for (size_t j = 0; j < input_nodes.size(); ++j) {
750       const auto &input_node = input_nodes[j];
751       MS_EXCEPTION_IF_NULL(input_node);
752       const auto &real_device_context = device::FetchRealDeviceContext(input_node, device_context);
753       MS_EXCEPTION_IF_NULL(real_device_context);
754       const auto &front_node = AnfAlgo::FetchFrontNodeByBackendNode(input_node, *graph);
755       if (IsPersistentDeviceTensor(input_node) && parser->IsRootGraphPersistentDeviceTensor(front_node)) {
756         std::vector<TensorPtr> graph_tensors = input_tensors.empty() ? std::vector<TensorPtr>() : input_tensors[i];
757         TensorPtr input_tensor = FetchInputTensor(graph_tensors, j, args, {front_node, 0});
758         PrepareDataForWeightNode(input_node, front_node, input_tensor, real_device_context, context);
759       }
760     }
761   }
762   if (RecoveryContext::GetInstance()->enable_recovery() &&
763       RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
764     RecoveryContext::GetInstance()->set_need_sync_weight_to_device(false);
765   }
766 
767   std::vector<TensorPtr> control_input = input_tensors.empty() ? std::vector<TensorPtr>() : input_tensors.back();
768   PrepareDeviceTensorStoreForControlNode(parser, control_input, args, context);
769 }
770 
PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> & input_tensors,const VectorRef & args,OpContext<DeviceTensor> * const context)771 void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors,
772                                                      const VectorRef &args, OpContext<DeviceTensor> *const context) {
773   MS_LOG(INFO) << "Prepare host data, input tensor size: " << input_tensors.size() << ", arg size: " << args.size();
774   ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, "PrepareHostData", true);
775   MS_EXCEPTION_IF_NULL(context);
776   MS_EXCEPTION_IF_NULL(graph_compiler_info_);
777   if ((host_data_source_actor_ == nullptr) || (host_tensor_queue_ == nullptr)) {
778     return;
779   }
780 
781   if (input_tensors.empty()) {
782     PrepareDataForHostTensorQueueNew(args, context);
783     return;
784   }
785 
786   // Fill host tensors.
787   std::vector<TensorPtr> host_tensors;
788   host_tensors.resize(host_data_source_actor_->data_nodes().size());
789   for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
790     const auto &graph = graph_compiler_info_->graphs_[i];
791     MS_EXCEPTION_IF_NULL(graph);
792 
793     const auto &input_nodes = graph->input_nodes();
794     const auto &tensors = input_tensors[i];
795     if (input_nodes.size() != tensors.size()) {
796       std::string error_info = "Invalid tensor size:" + std::to_string(tensors.size()) +
797                                " and input node size:" + std::to_string(input_nodes.size()) +
798                                " for kernel graph:" + graph->ToString();
799       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
800     }
801     for (size_t j = 0; j < input_nodes.size(); ++j) {
802       const auto &input_node = input_nodes[j];
803       const auto &input_tensor = tensors[j];
804       MS_EXCEPTION_IF_NULL(input_node);
805       if (!IsHostQueueDSActor(input_node, graph, graph_compiler_info_->origin_parameters_order_, strategy_) ||
806           input_tensor == nullptr) {
807         continue;
808       }
809 
810       auto tensor_position = host_data_source_actor_->FetchNodePosition({input_node, 0});
811       if (tensor_position >= host_tensors.size()) {
812         std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
813         SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
814       }
815       MS_LOG(DEBUG) << "Set tensor position:" << tensor_position << " for input data.";
816       host_tensors[tensor_position] = input_tensor;
817 
818       // Synchronize dynamic shape info of the input tensor to the parameter node of graph.
819       if (graph->is_dynamic_shape()) {
820         UpdateDynamicShapeAndSize(input_node, input_tensor);
821       }
822 
823       UpdateDeviceAddressForDataNode(input_node, input_tensor);
824     }
825   }
826 
827   PrepareHostTensorQueueForControlNode(input_tensors.back(), &host_tensors, context);
828 
829   host_tensor_queue_->Push(host_tensors);
830 }
831 
PrepareDataForHostTensorQueueNew(const VectorRef & args,OpContext<DeviceTensor> * const context)832 void DataPrepareActor::PrepareDataForHostTensorQueueNew(const VectorRef &args, OpContext<DeviceTensor> *const context) {
833   MS_EXCEPTION_IF_NULL(context);
834   size_t host_data_size = host_data_source_actor_->data_nodes().size();
835   size_t current_data_num = 0;
836   std::vector<TensorPtr> host_tensors;
837   host_tensors.resize(host_data_size);
838   host_tensors_.resize(host_data_size);
839   bool isDyn = false;
840   // Fill host tensors.
841   for (size_t i = 0; i < graph_compiler_info_->origin_parameters_order_.size(); ++i) {
842     if (current_data_num == host_data_size) {
843       break;
844     }
845     const auto &origin_parameter = graph_compiler_info_->origin_parameters_order_[i];
846     MS_EXCEPTION_IF_NULL(origin_parameter);
847     // The input data is front of the parameter weight.
848     if (common::AnfAlgo::IsParameterWeight(origin_parameter->cast<ParameterPtr>())) {
849       MS_LOG(DEBUG) << "Skip the prepare host data for parameter: " << origin_parameter->fullname_with_scope();
850       continue;
851     }
852 
853     auto iter = graph_compiler_info_->origin_parameters_to_backend_parameters_.find(origin_parameter);
854     if (iter == graph_compiler_info_->origin_parameters_to_backend_parameters_.end()) {
855       MS_LOG(DEBUG) << "Not find the parameter in the origin parameters: " << origin_parameter->fullname_with_scope();
856       continue;
857     }
858 
859     for (auto origin_to_backend_pair : iter->second) {
860       auto input_tensor = FetchInputTensorByArg(args, i, origin_to_backend_pair.first);
861       if (input_tensor == nullptr) {
862         MS_LOG(ERROR) << "The input tensor is nullptr for arg index: " << i
863                       << ", parameter: " << origin_parameter->fullname_with_scope();
864         continue;
865       }
866       // Single ops(run in pynative mode) output to net(context is graph mode) input.
867       runtime::DeviceAddressUtils::CreateKernelTensor(input_tensor);
868       auto tensor_position = host_data_source_actor_->FetchNodePosition(origin_to_backend_pair.second);
869       if (tensor_position >= host_tensors.size()) {
870         std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
871         SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
872       }
873       if (host_tensors[tensor_position] != nullptr) {
874         continue;
875       }
876       MS_LOG(INFO) << "Set host tensor position:" << tensor_position
877                    << " for input parameter:" << origin_parameter->fullname_with_scope();
878 
879       if (!isDyn) {
880         if (host_tensors_[tensor_position] != input_tensor->shape()) {
881           isDyn = true;
882         }
883       }
884       host_tensors_[tensor_position] = input_tensor->shape();
885       host_tensors[tensor_position] = input_tensor;
886       ++current_data_num;
887 
888       UpdateDynamicShapeAndSize(origin_to_backend_pair.second.first, input_tensor);
889 
890       // Avoid the device `ptr_` being hold by the input tensor and the output tensor, the input tensor address cannot
891       // be directly set to the input control node, which may be a passthrough node. The device 'ptr_' is re-malloced
892       // and device to device copy by input tensor address in data source process.
893       if (origin_to_backend_pair.first.first != origin_to_backend_pair.second.first) {
894         UpdateDeviceAddressForDataNode(origin_to_backend_pair.second.first, input_tensor);
895       }
896     }
897   }
898 
899   auto ms_context = MsContext::GetInstance();
900   MS_EXCEPTION_IF_NULL(ms_context);
901   static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
902   if (enable_infer_boost && has_dynamic_shape_ && EnableKbkSubGraphExecute()) {
903     ActorDispatcher::set_enable_static_shape(!isDyn);
904 
905     const auto &phase = PhaseManager::GetInstance().phase();
906     bool is_increment_graph = (phase.find("increment") != std::string::npos);
907     if (EnableTraceMemory() && is_increment_graph) {
908       if (continuous_memory_alloc_list_list_.size() > 0) {
909         MS_LOG(EXCEPTION)
910           << "Can not support continuous memory allocate in dynamic shape graph when enable trace memory.";
911       }
912       if (!ActorDispatcher::enable_static_shape()) {
913         ActorDispatcher::set_enable_trace_dynamic_memory(true);
914       } else {
915         ActorDispatcher::set_enable_use_trace_memory(true);
916       }
917     }
918   }
919   host_tensor_queue_->Push(host_tensors);
920 }
921 
922 //  The branch processing of PrepareDataForValueNode that value type is tensor.
PrepareDataForValueNodeTensor(const ValueNodePtr & node,const ValuePtr & node_value,const AnfNodePtr & front_node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context) const923 void DataPrepareActor::PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
924                                                      const AnfNodePtr &front_node, const DeviceContext *device_context,
925                                                      OpContext<DeviceTensor> *const context) const {
926   MS_EXCEPTION_IF_NULL(node);
927   MS_EXCEPTION_IF_NULL(node_value);
928   MS_EXCEPTION_IF_NULL(device_context);
929   MS_EXCEPTION_IF_NULL(context);
930 
931   auto tensor = node_value->cast<TensorPtr>();
932   MS_EXCEPTION_IF_NULL(tensor);
933   if (tensor->is_forward_output()) {
934     return;
935   }
936 
937   if (!first_step_) {
938     return;
939   }
940 
941   const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
942   MS_EXCEPTION_IF_NULL(device_tensor);
943   // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
944   if (device_tensor->IsPtrValid()) {
945     CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
946     return;
947   }
948   MS_LOG(INFO) << "Prepare device data for value node: " << node->DebugString() << ", output index: " << 0
949                << " device address:" << device_tensor;
950   tensor->set_device_address(device_tensor);
951   UpdateRefCount(device_tensor.get(), true);
952 
953   SyncTensorData(tensor, device_tensor, node, device_context, context, real_strategy_);
954   MS_LOG(DEBUG) << "Prepare device data for value node: " << node->DebugString() << ", output index: " << 0
955                 << " device address:" << device_tensor << " ptr:" << device_tensor->GetPtr();
956   CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
957 }
958 
PrepareDataForControlValueNode(const KernelWithIndex & node_with_index,const DeviceContext * device_context,OpContext<DeviceTensor> * const context,const ControlNodeParserPtr & parser) const959 void DataPrepareActor::PrepareDataForControlValueNode(const KernelWithIndex &node_with_index,
960                                                       const DeviceContext *device_context,
961                                                       OpContext<DeviceTensor> *const context,
962                                                       const ControlNodeParserPtr &parser) const {
963   MS_EXCEPTION_IF_NULL(device_context);
964   MS_EXCEPTION_IF_NULL(context);
965   MS_EXCEPTION_IF_NULL(node_with_index.first);
966   MS_EXCEPTION_IF_NULL(parser);
967   if (!node_with_index.first->isa<ValueNode>()) {
968     return;
969   }
970 
971   const auto &node = node_with_index.first->cast<ValueNodePtr>();
972   MS_EXCEPTION_IF_NULL(node);
973   size_t index = node_with_index.second;
974   MS_LOG(DEBUG) << "Prepare data for control value node:" << node->DebugString() << " index:" << index;
975   auto node_value = node->value();
976   if (common::AnfAlgo::IsDynamicSequence(node)) {
977     auto tensor = AnfAlgo::SequenceToTensor(node_value);
978     parser->AddControlNodeTensor(tensor);
979     node_value = tensor;
980     AnfAlgo::UpdateValueNodeShape(node);
981   }
982   MS_EXCEPTION_IF_NULL(node_value);
983   std::vector<ValuePtr> values;
984   ValueTupleToValue(node_value, &values);
985 
986   if (node_with_index.second >= values.size()) {
987     MS_LOG(INFO) << "Invalid index:" << node_with_index.second << " for node:" << node->DebugString();
988     return;
989   }
990   const auto &value = values[index];
991   MS_EXCEPTION_IF_NULL(value);
992   TensorPtr tensor = nullptr;
993   if (value->isa<StringImm>()) {
994     PrepareDataForStringValue(node, index, node, device_context, context);
995     return;
996   } else if (!value->isa<tensor::Tensor>()) {
997     tensor = parser->CreateTensorForValue(value);
998   } else {
999     tensor = value->cast<tensor::TensorPtr>();
1000   }
1001 
1002   MS_EXCEPTION_IF_NULL(tensor);
1003   const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, index, false);
1004   MS_EXCEPTION_IF_NULL(device_tensor);
1005   if (device_tensor->GetPtr() != nullptr) {
1006     return;
1007   }
1008 
1009   tensor->set_device_address(device_tensor);
1010   UpdateRefCount(device_tensor.get(), true);
1011 
1012   device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->DebugString(), device::AllocatorType::kConstantValue, 0);
1013   auto graph_str = (node->func_graph() == nullptr) ? "" : node->func_graph()->ToString();
1014   UpdateTracker("PrepareDataForControlValueNode", node, graph_str, device::tracker::MemType::kConstantValue,
1015                 device_tensor);
1016   if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex)) {
1017     SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context, node->fullname_with_scope(),
1018                                                 device_tensor->GetSize());
1019   }
1020   if (common::IsNeedProfileMemory()) {
1021     auto output_address = reinterpret_cast<uintptr_t>(device_tensor.get());
1022     MS_LOG(WARNING) << "Need Profile Memory, alloc type: PrepareDataForControlValueNode, device address class ptr: "
1023                     << output_address << ", node: " << node->fullname_with_scope() << ", graph: " << graph_str
1024                     << ", device address size: " << device_tensor->GetSize()
1025                     << ", device address addr: " << device_tensor->GetPtr();
1026   }
1027 
1028   if (tensor->data_ptr() == nullptr && device_tensor->GetSize() == 0) {
1029     MS_LOG(INFO) << "Empty tuple sync";
1030     return;
1031   }
1032 
1033   auto host_tensor_size = LongToSize(tensor->data().nbytes());
1034   auto host_tensor_type = tensor->data_type();
1035   auto shape = tensor->shape();
1036   if (!device_tensor->SyncHostToDevice(shape, host_tensor_size, host_tensor_type, tensor->device_info().host_format_,
1037                                        tensor->data_ptr())) {
1038     std::string error_info = "Sync host to device failed for node:" + node->DebugString();
1039     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
1040   }
1041 }
1042 
PrepareDataForStringValue(const ValueNodePtr & node,size_t index,const AnfNodePtr & front_node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context) const1043 void DataPrepareActor::PrepareDataForStringValue(const ValueNodePtr &node, size_t index, const AnfNodePtr &front_node,
1044                                                  const DeviceContext *device_context,
1045                                                  OpContext<DeviceTensor> *const context) const {
1046   MS_EXCEPTION_IF_NULL(node);
1047   if (!IsValueNode<StringImm>(node)) {
1048     return;
1049   }
1050   auto &node_value = node->value();
1051   MS_EXCEPTION_IF_NULL(node_value);
1052 
1053   const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, index, false);
1054   MS_EXCEPTION_IF_NULL(device_tensor);
1055   // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
1056   if (device_tensor->GetPtr() != nullptr) {
1057     if (first_step_) {
1058       CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
1059     }
1060     return;
1061   }
1062   MS_LOG(INFO) << "Prepare device data for value node: " << node->DebugString();
1063 
1064   device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), device::AllocatorType::kConstantValue,
1065                                                      0);
1066   auto graph_str = (node->func_graph() == nullptr) ? "" : node->func_graph()->ToString();
1067   UpdateTracker("PrepareDataForStringValue", node, graph_str, device::tracker::MemType::kConstantValue, device_tensor);
1068   if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex)) {
1069     SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context, node->fullname_with_scope(),
1070                                                 device_tensor->GetSize());
1071   }
1072   if (common::IsNeedProfileMemory()) {
1073     auto output_address = reinterpret_cast<uintptr_t>(device_tensor.get());
1074     MS_LOG(WARNING) << "Need Profile Memory, alloc type: PrepareDataForValueNode, device address class ptr: "
1075                     << output_address << ", device address size: " << device_tensor->GetSize()
1076                     << ", node: " << node->fullname_with_scope() << ", graph: " << graph_str
1077                     << ", device address addr: " << device_tensor->GetPtr();
1078   }
1079 
1080   // Copy data from value to device.
1081   auto value = GetValue<std::string>(node_value);
1082   size_t tensor_size = value.size();
1083   ShapeVector shape = {1, SizeToLong(tensor_size)};
1084   // account '\0' to string size, keep consistent with method `CreateDeviceAddressForScalarAndString` defined in
1085   // `device_address_utils.cc`
1086   size_t string_tensor_size = tensor_size + 1;
1087   if (!device_tensor->SyncHostToDevice(shape, string_tensor_size, kObjectTypeString, value.data())) {
1088     std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope();
1089     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1090   }
1091   CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
1092 }
1093 
PrepareDataForSequenceAndScalarValue(const ValueNodePtr & node,size_t index,const AnfNodePtr & front_node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context) const1094 void DataPrepareActor::PrepareDataForSequenceAndScalarValue(const ValueNodePtr &node, size_t index,
1095                                                             const AnfNodePtr &front_node,
1096                                                             const DeviceContext *device_context,
1097                                                             OpContext<DeviceTensor> *const context) const {
1098   if (!first_step_) {
1099     return;
1100   }
1101   MS_EXCEPTION_IF_NULL(node);
1102   MS_EXCEPTION_IF_NULL(device_context);
1103   MS_EXCEPTION_IF_NULL(context);
1104   auto &node_value = node->value();
1105   MS_EXCEPTION_IF_NULL(node_value);
1106 
1107   if ((!node_value->isa<ValueSequence>()) && (!node_value->isa<Scalar>())) {
1108     return;
1109   }
1110 
1111   if (node_value->isa<ValueSequence>() && node_value->cast<ValueSequencePtr>()->size() == 0) {
1112     return;
1113   }
1114 
1115   const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, index, false);
1116   MS_EXCEPTION_IF_NULL(device_tensor);
1117   // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
1118   if (device_tensor->GetPtr() != nullptr) {
1119     CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
1120     return;
1121   }
1122 
1123   UpdateRefCount(device_tensor.get(), true);
1124   MS_LOG(INFO) << "Prepare device data for value node: " << node->DebugString();
1125   device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), device::AllocatorType::kConstantValue,
1126                                                      0);
1127   // 1. Allocate device memory for value node.
1128   auto graph_str = (node->func_graph() == nullptr) ? "" : node->func_graph()->ToString();
1129   UpdateTracker("PrepareDataForSequenceAndScalarValue", node, graph_str, device::tracker::MemType::kConstantValue,
1130                 device_tensor);
1131   if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex)) {
1132     SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context, node->fullname_with_scope(),
1133                                                 device_tensor->GetSize());
1134   }
1135   if (common::IsNeedProfileMemory()) {
1136     auto output_address = reinterpret_cast<uintptr_t>(device_tensor.get());
1137     MS_LOG(WARNING) << "Need Profile Memory, alloc type: PrepareDataForValueNode, device address class ptr: "
1138                     << output_address << ", device address size: " << device_tensor->GetSize()
1139                     << ", node: " << node->fullname_with_scope() << ", graph: " << graph_str
1140                     << ", device address addr: " << device_tensor->GetPtr();
1141   }
1142 
1143   // 2. Sync copy data from host to device.
1144   const auto &kernel_tensor = device_tensor->kernel_tensor();
1145   MS_EXCEPTION_IF_NULL(kernel_tensor);
1146   if (!device_tensor->SyncHostToDevice(kernel_tensor->GetShapeVector(), kernel_tensor->size(),
1147                                        kernel_tensor->dtype_id(), kernel_tensor->GetValuePtr())) {
1148     std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope();
1149     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1150   }
1151 
1152   // 3. Handle heterogeneous scene.
1153   CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
1154 }
1155 
1156 // Prepare the device data for persistent device tensor of value node.
PrepareDataForValueNode(const ValueNodePtr & node,const AnfNodePtr & front_node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context) const1157 void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const AnfNodePtr &front_node,
1158                                                const DeviceContext *device_context,
1159                                                OpContext<DeviceTensor> *const context) const {
1160   MS_EXCEPTION_IF_NULL(node);
1161   MS_EXCEPTION_IF_NULL(front_node);
1162   MS_EXCEPTION_IF_NULL(device_context);
1163   MS_EXCEPTION_IF_NULL(context);
1164   auto &node_value = node->value();
1165   MS_EXCEPTION_IF_NULL(node_value);
1166   MS_LOG(DEBUG) << "Prepare data for value node:" << node->DebugString() << " front node:" << front_node->DebugString();
1167   if (node_value->isa<tensor::Tensor>()) {
1168     PrepareDataForValueNodeTensor(node, node_value, front_node, device_context, context);
1169   } else if (node_value->isa<ValueSequence>() || node_value->isa<Scalar>()) {
1170     PrepareDataForSequenceAndScalarValue(node, 0, front_node, device_context, context);
1171   } else if (node_value->isa<StringImm>()) {
1172     PrepareDataForStringValue(node, 0, front_node, device_context, context);
1173   } else if (node_value->isa<None>() || node_value->isa<Type>()) {
1174     MS_LOG(DEBUG) << "No need to prepare data for None or type value node:" << node->DebugString();
1175   } else {
1176     MS_LOG(WARNING) << "Not support the value type: " << node->fullname_with_scope();
1177   }
1178 }
1179 
CopyDataFromDeviceTensorStore(const AnfNodePtr & front_node,const AnfNodePtr & backend_node,const device::DeviceAddressPtr & host_tensor_address,const DeviceContext * device_context,OpContext<DeviceTensor> * context) const1180 void DataPrepareActor::CopyDataFromDeviceTensorStore(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
1181                                                      const device::DeviceAddressPtr &host_tensor_address,
1182                                                      const DeviceContext *device_context,
1183                                                      OpContext<DeviceTensor> *context) const {
1184   MS_EXCEPTION_IF_NULL(backend_node);
1185   MS_EXCEPTION_IF_NULL(device_context);
1186   MS_EXCEPTION_IF_NULL(context);
1187   const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
1188   for (auto &another_device_tensor : device_tensors) {
1189     if (another_device_tensor == host_tensor_address) {
1190       continue;
1191     }
1192     MS_EXCEPTION_IF_NULL(another_device_tensor);
1193     auto another_device_name = device::GetDeviceNameByType(another_device_tensor->GetDeviceType());
1194     const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1195       {another_device_name, device_context->device_context_key().device_id_});
1196     MS_EXCEPTION_IF_NULL(another_device_context);
1197     auto type = backend_node->isa<ValueNode>() ? device::AllocatorType::kConstantValue : device::AllocatorType::kWeight;
1198     device::DynamicMemAllocatorDebugInfo::SetDebugInfo(backend_node->fullname_with_scope(), type, 0);
1199     bool need_alloc_memory = (another_device_tensor->GetPtr() == nullptr);
1200     auto graph_str = (backend_node->func_graph() == nullptr) ? "" : backend_node->func_graph()->ToString();
1201     if (need_alloc_memory) {
1202       auto mem_type =
1203         backend_node->isa<ValueNode>() ? device::tracker::MemType::kConstantValue : device::tracker::MemType::kWeight;
1204       UpdateTracker("CopyDataFromDeviceTensorStore", backend_node, graph_str, mem_type, another_device_tensor);
1205     }
1206     if (need_alloc_memory && (!another_device_context->device_res_manager_->AllocateMemory(another_device_tensor.get(),
1207                                                                                            kDefaultStreamIndex))) {
1208       SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *another_device_context,
1209                                                   backend_node->fullname_with_scope(),
1210                                                   another_device_tensor->GetSize());
1211     }
1212     if (common::IsNeedProfileMemory() && need_alloc_memory) {
1213       auto output_address = reinterpret_cast<uintptr_t>(another_device_tensor.get());
1214       MS_LOG(WARNING) << "Need Profile Memory, alloc type: CopyDataFromDeviceTensorStore, device address class ptr: "
1215                       << output_address << ", device address size: " << another_device_tensor->GetSize()
1216                       << ", device address addr: " << another_device_tensor->GetPtr()
1217                       << ", node: " << backend_node->fullname_with_scope() << ", graph: " << graph_str
1218                       << ", frontnode: " << (front_node == nullptr ? "null" : front_node->DebugString());
1219     }
1220 
1221     MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
1222                  << ", device name:" << another_device_name << " from device address:" << host_tensor_address
1223                  << " to:" << another_device_tensor;
1224     if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
1225       std::string error_info = "Sync data error.";
1226       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1227     }
1228   }
1229 }
1230 
1231 // Prepare the device data for persistent device tensor of weight node from host tensor.
PrepareDataForWeightNode(const AnfNodePtr & backend_node,const AnfNodePtr & front_node,const TensorPtr & tensor,const DeviceContext * device_context,OpContext<DeviceTensor> * const context)1232 void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node,
1233                                                 const TensorPtr &tensor, const DeviceContext *device_context,
1234                                                 OpContext<DeviceTensor> *const context) {
1235   MS_EXCEPTION_IF_NULL(backend_node);
1236   MS_EXCEPTION_IF_NULL(front_node);
1237   MS_EXCEPTION_IF_NULL(device_context);
1238   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
1239   MS_EXCEPTION_IF_NULL(context);
1240   auto param_node = backend_node->cast<ParameterPtr>();
1241   if (param_node != nullptr) {
1242     auto param_info = param_node->param_info();
1243     bool used = !param_info->ignore_device_addr();
1244     if (!used) {
1245       MS_LOG(DEBUG) << backend_node->DebugString()
1246                     << " the Parameter is never used by real kernel in graphs, skip to allocate.";
1247       return;
1248     }
1249   }
1250   if (tensor == nullptr) {
1251     return;
1252   }
1253 
1254   auto device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
1255   MS_EXCEPTION_IF_NULL(device_tensor);
1256   auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
1257   // Use the device address of host tensor to set device tensor.
1258   bool is_need_sync = IsNeedSync(tensor);
1259   if (host_tensor_address != device_tensor) {
1260     if (host_tensor_address == nullptr) {
1261       if (device_tensor->GetDeviceType() != device_context->GetDeviceType()) {
1262         const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
1263           {backend_node, 0}, nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id(),
1264           device_tensor->host_shape(), device_context->device_context_key().device_name_,
1265           device_context->device_context_key().device_id_);
1266         kernel_tensor->set_stream_id(device_tensor->stream_id());
1267         host_tensor_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1268         MS_EXCEPTION_IF_NULL(host_tensor_address);
1269         MS_LOG(DEBUG) << "Create device tensor:" << host_tensor_address << " type:" << host_tensor_address->type_id();
1270         host_tensor_address->set_from_persistent_mem(tensor->is_parameter());
1271       } else {
1272         host_tensor_address = device_tensor;
1273       }
1274       is_need_sync = true;
1275       tensor->set_device_address(host_tensor_address);
1276       UpdateRefCount(host_tensor_address.get(), true);
1277     }
1278     MS_EXCEPTION_IF_NULL(host_tensor_address);
1279 
1280     if (host_tensor_address->GetDeviceType() != device_tensor->GetDeviceType()) {
1281       MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->GetDeviceType()
1282                    << ", device tensor type:" << device_tensor->GetDeviceType();
1283       // The fake heterogeneous scenario.
1284       if (DeviceTensorStore::GetInstance().Fetch(front_node.get()).size() == 1) {
1285         tensor->data_sync();
1286         host_tensor_address = device_tensor;
1287         tensor->set_device_address(device_tensor);
1288         is_need_sync = true;
1289       }
1290     } else if (host_tensor_address != device_tensor) {
1291       // In the scenario of training + inference , the device address of the weight node can not be changed when
1292       // multi-graphs sink mode is set.
1293       if (device_tensor->is_ptr_persisted() ||
1294           !AnfAlgo::IsEquivalentFormat(host_tensor_address->format(), device_tensor->format())) {
1295         if ((device_tensor->GetPtr() == nullptr) &&
1296             (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex))) {
1297           SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context,
1298                                                       backend_node->fullname_with_scope(), device_tensor->GetSize());
1299         }
1300         if (!Copy(device_tensor.get(), host_tensor_address.get())) {
1301           std::string error_info = "Sync data error.";
1302           SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1303         }
1304         host_tensor_address = device_tensor;
1305         tensor->set_device_address(device_tensor);
1306       } else {
1307         (void)address_modified_input_nodes_.insert(backend_node.get());
1308         host_tensor_address->set_flag(device_tensor->flag());
1309         DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(host_tensor_address, backend_node, 0);
1310         AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
1311       }
1312     }
1313   }
1314   // Maybe the same host_tensor_address corresponds to the different front_node in shared weight scene,
1315   // so need update the device tensor store always.
1316   MS_EXCEPTION_IF_NULL(host_tensor_address);
1317   host_tensor_address->SetNodeIndex(backend_node, 0);
1318   DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
1319 
1320   // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
1321   if (is_need_sync || (!host_tensor_address->IsPtrValid())) {
1322     MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->DebugString()
1323                  << ", device type:" << host_tensor_address->GetDeviceType();
1324     SyncTensorData(tensor, host_tensor_address, backend_node, device_context, context, real_strategy_);
1325   }
1326 
1327   // Allocate another device memory and copy data from host tensor to another device(if exist).
1328   CopyDataFromDeviceTensorStore(front_node, backend_node, host_tensor_address, device_context, context);
1329 }
1330 
PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr & control_node_parser,const std::vector<TensorPtr> & tensors,const VectorRef & args,OpContext<DeviceTensor> * const context) const1331 void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr &control_node_parser,
1332                                                               const std::vector<TensorPtr> &tensors,
1333                                                               const VectorRef &args,
1334                                                               OpContext<DeviceTensor> *const context) const {
1335   MS_EXCEPTION_IF_NULL(control_node_parser);
1336   if (!control_node_parser->IsInited()) {
1337     return;
1338   }
1339 
1340   for (const auto &value_node_with_context : control_node_parser->front_value_nodes()) {
1341     MS_EXCEPTION_IF_NULL(value_node_with_context.first.first);
1342     if (value_node_with_context.first.first->kernel_info() != nullptr &&
1343         AnfAlgo::OutputAddrExist(value_node_with_context.first.first, 0)) {
1344       PrepareDataForControlValueNode(value_node_with_context.first, value_node_with_context.second, context,
1345                                      control_node_parser);
1346     }
1347   }
1348 
1349   const auto &control_node_parameters = control_node_parser->control_node_parameters();
1350   if (!tensors.empty() && control_node_parameters.size() != tensors.size()) {
1351     SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Invalid tensor size.");
1352   }
1353   for (size_t i = 0; i < control_node_parameters.size(); ++i) {
1354     auto &front_parameter = control_node_parameters[i].first;
1355     MS_EXCEPTION_IF_NULL(front_parameter);
1356     if (!control_node_parser->IsRootGraphPersistentDeviceTensor(front_parameter)) {
1357       continue;
1358     }
1359 
1360     TensorPtr tensor = FetchInputTensor(tensors, i, args, control_node_parameters[i]);
1361     if (tensor == nullptr) {
1362       continue;
1363     }
1364 
1365     auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_parameter.get());
1366     if (device_tensors.empty()) {
1367       MS_LOG(WARNING) << "Failed to get device tensor for front node:" << front_parameter->DebugString();
1368       continue;
1369     }
1370     MS_EXCEPTION_IF_NULL(device_tensors[0]);
1371     auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
1372     if ((device_tensors[0] == host_tensor_address) || (device_tensors[0]->IsPtrValid())) {
1373       continue;
1374     }
1375 
1376     auto node = (device_tensors[0]->GetNodeIndex()).first;
1377     MS_EXCEPTION_IF_NULL(node);
1378     MS_LOG(INFO) << "Prepare device data for weight node by root graph parameter:"
1379                  << front_parameter->fullname_with_scope() << ", backend node:" << node->DebugString()
1380                  << ", device type:" << device_tensors[0]->GetDeviceType();
1381     if (host_tensor_address == nullptr) {
1382       tensor->set_device_address(device_tensors[0]);
1383       auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1384         {device_tensors[0]->device_name(), device_tensors[0]->device_id()});
1385       SyncTensorData(tensor, device_tensors[0], node, device_context, context, GraphExecutionStrategy::kPipeline);
1386     } else {
1387       if (host_tensor_address->GetSize() != device_tensors[0]->GetSize()) {
1388         MS_LOG(WARNING) << "Please check the size of parameter:" << front_parameter->fullname_with_scope()
1389                         << ", host tensor size:" << host_tensor_address->GetSize()
1390                         << ", device tensor size:" << device_tensors[0]->GetSize();
1391       }
1392       host_tensor_address->SetNodeIndex(node, 0);
1393       UpdateRefCount(host_tensor_address.get(), true);
1394       DeviceTensorStore::GetInstance().Remove(front_parameter.get());
1395       DeviceTensorStore::GetInstance().Insert(front_parameter.get(), host_tensor_address);
1396     }
1397   }
1398 }
1399 
PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> & tensors,std::vector<TensorPtr> * const host_tensors,OpContext<DeviceTensor> * const context)1400 void DataPrepareActor::PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> &tensors,
1401                                                             std::vector<TensorPtr> *const host_tensors,
1402                                                             OpContext<DeviceTensor> *const context) {
1403   MS_EXCEPTION_IF_NULL(graph_compiler_info_);
1404   MS_EXCEPTION_IF_NULL(graph_compiler_info_->control_node_parser_);
1405   MS_EXCEPTION_IF_NULL(host_data_source_actor_);
1406   MS_EXCEPTION_IF_NULL(host_tensors);
1407 
1408   const auto &control_node_parameters = graph_compiler_info_->control_node_parser_->control_node_parameters();
1409   for (size_t i = 0; i < control_node_parameters.size(); ++i) {
1410     const auto &input_node = control_node_parameters[i].first;
1411     const auto &input_tensor = tensors[i];
1412     MS_EXCEPTION_IF_NULL(input_node);
1413     if (IsPersistentDeviceTensor(input_node)) {
1414       continue;
1415     }
1416 
1417     if (find(graph_compiler_info_->origin_parameters_order_.begin(),
1418              graph_compiler_info_->origin_parameters_order_.end(),
1419              input_node) == graph_compiler_info_->origin_parameters_order_.end()) {
1420       continue;
1421     }
1422 
1423     auto tensor_position = host_data_source_actor_->FetchNodePosition(control_node_parameters[i]);
1424     if (tensor_position >= host_tensors->size()) {
1425       std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
1426       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1427     }
1428     if ((*host_tensors)[tensor_position] != nullptr) {
1429       continue;
1430     }
1431     MS_LOG(DEBUG) << "Set tensor position:" << tensor_position << " for input data.";
1432     (*host_tensors)[tensor_position] = input_tensor;
1433 
1434     UpdateDynamicShapeAndSize(input_node, input_tensor);
1435     // Avoid the device `ptr_` being hold by the input tensor and the output tensor, the input tensor address cannot
1436     // be directly set to the input control node, which may be a passthrough node. The device 'ptr_' is re-malloced
1437     // and device to device copy by input tensor address in data source process.
1438   }
1439 }
1440 
PreprocessBeforePrepareData() const1441 void DataPrepareActor::PreprocessBeforePrepareData() const {
1442   // Embedding Cache mode needs to record the number of global steps executed by the compute graph.
1443   // The first step compute graph needs to wait for the Embedding cache prefetch cache to warm up to prevent the
1444   // GetNext operator from timing out in the compute graph.
1445 #if defined(__linux__) && defined(WITH_BACKEND)
1446   EmbeddingCacheScheduler::GetInstance().IncreaseGraphStep(GetAID());
1447 #endif
1448 
1449   // Try to defrag memory.
1450   auto defrag_memory_step_freq = GetDefragMemoryStepFreq();
1451   if (++execution_count_ % defrag_memory_step_freq == 0) {
1452     std::set<const DeviceContext *> defrag_memory_contexts;
1453     for (auto &device_context : graph_compiler_info_->device_contexts_) {
1454       MS_EXCEPTION_IF_NULL(device_context);
1455       if ((defrag_memory_contexts.count(device_context) == 0)) {
1456         device_context->device_res_manager_->DefragMemory();
1457       }
1458       (void)defrag_memory_contexts.insert(device_context);
1459     }
1460   }
1461 }
1462 }  // namespace runtime
1463 }  // namespace mindspore
1464