• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/framework/actor/data_prepare_actor.h"
18 #include "runtime/framework/actor/memory_manager_actor.h"
19 #include "runtime/framework/actor/kernel_actor.h"
20 #include "runtime/framework/actor/loop_count_actor.h"
21 #include "runtime/framework/actor/debug_actor.h"
22 #include "runtime/hardware/device_context_manager.h"
23 #include "mindrt/include/async/async.h"
24 #include "utils/log_adapter.h"
25 #include "utils/convert_utils.h"
26 #include "common/trans.h"
27 
28 namespace mindspore {
29 namespace runtime {
30 namespace {
SyncTensorData(const TensorPtr & host_tensor,const DeviceTensorPtr & device_tensor,const AnfNodePtr & node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context,GraphExecutionStrategy strategy)31 void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
32                     const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
33                     GraphExecutionStrategy strategy) {
34   MS_EXCEPTION_IF_NULL(host_tensor);
35   MS_EXCEPTION_IF_NULL(device_tensor);
36   MS_EXCEPTION_IF_NULL(node);
37   MS_EXCEPTION_IF_NULL(device_context);
38   MS_EXCEPTION_IF_NULL(context);
39 
40   if ((device_tensor->GetPtr() == nullptr) &&
41       (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
42     SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy, *context, *device_context, node->fullname_with_scope(),
43                                                 device_tensor->GetSize());
44   }
45 
46   // Copy data from host tensor to device.
47   auto host_tensor_size = LongToSize(host_tensor->data().nbytes());
48   auto host_tensor_type = host_tensor->data_type();
49   if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), host_tensor_size, host_tensor_type,
50                                        host_tensor->data_c(), host_tensor->device_info().host_format_)) {
51     std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope() +
52                              ", host tensor size: " + std::to_string(host_tensor_size) +
53                              ", host tensor type: " + std::to_string(static_cast<int>(host_tensor_type)) +
54                              ", device tensor size: " + std::to_string(device_tensor->GetSize());
55     SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy, (*context), error_info);
56   }
57 }
58 
FetchContinuousMemoryInfo(const CNodePtr & node,std::vector<DeviceTensorPtr> * const addr_list,std::vector<size_t> * const size_list,size_t * const total_size,bool is_input)59 void FetchContinuousMemoryInfo(const CNodePtr &node, std::vector<DeviceTensorPtr> *const addr_list,
60                                std::vector<size_t> *const size_list, size_t *const total_size, bool is_input) {
61   MS_EXCEPTION_IF_NULL(node);
62   MS_EXCEPTION_IF_NULL(addr_list);
63   MS_EXCEPTION_IF_NULL(size_list);
64   MS_EXCEPTION_IF_NULL(total_size);
65 
66   const auto &kernel_mod = AnfAlgo::GetKernelMod(node);
67   MS_EXCEPTION_IF_NULL(kernel_mod);
68   (*addr_list).clear();
69   (*size_list).clear();
70   *total_size = 0;
71 
72   if (is_input) {
73     const auto &intput_sizes = kernel_mod->GetInputSizeList();
74     for (size_t i = 0; i < intput_sizes.size(); ++i) {
75       const auto &device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, false);
76       MS_EXCEPTION_IF_NULL(device_tensor);
77       *total_size += intput_sizes[i];
78       (void)size_list->emplace_back(intput_sizes[i]);
79       (void)addr_list->emplace_back(device_tensor);
80     }
81   } else {
82     const auto &output_sizes = kernel_mod->GetOutputSizeList();
83     for (size_t i = 0; i < output_sizes.size(); ++i) {
84       const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
85       MS_EXCEPTION_IF_NULL(device_tensor);
86       *total_size += output_sizes[i];
87       (void)size_list->emplace_back(output_sizes[i]);
88       (void)addr_list->emplace_back(device_tensor);
89     }
90   }
91 }
92 }  // namespace
Init()93 void DataPrepareActor::Init() {
94   MS_EXCEPTION_IF_NULL(graph_compiler_info_);
95   strategy_ = graph_compiler_info_->strategy_;
96   if (graph_compiler_info_->graphs_.size() != graph_compiler_info_->device_contexts_.size()) {
97     MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
98   }
99 
100   for (auto &iter : continuous_memory_nodes_) {
101     size_t total_size = 0;
102     std::vector<size_t> size_list;
103     std::vector<DeviceTensorPtr> addr_list;
104     // Inputs need continuous memory.
105     if (iter.second.first == true) {
106       FetchContinuousMemoryInfo(iter.first.first, &addr_list, &size_list, &total_size, true);
107       (void)continuous_memory_alloc_list_list_.emplace_back(addr_list);
108       (void)size_list_list_.emplace_back(size_list);
109       (void)total_size_list_.emplace_back(total_size);
110       (void)continuous_memory_device_contexts_.emplace_back(iter.first.second);
111     }
112 
113     // Outputs need continuous memory.
114     if (iter.second.second == true) {
115       FetchContinuousMemoryInfo(iter.first.first, &addr_list, &size_list, &total_size, false);
116       (void)continuous_memory_alloc_list_list_.emplace_back(addr_list);
117       (void)size_list_list_.emplace_back(size_list);
118       (void)total_size_list_.emplace_back(total_size);
119       (void)continuous_memory_device_contexts_.emplace_back(iter.first.second);
120     }
121   }
122 }
123 
PrepareData(const std::vector<std::vector<TensorPtr>> & input_tensors,OpContext<DeviceTensor> * const context)124 void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &input_tensors,
125                                    OpContext<DeviceTensor> *const context) {
126   MS_EXCEPTION_IF_NULL(context);
127 
128   // Convert actor running data from input tensors.
129   if (input_tensors.size() > 0) {
130     PrepareDataForDeviceTensorStore(input_tensors, context);
131     if (strategy_ == GraphExecutionStrategy::kPipeline) {
132       PrepareDataForHostTensorQueue(input_tensors, context);
133     } else if (strategy_ == GraphExecutionStrategy::kStep) {
134       PrepareDataForStepMode(input_tensors, context);
135     }
136 
137     // Debug actor is blocked, must wait debug actor callback message to process continue.
138     if (debug_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
139       SendDebugReq(context);
140       return;
141     }
142   }
143 
144   // Allocate continuous memory and send output to trigger the step running.
145   if (continuous_memory_alloc_list_list_.size() > 0) {
146     SendMemoryAllocReq(context);
147   } else {
148     SendOutput(context);
149   }
150 }
151 
SendDebugReq(OpContext<DeviceTensor> * const context)152 void DataPrepareActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
153   Async(*debug_aid_, &DebugActor::DebugOnStepBegin, graph_compiler_info_->graphs_,
154         graph_compiler_info_->device_contexts_, context, &GetAID());
155 }
156 
OnDebugFinish(OpContext<DeviceTensor> * const context)157 void DataPrepareActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
158   MS_EXCEPTION_IF_NULL(context);
159   if (continuous_memory_alloc_list_list_.size() > 0) {
160     SendMemoryAllocReq(context);
161   } else {
162     SendOutput(context);
163   }
164 }
165 
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)166 void DataPrepareActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
167   // Allocate continuous memory in the begin of the step running.
168   Async(memory_manager_aid_, &MemoryManagerActor::AllocateContinuousMemory, &continuous_memory_alloc_list_list_,
169         &size_list_list_, &total_size_list_, &continuous_memory_device_contexts_, context, GetAID());
170 }
171 
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)172 void DataPrepareActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
173   MS_EXCEPTION_IF_NULL(context);
174   SendOutput(context);
175 }
176 
SendOutput(OpContext<DeviceTensor> * const context)177 void DataPrepareActor::SendOutput(OpContext<DeviceTensor> *const context) {
178   for (auto &data_source_aid : data_source_aids_) {
179     Async(data_source_aid, &DataSourceActor::FetchData, context);
180   }
181 
182   auto source_aid = const_cast<AID *>(&GetAID());
183   for (auto &kernel_aid : no_input_kernel_aids_) {
184     Async(kernel_aid, &KernelActor::RunOpControl, source_aid, context);
185   }
186 
187   // Trigger loop count actor running when there are no data source actor and kernel actor.
188   if ((data_source_aids_.size() + no_input_kernel_aids_.size() == 0) && (loop_count_aid_ != nullptr)) {
189     Async(*loop_count_aid_, &LoopCountActor::RunOpControl, source_aid, context);
190   }
191 }
192 
PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> & input_tensors,OpContext<DeviceTensor> * const context)193 void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
194                                                        OpContext<DeviceTensor> *const context) {
195   for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
196     const auto &graph = graph_compiler_info_->graphs_[i];
197     const auto &device_context = graph_compiler_info_->device_contexts_[i];
198     MS_EXCEPTION_IF_NULL(graph);
199     // Prepare the data of device tensor store(value nodes of graph).
200     for (const auto &value_node : graph->graph_value_nodes()) {
201       if (AnfAlgo::OutputAddrExist(value_node, 0)) {
202         PrepareDataForValueNode(value_node, device_context, context);
203       }
204     }
205 
206     // Prepare the data of device tensor store(weights of graph).
207     const auto &input_nodes = graph->input_nodes();
208     const auto &tensors = input_tensors[i];
209     for (size_t j = 0; j < input_nodes.size(); ++j) {
210       const auto &input_node = input_nodes[j];
211       const auto &input_tensor = tensors[j];
212       MS_EXCEPTION_IF_NULL(input_node);
213       if (!IsPersistentDeviceTensor(input_node)) {
214         continue;
215       }
216       const auto front_node = FetchFrontNodeByBackendNode(input_node, graph);
217       PrepareDataForWeightNode(input_node, front_node, input_tensor, device_context, context);
218     }
219   }
220 
221   PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);
222 }
223 
PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> & input_tensors,OpContext<DeviceTensor> * const context)224 void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors,
225                                                      OpContext<DeviceTensor> *const context) {
226   MS_EXCEPTION_IF_NULL(context);
227   if ((host_data_source_actor_ == nullptr) || (host_tensor_queue_ == nullptr)) {
228     return;
229   }
230 
231   std::vector<TensorPtr> host_tensors;
232   host_tensors.resize(host_data_source_actor_->data_nodes().size());
233   // Fill host tensors.
234   for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
235     const auto &graph = graph_compiler_info_->graphs_[i];
236     MS_EXCEPTION_IF_NULL(graph);
237 
238     const auto &input_nodes = graph->input_nodes();
239     const auto &tensors = input_tensors[i];
240     for (size_t j = 0; j < input_nodes.size(); ++j) {
241       const auto &input_node = input_nodes[j];
242       const auto &input_tensor = tensors[j];
243       MS_EXCEPTION_IF_NULL(input_node);
244       if (!IsHostQueueDSActor(input_node, graph, graph_compiler_info_->origin_parameters_order_, strategy_)) {
245         continue;
246       }
247       auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
248       if (tensor_position >= host_tensors.size()) {
249         std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
250         SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
251       }
252       host_tensors[tensor_position] = input_tensor;
253 
254       auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
255       auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
256       MS_EXCEPTION_IF_NULL(device_address);
257       if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
258         AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
259       }
260     }
261   }
262 
263   PrepareHostTensorQueueForControlNode(input_tensors.back(), &host_tensors, context);
264 
265   host_tensor_queue_->Push(host_tensors);
266 }
267 
PrepareDataForStepMode(const std::vector<std::vector<TensorPtr>> & input_tensors,OpContext<DeviceTensor> * const context)268 void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<TensorPtr>> &input_tensors,
269                                               OpContext<DeviceTensor> *const context) {
270   MS_EXCEPTION_IF_NULL(context);
271   std::vector<TensorPtr> host_tensors;
272   if ((host_data_source_actor_ != nullptr) && (host_tensor_queue_ != nullptr)) {
273     host_tensors.resize(host_data_source_actor_->data_nodes().size());
274   }
275 
276   for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
277     const auto &graph = graph_compiler_info_->graphs_[i];
278     const auto &device_context = graph_compiler_info_->device_contexts_[i];
279     MS_EXCEPTION_IF_NULL(graph);
280     MS_EXCEPTION_IF_NULL(device_context);
281 
282     const auto &input_nodes = graph->input_nodes();
283     const auto &tensors = input_tensors[i];
284     for (size_t j = 0; j < input_nodes.size(); ++j) {
285       const auto &input_node = input_nodes[j];
286       const auto &input_tensor = tensors[j];
287       MS_EXCEPTION_IF_NULL(input_node);
288       MS_EXCEPTION_IF_NULL(input_tensor);
289       if (IsPersistentDeviceTensor(input_node)) {
290         continue;
291       }
292 
293       if ((host_data_source_actor_ != nullptr) && (host_tensor_queue_ != nullptr)) {
294         auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
295         if (tensor_position >= host_tensors.size()) {
296           std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
297           SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
298         }
299         host_tensors[tensor_position] = input_tensor;
300       }
301 
302       auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
303       if (host_tensor_address != nullptr) {
304         AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get());
305         continue;
306       }
307 
308       if (!AnfAlgo::OutputAddrExist(input_node, 0, false)) {
309         TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
310         if (output_type_id == kTypeUnknown) {
311           output_type_id = AnfAlgo::GetOutputInferDataType(input_node, 0);
312         }
313         size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
314         auto device_address = device_context->CreateDeviceAddress(
315           nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id);
316         AnfAlgo::SetOutputAddr(device_address, 0, input_node.get());
317       }
318       auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
319       input_tensor->set_device_address(device_tensor);
320       UpdateRefCount(device_tensor.get(), true);
321 
322       SyncTensorData(input_tensor, device_tensor, input_node, device_context, context, strategy_);
323     }
324   }
325 
326   if ((host_data_source_actor_ != nullptr) && (host_tensor_queue_ != nullptr)) {
327     host_tensor_queue_->Push(host_tensors);
328   }
329 }
330 
331 //  The branch processing of PrepareDataForValueNode that value type is tensor.
PrepareDataForValueNodeTensor(const ValueNodePtr & node,const ValuePtr & node_value,const DeviceContext * device_context,OpContext<DeviceTensor> * const context)332 void DataPrepareActor::PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
333                                                      const DeviceContext *device_context,
334                                                      OpContext<DeviceTensor> *const context) {
335   MS_EXCEPTION_IF_NULL(node);
336   MS_EXCEPTION_IF_NULL(node_value);
337   MS_EXCEPTION_IF_NULL(device_context);
338   MS_EXCEPTION_IF_NULL(context);
339 
340   std::vector<TensorPtr> tensors;
341   TensorValueToTensor(node_value, &tensors);
342   for (size_t i = 0; i < tensors.size(); i++) {
343     const auto &tensor = tensors[i];
344     if (tensor == nullptr) {
345       MS_LOG(WARNING) << "Tensor is null";
346       return;
347     }
348 
349     const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
350     MS_EXCEPTION_IF_NULL(device_tensor);
351     // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
352     if (device_tensor->GetPtr() != nullptr) {
353       return;
354     }
355     MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope() << ", output index: " << i;
356     tensor->set_device_address(device_tensor);
357     UpdateRefCount(device_tensor.get(), true);
358 
359     SyncTensorData(tensor, device_tensor, node, device_context, context, strategy_);
360   }
361 }
362 
363 // Prepare the device data for persistent device tensor of value node.
PrepareDataForValueNode(const ValueNodePtr & node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context)364 void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const DeviceContext *device_context,
365                                                OpContext<DeviceTensor> *const context) {
366   MS_EXCEPTION_IF_NULL(node);
367   MS_EXCEPTION_IF_NULL(device_context);
368   MS_EXCEPTION_IF_NULL(context);
369   auto &node_value = node->value();
370   MS_EXCEPTION_IF_NULL(node_value);
371 
372   if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
373     //  The branch processing that value type is tensor.
374     PrepareDataForValueNodeTensor(node, node_value, device_context, context);
375   } else if (node_value->isa<StringImm>()) {
376     const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
377     MS_EXCEPTION_IF_NULL(device_tensor);
378     // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
379     if (device_tensor->GetPtr() != nullptr) {
380       return;
381     }
382     MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope();
383 
384     if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
385       SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *device_context, node->fullname_with_scope(),
386                                                   device_tensor->GetSize());
387     }
388 
389     // Copy data from value to device.
390     auto value = GetValue<std::string>(node_value);
391     size_t tensor_size = value.size();
392     ShapeVector shape = {1, SizeToLong(tensor_size)};
393     if (!device_tensor->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
394       std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope();
395       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
396     }
397   }
398 }
399 
400 // Prepare the device data for persistent device tensor of weight node from host tensor.
PrepareDataForWeightNode(const AnfNodePtr & backend_node,const AnfNodePtr & front_node,const TensorPtr & tensor,const DeviceContext * device_context,OpContext<DeviceTensor> * const context)401 void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node,
402                                                 const TensorPtr &tensor, const DeviceContext *device_context,
403                                                 OpContext<DeviceTensor> *const context) {
404   MS_EXCEPTION_IF_NULL(backend_node);
405   MS_EXCEPTION_IF_NULL(front_node);
406   MS_EXCEPTION_IF_NULL(device_context);
407   MS_EXCEPTION_IF_NULL(context);
408   if (tensor == nullptr) {
409     return;
410   }
411 
412   auto device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
413   MS_EXCEPTION_IF_NULL(device_tensor);
414   auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
415   // Use the device address of host tensor to set device tensor.
416   if (host_tensor_address != device_tensor) {
417     if (host_tensor_address == nullptr) {
418       host_tensor_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
419                                                                 device_tensor->format(), device_tensor->type_id());
420       tensor->set_device_address(host_tensor_address);
421       UpdateRefCount(host_tensor_address.get(), true);
422     }
423     MS_EXCEPTION_IF_NULL(host_tensor_address);
424     if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
425       AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
426     } else {
427       MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
428                    << ", device tensor type:" << device_tensor->DeviceType();
429     }
430   }
431   // Maybe the same host_tensor_address corresponds to the different front_node in shared weight scene,
432   // so need update the device tensor store always.
433   DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
434 
435   // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
436   MS_EXCEPTION_IF_NULL(host_tensor_address);
437   if (host_tensor_address->GetPtr() == nullptr) {
438     MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
439                  << ", device type:" << host_tensor_address->DeviceType();
440     SyncTensorData(tensor, host_tensor_address, backend_node, device_context, context, strategy_);
441   }
442 
443   // Allocate another device memory and copy data from host tensor to another device(if exist).
444   const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
445   if (device_tensors.size() > 1) {
446     auto another_device_tensor = (device_tensors[0] == host_tensor_address) ? device_tensors[1] : device_tensors[0];
447     MS_EXCEPTION_IF_NULL(another_device_tensor);
448     auto another_device_type = another_device_tensor->DeviceType();
449     const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
450       {device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_});
451     MS_EXCEPTION_IF_NULL(another_device_context);
452     if ((another_device_tensor->GetPtr() == nullptr) &&
453         (!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize()))) {
454       SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *another_device_context,
455                                                   backend_node->fullname_with_scope(),
456                                                   another_device_tensor->GetSize());
457     }
458 
459     MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
460                  << ", device type:" << another_device_type;
461     if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
462       std::string error_info = "Sync data error.";
463       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
464     }
465   }
466 }
467 
468 // In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
PrepareDataForControlWeightNode(const AnfNodePtr & node,const AnfNodePtr & front_node,const TensorPtr & tensor,const DeviceContext * device_context,const std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> & host_parameter_to_weights,OpContext<DeviceTensor> * const context)469 void DataPrepareActor::PrepareDataForControlWeightNode(
470   const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context,
471   const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &host_parameter_to_weights,
472   OpContext<DeviceTensor> *const context) {
473   MS_EXCEPTION_IF_NULL(node);
474   MS_EXCEPTION_IF_NULL(front_node);
475   MS_EXCEPTION_IF_NULL(tensor);
476   MS_EXCEPTION_IF_NULL(device_context);
477 
478   auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
479   bool need_update_device_tensor_store = (device_tensors.size() == 0) ? true : false;
480   for (auto &device_tensor : device_tensors) {
481     MS_EXCEPTION_IF_NULL(device_tensor);
482     if (device_tensor->GetPtr() == nullptr) {
483       need_update_device_tensor_store = true;
484       break;
485     }
486   }
487   if (need_update_device_tensor_store) {
488     PrepareDataForWeightNode(node, front_node, tensor, device_context, context);
489   }
490 
491   const auto iter = host_parameter_to_weights.find(front_node);
492   if (iter == host_parameter_to_weights.end()) {
493     return;
494   }
495 
496   // Fetch all the device tensors of host weight node and insert as the weight of other nodes.
497   const auto &sub_front_nodes = host_parameter_to_weights.at(front_node);
498   device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
499   for (const auto &sub_front_node : sub_front_nodes) {
500     for (const auto &device_tensor : device_tensors) {
501       MS_EXCEPTION_IF_NULL(sub_front_node);
502       DeviceTensorStore::GetInstance().Insert(sub_front_node.get(), device_tensor);
503     }
504   }
505 }
506 
PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr & control_node_parser,const std::vector<TensorPtr> & tensors,OpContext<DeviceTensor> * const context)507 void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr &control_node_parser,
508                                                               const std::vector<TensorPtr> &tensors,
509                                                               OpContext<DeviceTensor> *const context) {
510   MS_EXCEPTION_IF_NULL(control_node_parser);
511   for (const auto &value_node_with_context : control_node_parser->front_value_nodes()) {
512     if (AnfAlgo::OutputAddrExist(value_node_with_context.first, 0)) {
513       PrepareDataForValueNode(value_node_with_context.first->cast<ValueNodePtr>(), value_node_with_context.second,
514                               context);
515     }
516   }
517 
518   const auto &control_node_parameters = control_node_parser->control_node_parameters();
519   for (size_t i = 0; i < control_node_parameters.size(); ++i) {
520     const auto &input_node = control_node_parameters[i];
521     const auto &input_tensor = tensors[i];
522     MS_EXCEPTION_IF_NULL(input_node);
523     if (IsPersistentDeviceTensor(input_node)) {
524       const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
525       const auto &iter = front_to_backend_parameters.find(input_node);
526       if (iter == front_to_backend_parameters.end()) {
527         MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
528                           << AnfAlgo::GetNodeDebugString(input_node);
529       }
530       const auto &node_with_context = iter->second;
531       PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
532                                       control_node_parser->host_parameter_to_weights(), context);
533     }
534   }
535 }
536 
PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> & tensors,std::vector<TensorPtr> * const host_tensors,OpContext<DeviceTensor> * const context)537 void DataPrepareActor::PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> &tensors,
538                                                             std::vector<TensorPtr> *const host_tensors,
539                                                             OpContext<DeviceTensor> *const context) {
540   MS_EXCEPTION_IF_NULL(graph_compiler_info_->control_node_parser_);
541   MS_EXCEPTION_IF_NULL(host_data_source_actor_);
542   MS_EXCEPTION_IF_NULL(host_tensors);
543 
544   const auto &control_node_parameters = graph_compiler_info_->control_node_parser_->control_node_parameters();
545   for (size_t i = 0; i < control_node_parameters.size(); ++i) {
546     const auto &input_node = control_node_parameters[i];
547     const auto &input_tensor = tensors[i];
548     MS_EXCEPTION_IF_NULL(input_node);
549     if (IsPersistentDeviceTensor(input_node)) {
550       continue;
551     }
552 
553     if (find(graph_compiler_info_->origin_parameters_order_.begin(),
554              graph_compiler_info_->origin_parameters_order_.end(),
555              input_node) == graph_compiler_info_->origin_parameters_order_.end()) {
556       continue;
557     }
558 
559     auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
560     if (tensor_position >= host_tensors->size()) {
561       std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
562       SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
563     }
564     (*host_tensors)[tensor_position] = input_tensor;
565 
566     const AnfNodePtr &backend_node = host_data_source_actor_->FetchNode(tensor_position);
567     auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
568     auto device_address = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
569     MS_EXCEPTION_IF_NULL(device_address);
570     if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
571       AnfAlgo::SetOutputAddr(tensor_address, 0, backend_node.get());
572     }
573   }
574 }
575 }  // namespace runtime
576 }  // namespace mindspore
577