1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/actor/actor_common.h"
18 #include <memory>
19 #include <unordered_map>
20 #include "ops/framework_op_name.h"
21 #include "ops/framework_ops.h"
22 #include "ops/structure_op_name.h"
23 #include "runtime/graph_scheduler/device_tensor_store.h"
24 #include "utils/ms_context.h"
25 #include "utils/ms_utils.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/backend/distributed/ps/ps_context.h"
28 #include "utils/phase.h"
29 #ifndef BUILD_LITE
30 #include "runtime/graph_scheduler/actor/kernel_async_launch_actor.h"
31 #include "runtime/graph_scheduler/actor/kernel_async_infer_actor.h"
32 #include "runtime/graph_scheduler/actor/kernel_async_resize_actor.h"
33 #endif
34
35 namespace mindspore {
36 namespace runtime {
37 bool ActorDispatcher::is_multi_thread_execution_ = true;
38 bool ActorDispatcher::enable_multi_stream_ = false;
39 bool ActorDispatcher::has_kernel_need_user_data_ = false;
40 bool ActorDispatcher::is_memory_allocation_sync_ = true;
41 bool ActorDispatcher::is_memory_free_sync_ = true;
42 bool ActorDispatcher::enable_runtime_multi_pipeline_ = false;
43 bool ActorDispatcher::enable_async_launch_kernel_ = false;
44 bool ActorDispatcher::disable_kbk_sub_graph_execute_ = false;
45 bool ActorDispatcher::enable_static_shape_ = false;
46 bool ActorDispatcher::enable_trace_dynamic_memory_ = false;
47 bool ActorDispatcher::enable_use_trace_memory_ = false;
48
IsRunningFailed(const OpContext<DeviceTensor> * context)49 bool IsRunningFailed(const OpContext<DeviceTensor> *context) { return (context->error_info_ != ""); }
50
ComputeThreadNums(size_t * actor_thread_num,size_t * actor_and_kernel_thread_num)51 void ComputeThreadNums(size_t *actor_thread_num, size_t *actor_and_kernel_thread_num) {
52 MS_EXCEPTION_IF_NULL(actor_thread_num);
53 MS_EXCEPTION_IF_NULL(actor_and_kernel_thread_num);
54 auto context_ptr = MsContext::GetInstance();
55 MS_EXCEPTION_IF_NULL(context_ptr);
56 const size_t cpu_core_num = std::thread::hardware_concurrency();
57 auto inter_op_parallel_num = static_cast<size_t>(context_ptr->get_param<uint32_t>(MS_CTX_INTER_OP_PARALLEL_NUM));
58 auto runtime_num_threads = static_cast<size_t>(context_ptr->get_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS));
59 size_t runtime_num_threads_min = std::min(runtime_num_threads, cpu_core_num);
60 size_t inter_op_parallel_num_min = std::min(inter_op_parallel_num, cpu_core_num);
61 const float kActorUsage = 0.18;
62 const size_t kActorThreadMinNum = 1;
63 // Compute the actor and kernel thread num.
64 // The MemoryManagerActor binds single thread, so if runtime_num_threads is 30, actor num would be 5,
65 // kernel num would be 25.
66 if (inter_op_parallel_num_min == 0) {
67 size_t actor_thread_max_num =
68 std::max(static_cast<size_t>(std::floor(runtime_num_threads_min * kActorUsage)), kActorThreadMinNum);
69 *actor_thread_num = actor_thread_max_num;
70 *actor_and_kernel_thread_num =
71 runtime_num_threads_min > *actor_thread_num ? (runtime_num_threads_min) : (*actor_thread_num + 1);
72 } else {
73 *actor_thread_num = inter_op_parallel_num_min;
74 *actor_and_kernel_thread_num = runtime_num_threads_min + *actor_thread_num;
75 }
76
77 if (*actor_and_kernel_thread_num > cpu_core_num) {
78 MS_LOG(WARNING) << "The total num of thread pool is " << *actor_and_kernel_thread_num
79 << ", but the num of cpu core is " << cpu_core_num
80 << ", please set the threads within reasonable limits.";
81 }
82 }
83
IsDeviceQueueDSActor(const AnfNodePtr &,GraphExecutionStrategy)84 bool IsDeviceQueueDSActor(const AnfNodePtr &, GraphExecutionStrategy) { return false; }
85
IsHostQueueDSActor(const AnfNodePtr & node,const KernelGraphPtr & graph,const std::vector<AnfNodePtr> & host_parameters,GraphExecutionStrategy strategy)86 bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph,
87 const std::vector<AnfNodePtr> &host_parameters, GraphExecutionStrategy strategy) {
88 MS_EXCEPTION_IF_NULL(node);
89
90 bool is_parameter_data = node->isa<Parameter>() && (!common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()));
91 if (!is_parameter_data) {
92 return false;
93 }
94 // Need to be updated every step.
95 if (node->has_user_data(kForwardOutput)) {
96 return true;
97 }
98
99 if (strategy == GraphExecutionStrategy::kStep) {
100 MS_EXCEPTION_IF_NULL(graph);
101 return graph->execution_order().size() > 1;
102 }
103
104 if (graph == nullptr) {
105 return true;
106 }
107
108 // In control flow, only the parameters of the root funcgraph are in the host data source.
109 const auto &front_node = graph->GetFrontAnfByBackendAnf(node);
110 bool is_host = ((front_node == nullptr) ||
111 find(host_parameters.begin(), host_parameters.end(), front_node) != host_parameters.end());
112
113 // Judge whether node is internal parameter.
114 const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(node);
115 if (internal_front_node.first == nullptr && is_host) {
116 return true;
117 }
118
119 return false;
120 }
121
IsSwitchActor(const AnfNodePtr & node)122 bool IsSwitchActor(const AnfNodePtr &node) { return common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch); }
123
IsInternalParameter(const AnfNodePtr & node,const KernelGraphPtr & graph)124 bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph) {
125 MS_EXCEPTION_IF_NULL(node);
126 MS_EXCEPTION_IF_NULL(graph);
127 if (node->isa<Parameter>() && (!common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
128 // Judge whether node is internal parameter.
129 const auto &front_node = graph->GetOriginFrontNodeByInternalParameter(node);
130 if (front_node.first != nullptr) {
131 return true;
132 }
133 }
134 return false;
135 }
136
IsCustomActor(const AnfNodePtr & node)137 bool IsCustomActor(const AnfNodePtr &node) {
138 MS_EXCEPTION_IF_NULL(node);
139 return AnfUtils::IsCustomActorNode(node);
140 }
141
IsKernelActor(const AnfNodePtr & node,GraphExecutionStrategy)142 bool IsKernelActor(const AnfNodePtr &node, GraphExecutionStrategy) {
143 MS_EXCEPTION_IF_NULL(node);
144 if (IsCustomActor(node)) {
145 return false;
146 }
147
148 if (!AnfUtils::IsRealCNodeKernel(node)) {
149 return false;
150 }
151
152 return true;
153 }
154
IsSkippedKernelActor(const AnfNodePtr & node)155 bool IsSkippedKernelActor(const AnfNodePtr &node) {
156 MS_EXCEPTION_IF_NULL(node);
157 if (IsKernelActor(node) && common::AnfAlgo::IsInplaceNode(node, "skip")) {
158 return true;
159 }
160 return false;
161 }
162
IsRpcActor(const AnfNodePtr & node)163 bool IsRpcActor(const AnfNodePtr &node) {
164 MS_EXCEPTION_IF_NULL(node);
165 if (IsKernelActor(node) && (common::AnfAlgo::GetCNodeName(node) == kRpcSendOpName ||
166 common::AnfAlgo::GetCNodeName(node) == kRpcRecvOpName)) {
167 return true;
168 }
169 return false;
170 }
171
IsInnerControlFlowActor(const AnfNodePtr & node)172 bool IsInnerControlFlowActor(const AnfNodePtr &node) {
173 MS_EXCEPTION_IF_NULL(node);
174 if (IsKernelActor(node) && (common::AnfAlgo::GetCNodeName(node) == "ConditionSwitch" ||
175 common::AnfAlgo::GetCNodeName(node) == "ConditionGather")) {
176 return true;
177 }
178 return false;
179 }
180
IsPersistentDeviceTensor(const AnfNodePtr & node)181 bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
182 MS_EXCEPTION_IF_NULL(node);
183 if (node->isa<ValueNode>()) {
184 return true;
185 }
186
187 // Maybe the load node, need fetch the real parameter node.
188 auto real_node = common::AnfAlgo::FetchRealNodeSkipMonadControl({node, 0}).first;
189 MS_EXCEPTION_IF_NULL(real_node);
190 if (real_node->isa<Parameter>() && common::AnfAlgo::IsParameterWeight(real_node->cast<ParameterPtr>())) {
191 return true;
192 }
193 return false;
194 }
195
IsControlFlowActor(KernelTransformType actor_type)196 bool IsControlFlowActor(KernelTransformType actor_type) {
197 return ((actor_type >= KernelTransformType::kSwitchActor) && (actor_type <= KernelTransformType::kStackActor));
198 }
199
IsMemoryActor(KernelTransformType actor_type)200 bool IsMemoryActor(KernelTransformType actor_type) {
201 return ((actor_type == KernelTransformType::kMemoryAllocActor) ||
202 (actor_type == KernelTransformType::kMemoryFreeActor));
203 }
204
IsSkippedLaunch(const CNodePtr & kernel,const KernelGraphPtr & kernel_graph)205 bool IsSkippedLaunch(const CNodePtr &kernel, const KernelGraphPtr &kernel_graph) {
206 static std::string launch_skipped = "";
207 static bool first_get_launch_skipped_env = true;
208 static const char kLaunchSkippedEnv[] = "MS_KERNEL_LAUNCH_SKIP";
209 if (first_get_launch_skipped_env) {
210 launch_skipped = common::GetEnv(kLaunchSkippedEnv);
211 first_get_launch_skipped_env = false;
212 if (launch_skipped.empty() && !common::GetEnv(kSimulationLevel).empty()) {
213 launch_skipped = "ALL";
214 }
215 }
216
217 if (launch_skipped.empty()) {
218 return false;
219 }
220
221 std::string launch_name = "";
222 std::string full_name = "";
223 if (kernel != nullptr) {
224 launch_name = common::AnfAlgo::GetCNodeName(kernel);
225 full_name = kernel->fullname_with_scope();
226 } else if (kernel_graph != nullptr) {
227 launch_name = kernel_graph->ToString();
228 full_name = kernel_graph->ToString();
229 } else {
230 MS_LOG(ERROR) << "The luanch kernel or graph is nullptr";
231 return false;
232 }
233
234 if ((launch_skipped == "ALL") || (launch_skipped == "all") || (launch_skipped == launch_name)) {
235 MS_LOG(DEBUG) << "Skip the launch of " << full_name;
236 return true;
237 }
238
239 return false;
240 }
241
EnableAsyncInfer()242 bool EnableAsyncInfer() {
243 static const char kEnableAsyncInferdEnv[] = "MS_ENABLE_ASYNC_INFER";
244 static bool ret = common::GetEnv(kEnableAsyncInferdEnv) == "1";
245 return ret;
246 }
247
EnableTraceMemory()248 bool EnableTraceMemory() {
249 auto ms_context = MsContext::GetInstance();
250 MS_EXCEPTION_IF_NULL(ms_context);
251 static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
252 if (!enable_infer_boost) {
253 return false;
254 }
255
256 if (!EnableKbkSubGraphExecute()) {
257 return false;
258 }
259
260 static const char kEnableTraceMemoryEnv[] = "MS_ENABLE_TRACE_MEMORY";
261 static bool ret = common::GetEnv(kEnableTraceMemoryEnv) == "1";
262 if (ret) {
263 MS_LOG(INFO) << "Enable trace memory to optimize dynamic memory manage performance.";
264 }
265 return ret;
266 }
267
ResetTraceMemoryStatus()268 void ResetTraceMemoryStatus() {
269 ActorDispatcher::set_enable_static_shape(false);
270 ActorDispatcher::set_enable_trace_dynamic_memory(false);
271 ActorDispatcher::set_enable_use_trace_memory(false);
272 }
273
EnableKbkSubGraphExecute()274 bool EnableKbkSubGraphExecute() {
275 static const char kEnableKbkSubGraphExecutedEnv[] = "MS_ENABLE_KBK_SUBGRAPH_EXECUTE";
276 static bool disable_sub_graph_execute_mode = common::GetEnv(kEnableKbkSubGraphExecutedEnv) == "0";
277 if (disable_sub_graph_execute_mode) {
278 return false;
279 }
280
281 if (ActorDispatcher::disable_kbk_sub_graph_execute()) {
282 return false;
283 }
284
285 // Only support sub graph execution mode for inference.
286 // static const bool enable_internal_kernels = common::GetEnv("MS_ENABLE_INTERNAL_KERNELS") == "on";
287 auto ms_context = MsContext::GetInstance();
288 MS_EXCEPTION_IF_NULL(ms_context);
289 static const bool enable_internal_kernels = ms_context->IsEnableInferBoost();
290 return enable_internal_kernels;
291 }
292
GetDefragMemoryStepFreq()293 size_t GetDefragMemoryStepFreq() {
294 static size_t defrag_memory_step_freq = 100L;
295
296 static std::once_flag init_flag;
297 std::call_once(init_flag, [&]() {
298 MS_LOG(INFO) << "Init defrag memory step freq.";
299 const auto &value = common::GetConfigValue(common::kAllocConf, common::kAllocDefragMemoryStepFreq);
300 MS_LOG(INFO) << "Config defrag memory step freq : " << value << ".";
301 if (value.size() != 0) {
302 std::stringstream sstream(value);
303 size_t config_value;
304 sstream >> config_value;
305 if (config_value != 0) {
306 defrag_memory_step_freq = config_value;
307 }
308 }
309 MS_LOG(INFO) << "Defrag memory step freq : " << defrag_memory_step_freq << ".";
310 });
311
312 return defrag_memory_step_freq;
313 }
314
WaitRuntimePipelineFinish(const OpContext<DeviceTensor> * context,bool wait_kernel_launch_finish)315 bool WaitRuntimePipelineFinish(const OpContext<DeviceTensor> *context, bool wait_kernel_launch_finish) {
316 #ifndef BUILD_LITE
317 if (ActorDispatcher::enable_runtime_multi_pipeline()) {
318 KernelAsyncInferActor::GetInstance()->Wait();
319 KernelAsyncResizeActor::GetInstance()->Wait();
320 }
321
322 if (ActorDispatcher::enable_async_launch_kernel() && wait_kernel_launch_finish) {
323 KernelAsyncLaunchActor::GetInstance()->Wait();
324 }
325
326 if (ActorDispatcher::enable_async_launch_kernel() && IsRunningFailed(context)) {
327 MS_LOG(ERROR) << "Wait runtime pipeline finish and an error occurred: " << context->error_info_;
328 return false;
329 }
330 return true;
331 #else
332 return true;
333 #endif
334 }
335
Copy(const DeviceTensor * dst_device_tensor,const DeviceTensor * src_device_tensor)336 bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
337 MS_EXCEPTION_IF_NULL(dst_device_tensor);
338 MS_EXCEPTION_IF_NULL(src_device_tensor);
339 if (src_device_tensor->GetSize() != dst_device_tensor->GetSize()) {
340 MS_LOG(INFO) << "Copy size is not equal, input size:" << src_device_tensor->GetSize()
341 << ", output size:" << dst_device_tensor->GetSize();
342 }
343
344 // Exist the size alignment in some device, so get the min device size.
345 size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize());
346
347 if (dst_device_tensor->GetDeviceType() == src_device_tensor->GetDeviceType()) {
348 return dst_device_tensor->SyncDeviceToDevice(src_device_tensor);
349 } else if (src_device_tensor->GetDeviceType() == device::DeviceType::kCPU) {
350 // CPU device tensor copy to other device tensor.
351 return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr());
352 } else if (dst_device_tensor->GetDeviceType() == device::DeviceType::kCPU) {
353 // Other device tensor copy to CPU device tensor.
354 return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
355 } else {
356 MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->GetDeviceType()
357 << ", dst device type: " << dst_device_tensor->GetDeviceType();
358 return false;
359 }
360 }
361
UpdateRefCount(DeviceTensor * const device_tensor,bool is_max_ref_count)362 void UpdateRefCount(DeviceTensor *const device_tensor, bool is_max_ref_count) {
363 MS_EXCEPTION_IF_NULL(device_tensor);
364 if (is_max_ref_count) {
365 device_tensor->set_original_ref_count(SIZE_MAX);
366 } else {
367 device_tensor->IncreaseOriginalRefCount();
368 }
369 device_tensor->ResetRefCount();
370 }
371
UpdateRefCount(const AnfNodePtr & node,size_t output_idx,bool is_max_ref_count)372 void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count) {
373 MS_EXCEPTION_IF_NULL(node);
374 auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx, false);
375 UpdateRefCount(device_tensor.get(), is_max_ref_count);
376 }
377
FreeMemoryByDeviceContext(DeviceTensor * const device_tensor,const DeviceContext * device_context)378 void FreeMemoryByDeviceContext(DeviceTensor *const device_tensor, const DeviceContext *device_context) {
379 MS_EXCEPTION_IF_NULL(device_tensor);
380 // The device context may be not accurate in the control flow scene, so need fetch by device name and device id.
381 if ((device_context == nullptr) || (device_context->GetDeviceType() != device_tensor->GetDeviceType())) {
382 const auto &new_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
383 {device_tensor->device_name(), device_tensor->device_id()});
384 MS_EXCEPTION_IF_NULL(new_device_context);
385 new_device_context->device_res_manager_->FreeMemory(device_tensor);
386 } else {
387 device_context->device_res_manager_->FreeMemory(device_tensor);
388 }
389 }
390
FreeMemoryByValueNode(const std::vector<std::weak_ptr<ValueNode>> & held_by_nodes,DeviceTensor * device_tensor)391 void FreeMemoryByValueNode(const std::vector<std::weak_ptr<ValueNode>> &held_by_nodes, DeviceTensor *device_tensor) {
392 MS_EXCEPTION_IF_NULL(device_tensor);
393 device_tensor->ClearHeldByNodes();
394 device_tensor->set_original_ref_count(SIZE_MAX);
395 device_tensor->ResetRefCount();
396
397 for (auto &node : held_by_nodes) {
398 auto value_node = node.lock();
399 MS_EXCEPTION_IF_NULL(value_node);
400 auto value = value_node->value();
401 MS_EXCEPTION_IF_NULL(value);
402 auto tensor = value->cast<tensor::TensorPtr>();
403 MS_EXCEPTION_IF_NULL(tensor);
404 tensor->set_device_address(nullptr);
405 runtime::DeviceTensorStore::GetInstance().Remove(value_node.get());
406 }
407 }
408
FetchKernelTransformType(const AnfNodePtr & node,const KernelGraphPtr & graph,const std::vector<AnfNodePtr> & host_parameters,GraphExecutionStrategy strategy)409 KernelTransformType FetchKernelTransformType(const AnfNodePtr &node, const KernelGraphPtr &graph,
410 const std::vector<AnfNodePtr> &host_parameters,
411 GraphExecutionStrategy strategy) {
412 // Fetch kernel graph.
413 KernelGraphPtr kernel_graph = nullptr;
414 if (graph == nullptr) {
415 kernel_graph = AnfAlgo::FetchKernelGraph(node.get());
416 } else {
417 kernel_graph = graph;
418 }
419 if (kernel_graph == nullptr) {
420 return KernelTransformType::kUnknown;
421 }
422 if (kernel_graph->is_any_type_input() && node != nullptr && node->isa<CNode>()) {
423 return KernelTransformType::kAnyTypeKernelActor;
424 }
425 // In sink mode, the data exchange between child graphs is expressed as parameters. These parameters are stored
426 // in the graph and should be obtained from the super kernel actor.
427 if (kernel_graph->is_graph_run_mode() &&
428 ((node == nullptr) || node->isa<CNode>() || kernel_graph->IsChildGraphResult(node))) {
429 return KernelTransformType::kSuperKernelActor;
430 }
431
432 KernelTransformType type = KernelTransformType::kUnknown;
433 MS_EXCEPTION_IF_NULL(node);
434 auto real_node = common::AnfAlgo::FetchRealNodeSkipMonadControl({node, 0}).first;
435 MS_EXCEPTION_IF_NULL(real_node);
436
437 if (IsDeviceQueueDSActor(real_node, strategy)) {
438 type = KernelTransformType::kDeviceDataSourceActor;
439 } else if (IsHostQueueDSActor(real_node, kernel_graph, host_parameters, strategy)) {
440 type = KernelTransformType::kHostDataSourceActor;
441 } else if (IsCustomActor(real_node)) {
442 type = KernelTransformType::kCustomActor;
443 } else if (IsKernelActor(real_node, strategy)) {
444 type = KernelTransformType::kKernelActor;
445 } else if (IsInternalParameter(real_node, kernel_graph)) {
446 type = KernelTransformType::kInternalParameter;
447 } else if (IsPersistentDeviceTensor(real_node)) {
448 type = KernelTransformType::kDeviceTensorStore;
449 } else {
450 // May exist the from kernel that no need link in the pynative mode.
451 MS_LOG(DEBUG) << "Invalid from kernel: " << node->DebugString();
452 }
453
454 return type;
455 }
456
FetchActorName(KernelTransformType kernel_type,const std::string & actor_set_name,const AnfNodePtr & node,const KernelGraphPtr & graph)457 std::string FetchActorName(KernelTransformType kernel_type, const std::string &actor_set_name, const AnfNodePtr &node,
458 const KernelGraphPtr &graph) {
459 // Fetch kernel graph.
460 KernelGraphPtr kernel_graph = nullptr;
461 if (graph == nullptr) {
462 kernel_graph = AnfAlgo::FetchKernelGraph(node.get());
463 } else {
464 kernel_graph = graph;
465 }
466 if (kernel_graph == nullptr) {
467 return "";
468 }
469
470 auto real_node = node;
471 if (real_node != nullptr) {
472 real_node = common::AnfAlgo::FetchRealNodeSkipMonadControl({node, 0}).first;
473 }
474 std::string actor_name = "";
475 switch (kernel_type) {
476 case KernelTransformType::kSuperKernelActor:
477 actor_name = kernel_graph->ToString() + kSuperKernelActorNameSuffix;
478 break;
479 case KernelTransformType::kAnyTypeKernelActor:
480 actor_name = kernel_graph->ToString() + kAnyTypeKernelActorNameSuffix;
481 break;
482 case KernelTransformType::kDeviceDataSourceActor:
483 actor_name = actor_set_name + kDeviceDSActorNameSuffix + "_" + std::to_string(kernel_graph->graph_id());
484 break;
485 case KernelTransformType::kHostDataSourceActor:
486 actor_name = actor_set_name + kHostDSActorNameSuffix;
487 break;
488 case KernelTransformType::kCustomActor:
489 MS_EXCEPTION_IF_NULL(real_node);
490 actor_name = AnfUtils::GetCustomActorName(real_node);
491 break;
492 case KernelTransformType::kKernelActor:
493 MS_EXCEPTION_IF_NULL(real_node);
494 actor_name = GetActorIdByKernel(real_node);
495 break;
496 case KernelTransformType::kKernelInferActor:
497 MS_EXCEPTION_IF_NULL(real_node);
498 actor_name = kKernelInferActorNamePrefix + real_node->fullname_with_scope();
499 break;
500 case KernelTransformType::kKernelResizeActor:
501 MS_EXCEPTION_IF_NULL(real_node);
502 actor_name = kKernelResizeActorNamePrefix + real_node->fullname_with_scope();
503 break;
504 default:
505 break;
506 }
507 return actor_name;
508 }
509
FetchModifiableRefInputIndex(const CNodePtr & cnode)510 std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &cnode) {
511 MS_EXCEPTION_IF_NULL(cnode);
512
513 bool has_monad = false;
514 std::set<size_t> ref_input_indexes;
515 for (size_t i = 1; i < cnode->size(); ++i) {
516 auto &input = cnode->inputs().at(i);
517 if (HasAbstractMonad(input)) {
518 has_monad = true;
519 }
520 if (common::AnfAlgo::HasAbstractRef(input)) {
521 (void)ref_input_indexes.insert(i - 1);
522 }
523 }
524
525 // Only the auto moand node will modify the input.
526 if (has_monad) {
527 return ref_input_indexes;
528 } else {
529 return {};
530 }
531 }
532
FetchModifiableRefOutputIndex(const CNodePtr & cnode,const KernelGraphPtr & graph)533 std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const KernelGraphPtr &graph) {
534 MS_EXCEPTION_IF_NULL(cnode);
535 MS_EXCEPTION_IF_NULL(graph);
536 std::set<size_t> ref_output_indexes;
537
538 auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
539 for (size_t i = 0; i < output_num; ++i) {
540 session::AnfWithOutIndex output_pair(cnode, i);
541 // Only the ref node will modify the ref input corresponding to the output.
542 if (!graph->IsInRefOutputMap(output_pair)) {
543 continue;
544 }
545 auto input_pair = graph->GetRefCorrespondOutput(output_pair);
546 MS_EXCEPTION_IF_NULL(input_pair.first);
547 if (common::AnfAlgo::HasAbstractRef(input_pair.first)) {
548 (void)ref_output_indexes.insert(i);
549 }
550 }
551 return ref_output_indexes;
552 }
553
is_embedding_cache_server()554 bool is_embedding_cache_server() {
555 return ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
556 }
557
ReserveKernelMemoryBlocks(size_t size,const DeviceContext * device_context)558 void MemoryTraceManager::ReserveKernelMemoryBlocks(size_t size, const DeviceContext *device_context) {
559 MS_EXCEPTION_IF_NULL(device_context);
560 (*kernel_memory_trace_blocks_)[device_context].reserve(size);
561 }
562
PickMemoryTrackInfoForGraph(uint32_t graph_id)563 void MemoryTraceManager::PickMemoryTrackInfoForGraph(uint32_t graph_id) {
564 if (graph_to_kernel_memory_trace_blocks_.find(graph_id) == graph_to_kernel_memory_trace_blocks_.end()) {
565 graph_to_kernel_memory_trace_blocks_.emplace(
566 graph_id, std::make_shared<std::map<const DeviceContext *, std::vector<KernelMemoryTraceBlockPtr>>>());
567 }
568 kernel_memory_trace_blocks_ = graph_to_kernel_memory_trace_blocks_[graph_id];
569 MS_EXCEPTION_IF_NULL(kernel_memory_trace_blocks_);
570
571 if (graph_to_merged_memory_trace_blocks_.find(graph_id) == graph_to_merged_memory_trace_blocks_.end()) {
572 graph_to_merged_memory_trace_blocks_.emplace(
573 graph_id, std::make_shared<std::map<const DeviceContext *, std::vector<MemoryTraceBlockPtr>>>());
574 }
575 merged_memory_trace_blocks_ = graph_to_merged_memory_trace_blocks_[graph_id];
576 MS_EXCEPTION_IF_NULL(merged_memory_trace_blocks_);
577
578 if (graph_to_kernel_blocks_.find(graph_id) == graph_to_kernel_blocks_.end()) {
579 graph_to_kernel_blocks_.emplace(
580 graph_id, std::make_shared<mindspore::HashMap<CNodePtr, std::vector<KernelMemoryTraceBlockPtr>>>());
581 }
582 kernel_to_block_ = graph_to_kernel_blocks_[graph_id];
583 MS_EXCEPTION_IF_NULL(kernel_to_block_);
584 }
585
AddKernelMemoryTraceBlock(const KernelMemoryTraceBlockPtr & block,const DeviceContext * device_context)586 void MemoryTraceManager::AddKernelMemoryTraceBlock(const KernelMemoryTraceBlockPtr &block,
587 const DeviceContext *device_context) {
588 MS_EXCEPTION_IF_NULL(block);
589 MS_EXCEPTION_IF_NULL(block->start_);
590 MS_EXCEPTION_IF_NULL(block->end_);
591 (*kernel_memory_trace_blocks_)[device_context].emplace_back(block);
592 }
593
594 const std::shared_ptr<std::map<const DeviceContext *, std::vector<MemoryTraceBlockPtr>>>
GetMergeBlocks()595 &MemoryTraceManager::GetMergeBlocks() {
596 return merged_memory_trace_blocks_;
597 }
598
599 const std::shared_ptr<mindspore::HashMap<CNodePtr, std::vector<KernelMemoryTraceBlockPtr>>>
GetAllKernelBlocksnfo()600 &MemoryTraceManager::GetAllKernelBlocksnfo() {
601 return kernel_to_block_;
602 }
603
MergeBlocks()604 void MemoryTraceManager::MergeBlocks() {
605 merged_memory_trace_blocks_->clear();
606 for (auto &item : *kernel_memory_trace_blocks_) {
607 auto &device_context = item.first;
608 auto &kernel_memory_trace_blocks = item.second;
609 MergeBlocksForSameDeviceContext(&kernel_memory_trace_blocks, &((*merged_memory_trace_blocks_)[device_context]));
610 MS_LOG(DEBUG) << "The number of merged blocks is " << (*merged_memory_trace_blocks_)[device_context].size()
611 << ", device type: " << device_context->device_context_key().device_name_;
612 }
613 }
614
MergeBlocksForSameDeviceContext(std::vector<KernelMemoryTraceBlockPtr> * kernel_memory_trace_blocks,std::vector<MemoryTraceBlockPtr> * merged_memory_trace_blocks)615 void MemoryTraceManager::MergeBlocksForSameDeviceContext(
616 std::vector<KernelMemoryTraceBlockPtr> *kernel_memory_trace_blocks,
617 std::vector<MemoryTraceBlockPtr> *merged_memory_trace_blocks) {
618 MS_EXCEPTION_IF_NULL(kernel_memory_trace_blocks);
619 MS_EXCEPTION_IF_NULL(merged_memory_trace_blocks);
620 merged_memory_trace_blocks->clear();
621
622 if (kernel_memory_trace_blocks->empty()) {
623 MS_LOG(INFO) << "No block to merge.";
624 return;
625 }
626
627 std::sort(kernel_memory_trace_blocks->begin(), kernel_memory_trace_blocks->end(),
628 [](const KernelMemoryTraceBlockPtr &block1, const KernelMemoryTraceBlockPtr &block2) {
629 return (block1->start_ < block2->start_) ||
630 ((block1->start_ == block2->start_) && (block1->end_ < block2->end_));
631 });
632 merged_memory_trace_blocks->emplace_back(std::make_shared<MemoryTraceBlock>((*kernel_memory_trace_blocks)[0]->start_,
633 (*kernel_memory_trace_blocks)[0]->size_));
634 (*kernel_memory_trace_blocks)[0]->in_memory_trace_block_index_ = 0;
635 for (size_t i = 1; i < kernel_memory_trace_blocks->size(); i++) {
636 auto &back = merged_memory_trace_blocks->back();
637 auto &block = (*kernel_memory_trace_blocks)[i];
638 if (block->start_ >= back->end_) {
639 merged_memory_trace_blocks->emplace_back(std::make_shared<MemoryTraceBlock>(block->start_, block->size_));
640 } else if (block->end_ > back->end_) {
641 back->end_ = block->end_;
642 back->size_ = back->end_ - back->start_;
643 }
644 block->in_memory_trace_block_index_ = merged_memory_trace_blocks->size() - 1;
645 }
646
647 // Reset offset
648 for (size_t i = 0; i < kernel_memory_trace_blocks->size(); i++) {
649 auto &kernel_mem_block = (*kernel_memory_trace_blocks)[i];
650 MS_EXCEPTION_IF_NULL(kernel_mem_block);
651 const auto &mem_block = (*merged_memory_trace_blocks)[kernel_mem_block->in_memory_trace_block_index_];
652 MS_EXCEPTION_IF_NULL(mem_block);
653 if (kernel_mem_block->start_ < mem_block->start_) {
654 MS_LOG(EXCEPTION) << "Invalid memory block, block start: " << kernel_mem_block->start_
655 << ", block end: " << kernel_mem_block->end_ << ", mem block start: " << mem_block->start_
656 << ", mem block end: " << mem_block->end_;
657 }
658
659 kernel_mem_block->offset_in_memory_trace_block_ = kernel_mem_block->start_ - mem_block->start_;
660 (*kernel_to_block_)[kernel_mem_block->kernel_].emplace_back(kernel_mem_block);
661 }
662 }
663
Clear()664 void MemoryTraceManager::Clear() {
665 kernel_memory_trace_blocks_->clear();
666 merged_memory_trace_blocks_->clear();
667 kernel_to_block_->clear();
668 }
669
IsTwoPhaseInfer()670 bool IsTwoPhaseInfer() {
671 const auto &phase = PhaseManager::GetInstance().phase();
672 return phase.find("prefill") != std::string::npos || phase.find("increment") != std::string::npos;
673 }
674
675 std::unordered_map<AnfNode *, std::string> actor_ids;
676 static size_t actor_index = 0;
677
GetActorIdByKernel(const AnfNodePtr & node)678 std::string GetActorIdByKernel(const AnfNodePtr &node) {
679 MS_EXCEPTION_IF_NULL(node);
680 if (actor_ids.find(node.get()) == actor_ids.end()) {
681 MS_LOG(INFO) << "Cannot get actor id by node:" << node->fullname_with_scope();
682 return node->fullname_with_scope();
683 }
684 return actor_ids[node.get()];
685 }
686
GenerateActorIdByKernel(const AnfNodePtr & node)687 std::string GenerateActorIdByKernel(const AnfNodePtr &node) {
688 MS_EXCEPTION_IF_NULL(node);
689 auto id = std::to_string(actor_index++) + "_" + node->fullname_with_scope();
690 actor_ids[node.get()] = id;
691 return id;
692 }
693 } // namespace runtime
694 } // namespace mindspore
695