• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/actor/actor_common.h"
18 #include <memory>
19 #include <unordered_map>
20 #include "ops/framework_op_name.h"
21 #include "ops/framework_ops.h"
22 #include "ops/structure_op_name.h"
23 #include "runtime/graph_scheduler/device_tensor_store.h"
24 #include "utils/ms_context.h"
25 #include "utils/ms_utils.h"
26 #include "include/common/utils/anfalgo.h"
27 #include "include/backend/distributed/ps/ps_context.h"
28 #include "utils/phase.h"
29 #ifndef BUILD_LITE
30 #include "runtime/graph_scheduler/actor/kernel_async_launch_actor.h"
31 #include "runtime/graph_scheduler/actor/kernel_async_infer_actor.h"
32 #include "runtime/graph_scheduler/actor/kernel_async_resize_actor.h"
33 #endif
34 
35 namespace mindspore {
36 namespace runtime {
37 bool ActorDispatcher::is_multi_thread_execution_ = true;
38 bool ActorDispatcher::enable_multi_stream_ = false;
39 bool ActorDispatcher::has_kernel_need_user_data_ = false;
40 bool ActorDispatcher::is_memory_allocation_sync_ = true;
41 bool ActorDispatcher::is_memory_free_sync_ = true;
42 bool ActorDispatcher::enable_runtime_multi_pipeline_ = false;
43 bool ActorDispatcher::enable_async_launch_kernel_ = false;
44 bool ActorDispatcher::disable_kbk_sub_graph_execute_ = false;
45 bool ActorDispatcher::enable_static_shape_ = false;
46 bool ActorDispatcher::enable_trace_dynamic_memory_ = false;
47 bool ActorDispatcher::enable_use_trace_memory_ = false;
48 
IsRunningFailed(const OpContext<DeviceTensor> * context)49 bool IsRunningFailed(const OpContext<DeviceTensor> *context) { return (context->error_info_ != ""); }
50 
ComputeThreadNums(size_t * actor_thread_num,size_t * actor_and_kernel_thread_num)51 void ComputeThreadNums(size_t *actor_thread_num, size_t *actor_and_kernel_thread_num) {
52   MS_EXCEPTION_IF_NULL(actor_thread_num);
53   MS_EXCEPTION_IF_NULL(actor_and_kernel_thread_num);
54   auto context_ptr = MsContext::GetInstance();
55   MS_EXCEPTION_IF_NULL(context_ptr);
56   const size_t cpu_core_num = std::thread::hardware_concurrency();
57   auto inter_op_parallel_num = static_cast<size_t>(context_ptr->get_param<uint32_t>(MS_CTX_INTER_OP_PARALLEL_NUM));
58   auto runtime_num_threads = static_cast<size_t>(context_ptr->get_param<uint32_t>(MS_CTX_RUNTIME_NUM_THREADS));
59   size_t runtime_num_threads_min = std::min(runtime_num_threads, cpu_core_num);
60   size_t inter_op_parallel_num_min = std::min(inter_op_parallel_num, cpu_core_num);
61   const float kActorUsage = 0.18;
62   const size_t kActorThreadMinNum = 1;
63   // Compute the actor and kernel thread num.
64   // The MemoryManagerActor binds single thread, so if runtime_num_threads is 30, actor num would be 5,
65   // kernel num would be 25.
66   if (inter_op_parallel_num_min == 0) {
67     size_t actor_thread_max_num =
68       std::max(static_cast<size_t>(std::floor(runtime_num_threads_min * kActorUsage)), kActorThreadMinNum);
69     *actor_thread_num = actor_thread_max_num;
70     *actor_and_kernel_thread_num =
71       runtime_num_threads_min > *actor_thread_num ? (runtime_num_threads_min) : (*actor_thread_num + 1);
72   } else {
73     *actor_thread_num = inter_op_parallel_num_min;
74     *actor_and_kernel_thread_num = runtime_num_threads_min + *actor_thread_num;
75   }
76 
77   if (*actor_and_kernel_thread_num > cpu_core_num) {
78     MS_LOG(WARNING) << "The total num of thread pool is " << *actor_and_kernel_thread_num
79                     << ", but the num of cpu core is " << cpu_core_num
80                     << ", please set the threads within reasonable limits.";
81   }
82 }
83 
IsDeviceQueueDSActor(const AnfNodePtr &,GraphExecutionStrategy)84 bool IsDeviceQueueDSActor(const AnfNodePtr &, GraphExecutionStrategy) { return false; }
85 
IsHostQueueDSActor(const AnfNodePtr & node,const KernelGraphPtr & graph,const std::vector<AnfNodePtr> & host_parameters,GraphExecutionStrategy strategy)86 bool IsHostQueueDSActor(const AnfNodePtr &node, const KernelGraphPtr &graph,
87                         const std::vector<AnfNodePtr> &host_parameters, GraphExecutionStrategy strategy) {
88   MS_EXCEPTION_IF_NULL(node);
89 
90   bool is_parameter_data = node->isa<Parameter>() && (!common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()));
91   if (!is_parameter_data) {
92     return false;
93   }
94   // Need to be updated every step.
95   if (node->has_user_data(kForwardOutput)) {
96     return true;
97   }
98 
99   if (strategy == GraphExecutionStrategy::kStep) {
100     MS_EXCEPTION_IF_NULL(graph);
101     return graph->execution_order().size() > 1;
102   }
103 
104   if (graph == nullptr) {
105     return true;
106   }
107 
108   // In control flow, only the parameters of the root funcgraph are in the host data source.
109   const auto &front_node = graph->GetFrontAnfByBackendAnf(node);
110   bool is_host = ((front_node == nullptr) ||
111                   find(host_parameters.begin(), host_parameters.end(), front_node) != host_parameters.end());
112 
113   // Judge whether node is internal parameter.
114   const auto &internal_front_node = graph->GetFrontNodeByInternalParameter(node);
115   if (internal_front_node.first == nullptr && is_host) {
116     return true;
117   }
118 
119   return false;
120 }
121 
IsSwitchActor(const AnfNodePtr & node)122 bool IsSwitchActor(const AnfNodePtr &node) { return common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch); }
123 
IsInternalParameter(const AnfNodePtr & node,const KernelGraphPtr & graph)124 bool IsInternalParameter(const AnfNodePtr &node, const KernelGraphPtr &graph) {
125   MS_EXCEPTION_IF_NULL(node);
126   MS_EXCEPTION_IF_NULL(graph);
127   if (node->isa<Parameter>() && (!common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
128     //  Judge whether node is internal parameter.
129     const auto &front_node = graph->GetOriginFrontNodeByInternalParameter(node);
130     if (front_node.first != nullptr) {
131       return true;
132     }
133   }
134   return false;
135 }
136 
IsCustomActor(const AnfNodePtr & node)137 bool IsCustomActor(const AnfNodePtr &node) {
138   MS_EXCEPTION_IF_NULL(node);
139   return AnfUtils::IsCustomActorNode(node);
140 }
141 
IsKernelActor(const AnfNodePtr & node,GraphExecutionStrategy)142 bool IsKernelActor(const AnfNodePtr &node, GraphExecutionStrategy) {
143   MS_EXCEPTION_IF_NULL(node);
144   if (IsCustomActor(node)) {
145     return false;
146   }
147 
148   if (!AnfUtils::IsRealCNodeKernel(node)) {
149     return false;
150   }
151 
152   return true;
153 }
154 
IsSkippedKernelActor(const AnfNodePtr & node)155 bool IsSkippedKernelActor(const AnfNodePtr &node) {
156   MS_EXCEPTION_IF_NULL(node);
157   if (IsKernelActor(node) && common::AnfAlgo::IsInplaceNode(node, "skip")) {
158     return true;
159   }
160   return false;
161 }
162 
IsRpcActor(const AnfNodePtr & node)163 bool IsRpcActor(const AnfNodePtr &node) {
164   MS_EXCEPTION_IF_NULL(node);
165   if (IsKernelActor(node) && (common::AnfAlgo::GetCNodeName(node) == kRpcSendOpName ||
166                               common::AnfAlgo::GetCNodeName(node) == kRpcRecvOpName)) {
167     return true;
168   }
169   return false;
170 }
171 
IsInnerControlFlowActor(const AnfNodePtr & node)172 bool IsInnerControlFlowActor(const AnfNodePtr &node) {
173   MS_EXCEPTION_IF_NULL(node);
174   if (IsKernelActor(node) && (common::AnfAlgo::GetCNodeName(node) == "ConditionSwitch" ||
175                               common::AnfAlgo::GetCNodeName(node) == "ConditionGather")) {
176     return true;
177   }
178   return false;
179 }
180 
IsPersistentDeviceTensor(const AnfNodePtr & node)181 bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
182   MS_EXCEPTION_IF_NULL(node);
183   if (node->isa<ValueNode>()) {
184     return true;
185   }
186 
187   // Maybe the load node, need fetch the real parameter node.
188   auto real_node = common::AnfAlgo::FetchRealNodeSkipMonadControl({node, 0}).first;
189   MS_EXCEPTION_IF_NULL(real_node);
190   if (real_node->isa<Parameter>() && common::AnfAlgo::IsParameterWeight(real_node->cast<ParameterPtr>())) {
191     return true;
192   }
193   return false;
194 }
195 
IsControlFlowActor(KernelTransformType actor_type)196 bool IsControlFlowActor(KernelTransformType actor_type) {
197   return ((actor_type >= KernelTransformType::kSwitchActor) && (actor_type <= KernelTransformType::kStackActor));
198 }
199 
IsMemoryActor(KernelTransformType actor_type)200 bool IsMemoryActor(KernelTransformType actor_type) {
201   return ((actor_type == KernelTransformType::kMemoryAllocActor) ||
202           (actor_type == KernelTransformType::kMemoryFreeActor));
203 }
204 
IsSkippedLaunch(const CNodePtr & kernel,const KernelGraphPtr & kernel_graph)205 bool IsSkippedLaunch(const CNodePtr &kernel, const KernelGraphPtr &kernel_graph) {
206   static std::string launch_skipped = "";
207   static bool first_get_launch_skipped_env = true;
208   static const char kLaunchSkippedEnv[] = "MS_KERNEL_LAUNCH_SKIP";
209   if (first_get_launch_skipped_env) {
210     launch_skipped = common::GetEnv(kLaunchSkippedEnv);
211     first_get_launch_skipped_env = false;
212     if (launch_skipped.empty() && !common::GetEnv(kSimulationLevel).empty()) {
213       launch_skipped = "ALL";
214     }
215   }
216 
217   if (launch_skipped.empty()) {
218     return false;
219   }
220 
221   std::string launch_name = "";
222   std::string full_name = "";
223   if (kernel != nullptr) {
224     launch_name = common::AnfAlgo::GetCNodeName(kernel);
225     full_name = kernel->fullname_with_scope();
226   } else if (kernel_graph != nullptr) {
227     launch_name = kernel_graph->ToString();
228     full_name = kernel_graph->ToString();
229   } else {
230     MS_LOG(ERROR) << "The luanch kernel or graph is nullptr";
231     return false;
232   }
233 
234   if ((launch_skipped == "ALL") || (launch_skipped == "all") || (launch_skipped == launch_name)) {
235     MS_LOG(DEBUG) << "Skip the launch of " << full_name;
236     return true;
237   }
238 
239   return false;
240 }
241 
EnableAsyncInfer()242 bool EnableAsyncInfer() {
243   static const char kEnableAsyncInferdEnv[] = "MS_ENABLE_ASYNC_INFER";
244   static bool ret = common::GetEnv(kEnableAsyncInferdEnv) == "1";
245   return ret;
246 }
247 
EnableTraceMemory()248 bool EnableTraceMemory() {
249   auto ms_context = MsContext::GetInstance();
250   MS_EXCEPTION_IF_NULL(ms_context);
251   static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
252   if (!enable_infer_boost) {
253     return false;
254   }
255 
256   if (!EnableKbkSubGraphExecute()) {
257     return false;
258   }
259 
260   static const char kEnableTraceMemoryEnv[] = "MS_ENABLE_TRACE_MEMORY";
261   static bool ret = common::GetEnv(kEnableTraceMemoryEnv) == "1";
262   if (ret) {
263     MS_LOG(INFO) << "Enable trace memory to optimize dynamic memory manage performance.";
264   }
265   return ret;
266 }
267 
ResetTraceMemoryStatus()268 void ResetTraceMemoryStatus() {
269   ActorDispatcher::set_enable_static_shape(false);
270   ActorDispatcher::set_enable_trace_dynamic_memory(false);
271   ActorDispatcher::set_enable_use_trace_memory(false);
272 }
273 
EnableKbkSubGraphExecute()274 bool EnableKbkSubGraphExecute() {
275   static const char kEnableKbkSubGraphExecutedEnv[] = "MS_ENABLE_KBK_SUBGRAPH_EXECUTE";
276   static bool disable_sub_graph_execute_mode = common::GetEnv(kEnableKbkSubGraphExecutedEnv) == "0";
277   if (disable_sub_graph_execute_mode) {
278     return false;
279   }
280 
281   if (ActorDispatcher::disable_kbk_sub_graph_execute()) {
282     return false;
283   }
284 
285   // Only support sub graph execution mode for inference.
286   // static const bool enable_internal_kernels = common::GetEnv("MS_ENABLE_INTERNAL_KERNELS") == "on";
287   auto ms_context = MsContext::GetInstance();
288   MS_EXCEPTION_IF_NULL(ms_context);
289   static const bool enable_internal_kernels = ms_context->IsEnableInferBoost();
290   return enable_internal_kernels;
291 }
292 
GetDefragMemoryStepFreq()293 size_t GetDefragMemoryStepFreq() {
294   static size_t defrag_memory_step_freq = 100L;
295 
296   static std::once_flag init_flag;
297   std::call_once(init_flag, [&]() {
298     MS_LOG(INFO) << "Init defrag memory step freq.";
299     const auto &value = common::GetConfigValue(common::kAllocConf, common::kAllocDefragMemoryStepFreq);
300     MS_LOG(INFO) << "Config defrag memory step freq : " << value << ".";
301     if (value.size() != 0) {
302       std::stringstream sstream(value);
303       size_t config_value;
304       sstream >> config_value;
305       if (config_value != 0) {
306         defrag_memory_step_freq = config_value;
307       }
308     }
309     MS_LOG(INFO) << "Defrag memory step freq : " << defrag_memory_step_freq << ".";
310   });
311 
312   return defrag_memory_step_freq;
313 }
314 
WaitRuntimePipelineFinish(const OpContext<DeviceTensor> * context,bool wait_kernel_launch_finish)315 bool WaitRuntimePipelineFinish(const OpContext<DeviceTensor> *context, bool wait_kernel_launch_finish) {
316 #ifndef BUILD_LITE
317   if (ActorDispatcher::enable_runtime_multi_pipeline()) {
318     KernelAsyncInferActor::GetInstance()->Wait();
319     KernelAsyncResizeActor::GetInstance()->Wait();
320   }
321 
322   if (ActorDispatcher::enable_async_launch_kernel() && wait_kernel_launch_finish) {
323     KernelAsyncLaunchActor::GetInstance()->Wait();
324   }
325 
326   if (ActorDispatcher::enable_async_launch_kernel() && IsRunningFailed(context)) {
327     MS_LOG(ERROR) << "Wait runtime pipeline finish and an error occurred: " << context->error_info_;
328     return false;
329   }
330   return true;
331 #else
332   return true;
333 #endif
334 }
335 
Copy(const DeviceTensor * dst_device_tensor,const DeviceTensor * src_device_tensor)336 bool Copy(const DeviceTensor *dst_device_tensor, const DeviceTensor *src_device_tensor) {
337   MS_EXCEPTION_IF_NULL(dst_device_tensor);
338   MS_EXCEPTION_IF_NULL(src_device_tensor);
339   if (src_device_tensor->GetSize() != dst_device_tensor->GetSize()) {
340     MS_LOG(INFO) << "Copy size is not equal, input size:" << src_device_tensor->GetSize()
341                  << ", output size:" << dst_device_tensor->GetSize();
342   }
343 
344   // Exist the size alignment in some device, so get the min device size.
345   size_t copy_size = std::min(src_device_tensor->GetSize(), dst_device_tensor->GetSize());
346 
347   if (dst_device_tensor->GetDeviceType() == src_device_tensor->GetDeviceType()) {
348     return dst_device_tensor->SyncDeviceToDevice(src_device_tensor);
349   } else if (src_device_tensor->GetDeviceType() == device::DeviceType::kCPU) {
350     // CPU device tensor copy to other device tensor.
351     return dst_device_tensor->SyncHostToDevice(copy_size, src_device_tensor->GetPtr());
352   } else if (dst_device_tensor->GetDeviceType() == device::DeviceType::kCPU) {
353     // Other device tensor copy to CPU device tensor.
354     return src_device_tensor->SyncDeviceToHost(copy_size, dst_device_tensor->GetMutablePtr());
355   } else {
356     MS_LOG(ERROR) << "Invalid device type, src device type: " << src_device_tensor->GetDeviceType()
357                   << ", dst device type: " << dst_device_tensor->GetDeviceType();
358     return false;
359   }
360 }
361 
UpdateRefCount(DeviceTensor * const device_tensor,bool is_max_ref_count)362 void UpdateRefCount(DeviceTensor *const device_tensor, bool is_max_ref_count) {
363   MS_EXCEPTION_IF_NULL(device_tensor);
364   if (is_max_ref_count) {
365     device_tensor->set_original_ref_count(SIZE_MAX);
366   } else {
367     device_tensor->IncreaseOriginalRefCount();
368   }
369   device_tensor->ResetRefCount();
370 }
371 
UpdateRefCount(const AnfNodePtr & node,size_t output_idx,bool is_max_ref_count)372 void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count) {
373   MS_EXCEPTION_IF_NULL(node);
374   auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx, false);
375   UpdateRefCount(device_tensor.get(), is_max_ref_count);
376 }
377 
FreeMemoryByDeviceContext(DeviceTensor * const device_tensor,const DeviceContext * device_context)378 void FreeMemoryByDeviceContext(DeviceTensor *const device_tensor, const DeviceContext *device_context) {
379   MS_EXCEPTION_IF_NULL(device_tensor);
380   // The device context may be not accurate in the control flow scene, so need fetch by device name and device id.
381   if ((device_context == nullptr) || (device_context->GetDeviceType() != device_tensor->GetDeviceType())) {
382     const auto &new_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
383       {device_tensor->device_name(), device_tensor->device_id()});
384     MS_EXCEPTION_IF_NULL(new_device_context);
385     new_device_context->device_res_manager_->FreeMemory(device_tensor);
386   } else {
387     device_context->device_res_manager_->FreeMemory(device_tensor);
388   }
389 }
390 
FreeMemoryByValueNode(const std::vector<std::weak_ptr<ValueNode>> & held_by_nodes,DeviceTensor * device_tensor)391 void FreeMemoryByValueNode(const std::vector<std::weak_ptr<ValueNode>> &held_by_nodes, DeviceTensor *device_tensor) {
392   MS_EXCEPTION_IF_NULL(device_tensor);
393   device_tensor->ClearHeldByNodes();
394   device_tensor->set_original_ref_count(SIZE_MAX);
395   device_tensor->ResetRefCount();
396 
397   for (auto &node : held_by_nodes) {
398     auto value_node = node.lock();
399     MS_EXCEPTION_IF_NULL(value_node);
400     auto value = value_node->value();
401     MS_EXCEPTION_IF_NULL(value);
402     auto tensor = value->cast<tensor::TensorPtr>();
403     MS_EXCEPTION_IF_NULL(tensor);
404     tensor->set_device_address(nullptr);
405     runtime::DeviceTensorStore::GetInstance().Remove(value_node.get());
406   }
407 }
408 
FetchKernelTransformType(const AnfNodePtr & node,const KernelGraphPtr & graph,const std::vector<AnfNodePtr> & host_parameters,GraphExecutionStrategy strategy)409 KernelTransformType FetchKernelTransformType(const AnfNodePtr &node, const KernelGraphPtr &graph,
410                                              const std::vector<AnfNodePtr> &host_parameters,
411                                              GraphExecutionStrategy strategy) {
412   // Fetch kernel graph.
413   KernelGraphPtr kernel_graph = nullptr;
414   if (graph == nullptr) {
415     kernel_graph = AnfAlgo::FetchKernelGraph(node.get());
416   } else {
417     kernel_graph = graph;
418   }
419   if (kernel_graph == nullptr) {
420     return KernelTransformType::kUnknown;
421   }
422   if (kernel_graph->is_any_type_input() && node != nullptr && node->isa<CNode>()) {
423     return KernelTransformType::kAnyTypeKernelActor;
424   }
425   // In sink mode, the data exchange between child graphs is expressed as parameters. These parameters are stored
426   // in the graph and should be obtained from the super kernel actor.
427   if (kernel_graph->is_graph_run_mode() &&
428       ((node == nullptr) || node->isa<CNode>() || kernel_graph->IsChildGraphResult(node))) {
429     return KernelTransformType::kSuperKernelActor;
430   }
431 
432   KernelTransformType type = KernelTransformType::kUnknown;
433   MS_EXCEPTION_IF_NULL(node);
434   auto real_node = common::AnfAlgo::FetchRealNodeSkipMonadControl({node, 0}).first;
435   MS_EXCEPTION_IF_NULL(real_node);
436 
437   if (IsDeviceQueueDSActor(real_node, strategy)) {
438     type = KernelTransformType::kDeviceDataSourceActor;
439   } else if (IsHostQueueDSActor(real_node, kernel_graph, host_parameters, strategy)) {
440     type = KernelTransformType::kHostDataSourceActor;
441   } else if (IsCustomActor(real_node)) {
442     type = KernelTransformType::kCustomActor;
443   } else if (IsKernelActor(real_node, strategy)) {
444     type = KernelTransformType::kKernelActor;
445   } else if (IsInternalParameter(real_node, kernel_graph)) {
446     type = KernelTransformType::kInternalParameter;
447   } else if (IsPersistentDeviceTensor(real_node)) {
448     type = KernelTransformType::kDeviceTensorStore;
449   } else {
450     // May exist the from kernel that no need link in the pynative mode.
451     MS_LOG(DEBUG) << "Invalid from kernel: " << node->DebugString();
452   }
453 
454   return type;
455 }
456 
FetchActorName(KernelTransformType kernel_type,const std::string & actor_set_name,const AnfNodePtr & node,const KernelGraphPtr & graph)457 std::string FetchActorName(KernelTransformType kernel_type, const std::string &actor_set_name, const AnfNodePtr &node,
458                            const KernelGraphPtr &graph) {
459   // Fetch kernel graph.
460   KernelGraphPtr kernel_graph = nullptr;
461   if (graph == nullptr) {
462     kernel_graph = AnfAlgo::FetchKernelGraph(node.get());
463   } else {
464     kernel_graph = graph;
465   }
466   if (kernel_graph == nullptr) {
467     return "";
468   }
469 
470   auto real_node = node;
471   if (real_node != nullptr) {
472     real_node = common::AnfAlgo::FetchRealNodeSkipMonadControl({node, 0}).first;
473   }
474   std::string actor_name = "";
475   switch (kernel_type) {
476     case KernelTransformType::kSuperKernelActor:
477       actor_name = kernel_graph->ToString() + kSuperKernelActorNameSuffix;
478       break;
479     case KernelTransformType::kAnyTypeKernelActor:
480       actor_name = kernel_graph->ToString() + kAnyTypeKernelActorNameSuffix;
481       break;
482     case KernelTransformType::kDeviceDataSourceActor:
483       actor_name = actor_set_name + kDeviceDSActorNameSuffix + "_" + std::to_string(kernel_graph->graph_id());
484       break;
485     case KernelTransformType::kHostDataSourceActor:
486       actor_name = actor_set_name + kHostDSActorNameSuffix;
487       break;
488     case KernelTransformType::kCustomActor:
489       MS_EXCEPTION_IF_NULL(real_node);
490       actor_name = AnfUtils::GetCustomActorName(real_node);
491       break;
492     case KernelTransformType::kKernelActor:
493       MS_EXCEPTION_IF_NULL(real_node);
494       actor_name = GetActorIdByKernel(real_node);
495       break;
496     case KernelTransformType::kKernelInferActor:
497       MS_EXCEPTION_IF_NULL(real_node);
498       actor_name = kKernelInferActorNamePrefix + real_node->fullname_with_scope();
499       break;
500     case KernelTransformType::kKernelResizeActor:
501       MS_EXCEPTION_IF_NULL(real_node);
502       actor_name = kKernelResizeActorNamePrefix + real_node->fullname_with_scope();
503       break;
504     default:
505       break;
506   }
507   return actor_name;
508 }
509 
FetchModifiableRefInputIndex(const CNodePtr & cnode)510 std::set<size_t> FetchModifiableRefInputIndex(const CNodePtr &cnode) {
511   MS_EXCEPTION_IF_NULL(cnode);
512 
513   bool has_monad = false;
514   std::set<size_t> ref_input_indexes;
515   for (size_t i = 1; i < cnode->size(); ++i) {
516     auto &input = cnode->inputs().at(i);
517     if (HasAbstractMonad(input)) {
518       has_monad = true;
519     }
520     if (common::AnfAlgo::HasAbstractRef(input)) {
521       (void)ref_input_indexes.insert(i - 1);
522     }
523   }
524 
525   // Only the auto moand node will modify the input.
526   if (has_monad) {
527     return ref_input_indexes;
528   } else {
529     return {};
530   }
531 }
532 
FetchModifiableRefOutputIndex(const CNodePtr & cnode,const KernelGraphPtr & graph)533 std::set<size_t> FetchModifiableRefOutputIndex(const CNodePtr &cnode, const KernelGraphPtr &graph) {
534   MS_EXCEPTION_IF_NULL(cnode);
535   MS_EXCEPTION_IF_NULL(graph);
536   std::set<size_t> ref_output_indexes;
537 
538   auto output_num = AnfAlgo::GetOutputTensorNum(cnode);
539   for (size_t i = 0; i < output_num; ++i) {
540     session::AnfWithOutIndex output_pair(cnode, i);
541     // Only the ref node will modify the ref input corresponding to the output.
542     if (!graph->IsInRefOutputMap(output_pair)) {
543       continue;
544     }
545     auto input_pair = graph->GetRefCorrespondOutput(output_pair);
546     MS_EXCEPTION_IF_NULL(input_pair.first);
547     if (common::AnfAlgo::HasAbstractRef(input_pair.first)) {
548       (void)ref_output_indexes.insert(i);
549     }
550   }
551   return ref_output_indexes;
552 }
553 
is_embedding_cache_server()554 bool is_embedding_cache_server() {
555   return ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
556 }
557 
ReserveKernelMemoryBlocks(size_t size,const DeviceContext * device_context)558 void MemoryTraceManager::ReserveKernelMemoryBlocks(size_t size, const DeviceContext *device_context) {
559   MS_EXCEPTION_IF_NULL(device_context);
560   (*kernel_memory_trace_blocks_)[device_context].reserve(size);
561 }
562 
PickMemoryTrackInfoForGraph(uint32_t graph_id)563 void MemoryTraceManager::PickMemoryTrackInfoForGraph(uint32_t graph_id) {
564   if (graph_to_kernel_memory_trace_blocks_.find(graph_id) == graph_to_kernel_memory_trace_blocks_.end()) {
565     graph_to_kernel_memory_trace_blocks_.emplace(
566       graph_id, std::make_shared<std::map<const DeviceContext *, std::vector<KernelMemoryTraceBlockPtr>>>());
567   }
568   kernel_memory_trace_blocks_ = graph_to_kernel_memory_trace_blocks_[graph_id];
569   MS_EXCEPTION_IF_NULL(kernel_memory_trace_blocks_);
570 
571   if (graph_to_merged_memory_trace_blocks_.find(graph_id) == graph_to_merged_memory_trace_blocks_.end()) {
572     graph_to_merged_memory_trace_blocks_.emplace(
573       graph_id, std::make_shared<std::map<const DeviceContext *, std::vector<MemoryTraceBlockPtr>>>());
574   }
575   merged_memory_trace_blocks_ = graph_to_merged_memory_trace_blocks_[graph_id];
576   MS_EXCEPTION_IF_NULL(merged_memory_trace_blocks_);
577 
578   if (graph_to_kernel_blocks_.find(graph_id) == graph_to_kernel_blocks_.end()) {
579     graph_to_kernel_blocks_.emplace(
580       graph_id, std::make_shared<mindspore::HashMap<CNodePtr, std::vector<KernelMemoryTraceBlockPtr>>>());
581   }
582   kernel_to_block_ = graph_to_kernel_blocks_[graph_id];
583   MS_EXCEPTION_IF_NULL(kernel_to_block_);
584 }
585 
AddKernelMemoryTraceBlock(const KernelMemoryTraceBlockPtr & block,const DeviceContext * device_context)586 void MemoryTraceManager::AddKernelMemoryTraceBlock(const KernelMemoryTraceBlockPtr &block,
587                                                    const DeviceContext *device_context) {
588   MS_EXCEPTION_IF_NULL(block);
589   MS_EXCEPTION_IF_NULL(block->start_);
590   MS_EXCEPTION_IF_NULL(block->end_);
591   (*kernel_memory_trace_blocks_)[device_context].emplace_back(block);
592 }
593 
594 const std::shared_ptr<std::map<const DeviceContext *, std::vector<MemoryTraceBlockPtr>>>
GetMergeBlocks()595   &MemoryTraceManager::GetMergeBlocks() {
596   return merged_memory_trace_blocks_;
597 }
598 
599 const std::shared_ptr<mindspore::HashMap<CNodePtr, std::vector<KernelMemoryTraceBlockPtr>>>
GetAllKernelBlocksnfo()600   &MemoryTraceManager::GetAllKernelBlocksnfo() {
601   return kernel_to_block_;
602 }
603 
MergeBlocks()604 void MemoryTraceManager::MergeBlocks() {
605   merged_memory_trace_blocks_->clear();
606   for (auto &item : *kernel_memory_trace_blocks_) {
607     auto &device_context = item.first;
608     auto &kernel_memory_trace_blocks = item.second;
609     MergeBlocksForSameDeviceContext(&kernel_memory_trace_blocks, &((*merged_memory_trace_blocks_)[device_context]));
610     MS_LOG(DEBUG) << "The number of merged blocks is " << (*merged_memory_trace_blocks_)[device_context].size()
611                   << ", device type: " << device_context->device_context_key().device_name_;
612   }
613 }
614 
MergeBlocksForSameDeviceContext(std::vector<KernelMemoryTraceBlockPtr> * kernel_memory_trace_blocks,std::vector<MemoryTraceBlockPtr> * merged_memory_trace_blocks)615 void MemoryTraceManager::MergeBlocksForSameDeviceContext(
616   std::vector<KernelMemoryTraceBlockPtr> *kernel_memory_trace_blocks,
617   std::vector<MemoryTraceBlockPtr> *merged_memory_trace_blocks) {
618   MS_EXCEPTION_IF_NULL(kernel_memory_trace_blocks);
619   MS_EXCEPTION_IF_NULL(merged_memory_trace_blocks);
620   merged_memory_trace_blocks->clear();
621 
622   if (kernel_memory_trace_blocks->empty()) {
623     MS_LOG(INFO) << "No block to merge.";
624     return;
625   }
626 
627   std::sort(kernel_memory_trace_blocks->begin(), kernel_memory_trace_blocks->end(),
628             [](const KernelMemoryTraceBlockPtr &block1, const KernelMemoryTraceBlockPtr &block2) {
629               return (block1->start_ < block2->start_) ||
630                      ((block1->start_ == block2->start_) && (block1->end_ < block2->end_));
631             });
632   merged_memory_trace_blocks->emplace_back(std::make_shared<MemoryTraceBlock>((*kernel_memory_trace_blocks)[0]->start_,
633                                                                               (*kernel_memory_trace_blocks)[0]->size_));
634   (*kernel_memory_trace_blocks)[0]->in_memory_trace_block_index_ = 0;
635   for (size_t i = 1; i < kernel_memory_trace_blocks->size(); i++) {
636     auto &back = merged_memory_trace_blocks->back();
637     auto &block = (*kernel_memory_trace_blocks)[i];
638     if (block->start_ >= back->end_) {
639       merged_memory_trace_blocks->emplace_back(std::make_shared<MemoryTraceBlock>(block->start_, block->size_));
640     } else if (block->end_ > back->end_) {
641       back->end_ = block->end_;
642       back->size_ = back->end_ - back->start_;
643     }
644     block->in_memory_trace_block_index_ = merged_memory_trace_blocks->size() - 1;
645   }
646 
647   // Reset offset
648   for (size_t i = 0; i < kernel_memory_trace_blocks->size(); i++) {
649     auto &kernel_mem_block = (*kernel_memory_trace_blocks)[i];
650     MS_EXCEPTION_IF_NULL(kernel_mem_block);
651     const auto &mem_block = (*merged_memory_trace_blocks)[kernel_mem_block->in_memory_trace_block_index_];
652     MS_EXCEPTION_IF_NULL(mem_block);
653     if (kernel_mem_block->start_ < mem_block->start_) {
654       MS_LOG(EXCEPTION) << "Invalid memory block, block start: " << kernel_mem_block->start_
655                         << ", block end: " << kernel_mem_block->end_ << ", mem block start: " << mem_block->start_
656                         << ", mem block end: " << mem_block->end_;
657     }
658 
659     kernel_mem_block->offset_in_memory_trace_block_ = kernel_mem_block->start_ - mem_block->start_;
660     (*kernel_to_block_)[kernel_mem_block->kernel_].emplace_back(kernel_mem_block);
661   }
662 }
663 
Clear()664 void MemoryTraceManager::Clear() {
665   kernel_memory_trace_blocks_->clear();
666   merged_memory_trace_blocks_->clear();
667   kernel_to_block_->clear();
668 }
669 
IsTwoPhaseInfer()670 bool IsTwoPhaseInfer() {
671   const auto &phase = PhaseManager::GetInstance().phase();
672   return phase.find("prefill") != std::string::npos || phase.find("increment") != std::string::npos;
673 }
674 
675 std::unordered_map<AnfNode *, std::string> actor_ids;
676 static size_t actor_index = 0;
677 
GetActorIdByKernel(const AnfNodePtr & node)678 std::string GetActorIdByKernel(const AnfNodePtr &node) {
679   MS_EXCEPTION_IF_NULL(node);
680   if (actor_ids.find(node.get()) == actor_ids.end()) {
681     MS_LOG(INFO) << "Cannot get actor id by node:" << node->fullname_with_scope();
682     return node->fullname_with_scope();
683   }
684   return actor_ids[node.get()];
685 }
686 
GenerateActorIdByKernel(const AnfNodePtr & node)687 std::string GenerateActorIdByKernel(const AnfNodePtr &node) {
688   MS_EXCEPTION_IF_NULL(node);
689   auto id = std::to_string(actor_index++) + "_" + node->fullname_with_scope();
690   actor_ids[node.get()] = id;
691   return id;
692 }
693 }  // namespace runtime
694 }  // namespace mindspore
695