1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <algorithm>
18 #include <set>
19 #include "runtime/graph_scheduler/actor/data_prepare_actor.h"
20 #include "runtime/graph_scheduler/actor/memory_manager_actor.h"
21 #include "runtime/graph_scheduler/actor/kernel_actor.h"
22 #include "runtime/graph_scheduler/actor/loop_count_actor.h"
23 #include "runtime/graph_scheduler/actor/debug_actor.h"
24 #include "runtime/graph_scheduler/actor/profiler_actor.h"
25 #include "runtime/hardware/device_context_manager.h"
26 #include "runtime/device/auto_mem_offload.h"
27 #include "runtime/device/device_address_utils.h"
28 #include "mindrt/include/async/async.h"
29 #include "utils/log_adapter.h"
30 #include "utils/phase.h"
31 #include "include/common/utils/convert_utils.h"
32 #include "include/backend/distributed/recovery/recovery_context.h"
33 #include "include/backend/mem_reuse/mem_tracker.h"
34 #if defined(__linux__) && defined(WITH_BACKEND)
35 #include "runtime/graph_scheduler/rpc_node_scheduler.h"
36 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
37 #endif
38
39 namespace mindspore {
40 namespace runtime {
41 using distributed::recovery::RecoveryContext;
42 namespace {
43 constexpr size_t kNormalTensorNum = 1;
44 constexpr size_t kMapTensorNum = 3;
45 constexpr size_t kMapTensorKeyIndex = 0;
46 constexpr size_t kMapTensorValueIndex = 1;
47 constexpr size_t kMapTensorStatusIndex = 2;
48 constexpr size_t kPinMemThreshold = 1024 << 10;
49
IsEmptySequenceTensor(const TensorPtr & tensor)50 bool IsEmptySequenceTensor(const TensorPtr &tensor) {
51 MS_EXCEPTION_IF_NULL(tensor);
52 if (tensor->base_shape_ptr() == nullptr || (!tensor->base_shape_ptr()->isa<abstract::SequenceShape>())) {
53 return false;
54 }
55 const auto &sequence_shape = tensor->base_shape_ptr()->cast<abstract::SequenceShapePtr>();
56 MS_EXCEPTION_IF_NULL(sequence_shape);
57 return sequence_shape->size() == 0;
58 }
59
IsDataTakenOverByMemOffload(const DeviceContext * device_context)60 bool IsDataTakenOverByMemOffload(const DeviceContext *device_context) {
61 MS_EXCEPTION_IF_NULL(device_context);
62 if (device_context->GetDeviceType() == device::DeviceType::kCPU) {
63 return false;
64 }
65 auto ms_context = MsContext::GetInstance();
66 MS_EXCEPTION_IF_NULL(ms_context);
67 return ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD);
68 }
69
GetStorageInfo(const TensorPtr & host_tensor,const DeviceTensorPtr & device_tensor,const DeviceContext * device_context)70 device::StorageInfo GetStorageInfo(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor,
71 const DeviceContext *device_context) {
72 MS_EXCEPTION_IF_NULL(host_tensor);
73 MS_EXCEPTION_IF_NULL(device_tensor);
74 MS_EXCEPTION_IF_NULL(device_context);
75 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
76 auto swap_manager = device_context->device_res_manager_->swap_manager();
77 MS_EXCEPTION_IF_NULL(swap_manager);
78 if (host_tensor->data_type() == device_tensor->type_id()) {
79 const auto &offload_file = host_tensor->GetOffloadFilePath();
80 if (!offload_file.empty()) {
81 return {nullptr, offload_file};
82 } else if (host_tensor->Size() > kPinMemThreshold) {
83 host_tensor->PinMemory(swap_manager->GetPinMemPool());
84 }
85 return {host_tensor->data_c(), ""};
86 }
87 const auto shape_size = abstract::ShapeSize(host_tensor->shape());
88 const auto data_size = host_tensor->Size();
89 const trans::TypeIdArgs type_args{host_tensor->data_c(), shape_size, host_tensor->data_type(),
90 device_tensor->type_id(), data_size};
91 auto offload_ptr = swap_manager->AllocHostMemory(device_tensor->GetSize());
92 MS_EXCEPTION_IF_NULL(offload_ptr);
93 bool trans_ret = trans::TransDataType(type_args, offload_ptr);
94 if (!trans_ret) {
95 MS_LOG(EXCEPTION) << "Trans data type for offload ptr failed, src type: "
96 << TypeIdToString(host_tensor->data_type())
97 << ", dst type: " << TypeIdToString(device_tensor->type_id());
98 }
99 return {offload_ptr, ""};
100 }
101
UpdateTracker(const std::string & task_name,const AnfNodePtr & node,const std::string & graph_str,device::tracker::MemType mem_type,const DeviceTensorPtr & device_tensor)102 void UpdateTracker(const std::string &task_name, const AnfNodePtr &node, const std::string &graph_str,
103 device::tracker::MemType mem_type, const DeviceTensorPtr &device_tensor) {
104 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, task_name, node->fullname_with_scope(), graph_str);
105 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, task_name, mem_type, device_tensor->GetSize(),
106 device_tensor.get());
107 }
108
SyncTensorData(const TensorPtr & host_tensor,const DeviceTensorPtr & device_tensor,const AnfNodePtr & node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context,GraphExecutionStrategy strategy)109 void SyncTensorData(const TensorPtr &host_tensor, const DeviceTensorPtr &device_tensor, const AnfNodePtr &node,
110 const DeviceContext *device_context, OpContext<DeviceTensor> *const context,
111 GraphExecutionStrategy strategy) {
112 MS_EXCEPTION_IF_NULL(host_tensor);
113 MS_EXCEPTION_IF_NULL(device_tensor);
114 MS_EXCEPTION_IF_NULL(node);
115 MS_EXCEPTION_IF_NULL(device_context);
116 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
117 MS_EXCEPTION_IF_NULL(context);
118 const bool taken_over_by_swap_manager = IsDataTakenOverByMemOffload(device_context);
119 auto allocator_type = node->isa<ValueNode>() ? device::AllocatorType::kConstantValue : device::AllocatorType::kWeight;
120 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), allocator_type, 0);
121 bool need_alloc_memory = !taken_over_by_swap_manager && (device_tensor->GetPtr() == nullptr);
122 auto graph_str = (node->func_graph() == nullptr) ? "" : node->func_graph()->ToString();
123 auto mem_type = node->isa<ValueNode>() ? device::tracker::MemType::kConstantValue : device::tracker::MemType::kWeight;
124 if (need_alloc_memory) {
125 UpdateTracker("SyncTensorData", node, graph_str, mem_type, device_tensor);
126 if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex)) {
127 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(strategy, *context, *device_context, node->fullname_with_scope(),
128 device_tensor->GetSize());
129 }
130 if (common::IsNeedProfileMemory()) {
131 auto output_address = reinterpret_cast<std::uintptr_t>(device_tensor.get());
132 MS_LOG(WARNING) << "Need Profile Memory, alloc type: SyncTensorData, device address class ptr: " << output_address
133 << ", node: " << node->fullname_with_scope() << ", graph: " << graph_str
134 << ", device address size: " << device_tensor->GetSize()
135 << ", device address addr: " << device_tensor->GetPtr();
136 }
137 }
138
139 auto get_tensor_by_index = [&host_tensor](size_t index) {
140 if (!host_tensor->isa<tensor::MapTensor>()) {
141 return host_tensor;
142 }
143 const auto &map_tensor = host_tensor->cast<tensor::MapTensorPtr>();
144 MS_EXCEPTION_IF_NULL(map_tensor);
145 switch (index) {
146 case kMapTensorKeyIndex:
147 return map_tensor->key_tensor();
148 case kMapTensorValueIndex:
149 return map_tensor->value_tensor();
150 case kMapTensorStatusIndex:
151 return map_tensor->status_tensor();
152 default:
153 MS_LOG(EXCEPTION) << "Invalid index:" << index << " for map tensor:" << host_tensor->ToString();
154 }
155 };
156
157 ShapeVector host_shape = {};
158 // GetRuntimePaddingShape doesn't support the value tuple node.
159 if (!node->isa<ValueNode>()) {
160 host_shape = trans::GetRuntimePaddingShape(node, 0);
161 }
162 auto get_tensor_num = (host_tensor->isa<tensor::MapTensor>() ? kMapTensorNum : kNormalTensorNum);
163 for (size_t i = 0; i < get_tensor_num; ++i) {
164 const auto &real_host_tensor = get_tensor_by_index(i);
165 MS_EXCEPTION_IF_NULL(real_host_tensor);
166 // Copy data from host tensor to device.
167 auto host_tensor_size = LongToSize(real_host_tensor->data().nbytes());
168 auto host_tensor_type = real_host_tensor->data_type();
169 if (node->isa<ValueNode>()) {
170 host_shape = real_host_tensor->shape();
171 }
172 if (taken_over_by_swap_manager) {
173 device_tensor->SetStorageInfo(GetStorageInfo(real_host_tensor, device_tensor, device_context));
174 } else if (!device_tensor->SyncHostToDevice(host_shape, host_tensor_size, host_tensor_type,
175 real_host_tensor->device_info().host_format_,
176 real_host_tensor->data_ptr())) {
177 std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope() +
178 ", host tensor size: " + std::to_string(host_tensor_size) +
179 ", host tensor type: " + std::to_string(static_cast<int>(host_tensor_type)) +
180 ", device tensor size: " + std::to_string(device_tensor->GetSize());
181 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(strategy, (*context), error_info);
182 }
183 }
184 }
185
FetchContinuousMemoryInfo(const CNodePtr & node,std::vector<DeviceTensorPtr> * const addr_list,std::vector<size_t> * const size_list,size_t * const total_size,bool is_input)186 void FetchContinuousMemoryInfo(const CNodePtr &node, std::vector<DeviceTensorPtr> *const addr_list,
187 std::vector<size_t> *const size_list, size_t *const total_size, bool is_input) {
188 MS_EXCEPTION_IF_NULL(node);
189 MS_EXCEPTION_IF_NULL(addr_list);
190 MS_EXCEPTION_IF_NULL(size_list);
191 MS_EXCEPTION_IF_NULL(total_size);
192
193 (*addr_list).clear();
194 (*size_list).clear();
195 *total_size = 0;
196
197 if (is_input) {
198 const auto &intput_sizes = AnfAlgo::GetNodeInputSizeList(node);
199 for (size_t i = 0; i < intput_sizes.size(); ++i) {
200 const auto &device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(node, i, false);
201 MS_EXCEPTION_IF_NULL(device_tensor);
202 *total_size += intput_sizes[i];
203 (void)size_list->emplace_back(intput_sizes[i]);
204 (void)addr_list->emplace_back(device_tensor);
205 }
206 } else {
207 const auto &kernel_mod = AnfAlgo::GetKernelMod(node);
208 MS_EXCEPTION_IF_NULL(kernel_mod);
209 const auto &output_sizes = kernel_mod->GetOutputSizeList();
210 for (size_t i = 0; i < output_sizes.size(); ++i) {
211 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
212 MS_EXCEPTION_IF_NULL(device_tensor);
213 *total_size += output_sizes[i];
214 (void)size_list->emplace_back(output_sizes[i]);
215 (void)addr_list->emplace_back(device_tensor);
216 }
217 }
218 }
219
ValueTupleToValue(const ValuePtr & value,std::vector<ValuePtr> * const values)220 void ValueTupleToValue(const ValuePtr &value, std::vector<ValuePtr> *const values) {
221 MS_EXCEPTION_IF_NULL(value);
222 MS_EXCEPTION_IF_NULL(values);
223 if (value->isa<ValueSequence>()) {
224 auto value_tuple = value->cast<ValueSequencePtr>();
225 MS_EXCEPTION_IF_NULL(value_tuple);
226 for (size_t i = 0; i < value_tuple->size(); ++i) {
227 ValuePtr element = value_tuple->value()[i];
228 MS_EXCEPTION_IF_NULL(element);
229
230 if (element->isa<ValueSequence>()) {
231 ValueTupleToValue(element, values);
232 } else {
233 (void)values->emplace_back(element);
234 }
235 }
236 } else if (value->isa<tensor::CSRTensor>()) {
237 auto csr_tensor = value->cast<tensor::CSRTensorPtr>();
238 MS_EXCEPTION_IF_NULL(csr_tensor);
239 MS_EXCEPTION_IF_NULL(csr_tensor->GetIndptr());
240 MS_EXCEPTION_IF_NULL(csr_tensor->GetIndices());
241 MS_EXCEPTION_IF_NULL(csr_tensor->GetValues());
242 (void)values->emplace_back(csr_tensor->GetIndptr());
243 (void)values->emplace_back(csr_tensor->GetIndices());
244 (void)values->emplace_back(csr_tensor->GetValues());
245 (void)std::transform(csr_tensor->shape().begin(), csr_tensor->shape().end(), std::back_inserter(*values),
246 [](int64_t n) { return std::make_shared<Int64Imm>(n); });
247 } else if (value->isa<tensor::COOTensor>()) {
248 auto coo_tensor = value->cast<tensor::COOTensorPtr>();
249 MS_EXCEPTION_IF_NULL(coo_tensor);
250 MS_EXCEPTION_IF_NULL(coo_tensor->GetIndices());
251 MS_EXCEPTION_IF_NULL(coo_tensor->GetValues());
252 (void)values->emplace_back(coo_tensor->GetIndices());
253 (void)values->emplace_back(coo_tensor->GetValues());
254 (void)std::transform(coo_tensor->shape().begin(), coo_tensor->shape().end(), std::back_inserter(*values),
255 [](int64_t n) { return std::make_shared<Int64Imm>(n); });
256 } else {
257 (void)values->emplace_back(value);
258 }
259 }
260
261 // The device address of input ref node may be modified by input tensor, so need update the device address of ref node.
UpdateDeviceAddressByRefInputNode(const std::vector<KernelGraphPtr> & graphs,const std::set<AnfNode * > & modified_input_nodes)262 void UpdateDeviceAddressByRefInputNode(const std::vector<KernelGraphPtr> &graphs,
263 const std::set<AnfNode *> &modified_input_nodes) {
264 for (const auto &graph : graphs) {
265 MS_EXCEPTION_IF_NULL(graph);
266 // The DeviceAddress of the graph parameter has been updated.
267 if (graph->is_graph_run_mode()) {
268 continue;
269 }
270
271 for (auto &iter : graph->GetRefMap()) {
272 auto &output_pair = iter.first;
273 auto &input_pair = iter.second;
274 MS_EXCEPTION_IF_NULL(output_pair.first);
275 MS_EXCEPTION_IF_NULL(input_pair.first);
276 if (modified_input_nodes.count(input_pair.first.get()) == 0) {
277 continue;
278 }
279 // The output device tensor of ref node actor can't be changed in the running, and only the ptr of output device
280 // address can be modified. And need set `ref_count` to `SIZE_MAX` for avoiding clean. So only support the
281 // persistent device tensor.
282 if (!IsPersistentDeviceTensor(input_pair.first)) {
283 MS_LOG(INFO) << "The input parameter: " << input_pair.first->fullname_with_scope()
284 << " isn't the ref parameter which used by the ref node: "
285 << output_pair.first->fullname_with_scope();
286 continue;
287 }
288
289 MS_LOG(INFO) << "Update the ptr of ref node: " << output_pair.first->fullname_with_scope()
290 << " by the modified ref input parameter: " << input_pair.first->fullname_with_scope();
291 auto ref_node_output_addr = AnfAlgo::GetMutableOutputAddr(output_pair.first, output_pair.second, false);
292 MS_EXCEPTION_IF_NULL(ref_node_output_addr);
293 const auto &front_input_node = AnfAlgo::FetchFrontNodeByBackendNode(input_pair.first, *graph);
294 auto input_addr =
295 DeviceTensorStore::GetInstance().Fetch(front_input_node.get(), ref_node_output_addr->GetDeviceType());
296 // Maybe subgraphs share the same backend input parameter, so fetch device tensor store by front node of this
297 // subgraph maybe nullptr and use the output addr of input parameter directly.
298 if (input_addr == nullptr) {
299 input_addr = AnfAlgo::GetMutableOutputAddr(input_pair.first, input_pair.second, false);
300 }
301 MS_EXCEPTION_IF_NULL(input_addr);
302 MS_EXCEPTION_IF_CHECK_FAIL((ref_node_output_addr->GetDeviceType() == input_addr->GetDeviceType()),
303 "The device type of ref node is not equal.");
304 ref_node_output_addr->set_ptr(input_addr->GetMutablePtr());
305 ref_node_output_addr->set_original_ref_count(SIZE_MAX);
306 ref_node_output_addr->ResetRefCount();
307 }
308 }
309 }
310
IsNeedSync(const TensorPtr & tensor)311 bool IsNeedSync(const TensorPtr &tensor) {
312 if (RecoveryContext::GetInstance()->enable_recovery() &&
313 RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
314 return true;
315 }
316
317 if (tensor == nullptr) {
318 return false;
319 }
320 // Sub data need sync each step
321 auto data_ptr = tensor->data_ptr();
322 return data_ptr != nullptr && data_ptr->is_sub_data();
323 }
324
SyncTensorTrunk(const std::vector<std::vector<TensorPtr>> & input_tensors)325 void SyncTensorTrunk(const std::vector<std::vector<TensorPtr>> &input_tensors) {
326 for (auto &tensors : input_tensors) {
327 for (auto &tensor : tensors) {
328 if (tensor == nullptr) {
329 continue;
330 }
331 auto data_ptr = tensor->data_ptr();
332 if (data_ptr != nullptr && data_ptr->has_sub_data()) {
333 tensor->data_sync();
334 }
335 }
336 }
337 }
338
UpdateDataNodeDeviceAddressSize(const AnfNodePtr & input_node,const TensorPtr & input_tensor,const device::DeviceAddressPtr & device_address)339 void UpdateDataNodeDeviceAddressSize(const AnfNodePtr &input_node, const TensorPtr &input_tensor,
340 const device::DeviceAddressPtr &device_address) {
341 MS_EXCEPTION_IF_NULL(input_node);
342 MS_EXCEPTION_IF_NULL(input_tensor);
343 MS_EXCEPTION_IF_NULL(device_address);
344 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, 0);
345 if (output_type_id == kTypeUnknown) {
346 output_type_id = common::AnfAlgo::GetOutputInferDataType(input_node, 0);
347 }
348 auto device_shape =
349 trans::TransShapeToDevice(input_tensor->shape(), device_address->format(), input_node, 0, output_type_id);
350 size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
351 auto device_address_size = type_size * SizeOf(device_shape);
352 MS_LOG(INFO) << "Size of device_address is updated from " << device_address->GetSize() << " to "
353 << device_address_size;
354 device_address->SetSize(device_address_size);
355 }
356 } // namespace
357
358 mindspore::HashSet<const tensor::Tensor *> DataPrepareActor::tensors_need_reprepare_ = {};
359
360 std::atomic<size_t> DataPrepareActor::execution_count_ = 0;
361
Init()362 void DataPrepareActor::Init() {
363 MS_EXCEPTION_IF_NULL(graph_compiler_info_);
364 strategy_ = graph_compiler_info_->strategy_;
365 if (graph_compiler_info_->graphs_.size() != graph_compiler_info_->device_contexts_.size()) {
366 MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
367 }
368
369 size_t host_data_size = 0;
370 if (host_data_source_actor_ != nullptr) {
371 host_data_size = host_data_source_actor_->data_nodes().size();
372 }
373 has_parameter_input_ = graph_compiler_info_->inputs_num_ > host_data_size;
374 MS_LOG(INFO) << graph_compiler_info_->name_
375 << " has the parameter input num: " << graph_compiler_info_->inputs_num_ - host_data_size;
376
377 for (const auto &graph : graph_compiler_info_->graphs_) {
378 MS_EXCEPTION_IF_NULL(graph);
379 if (graph->is_dynamic_shape()) {
380 has_dynamic_shape_ = true;
381 break;
382 }
383 }
384
385 for (auto &iter : continuous_memory_nodes_) {
386 size_t total_size = 0;
387 std::vector<size_t> size_list;
388 std::vector<DeviceTensorPtr> addr_list;
389 // Inputs need continuous memory.
390 if (iter.second.first) {
391 const auto &cnode = iter.first.first;
392 FetchContinuousMemoryInfo(cnode, &addr_list, &size_list, &total_size, true);
393 (void)continuous_memory_alloc_list_list_.emplace_back(addr_list);
394 (void)size_list_list_.emplace_back(size_list);
395 (void)stream_id_list_.emplace_back(kDefaultStreamIndex);
396 (void)total_size_list_.emplace_back(total_size);
397 (void)continuous_memory_device_contexts_.emplace_back(iter.first.second);
398 }
399
400 // Outputs need continuous memory.
401 if (iter.second.second) {
402 const auto &cnode = iter.first.first;
403 FetchContinuousMemoryInfo(cnode, &addr_list, &size_list, &total_size, false);
404 (void)continuous_memory_alloc_list_list_.emplace_back(addr_list);
405 (void)size_list_list_.emplace_back(size_list);
406 (void)stream_id_list_.emplace_back(kDefaultStreamIndex);
407 (void)total_size_list_.emplace_back(total_size);
408 (void)continuous_memory_device_contexts_.emplace_back(iter.first.second);
409 }
410 }
411 }
412
UpdateDynamicShapeAndSize(const AnfNodePtr & input_node,const TensorPtr & input_tensor) const413 void DataPrepareActor::UpdateDynamicShapeAndSize(const AnfNodePtr &input_node, const TensorPtr &input_tensor) const {
414 MS_EXCEPTION_IF_NULL(input_node);
415 if (input_tensor == nullptr || IsEmptySequenceTensor(input_tensor)) {
416 return;
417 }
418 if (!input_node->isa<Parameter>()) {
419 return;
420 }
421 auto input_param = input_node->cast<ParameterPtr>();
422 MS_EXCEPTION_IF_NULL(input_param);
423 auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
424 MS_EXCEPTION_IF_NULL(device_address);
425 if (!input_param->has_dynamic_shape() && !IsDynamic(device_address->host_shape())) {
426 return;
427 }
428
429 // Update shape.
430 MS_LOG(DEBUG) << "Update dynamic shape for parameter:" << input_param->DebugString();
431 const auto &output_kernel_tensor = AnfAlgo::GetOutputKernelTensor(input_node, 0);
432 MS_EXCEPTION_IF_NULL(output_kernel_tensor);
433 if (input_tensor->base_shape_ptr() == nullptr || (!input_tensor->base_shape_ptr()->isa<abstract::SequenceShape>())) {
434 output_kernel_tensor->SetShape(input_tensor->ToAbstract()->GetShape());
435 return;
436 }
437 output_kernel_tensor->SetShape(input_tensor->base_shape_ptr());
438
439 // Update size.
440 auto device_format = device_address->format();
441 static const std::set<std::string> kNormalFormat = {
442 kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN,
443 };
444 if (kNormalFormat.find(device_format) != kNormalFormat.end()) {
445 auto tensor_data_size = input_tensor->data().nbytes();
446 MS_LOG(DEBUG) << "Set device address:" << device_address << " size from:" << device_address->GetSize()
447 << " to:" << tensor_data_size;
448 device_address->SetSize(tensor_data_size);
449 } else {
450 MS_LOG(DEBUG) << "Update data node device address size";
451 // Size of 5D format device_address is larger than tensor_data_size.
452 UpdateDataNodeDeviceAddressSize(input_node, input_tensor, device_address);
453 }
454 }
455
UpdateDeviceAddressForDataNode(const AnfNodePtr & input_node,const TensorPtr & input_tensor)456 void DataPrepareActor::UpdateDeviceAddressForDataNode(const AnfNodePtr &input_node, const TensorPtr &input_tensor) {
457 MS_EXCEPTION_IF_NULL(input_tensor);
458 MS_EXCEPTION_IF_NULL(input_node);
459
460 auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
461 if (tensor_address == nullptr) {
462 return;
463 }
464
465 auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
466 MS_EXCEPTION_IF_NULL(device_address);
467 if (tensor_address == device_address) {
468 tensor_address->SetNodeIndex(input_node, 0);
469 tensor_address->set_original_ref_count(SIZE_MAX);
470 tensor_address->ResetRefCount();
471 return;
472 }
473
474 // If tensor address and device address are different (heterogeneous scenarios), or device address is persisted
475 // Update device address data in data source actor process.
476 if (device_address->is_ptr_persisted() || (tensor_address->GetDeviceType() != device_address->GetDeviceType()) ||
477 (!AnfAlgo::IsEquivalentFormat(tensor_address->format(), device_address->format())) ||
478 (tensor_address->type_id() != device_address->type_id())) {
479 MS_LOG(DEBUG) << "Cannot update address of " << input_node->DebugString();
480 return;
481 }
482
483 // Assign tensor address to input data node and set `ref_count` to `SIZE_MAX` for avoiding clean.
484 (void)address_modified_input_nodes_.insert(input_node.get());
485 tensor_address->set_flag(device_address->flag());
486 DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(tensor_address, input_node, 0);
487 AnfAlgo::SetOutputAddr(tensor_address, 0, input_node.get());
488 MS_LOG(DEBUG) << "Update device address of " << input_node->DebugString() << " to " << tensor_address.get()
489 << ", kernel tensor addr:" << tensor_address->kernel_tensor().get()
490 << " ptr:" << tensor_address->GetPtr();
491 tensor_address->SetNodeIndex(input_node, 0);
492 tensor_address->set_original_ref_count(SIZE_MAX);
493 tensor_address->ResetRefCount();
494 }
495
SetInitTensorsIfNeeded(const std::vector<std::vector<TensorPtr>> & input_tensors)496 void DataPrepareActor::SetInitTensorsIfNeeded(const std::vector<std::vector<TensorPtr>> &input_tensors) {
497 if (!init_tensors_.empty()) {
498 return;
499 }
500 bool need_save = std::any_of(input_tensors.begin(), input_tensors.end(), [](const std::vector<TensorPtr> &tensors) {
501 return std::any_of(tensors.begin(), tensors.end(), [](const TensorPtr &tensor) {
502 if (tensor == nullptr) {
503 return false;
504 }
505 auto data_ptr = tensor->data_ptr();
506 return data_ptr != nullptr && data_ptr->is_sub_data();
507 });
508 });
509 if (need_save) {
510 init_tensors_ = input_tensors;
511 }
512 }
513
PrepareData(const std::vector<std::vector<TensorPtr>> & input_tensors,const VectorRef & args,OpContext<DeviceTensor> * const context,GraphExecutionStrategy real_strategy)514 void DataPrepareActor::PrepareData(const std::vector<std::vector<TensorPtr>> &input_tensors, const VectorRef &args,
515 OpContext<DeviceTensor> *const context, GraphExecutionStrategy real_strategy) {
516 MS_EXCEPTION_IF_NULL(context);
517 uint64_t start_time = 0;
518 PROFILER_START(start_time);
519
520 #if defined(__linux__) && defined(WITH_BACKEND)
521 // Update rpc actors' status.
522 RpcActorStatusUpdater::GetInstance().UpdateRpcActorStatus(graph_compiler_info_->name_);
523 #endif
524
525 try {
526 // Preprocess before prepare data for data prepare actor.
527 PreprocessBeforePrepareData();
528 } catch (const std::exception &e) {
529 MsException::Instance().SetException();
530 std::string error_info = e.what();
531 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
532 }
533
534 MS_LOG(DEBUG) << "Data prepare actor(" << GetAID().Name() << ") prepares data.";
535 real_strategy_ = real_strategy;
536 // Convert actor running data from input tensors.
537 if (!input_tensors.empty()) {
538 SyncTensorTrunk(input_tensors);
539 SetInitTensorsIfNeeded(input_tensors);
540 }
541 try {
542 auto ms_context = MsContext::GetInstance();
543 MS_EXCEPTION_IF_NULL(ms_context);
544 static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
545 if (first_step_ || !tensors_need_reprepare_.empty() || (has_parameter_input_ && !enable_infer_boost)) {
546 PrepareDataForDeviceTensorStore(input_tensors, args, context);
547 }
548 PrepareDataForHostTensorQueue(input_tensors, args, context);
549 } catch (const std::exception &e) {
550 std::string error_info = e.what();
551 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
552 }
553
554 first_step_ = false;
555 if (IsRunningFailed(context)) {
556 return;
557 }
558 MS_EXCEPTION_IF_NULL(graph_compiler_info_);
559 if (!address_modified_input_nodes_.empty()) {
560 UpdateDeviceAddressByRefInputNode(graph_compiler_info_->graphs_, address_modified_input_nodes_);
561 address_modified_input_nodes_.clear();
562 }
563
564 // Debug actor is blocked, must wait debug actor callback message to process continue.
565 if (debug_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
566 SendDebugReq(context);
567 return;
568 }
569
570 if (profiler_aid_ != nullptr && strategy_ == GraphExecutionStrategy::kPipeline) {
571 SendProfilerReq(context);
572 return;
573 }
574
575 PROFILER_END(start_time, runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kPreLaunch, GetAID().Name(),
576 false);
577
578 // Allocate continuous memory and send output to trigger the step running.
579 if (continuous_memory_alloc_list_list_.size() > 0) {
580 SendMemoryAllocReq(context);
581 } else {
582 PostRun(context);
583 }
584 }
585
SendDebugReq(OpContext<DeviceTensor> * const context)586 void DataPrepareActor::SendDebugReq(OpContext<DeviceTensor> *const context) {
587 MS_EXCEPTION_IF_NULL(graph_compiler_info_);
588 ActorDispatcher::SendSync(*debug_aid_, &DebugActor::DebugOnStepBegin, graph_compiler_info_->graphs_,
589 graph_compiler_info_->origin_parameters_order_, graph_compiler_info_->device_contexts_,
590 context, &GetAID());
591 OnDebugFinish(context);
592 }
593
SendProfilerReq(OpContext<DeviceTensor> * const context)594 void DataPrepareActor::SendProfilerReq(OpContext<DeviceTensor> *const context) {
595 MS_EXCEPTION_IF_NULL(graph_compiler_info_);
596 ActorDispatcher::SendSync(*profiler_aid_, &ProfilerActor::ProfilerOnStepBegin, graph_compiler_info_->graphs_,
597 graph_compiler_info_->origin_parameters_order_, graph_compiler_info_->device_contexts_,
598 context, &GetAID());
599 OnDebugFinish(context);
600 }
601
OnDebugFinish(OpContext<DeviceTensor> * const context)602 void DataPrepareActor::OnDebugFinish(OpContext<DeviceTensor> *const context) {
603 MS_EXCEPTION_IF_NULL(context);
604 if (continuous_memory_alloc_list_list_.size() > 0) {
605 SendMemoryAllocReq(context);
606 } else {
607 PostRun(context);
608 }
609 }
610
SendMemoryAllocReq(OpContext<DeviceTensor> * const context)611 void DataPrepareActor::SendMemoryAllocReq(OpContext<DeviceTensor> *const context) {
612 // Allocate continuous memory in the begin of the step running.
613 if (ActorDispatcher::is_memory_allocation_sync()) {
614 if (!ActorDispatcher::enable_use_trace_memory()) {
615 ActorDispatcher::SendSync(memory_manager_aid_, &MemoryManagerActor::AllocateContinuousMemory,
616 &continuous_memory_alloc_list_list_, &size_list_list_, &stream_id_list_,
617 &total_size_list_, &continuous_memory_device_contexts_, context, GetAID());
618 }
619 OnMemoryAllocFinish(context);
620 } else {
621 ActorDispatcher::Send(memory_manager_aid_, &MemoryManagerActor::AllocateContinuousMemory,
622 &continuous_memory_alloc_list_list_, &size_list_list_, &stream_id_list_, &total_size_list_,
623 &continuous_memory_device_contexts_, context, GetAID());
624 }
625 }
626
OnMemoryAllocFinish(OpContext<DeviceTensor> * const context)627 void DataPrepareActor::OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) {
628 MS_EXCEPTION_IF_NULL(context);
629 if (IsRunningFailed(context)) {
630 return;
631 }
632
633 PostRun(context);
634 }
635
FetchInputTensor(const std::vector<TensorPtr> & tensors,size_t tensor_index,const VectorRef & args,const KernelWithIndex & front_node) const636 TensorPtr DataPrepareActor::FetchInputTensor(const std::vector<TensorPtr> &tensors, size_t tensor_index,
637 const VectorRef &args, const KernelWithIndex &front_node) const {
638 if (!tensors.empty()) {
639 MS_EXCEPTION_IF_CHECK_FAIL((tensor_index < tensors.size()), "The tensor index is out of range.");
640 auto tensor = tensors[tensor_index];
641 // The tensor needs to be converted to contiguous before being given to the actors.
642 // After the view feature is supported in the graph mode, the following code will be deleted.
643 DeviceAddressUtils::ConvertContiguousTensorSync(tensor);
644 return tensor;
645 }
646
647 MS_EXCEPTION_IF_NULL(front_node.first);
648 const auto &iter = std::find(graph_compiler_info_->origin_parameters_order_.begin(),
649 graph_compiler_info_->origin_parameters_order_.end(), front_node.first);
650 if (iter == graph_compiler_info_->origin_parameters_order_.end()) {
651 MS_LOG(INFO) << "Not origin parameter: " << front_node.first->fullname_with_scope();
652 return nullptr;
653 }
654 auto arg_index = iter - graph_compiler_info_->origin_parameters_order_.begin();
655 auto tensor = FetchInputTensorByArg(args, arg_index, front_node);
656 // The tensor needs to be converted to contiguous before being given to the actors.
657 // After the view feature is supported in the graph mode, the following code will be deleted.
658 DeviceAddressUtils::ConvertContiguousTensorSync(tensor);
659 return tensor;
660 }
661
FetchInputTensorByArg(const VectorRef & args,size_t arg_index,const KernelWithIndex & front_node) const662 TensorPtr DataPrepareActor::FetchInputTensorByArg(const VectorRef &args, size_t arg_index,
663 const KernelWithIndex &front_node) const {
664 if (arg_index >= args.size()) {
665 MS_LOG(INFO) << "Arg index out of args range, index is " << arg_index << " and args size is " << args.size();
666 return nullptr;
667 }
668
669 std::vector<tensor::TensorPtr> flatten_tensors;
670 AnfAlgo::FlattenInputArg(args[arg_index], front_node.first, &flatten_tensors);
671 auto input_tensor_index = FetchInputTensorIndex(front_node);
672 if (input_tensor_index >= flatten_tensors.size()) {
673 MS_LOG(INFO) << "Input tensor index out of args range, index is " << input_tensor_index << " and tensors size is "
674 << flatten_tensors.size();
675 return nullptr;
676 }
677
678 auto tensor = flatten_tensors[input_tensor_index];
679 // The tensor needs to be converted to contiguous before being given to the actors.
680 // After the view feature is supported in the graph mode, the following code will be deleted.
681 DeviceAddressUtils::ConvertContiguousTensorSync(tensor);
682
683 if (tensor != nullptr && tensor->update_value_callback() == nullptr && tensor->is_parameter()) {
684 static auto callback = [](const tensor::Tensor *tensor) { tensors_need_reprepare_.insert(tensor); };
685 tensor->set_update_value_callback(callback);
686 }
687
688 if (tensor != nullptr && !tensors_need_reprepare_.empty() && tensor->is_parameter()) {
689 auto erased_num = tensors_need_reprepare_.erase(tensor.get());
690 MS_LOG(DEBUG) << "Erase " << erased_num << " tensor which is reprepared.";
691 }
692
693 return tensor;
694 }
695
FetchInputTensorIndex(const KernelWithIndex & front_node) const696 size_t DataPrepareActor::FetchInputTensorIndex(const KernelWithIndex &front_node) const {
697 MS_EXCEPTION_IF_NULL(front_node.first);
698 if (common::AnfAlgo::IsDynamicSequence(front_node.first)) {
699 return 0;
700 }
701
702 const auto &abs = front_node.first->abstract();
703 MS_EXCEPTION_IF_NULL(abs);
704 if (abs->isa<abstract::AbstractSequence>()) {
705 return front_node.second;
706 }
707
708 return 0;
709 }
710
PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> & input_tensors,const VectorRef & args,OpContext<DeviceTensor> * const context)711 void DataPrepareActor::PrepareDataForDeviceTensorStore(const std::vector<std::vector<TensorPtr>> &input_tensors,
712 const VectorRef &args, OpContext<DeviceTensor> *const context) {
713 MS_LOG(INFO) << "Prepare store data, input tensor size: " << input_tensors.size() << ", arg size: " << args.size();
714 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, "PrepareStoreData", true);
715 MS_EXCEPTION_IF_NULL(graph_compiler_info_);
716 const auto &parser = graph_compiler_info_->control_node_parser_;
717 MS_EXCEPTION_IF_NULL(parser);
718 for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
719 const auto &graph = graph_compiler_info_->graphs_[i];
720 const auto &device_context = graph_compiler_info_->device_contexts_[i];
721 MS_EXCEPTION_IF_NULL(graph);
722 MS_LOG(DEBUG) << "prepare data for graph:" << graph->ToString();
723 // Prepare the data of device tensor store(value nodes of graph).
724 for (const auto &value_node : graph->graph_value_nodes()) {
725 MS_EXCEPTION_IF_NULL(value_node);
726 if (AnfAlgo::OutputAddrExist(value_node, 0)) {
727 const auto &front_node = AnfAlgo::FetchFrontNodeByBackendNode(value_node, *graph);
728 MS_EXCEPTION_IF_NULL(front_node);
729 MS_LOG(DEBUG) << "Prepare data for value node:" << value_node->fullname_with_scope()
730 << ", debug name:" << value_node->DebugString() << ", front node:" << front_node->DebugString()
731 << " for graph:" << graph->ToString();
732 const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
733 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
734 MS_EXCEPTION_IF_NULL(device_tensor);
735 // If front_node has more than one device tensor, it means the node may used in multi graphs.
736 // so we will clear the deviceaddress flag of ignore.
737 if (TEST_FLAG(device_tensor->flag(), device::kDeviceAddressFlagIgnoreDevicePtr) && device_tensors.size() > 1) {
738 device_tensor->ClearFlag(device::kDeviceAddressFlagIgnoreDevicePtr);
739 }
740 // If node address has flag ignore, we will not prepare device data for it.
741 if (!TEST_FLAG(device_tensor->flag(), device::kDeviceAddressFlagIgnoreDevicePtr)) {
742 PrepareDataForValueNode(value_node, front_node, device_context, context);
743 }
744 }
745 }
746
747 // Prepare the data of device tensor store(weights of graph).
748 const auto &input_nodes = graph->input_nodes();
749 for (size_t j = 0; j < input_nodes.size(); ++j) {
750 const auto &input_node = input_nodes[j];
751 MS_EXCEPTION_IF_NULL(input_node);
752 const auto &real_device_context = device::FetchRealDeviceContext(input_node, device_context);
753 MS_EXCEPTION_IF_NULL(real_device_context);
754 const auto &front_node = AnfAlgo::FetchFrontNodeByBackendNode(input_node, *graph);
755 if (IsPersistentDeviceTensor(input_node) && parser->IsRootGraphPersistentDeviceTensor(front_node)) {
756 std::vector<TensorPtr> graph_tensors = input_tensors.empty() ? std::vector<TensorPtr>() : input_tensors[i];
757 TensorPtr input_tensor = FetchInputTensor(graph_tensors, j, args, {front_node, 0});
758 PrepareDataForWeightNode(input_node, front_node, input_tensor, real_device_context, context);
759 }
760 }
761 }
762 if (RecoveryContext::GetInstance()->enable_recovery() &&
763 RecoveryContext::GetInstance()->need_sync_weight_to_device()) {
764 RecoveryContext::GetInstance()->set_need_sync_weight_to_device(false);
765 }
766
767 std::vector<TensorPtr> control_input = input_tensors.empty() ? std::vector<TensorPtr>() : input_tensors.back();
768 PrepareDeviceTensorStoreForControlNode(parser, control_input, args, context);
769 }
770
PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> & input_tensors,const VectorRef & args,OpContext<DeviceTensor> * const context)771 void DataPrepareActor::PrepareDataForHostTensorQueue(const std::vector<std::vector<TensorPtr>> &input_tensors,
772 const VectorRef &args, OpContext<DeviceTensor> *const context) {
773 MS_LOG(INFO) << "Prepare host data, input tensor size: " << input_tensors.size() << ", arg size: " << args.size();
774 ProfilerRecorder profiler(ProfilerModule::kRuntime, ProfilerEvent::kPreLaunch, "PrepareHostData", true);
775 MS_EXCEPTION_IF_NULL(context);
776 MS_EXCEPTION_IF_NULL(graph_compiler_info_);
777 if ((host_data_source_actor_ == nullptr) || (host_tensor_queue_ == nullptr)) {
778 return;
779 }
780
781 if (input_tensors.empty()) {
782 PrepareDataForHostTensorQueueNew(args, context);
783 return;
784 }
785
786 // Fill host tensors.
787 std::vector<TensorPtr> host_tensors;
788 host_tensors.resize(host_data_source_actor_->data_nodes().size());
789 for (size_t i = 0; i < graph_compiler_info_->graphs_.size(); ++i) {
790 const auto &graph = graph_compiler_info_->graphs_[i];
791 MS_EXCEPTION_IF_NULL(graph);
792
793 const auto &input_nodes = graph->input_nodes();
794 const auto &tensors = input_tensors[i];
795 if (input_nodes.size() != tensors.size()) {
796 std::string error_info = "Invalid tensor size:" + std::to_string(tensors.size()) +
797 " and input node size:" + std::to_string(input_nodes.size()) +
798 " for kernel graph:" + graph->ToString();
799 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
800 }
801 for (size_t j = 0; j < input_nodes.size(); ++j) {
802 const auto &input_node = input_nodes[j];
803 const auto &input_tensor = tensors[j];
804 MS_EXCEPTION_IF_NULL(input_node);
805 if (!IsHostQueueDSActor(input_node, graph, graph_compiler_info_->origin_parameters_order_, strategy_) ||
806 input_tensor == nullptr) {
807 continue;
808 }
809
810 auto tensor_position = host_data_source_actor_->FetchNodePosition({input_node, 0});
811 if (tensor_position >= host_tensors.size()) {
812 std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
813 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
814 }
815 MS_LOG(DEBUG) << "Set tensor position:" << tensor_position << " for input data.";
816 host_tensors[tensor_position] = input_tensor;
817
818 // Synchronize dynamic shape info of the input tensor to the parameter node of graph.
819 if (graph->is_dynamic_shape()) {
820 UpdateDynamicShapeAndSize(input_node, input_tensor);
821 }
822
823 UpdateDeviceAddressForDataNode(input_node, input_tensor);
824 }
825 }
826
827 PrepareHostTensorQueueForControlNode(input_tensors.back(), &host_tensors, context);
828
829 host_tensor_queue_->Push(host_tensors);
830 }
831
PrepareDataForHostTensorQueueNew(const VectorRef & args,OpContext<DeviceTensor> * const context)832 void DataPrepareActor::PrepareDataForHostTensorQueueNew(const VectorRef &args, OpContext<DeviceTensor> *const context) {
833 MS_EXCEPTION_IF_NULL(context);
834 size_t host_data_size = host_data_source_actor_->data_nodes().size();
835 size_t current_data_num = 0;
836 std::vector<TensorPtr> host_tensors;
837 host_tensors.resize(host_data_size);
838 host_tensors_.resize(host_data_size);
839 bool isDyn = false;
840 // Fill host tensors.
841 for (size_t i = 0; i < graph_compiler_info_->origin_parameters_order_.size(); ++i) {
842 if (current_data_num == host_data_size) {
843 break;
844 }
845 const auto &origin_parameter = graph_compiler_info_->origin_parameters_order_[i];
846 MS_EXCEPTION_IF_NULL(origin_parameter);
847 // The input data is front of the parameter weight.
848 if (common::AnfAlgo::IsParameterWeight(origin_parameter->cast<ParameterPtr>())) {
849 MS_LOG(DEBUG) << "Skip the prepare host data for parameter: " << origin_parameter->fullname_with_scope();
850 continue;
851 }
852
853 auto iter = graph_compiler_info_->origin_parameters_to_backend_parameters_.find(origin_parameter);
854 if (iter == graph_compiler_info_->origin_parameters_to_backend_parameters_.end()) {
855 MS_LOG(DEBUG) << "Not find the parameter in the origin parameters: " << origin_parameter->fullname_with_scope();
856 continue;
857 }
858
859 for (auto origin_to_backend_pair : iter->second) {
860 auto input_tensor = FetchInputTensorByArg(args, i, origin_to_backend_pair.first);
861 if (input_tensor == nullptr) {
862 MS_LOG(ERROR) << "The input tensor is nullptr for arg index: " << i
863 << ", parameter: " << origin_parameter->fullname_with_scope();
864 continue;
865 }
866 // Single ops(run in pynative mode) output to net(context is graph mode) input.
867 runtime::DeviceAddressUtils::CreateKernelTensor(input_tensor);
868 auto tensor_position = host_data_source_actor_->FetchNodePosition(origin_to_backend_pair.second);
869 if (tensor_position >= host_tensors.size()) {
870 std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
871 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
872 }
873 if (host_tensors[tensor_position] != nullptr) {
874 continue;
875 }
876 MS_LOG(INFO) << "Set host tensor position:" << tensor_position
877 << " for input parameter:" << origin_parameter->fullname_with_scope();
878
879 if (!isDyn) {
880 if (host_tensors_[tensor_position] != input_tensor->shape()) {
881 isDyn = true;
882 }
883 }
884 host_tensors_[tensor_position] = input_tensor->shape();
885 host_tensors[tensor_position] = input_tensor;
886 ++current_data_num;
887
888 UpdateDynamicShapeAndSize(origin_to_backend_pair.second.first, input_tensor);
889
890 // Avoid the device `ptr_` being hold by the input tensor and the output tensor, the input tensor address cannot
891 // be directly set to the input control node, which may be a passthrough node. The device 'ptr_' is re-malloced
892 // and device to device copy by input tensor address in data source process.
893 if (origin_to_backend_pair.first.first != origin_to_backend_pair.second.first) {
894 UpdateDeviceAddressForDataNode(origin_to_backend_pair.second.first, input_tensor);
895 }
896 }
897 }
898
899 auto ms_context = MsContext::GetInstance();
900 MS_EXCEPTION_IF_NULL(ms_context);
901 static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
902 if (enable_infer_boost && has_dynamic_shape_ && EnableKbkSubGraphExecute()) {
903 ActorDispatcher::set_enable_static_shape(!isDyn);
904
905 const auto &phase = PhaseManager::GetInstance().phase();
906 bool is_increment_graph = (phase.find("increment") != std::string::npos);
907 if (EnableTraceMemory() && is_increment_graph) {
908 if (continuous_memory_alloc_list_list_.size() > 0) {
909 MS_LOG(EXCEPTION)
910 << "Can not support continuous memory allocate in dynamic shape graph when enable trace memory.";
911 }
912 if (!ActorDispatcher::enable_static_shape()) {
913 ActorDispatcher::set_enable_trace_dynamic_memory(true);
914 } else {
915 ActorDispatcher::set_enable_use_trace_memory(true);
916 }
917 }
918 }
919 host_tensor_queue_->Push(host_tensors);
920 }
921
922 // The branch processing of PrepareDataForValueNode that value type is tensor.
PrepareDataForValueNodeTensor(const ValueNodePtr & node,const ValuePtr & node_value,const AnfNodePtr & front_node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context) const923 void DataPrepareActor::PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
924 const AnfNodePtr &front_node, const DeviceContext *device_context,
925 OpContext<DeviceTensor> *const context) const {
926 MS_EXCEPTION_IF_NULL(node);
927 MS_EXCEPTION_IF_NULL(node_value);
928 MS_EXCEPTION_IF_NULL(device_context);
929 MS_EXCEPTION_IF_NULL(context);
930
931 auto tensor = node_value->cast<TensorPtr>();
932 MS_EXCEPTION_IF_NULL(tensor);
933 if (tensor->is_forward_output()) {
934 return;
935 }
936
937 if (!first_step_) {
938 return;
939 }
940
941 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
942 MS_EXCEPTION_IF_NULL(device_tensor);
943 // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
944 if (device_tensor->IsPtrValid()) {
945 CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
946 return;
947 }
948 MS_LOG(INFO) << "Prepare device data for value node: " << node->DebugString() << ", output index: " << 0
949 << " device address:" << device_tensor;
950 tensor->set_device_address(device_tensor);
951 UpdateRefCount(device_tensor.get(), true);
952
953 SyncTensorData(tensor, device_tensor, node, device_context, context, real_strategy_);
954 MS_LOG(DEBUG) << "Prepare device data for value node: " << node->DebugString() << ", output index: " << 0
955 << " device address:" << device_tensor << " ptr:" << device_tensor->GetPtr();
956 CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
957 }
958
PrepareDataForControlValueNode(const KernelWithIndex & node_with_index,const DeviceContext * device_context,OpContext<DeviceTensor> * const context,const ControlNodeParserPtr & parser) const959 void DataPrepareActor::PrepareDataForControlValueNode(const KernelWithIndex &node_with_index,
960 const DeviceContext *device_context,
961 OpContext<DeviceTensor> *const context,
962 const ControlNodeParserPtr &parser) const {
963 MS_EXCEPTION_IF_NULL(device_context);
964 MS_EXCEPTION_IF_NULL(context);
965 MS_EXCEPTION_IF_NULL(node_with_index.first);
966 MS_EXCEPTION_IF_NULL(parser);
967 if (!node_with_index.first->isa<ValueNode>()) {
968 return;
969 }
970
971 const auto &node = node_with_index.first->cast<ValueNodePtr>();
972 MS_EXCEPTION_IF_NULL(node);
973 size_t index = node_with_index.second;
974 MS_LOG(DEBUG) << "Prepare data for control value node:" << node->DebugString() << " index:" << index;
975 auto node_value = node->value();
976 if (common::AnfAlgo::IsDynamicSequence(node)) {
977 auto tensor = AnfAlgo::SequenceToTensor(node_value);
978 parser->AddControlNodeTensor(tensor);
979 node_value = tensor;
980 AnfAlgo::UpdateValueNodeShape(node);
981 }
982 MS_EXCEPTION_IF_NULL(node_value);
983 std::vector<ValuePtr> values;
984 ValueTupleToValue(node_value, &values);
985
986 if (node_with_index.second >= values.size()) {
987 MS_LOG(INFO) << "Invalid index:" << node_with_index.second << " for node:" << node->DebugString();
988 return;
989 }
990 const auto &value = values[index];
991 MS_EXCEPTION_IF_NULL(value);
992 TensorPtr tensor = nullptr;
993 if (value->isa<StringImm>()) {
994 PrepareDataForStringValue(node, index, node, device_context, context);
995 return;
996 } else if (!value->isa<tensor::Tensor>()) {
997 tensor = parser->CreateTensorForValue(value);
998 } else {
999 tensor = value->cast<tensor::TensorPtr>();
1000 }
1001
1002 MS_EXCEPTION_IF_NULL(tensor);
1003 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, index, false);
1004 MS_EXCEPTION_IF_NULL(device_tensor);
1005 if (device_tensor->GetPtr() != nullptr) {
1006 return;
1007 }
1008
1009 tensor->set_device_address(device_tensor);
1010 UpdateRefCount(device_tensor.get(), true);
1011
1012 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->DebugString(), device::AllocatorType::kConstantValue, 0);
1013 auto graph_str = (node->func_graph() == nullptr) ? "" : node->func_graph()->ToString();
1014 UpdateTracker("PrepareDataForControlValueNode", node, graph_str, device::tracker::MemType::kConstantValue,
1015 device_tensor);
1016 if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex)) {
1017 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context, node->fullname_with_scope(),
1018 device_tensor->GetSize());
1019 }
1020 if (common::IsNeedProfileMemory()) {
1021 auto output_address = reinterpret_cast<uintptr_t>(device_tensor.get());
1022 MS_LOG(WARNING) << "Need Profile Memory, alloc type: PrepareDataForControlValueNode, device address class ptr: "
1023 << output_address << ", node: " << node->fullname_with_scope() << ", graph: " << graph_str
1024 << ", device address size: " << device_tensor->GetSize()
1025 << ", device address addr: " << device_tensor->GetPtr();
1026 }
1027
1028 if (tensor->data_ptr() == nullptr && device_tensor->GetSize() == 0) {
1029 MS_LOG(INFO) << "Empty tuple sync";
1030 return;
1031 }
1032
1033 auto host_tensor_size = LongToSize(tensor->data().nbytes());
1034 auto host_tensor_type = tensor->data_type();
1035 auto shape = tensor->shape();
1036 if (!device_tensor->SyncHostToDevice(shape, host_tensor_size, host_tensor_type, tensor->device_info().host_format_,
1037 tensor->data_ptr())) {
1038 std::string error_info = "Sync host to device failed for node:" + node->DebugString();
1039 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info);
1040 }
1041 }
1042
PrepareDataForStringValue(const ValueNodePtr & node,size_t index,const AnfNodePtr & front_node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context) const1043 void DataPrepareActor::PrepareDataForStringValue(const ValueNodePtr &node, size_t index, const AnfNodePtr &front_node,
1044 const DeviceContext *device_context,
1045 OpContext<DeviceTensor> *const context) const {
1046 MS_EXCEPTION_IF_NULL(node);
1047 if (!IsValueNode<StringImm>(node)) {
1048 return;
1049 }
1050 auto &node_value = node->value();
1051 MS_EXCEPTION_IF_NULL(node_value);
1052
1053 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, index, false);
1054 MS_EXCEPTION_IF_NULL(device_tensor);
1055 // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
1056 if (device_tensor->GetPtr() != nullptr) {
1057 if (first_step_) {
1058 CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
1059 }
1060 return;
1061 }
1062 MS_LOG(INFO) << "Prepare device data for value node: " << node->DebugString();
1063
1064 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), device::AllocatorType::kConstantValue,
1065 0);
1066 auto graph_str = (node->func_graph() == nullptr) ? "" : node->func_graph()->ToString();
1067 UpdateTracker("PrepareDataForStringValue", node, graph_str, device::tracker::MemType::kConstantValue, device_tensor);
1068 if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex)) {
1069 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context, node->fullname_with_scope(),
1070 device_tensor->GetSize());
1071 }
1072 if (common::IsNeedProfileMemory()) {
1073 auto output_address = reinterpret_cast<uintptr_t>(device_tensor.get());
1074 MS_LOG(WARNING) << "Need Profile Memory, alloc type: PrepareDataForValueNode, device address class ptr: "
1075 << output_address << ", device address size: " << device_tensor->GetSize()
1076 << ", node: " << node->fullname_with_scope() << ", graph: " << graph_str
1077 << ", device address addr: " << device_tensor->GetPtr();
1078 }
1079
1080 // Copy data from value to device.
1081 auto value = GetValue<std::string>(node_value);
1082 size_t tensor_size = value.size();
1083 ShapeVector shape = {1, SizeToLong(tensor_size)};
1084 // account '\0' to string size, keep consistent with method `CreateDeviceAddressForScalarAndString` defined in
1085 // `device_address_utils.cc`
1086 size_t string_tensor_size = tensor_size + 1;
1087 if (!device_tensor->SyncHostToDevice(shape, string_tensor_size, kObjectTypeString, value.data())) {
1088 std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope();
1089 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1090 }
1091 CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
1092 }
1093
PrepareDataForSequenceAndScalarValue(const ValueNodePtr & node,size_t index,const AnfNodePtr & front_node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context) const1094 void DataPrepareActor::PrepareDataForSequenceAndScalarValue(const ValueNodePtr &node, size_t index,
1095 const AnfNodePtr &front_node,
1096 const DeviceContext *device_context,
1097 OpContext<DeviceTensor> *const context) const {
1098 if (!first_step_) {
1099 return;
1100 }
1101 MS_EXCEPTION_IF_NULL(node);
1102 MS_EXCEPTION_IF_NULL(device_context);
1103 MS_EXCEPTION_IF_NULL(context);
1104 auto &node_value = node->value();
1105 MS_EXCEPTION_IF_NULL(node_value);
1106
1107 if ((!node_value->isa<ValueSequence>()) && (!node_value->isa<Scalar>())) {
1108 return;
1109 }
1110
1111 if (node_value->isa<ValueSequence>() && node_value->cast<ValueSequencePtr>()->size() == 0) {
1112 return;
1113 }
1114
1115 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, index, false);
1116 MS_EXCEPTION_IF_NULL(device_tensor);
1117 // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
1118 if (device_tensor->GetPtr() != nullptr) {
1119 CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
1120 return;
1121 }
1122
1123 UpdateRefCount(device_tensor.get(), true);
1124 MS_LOG(INFO) << "Prepare device data for value node: " << node->DebugString();
1125 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), device::AllocatorType::kConstantValue,
1126 0);
1127 // 1. Allocate device memory for value node.
1128 auto graph_str = (node->func_graph() == nullptr) ? "" : node->func_graph()->ToString();
1129 UpdateTracker("PrepareDataForSequenceAndScalarValue", node, graph_str, device::tracker::MemType::kConstantValue,
1130 device_tensor);
1131 if (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex)) {
1132 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context, node->fullname_with_scope(),
1133 device_tensor->GetSize());
1134 }
1135 if (common::IsNeedProfileMemory()) {
1136 auto output_address = reinterpret_cast<uintptr_t>(device_tensor.get());
1137 MS_LOG(WARNING) << "Need Profile Memory, alloc type: PrepareDataForValueNode, device address class ptr: "
1138 << output_address << ", device address size: " << device_tensor->GetSize()
1139 << ", node: " << node->fullname_with_scope() << ", graph: " << graph_str
1140 << ", device address addr: " << device_tensor->GetPtr();
1141 }
1142
1143 // 2. Sync copy data from host to device.
1144 const auto &kernel_tensor = device_tensor->kernel_tensor();
1145 MS_EXCEPTION_IF_NULL(kernel_tensor);
1146 if (!device_tensor->SyncHostToDevice(kernel_tensor->GetShapeVector(), kernel_tensor->size(),
1147 kernel_tensor->dtype_id(), kernel_tensor->GetValuePtr())) {
1148 std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope();
1149 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1150 }
1151
1152 // 3. Handle heterogeneous scene.
1153 CopyDataFromDeviceTensorStore(front_node, node, device_tensor, device_context, context);
1154 }
1155
1156 // Prepare the device data for persistent device tensor of value node.
PrepareDataForValueNode(const ValueNodePtr & node,const AnfNodePtr & front_node,const DeviceContext * device_context,OpContext<DeviceTensor> * const context) const1157 void DataPrepareActor::PrepareDataForValueNode(const ValueNodePtr &node, const AnfNodePtr &front_node,
1158 const DeviceContext *device_context,
1159 OpContext<DeviceTensor> *const context) const {
1160 MS_EXCEPTION_IF_NULL(node);
1161 MS_EXCEPTION_IF_NULL(front_node);
1162 MS_EXCEPTION_IF_NULL(device_context);
1163 MS_EXCEPTION_IF_NULL(context);
1164 auto &node_value = node->value();
1165 MS_EXCEPTION_IF_NULL(node_value);
1166 MS_LOG(DEBUG) << "Prepare data for value node:" << node->DebugString() << " front node:" << front_node->DebugString();
1167 if (node_value->isa<tensor::Tensor>()) {
1168 PrepareDataForValueNodeTensor(node, node_value, front_node, device_context, context);
1169 } else if (node_value->isa<ValueSequence>() || node_value->isa<Scalar>()) {
1170 PrepareDataForSequenceAndScalarValue(node, 0, front_node, device_context, context);
1171 } else if (node_value->isa<StringImm>()) {
1172 PrepareDataForStringValue(node, 0, front_node, device_context, context);
1173 } else if (node_value->isa<None>() || node_value->isa<Type>()) {
1174 MS_LOG(DEBUG) << "No need to prepare data for None or type value node:" << node->DebugString();
1175 } else {
1176 MS_LOG(WARNING) << "Not support the value type: " << node->fullname_with_scope();
1177 }
1178 }
1179
CopyDataFromDeviceTensorStore(const AnfNodePtr & front_node,const AnfNodePtr & backend_node,const device::DeviceAddressPtr & host_tensor_address,const DeviceContext * device_context,OpContext<DeviceTensor> * context) const1180 void DataPrepareActor::CopyDataFromDeviceTensorStore(const AnfNodePtr &front_node, const AnfNodePtr &backend_node,
1181 const device::DeviceAddressPtr &host_tensor_address,
1182 const DeviceContext *device_context,
1183 OpContext<DeviceTensor> *context) const {
1184 MS_EXCEPTION_IF_NULL(backend_node);
1185 MS_EXCEPTION_IF_NULL(device_context);
1186 MS_EXCEPTION_IF_NULL(context);
1187 const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
1188 for (auto &another_device_tensor : device_tensors) {
1189 if (another_device_tensor == host_tensor_address) {
1190 continue;
1191 }
1192 MS_EXCEPTION_IF_NULL(another_device_tensor);
1193 auto another_device_name = device::GetDeviceNameByType(another_device_tensor->GetDeviceType());
1194 const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1195 {another_device_name, device_context->device_context_key().device_id_});
1196 MS_EXCEPTION_IF_NULL(another_device_context);
1197 auto type = backend_node->isa<ValueNode>() ? device::AllocatorType::kConstantValue : device::AllocatorType::kWeight;
1198 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(backend_node->fullname_with_scope(), type, 0);
1199 bool need_alloc_memory = (another_device_tensor->GetPtr() == nullptr);
1200 auto graph_str = (backend_node->func_graph() == nullptr) ? "" : backend_node->func_graph()->ToString();
1201 if (need_alloc_memory) {
1202 auto mem_type =
1203 backend_node->isa<ValueNode>() ? device::tracker::MemType::kConstantValue : device::tracker::MemType::kWeight;
1204 UpdateTracker("CopyDataFromDeviceTensorStore", backend_node, graph_str, mem_type, another_device_tensor);
1205 }
1206 if (need_alloc_memory && (!another_device_context->device_res_manager_->AllocateMemory(another_device_tensor.get(),
1207 kDefaultStreamIndex))) {
1208 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *another_device_context,
1209 backend_node->fullname_with_scope(),
1210 another_device_tensor->GetSize());
1211 }
1212 if (common::IsNeedProfileMemory() && need_alloc_memory) {
1213 auto output_address = reinterpret_cast<uintptr_t>(another_device_tensor.get());
1214 MS_LOG(WARNING) << "Need Profile Memory, alloc type: CopyDataFromDeviceTensorStore, device address class ptr: "
1215 << output_address << ", device address size: " << another_device_tensor->GetSize()
1216 << ", device address addr: " << another_device_tensor->GetPtr()
1217 << ", node: " << backend_node->fullname_with_scope() << ", graph: " << graph_str
1218 << ", frontnode: " << (front_node == nullptr ? "null" : front_node->DebugString());
1219 }
1220
1221 MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
1222 << ", device name:" << another_device_name << " from device address:" << host_tensor_address
1223 << " to:" << another_device_tensor;
1224 if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
1225 std::string error_info = "Sync data error.";
1226 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1227 }
1228 }
1229 }
1230
1231 // Prepare the device data for persistent device tensor of weight node from host tensor.
PrepareDataForWeightNode(const AnfNodePtr & backend_node,const AnfNodePtr & front_node,const TensorPtr & tensor,const DeviceContext * device_context,OpContext<DeviceTensor> * const context)1232 void DataPrepareActor::PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node,
1233 const TensorPtr &tensor, const DeviceContext *device_context,
1234 OpContext<DeviceTensor> *const context) {
1235 MS_EXCEPTION_IF_NULL(backend_node);
1236 MS_EXCEPTION_IF_NULL(front_node);
1237 MS_EXCEPTION_IF_NULL(device_context);
1238 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
1239 MS_EXCEPTION_IF_NULL(context);
1240 auto param_node = backend_node->cast<ParameterPtr>();
1241 if (param_node != nullptr) {
1242 auto param_info = param_node->param_info();
1243 bool used = !param_info->ignore_device_addr();
1244 if (!used) {
1245 MS_LOG(DEBUG) << backend_node->DebugString()
1246 << " the Parameter is never used by real kernel in graphs, skip to allocate.";
1247 return;
1248 }
1249 }
1250 if (tensor == nullptr) {
1251 return;
1252 }
1253
1254 auto device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
1255 MS_EXCEPTION_IF_NULL(device_tensor);
1256 auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
1257 // Use the device address of host tensor to set device tensor.
1258 bool is_need_sync = IsNeedSync(tensor);
1259 if (host_tensor_address != device_tensor) {
1260 if (host_tensor_address == nullptr) {
1261 if (device_tensor->GetDeviceType() != device_context->GetDeviceType()) {
1262 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
1263 {backend_node, 0}, nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id(),
1264 device_tensor->host_shape(), device_context->device_context_key().device_name_,
1265 device_context->device_context_key().device_id_);
1266 kernel_tensor->set_stream_id(device_tensor->stream_id());
1267 host_tensor_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1268 MS_EXCEPTION_IF_NULL(host_tensor_address);
1269 MS_LOG(DEBUG) << "Create device tensor:" << host_tensor_address << " type:" << host_tensor_address->type_id();
1270 host_tensor_address->set_from_persistent_mem(tensor->is_parameter());
1271 } else {
1272 host_tensor_address = device_tensor;
1273 }
1274 is_need_sync = true;
1275 tensor->set_device_address(host_tensor_address);
1276 UpdateRefCount(host_tensor_address.get(), true);
1277 }
1278 MS_EXCEPTION_IF_NULL(host_tensor_address);
1279
1280 if (host_tensor_address->GetDeviceType() != device_tensor->GetDeviceType()) {
1281 MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->GetDeviceType()
1282 << ", device tensor type:" << device_tensor->GetDeviceType();
1283 // The fake heterogeneous scenario.
1284 if (DeviceTensorStore::GetInstance().Fetch(front_node.get()).size() == 1) {
1285 tensor->data_sync();
1286 host_tensor_address = device_tensor;
1287 tensor->set_device_address(device_tensor);
1288 is_need_sync = true;
1289 }
1290 } else if (host_tensor_address != device_tensor) {
1291 // In the scenario of training + inference , the device address of the weight node can not be changed when
1292 // multi-graphs sink mode is set.
1293 if (device_tensor->is_ptr_persisted() ||
1294 !AnfAlgo::IsEquivalentFormat(host_tensor_address->format(), device_tensor->format())) {
1295 if ((device_tensor->GetPtr() == nullptr) &&
1296 (!device_context->device_res_manager_->AllocateMemory(device_tensor.get(), kDefaultStreamIndex))) {
1297 SET_OPCONTEXT_MEMORY_ALLOC_FAIL_BY_STRATEGY(real_strategy_, *context, *device_context,
1298 backend_node->fullname_with_scope(), device_tensor->GetSize());
1299 }
1300 if (!Copy(device_tensor.get(), host_tensor_address.get())) {
1301 std::string error_info = "Sync data error.";
1302 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1303 }
1304 host_tensor_address = device_tensor;
1305 tensor->set_device_address(device_tensor);
1306 } else {
1307 (void)address_modified_input_nodes_.insert(backend_node.get());
1308 host_tensor_address->set_flag(device_tensor->flag());
1309 DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(host_tensor_address, backend_node, 0);
1310 AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
1311 }
1312 }
1313 }
1314 // Maybe the same host_tensor_address corresponds to the different front_node in shared weight scene,
1315 // so need update the device tensor store always.
1316 MS_EXCEPTION_IF_NULL(host_tensor_address);
1317 host_tensor_address->SetNodeIndex(backend_node, 0);
1318 DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
1319
1320 // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
1321 if (is_need_sync || (!host_tensor_address->IsPtrValid())) {
1322 MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->DebugString()
1323 << ", device type:" << host_tensor_address->GetDeviceType();
1324 SyncTensorData(tensor, host_tensor_address, backend_node, device_context, context, real_strategy_);
1325 }
1326
1327 // Allocate another device memory and copy data from host tensor to another device(if exist).
1328 CopyDataFromDeviceTensorStore(front_node, backend_node, host_tensor_address, device_context, context);
1329 }
1330
PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr & control_node_parser,const std::vector<TensorPtr> & tensors,const VectorRef & args,OpContext<DeviceTensor> * const context) const1331 void DataPrepareActor::PrepareDeviceTensorStoreForControlNode(const ControlNodeParserPtr &control_node_parser,
1332 const std::vector<TensorPtr> &tensors,
1333 const VectorRef &args,
1334 OpContext<DeviceTensor> *const context) const {
1335 MS_EXCEPTION_IF_NULL(control_node_parser);
1336 if (!control_node_parser->IsInited()) {
1337 return;
1338 }
1339
1340 for (const auto &value_node_with_context : control_node_parser->front_value_nodes()) {
1341 MS_EXCEPTION_IF_NULL(value_node_with_context.first.first);
1342 if (value_node_with_context.first.first->kernel_info() != nullptr &&
1343 AnfAlgo::OutputAddrExist(value_node_with_context.first.first, 0)) {
1344 PrepareDataForControlValueNode(value_node_with_context.first, value_node_with_context.second, context,
1345 control_node_parser);
1346 }
1347 }
1348
1349 const auto &control_node_parameters = control_node_parser->control_node_parameters();
1350 if (!tensors.empty() && control_node_parameters.size() != tensors.size()) {
1351 SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), "Invalid tensor size.");
1352 }
1353 for (size_t i = 0; i < control_node_parameters.size(); ++i) {
1354 auto &front_parameter = control_node_parameters[i].first;
1355 MS_EXCEPTION_IF_NULL(front_parameter);
1356 if (!control_node_parser->IsRootGraphPersistentDeviceTensor(front_parameter)) {
1357 continue;
1358 }
1359
1360 TensorPtr tensor = FetchInputTensor(tensors, i, args, control_node_parameters[i]);
1361 if (tensor == nullptr) {
1362 continue;
1363 }
1364
1365 auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_parameter.get());
1366 if (device_tensors.empty()) {
1367 MS_LOG(WARNING) << "Failed to get device tensor for front node:" << front_parameter->DebugString();
1368 continue;
1369 }
1370 MS_EXCEPTION_IF_NULL(device_tensors[0]);
1371 auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
1372 if ((device_tensors[0] == host_tensor_address) || (device_tensors[0]->IsPtrValid())) {
1373 continue;
1374 }
1375
1376 auto node = (device_tensors[0]->GetNodeIndex()).first;
1377 MS_EXCEPTION_IF_NULL(node);
1378 MS_LOG(INFO) << "Prepare device data for weight node by root graph parameter:"
1379 << front_parameter->fullname_with_scope() << ", backend node:" << node->DebugString()
1380 << ", device type:" << device_tensors[0]->GetDeviceType();
1381 if (host_tensor_address == nullptr) {
1382 tensor->set_device_address(device_tensors[0]);
1383 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1384 {device_tensors[0]->device_name(), device_tensors[0]->device_id()});
1385 SyncTensorData(tensor, device_tensors[0], node, device_context, context, GraphExecutionStrategy::kPipeline);
1386 } else {
1387 if (host_tensor_address->GetSize() != device_tensors[0]->GetSize()) {
1388 MS_LOG(WARNING) << "Please check the size of parameter:" << front_parameter->fullname_with_scope()
1389 << ", host tensor size:" << host_tensor_address->GetSize()
1390 << ", device tensor size:" << device_tensors[0]->GetSize();
1391 }
1392 host_tensor_address->SetNodeIndex(node, 0);
1393 UpdateRefCount(host_tensor_address.get(), true);
1394 DeviceTensorStore::GetInstance().Remove(front_parameter.get());
1395 DeviceTensorStore::GetInstance().Insert(front_parameter.get(), host_tensor_address);
1396 }
1397 }
1398 }
1399
PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> & tensors,std::vector<TensorPtr> * const host_tensors,OpContext<DeviceTensor> * const context)1400 void DataPrepareActor::PrepareHostTensorQueueForControlNode(const std::vector<TensorPtr> &tensors,
1401 std::vector<TensorPtr> *const host_tensors,
1402 OpContext<DeviceTensor> *const context) {
1403 MS_EXCEPTION_IF_NULL(graph_compiler_info_);
1404 MS_EXCEPTION_IF_NULL(graph_compiler_info_->control_node_parser_);
1405 MS_EXCEPTION_IF_NULL(host_data_source_actor_);
1406 MS_EXCEPTION_IF_NULL(host_tensors);
1407
1408 const auto &control_node_parameters = graph_compiler_info_->control_node_parser_->control_node_parameters();
1409 for (size_t i = 0; i < control_node_parameters.size(); ++i) {
1410 const auto &input_node = control_node_parameters[i].first;
1411 const auto &input_tensor = tensors[i];
1412 MS_EXCEPTION_IF_NULL(input_node);
1413 if (IsPersistentDeviceTensor(input_node)) {
1414 continue;
1415 }
1416
1417 if (find(graph_compiler_info_->origin_parameters_order_.begin(),
1418 graph_compiler_info_->origin_parameters_order_.end(),
1419 input_node) == graph_compiler_info_->origin_parameters_order_.end()) {
1420 continue;
1421 }
1422
1423 auto tensor_position = host_data_source_actor_->FetchNodePosition(control_node_parameters[i]);
1424 if (tensor_position >= host_tensors->size()) {
1425 std::string error_info = "The position of tensor is out of range: " + std::to_string(tensor_position);
1426 SET_OPCONTEXT_FAIL_RET_WITH_ERROR_BY_STRATEGY(real_strategy_, (*context), error_info);
1427 }
1428 if ((*host_tensors)[tensor_position] != nullptr) {
1429 continue;
1430 }
1431 MS_LOG(DEBUG) << "Set tensor position:" << tensor_position << " for input data.";
1432 (*host_tensors)[tensor_position] = input_tensor;
1433
1434 UpdateDynamicShapeAndSize(input_node, input_tensor);
1435 // Avoid the device `ptr_` being hold by the input tensor and the output tensor, the input tensor address cannot
1436 // be directly set to the input control node, which may be a passthrough node. The device 'ptr_' is re-malloced
1437 // and device to device copy by input tensor address in data source process.
1438 }
1439 }
1440
PreprocessBeforePrepareData() const1441 void DataPrepareActor::PreprocessBeforePrepareData() const {
1442 // Embedding Cache mode needs to record the number of global steps executed by the compute graph.
1443 // The first step compute graph needs to wait for the Embedding cache prefetch cache to warm up to prevent the
1444 // GetNext operator from timing out in the compute graph.
1445 #if defined(__linux__) && defined(WITH_BACKEND)
1446 EmbeddingCacheScheduler::GetInstance().IncreaseGraphStep(GetAID());
1447 #endif
1448
1449 // Try to defrag memory.
1450 auto defrag_memory_step_freq = GetDefragMemoryStepFreq();
1451 if (++execution_count_ % defrag_memory_step_freq == 0) {
1452 std::set<const DeviceContext *> defrag_memory_contexts;
1453 for (auto &device_context : graph_compiler_info_->device_contexts_) {
1454 MS_EXCEPTION_IF_NULL(device_context);
1455 if ((defrag_memory_contexts.count(device_context) == 0)) {
1456 device_context->device_res_manager_->DefragMemory();
1457 }
1458 (void)defrag_memory_contexts.insert(device_context);
1459 }
1460 }
1461 }
1462 } // namespace runtime
1463 } // namespace mindspore
1464