1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/actor/data_prepare_actor.h"
18 #include "runtime/framework/actor/memory_manager_actor.h"
19 #include "runtime/framework/actor/kernel_actor.h"
20 #include "runtime/framework/actor/loop_count_actor.h"
21 #include "runtime/framework/actor/debug_actor.h"
22 #include "runtime/hardware/device_context_manager.h"
23 #include "mindrt/include/async/async.h"
24 #include "utils/log_adapter.h"
25 #include "utils/convert_utils.h"
26 #include "common/trans.h"
27
28 namespace mindspore {
29 namespace runtime {
30 namespace {
SyncTensorData(const TensorPtr & host_tensor,const DeviceTensorPtr & device_tensor,const AnfNodePtr & node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context,GraphExecutionStrategy strategy)31 void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
32 const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
33 GraphExecutionStrategy strategy) {
34 MS_EXCEPTION_IF_NULL(host_tensor);
35 MS_EXCEPTION_IF_NULL(device_tensor);
36 MS_EXCEPTION_IF_NULL(node);
37 MS_EXCEPTION_IF_NULL(device_context);
38 MS_EXCEPTION_IF_NULL(context);
39
40 if ((device_tensor->GetPtr() == nullptr) &&
41 (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize()))) {
42 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy, *context, *device_context, node->fullname_with_scope(),
43 device_tensor->GetSize());
44 }
45
46 // Copy data from host tensor to device.
47 auto host_tensor_size = LongToSize(host_tensor->data().nbytes());
48 auto host_tensor_type = host_tensor->data_type();
49 if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), host_tensor_size, host_tensor_type,
50 host_tensor->data_c(), host_tensor->device_info().host_format_)) {
51 std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope() +
52 ", host tensor size: " + std::to_string(host_tensor_size) +
53 ", host tensor type: " + std::to_string(static_cast<int>(host_tensor_type)) +
54 ", device tensor size: " + std::to_string(device_tensor->GetSize());
55 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy, (*context), error_info);
56 }
57 }
58
FetchContinuousMemoryInfo(const CNodePtr & node,std::vector<DeviceTensorPtr> * const addr_list,std::vector<size_t> * const size_list,size_t * const total_size,bool is_input)59 void FetchContinuousMemoryInfo(const CNodePtr &node, std::vector<DeviceTensorPtr> *const addr_list,
60 std::vector<size_t> *const size_list, size_t *const total_size, bool is_input) {
61 MS_EXCEPTION_IF_NULL(node);
62 MS_EXCEPTION_IF_NULL(addr_list);
63 MS_EXCEPTION_IF_NULL(size_list);
64 MS_EXCEPTION_IF_NULL(total_size);
65
66 const auto &kernel_mod = AnfAlgo::GetKernelMod(node);
67 MS_EXCEPTION_IF_NULL(kernel_mod);
68 (*addr_list).clear();
69 (*size_list).clear();
70 *total_size = 0;
71
72 if (is_input) {
73 const auto &intput_sizes = kernel_mod->GetInputSizeList();
74 for (size_t i = 0; i < intput_sizes.size(); ++i) {
75 const auto &device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, false);
76 MS_EXCEPTION_IF_NULL(device_tensor);
77 *total_size += intput_sizes[i];
78 (void)size_list->emplace_back(intput_sizes[i]);
79 (void)addr_list->emplace_back(device_tensor);
80 }
81 } else {
82 const auto &output_sizes = kernel_mod->GetOutputSizeList();
83 for (size_t i = 0; i < output_sizes.size(); ++i) {
84 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
85 MS_EXCEPTION_IF_NULL(device_tensor);
86 *total_size += output_sizes[i];
87 (void)size_list->emplace_back(output_sizes[i]);
88 (void)addr_list->emplace_back(device_tensor);
89 }
90 }
91 }
92 } // namespace
Init()93 void DataPrepareActor::Init() {
94 MS_EXCEPTION_IF_NULL(graph_compiler_info_);
95 strategy_ = graph_compiler_info_->strategy_;
96 if (graph_compiler_info_->graphs_.size() != graph_compiler_info_->device_contexts_.size()) {
97 MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
98 }
99
100 for (auto &iter : continuous_memory_nodes_) {
101 size_t total_size = 0;
102 std::vector<size_t> size_list;
103 std::vector<DeviceTensorPtr> addr_list;
104 // Inputs need continuous memory.
105 if (iter.second.first == true) {
106 FetchContinuousMemoryInfo(iter.first.first, &addr_list, &size_list, &total_size, true);
107 (void)continuous_memory_alloc_list_list_.emplace_back(addr_list);
108 (void)size_list_list_.emplace_back(size_list);
109 (void)total_size_list_.emplace_back(total_size);
110 (void)continuous_memory_device_contexts_.emplace_back(iter.first.second);
111 }
112
113 // Outputs need continuous memory.
114 if (iter.second.second == true) {
115 FetchContinuousMemoryInfo(iter.first.first, &addr_list, &size_list, &total_size, false);
116 (void)continuous_memory_alloc_list_list_.emplace_back(addr_list);
117 (void)size_list_list_.emplace_back(size_list);
118 (void)total_size_list_.emplace_back(total_size);
119 (void)continuous_memory_device_contexts_.emplace_back(iter.first.second);
120 }
121 }
122 }
123
PrepareData(const std::vector<std::vector<TensorPtr>> & input_tensors,OpContext<DeviceTensor> * const context)124 void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &input_tensors,
125 OpContext<DeviceTensor> *const context) {
126 MS_EXCEPTION_IF_NULL(context);
127
128 // Convert actor running data from input tensors.
129 if (input_tensors.size() > 0) {
130 PrepareDataForDeviceTensorStore(input_tensors, context);
131 if (strategy_ == GraphExecutionStrategy::kPipeline) {
132 PrepareDataForHostTensorQueue(input_tensors, context);
133 } else if (strategy_ == GraphExecutionStrategy::kStep) {
134 PrepareDataForStepMode(input_tensors, context);
135 }
136
137 // Debug actor is blocked, must wait debug actor callback message to process continue.
138 if (debug_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
139 SendDebugReq(context);
140 return;
141 }
142 }
143
144 // Allocate continuous memory and send output to trigger the step running.
145 if (continuous_memory_alloc_list_list_.size() > 0) {
146 SendMemoryAllocReq(context);
147 } else {
148 SendOutput(context);
149 }
150 }
151
SendDebugReq(OpContext<DeviceTensor> * const context)152 void DataPrepareActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
153 Async(*debug_aid_, &DebugActor::DebugOnStepBegin, graph_compiler_info_->graphs_,
154 graph_compiler_info_->device_contexts_, context, &GetAID());
155 }
156
OnDebugFinish(OpContext<DeviceTensor> * const context)157 void DataPrepareActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
158 MS_EXCEPTION_IF_NULL(context);
159 if (continuous_memory_alloc_list_list_.size() > 0) {
160 SendMemoryAllocReq(context);
161 } else {
162 SendOutput(context);
163 }
164 }
165
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)166 void DataPrepareActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
167 // Allocate continuous memory in the begin of the step running.
168 Async(memory_manager_aid_, &MemoryManagerActor::AllocateContinuousMemory, &continuous_memory_alloc_list_list_,
169 &size_list_list_, &total_size_list_, &continuous_memory_device_contexts_, context, GetAID());
170 }
171
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)172 void DataPrepareActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
173 MS_EXCEPTION_IF_NULL(context);
174 SendOutput(context);
175 }
176
SendOutput(OpContext<DeviceTensor> * const context)177 void DataPrepareActor::SendOutput(OpContext<DeviceTensor> *const context) {
178 for (auto &data_source_aid : data_source_aids_) {
179 Async(data_source_aid, &DataSourceActor::FetchData, context);
180 }
181
182 auto source_aid = const_cast<AID *>(&GetAID());
183 for (auto &kernel_aid : no_input_kernel_aids_) {
184 Async(kernel_aid, &KernelActor::RunOpControl, source_aid, context);
185 }
186
187 // Trigger loop count actor running when there are no data source actor and kernel actor.
188 if ((data_source_aids_.size() + no_input_kernel_aids_.size() == 0) && (loop_count_aid_ != nullptr)) {
189 Async(*loop_count_aid_, &LoopCountActor::RunOpControl, source_aid, context);
190 }
191 }
192
PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> & input_tensors,OpContext<DeviceTensor> * const context)193 void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
194 OpContext<DeviceTensor> *const context) {
195 for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
196 const auto &graph = graph_compiler_info_->graphs_[i];
197 const auto &device_context = graph_compiler_info_->device_contexts_[i];
198 MS_EXCEPTION_IF_NULL(graph);
199 // Prepare the data of device tensor store(value nodes of graph).
200 for (const auto &value_node : graph->graph_value_nodes()) {
201 if (AnfAlgo::OutputAddrExist(value_node, 0)) {
202 PrepareDataForValueNode(value_node, device_context, context);
203 }
204 }
205
206 // Prepare the data of device tensor store(weights of graph).
207 const auto &input_nodes = graph->input_nodes();
208 const auto &tensors = input_tensors[i];
209 for (size_t j = 0; j < input_nodes.size(); ++j) {
210 const auto &input_node = input_nodes[j];
211 const auto &input_tensor = tensors[j];
212 MS_EXCEPTION_IF_NULL(input_node);
213 if (!IsPersistentDeviceTensor(input_node)) {
214 continue;
215 }
216 const auto front_node = FetchFrontNodeByBackendNode(input_node, graph);
217 PrepareDataForWeightNode(input_node, front_node, input_tensor, device_context, context);
218 }
219 }
220
221 PrepareDeviceTensorStoreForControlNode(graph_compiler_info_->control_node_parser_, input_tensors.back(), context);
222 }
223
PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> & input_tensors,OpContext<DeviceTensor> * const context)224 void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors,
225 OpContext<DeviceTensor> *const context) {
226 MS_EXCEPTION_IF_NULL(context);
227 if ((host_data_source_actor_ == nullptr) || (host_tensor_queue_ == nullptr)) {
228 return;
229 }
230
231 std::vector<TensorPtr> host_tensors;
232 host_tensors.resize(host_data_source_actor_->data_nodes().size());
233 // Fill host tensors.
234 for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
235 const auto &graph = graph_compiler_info_->graphs_[i];
236 MS_EXCEPTION_IF_NULL(graph);
237
238 const auto &input_nodes = graph->input_nodes();
239 const auto &tensors = input_tensors[i];
240 for (size_t j = 0; j < input_nodes.size(); ++j) {
241 const auto &input_node = input_nodes[j];
242 const auto &input_tensor = tensors[j];
243 MS_EXCEPTION_IF_NULL(input_node);
244 if (!IsHostQueueDSActor(input_node, graph, graph_compiler_info_->origin_parameters_order_, strategy_)) {
245 continue;
246 }
247 auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
248 if (tensor_position >= host_tensors.size()) {
249 std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
250 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
251 }
252 host_tensors[tensor_position] = input_tensor;
253
254 auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
255 auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
256 MS_EXCEPTION_IF_NULL(device_address);
257 if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
258 AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
259 }
260 }
261 }
262
263 PrepareHostTensorQueueForControlNode(input_tensors.back(), &host_tensors, context);
264
265 host_tensor_queue_->Push(host_tensors);
266 }
267
PrepareDataForStepMode(const std::vector<std::vector<TensorPtr>> & input_tensors,OpContext<DeviceTensor> * const context)268 void DataPrepareActor::PrepareDataForStepMode(const std::vector<std::vector<TensorPtr>> &input_tensors,
269 OpContext<DeviceTensor> *const context) {
270 MS_EXCEPTION_IF_NULL(context);
271 std::vector<TensorPtr> host_tensors;
272 if ((host_data_source_actor_ != nullptr) && (host_tensor_queue_ != nullptr)) {
273 host_tensors.resize(host_data_source_actor_->data_nodes().size());
274 }
275
276 for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
277 const auto &graph = graph_compiler_info_->graphs_[i];
278 const auto &device_context = graph_compiler_info_->device_contexts_[i];
279 MS_EXCEPTION_IF_NULL(graph);
280 MS_EXCEPTION_IF_NULL(device_context);
281
282 const auto &input_nodes = graph->input_nodes();
283 const auto &tensors = input_tensors[i];
284 for (size_t j = 0; j < input_nodes.size(); ++j) {
285 const auto &input_node = input_nodes[j];
286 const auto &input_tensor = tensors[j];
287 MS_EXCEPTION_IF_NULL(input_node);
288 MS_EXCEPTION_IF_NULL(input_tensor);
289 if (IsPersistentDeviceTensor(input_node)) {
290 continue;
291 }
292
293 if ((host_data_source_actor_ != nullptr) && (host_tensor_queue_ != nullptr)) {
294 auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
295 if (tensor_position >= host_tensors.size()) {
296 std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
297 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
298 }
299 host_tensors[tensor_position] = input_tensor;
300 }
301
302 auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
303 if (host_tensor_address != nullptr) {
304 AnfAlgo::SetOutputAddr(host_tensor_address, 0, input_node.get());
305 continue;
306 }
307
308 if (!AnfAlgo::OutputAddrExist(input_node, 0, false)) {
309 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
310 if (output_type_id == kTypeUnknown) {
311 output_type_id = AnfAlgo::GetOutputInferDataType(input_node, 0);
312 }
313 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, 0);
314 auto device_address = device_context->CreateDeviceAddress(
315 nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, 0), output_type_id);
316 AnfAlgo::SetOutputAddr(device_address, 0, input_node.get());
317 }
318 auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
319 input_tensor->set_device_address(device_tensor);
320 UpdateRefCount(device_tensor.get(), true);
321
322 SyncTensorData(input_tensor, device_tensor, input_node, device_context, context, strategy_);
323 }
324 }
325
326 if ((host_data_source_actor_ != nullptr) && (host_tensor_queue_ != nullptr)) {
327 host_tensor_queue_->Push(host_tensors);
328 }
329 }
330
331 // The branch processing of PrepareDataForValueNode that value type is tensor.
PrepareDataForValueNodeTensor(const ValueNodePtr & node,const ValuePtr & node_value,const DeviceContext * device_context,OpContext<DeviceTensor> * const context)332 void DataPrepareActor::PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
333 const DeviceContext *device_context,
334 OpContext<DeviceTensor> *const context) {
335 MS_EXCEPTION_IF_NULL(node);
336 MS_EXCEPTION_IF_NULL(node_value);
337 MS_EXCEPTION_IF_NULL(device_context);
338 MS_EXCEPTION_IF_NULL(context);
339
340 std::vector<TensorPtr> tensors;
341 TensorValueToTensor(node_value, &tensors);
342 for (size_t i = 0; i < tensors.size(); i++) {
343 const auto &tensor = tensors[i];
344 if (tensor == nullptr) {
345 MS_LOG(WARNING) << "Tensor is null";
346 return;
347 }
348
349 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
350 MS_EXCEPTION_IF_NULL(device_tensor);
351 // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
352 if (device_tensor->GetPtr() != nullptr) {
353 return;
354 }
355 MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope() << ", output index: " << i;
356 tensor->set_device_address(device_tensor);
357 UpdateRefCount(device_tensor.get(), true);
358
359 SyncTensorData(tensor, device_tensor, node, device_context, context, strategy_);
360 }
361 }
362
363 // Prepare the device data for persistent device tensor of value node.
PrepareDataForValueNode(const ValueNodePtr & node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context)364 void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const DeviceContext *device_context,
365 OpContext<DeviceTensor> *const context) {
366 MS_EXCEPTION_IF_NULL(node);
367 MS_EXCEPTION_IF_NULL(device_context);
368 MS_EXCEPTION_IF_NULL(context);
369 auto &node_value = node->value();
370 MS_EXCEPTION_IF_NULL(node_value);
371
372 if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
373 // The branch processing that value type is tensor.
374 PrepareDataForValueNodeTensor(node, node_value, device_context, context);
375 } else if (node_value->isa<StringImm>()) {
376 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
377 MS_EXCEPTION_IF_NULL(device_tensor);
378 // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
379 if (device_tensor->GetPtr() != nullptr) {
380 return;
381 }
382 MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope();
383
384 if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
385 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *device_context, node->fullname_with_scope(),
386 device_tensor->GetSize());
387 }
388
389 // Copy data from value to device.
390 auto value = GetValue<std::string>(node_value);
391 size_t tensor_size = value.size();
392 ShapeVector shape = {1, SizeToLong(tensor_size)};
393 if (!device_tensor->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
394 std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope();
395 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
396 }
397 }
398 }
399
400 // Prepare the device data for persistent device tensor of weight node from host tensor.
PrepareDataForWeightNode(const AnfNodePtr & backend_node,const AnfNodePtr & front_node,const TensorPtr & tensor,const DeviceContext * device_context,OpContext<DeviceTensor> * const context)401 void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node,
402 const TensorPtr &tensor, const DeviceContext *device_context,
403 OpContext<DeviceTensor> *const context) {
404 MS_EXCEPTION_IF_NULL(backend_node);
405 MS_EXCEPTION_IF_NULL(front_node);
406 MS_EXCEPTION_IF_NULL(device_context);
407 MS_EXCEPTION_IF_NULL(context);
408 if (tensor == nullptr) {
409 return;
410 }
411
412 auto device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
413 MS_EXCEPTION_IF_NULL(device_tensor);
414 auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
415 // Use the device address of host tensor to set device tensor.
416 if (host_tensor_address != device_tensor) {
417 if (host_tensor_address == nullptr) {
418 host_tensor_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
419 device_tensor->format(), device_tensor->type_id());
420 tensor->set_device_address(host_tensor_address);
421 UpdateRefCount(host_tensor_address.get(), true);
422 }
423 MS_EXCEPTION_IF_NULL(host_tensor_address);
424 if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
425 AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
426 } else {
427 MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
428 << ", device tensor type:" << device_tensor->DeviceType();
429 }
430 }
431 // Maybe the same host_tensor_address corresponds to the different front_node in shared weight scene,
432 // so need update the device tensor store always.
433 DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
434
435 // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
436 MS_EXCEPTION_IF_NULL(host_tensor_address);
437 if (host_tensor_address->GetPtr() == nullptr) {
438 MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
439 << ", device type:" << host_tensor_address->DeviceType();
440 SyncTensorData(tensor, host_tensor_address, backend_node, device_context, context, strategy_);
441 }
442
443 // Allocate another device memory and copy data from host tensor to another device(if exist).
444 const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
445 if (device_tensors.size() > 1) {
446 auto another_device_tensor = (device_tensors[0] == host_tensor_address) ? device_tensors[1] : device_tensors[0];
447 MS_EXCEPTION_IF_NULL(another_device_tensor);
448 auto another_device_type = another_device_tensor->DeviceType();
449 const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
450 {device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_});
451 MS_EXCEPTION_IF_NULL(another_device_context);
452 if ((another_device_tensor->GetPtr() == nullptr) &&
453 (!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize()))) {
454 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy_, *context, *another_device_context,
455 backend_node->fullname_with_scope(),
456 another_device_tensor->GetSize());
457 }
458
459 MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
460 << ", device type:" << another_device_type;
461 if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
462 std::string error_info = "Sync data error.";
463 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
464 }
465 }
466 }
467
468 // In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
PrepareDataForControlWeightNode(const AnfNodePtr & node,const AnfNodePtr & front_node,const TensorPtr & tensor,const DeviceContext * device_context,const std::unordered_map<AnfNodePtr,std::vector<AnfNodePtr>> & host_parameter_to_weights,OpContext<DeviceTensor> * const context)469 void DataPrepareActor::PrepareDataForControlWeightNode(
470 const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context,
471 const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &host_parameter_to_weights,
472 OpContext<DeviceTensor> *const context) {
473 MS_EXCEPTION_IF_NULL(node);
474 MS_EXCEPTION_IF_NULL(front_node);
475 MS_EXCEPTION_IF_NULL(tensor);
476 MS_EXCEPTION_IF_NULL(device_context);
477
478 auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
479 bool need_update_device_tensor_store = (device_tensors.size() == 0) ? true : false;
480 for (auto &device_tensor : device_tensors) {
481 MS_EXCEPTION_IF_NULL(device_tensor);
482 if (device_tensor->GetPtr() == nullptr) {
483 need_update_device_tensor_store = true;
484 break;
485 }
486 }
487 if (need_update_device_tensor_store) {
488 PrepareDataForWeightNode(node, front_node, tensor, device_context, context);
489 }
490
491 const auto iter = host_parameter_to_weights.find(front_node);
492 if (iter == host_parameter_to_weights.end()) {
493 return;
494 }
495
496 // Fetch all the device tensors of host weight node and insert as the weight of other nodes.
497 const auto &sub_front_nodes = host_parameter_to_weights.at(front_node);
498 device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
499 for (const auto &sub_front_node : sub_front_nodes) {
500 for (const auto &device_tensor : device_tensors) {
501 MS_EXCEPTION_IF_NULL(sub_front_node);
502 DeviceTensorStore::GetInstance().Insert(sub_front_node.get(), device_tensor);
503 }
504 }
505 }
506
PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr & control_node_parser,const std::vector<TensorPtr> & tensors,OpContext<DeviceTensor> * const context)507 void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr &control_node_parser,
508 const std::vector<TensorPtr> &tensors,
509 OpContext<DeviceTensor> *const context) {
510 MS_EXCEPTION_IF_NULL(control_node_parser);
511 for (const auto &value_node_with_context : control_node_parser->front_value_nodes()) {
512 if (AnfAlgo::OutputAddrExist(value_node_with_context.first, 0)) {
513 PrepareDataForValueNode(value_node_with_context.first->cast<ValueNodePtr>(), value_node_with_context.second,
514 context);
515 }
516 }
517
518 const auto &control_node_parameters = control_node_parser->control_node_parameters();
519 for (size_t i = 0; i < control_node_parameters.size(); ++i) {
520 const auto &input_node = control_node_parameters[i];
521 const auto &input_tensor = tensors[i];
522 MS_EXCEPTION_IF_NULL(input_node);
523 if (IsPersistentDeviceTensor(input_node)) {
524 const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters();
525 const auto &iter = front_to_backend_parameters.find(input_node);
526 if (iter == front_to_backend_parameters.end()) {
527 MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
528 << AnfAlgo::GetNodeDebugString(input_node);
529 }
530 const auto &node_with_context = iter->second;
531 PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
532 control_node_parser->host_parameter_to_weights(), context);
533 }
534 }
535 }
536
PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> & tensors,std::vector<TensorPtr> * const host_tensors,OpContext<DeviceTensor> * const context)537 void DataPrepareActor::PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> &tensors,
538 std::vector<TensorPtr> *const host_tensors,
539 OpContext<DeviceTensor> *const context) {
540 MS_EXCEPTION_IF_NULL(graph_compiler_info_->control_node_parser_);
541 MS_EXCEPTION_IF_NULL(host_data_source_actor_);
542 MS_EXCEPTION_IF_NULL(host_tensors);
543
544 const auto &control_node_parameters = graph_compiler_info_->control_node_parser_->control_node_parameters();
545 for (size_t i = 0; i < control_node_parameters.size(); ++i) {
546 const auto &input_node = control_node_parameters[i];
547 const auto &input_tensor = tensors[i];
548 MS_EXCEPTION_IF_NULL(input_node);
549 if (IsPersistentDeviceTensor(input_node)) {
550 continue;
551 }
552
553 if (find(graph_compiler_info_->origin_parameters_order_.begin(),
554 graph_compiler_info_->origin_parameters_order_.end(),
555 input_node) == graph_compiler_info_->origin_parameters_order_.end()) {
556 continue;
557 }
558
559 auto tensor_position = host_data_source_actor_->FetchNodePosition(input_node);
560 if (tensor_position >= host_tensors->size()) {
561 std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
562 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy_, (*context), error_info);
563 }
564 (*host_tensors)[tensor_position] = input_tensor;
565
566 const AnfNodePtr &backend_node = host_data_source_actor_->FetchNode(tensor_position);
567 auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
568 auto device_address = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
569 MS_EXCEPTION_IF_NULL(device_address);
570 if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
571 AnfAlgo::SetOutputAddr(tensor_address, 0, backend_node.get());
572 }
573 }
574 }
575 } // namespace runtime
576 } // namespace mindspore
577