• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/kernel_runtime.h"
18 #include <algorithm>
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 #include <set>
23 #include <shared_mutex>
24 #include "ops/ascend_op_name.h"
25 #include "ops/nn_optimizer_op_name.h"
26 #include "ops/sequence_ops.h"
27 #include "include/backend/optimizer/helper.h"
28 #include "include/backend/anf_runtime_algorithm.h"
29 #include "include/common/utils/anfalgo.h"
30 #include "include/backend/kernel_graph.h"
31 #include "runtime/device/ms_device_shape_transfer.h"
32 #include "runtime/pynative/op_runtime_info.h"
33 #include "runtime/device/kernel_runtime_manager.h"
34 #include "include/backend/debug/data_dump/dump_json_parser.h"
35 #include "include/backend/mem_reuse/mem_tracker.h"
36 #include "frontend/operator/ops.h"
37 #include "ir/value.h"
38 #include "utils/ms_context.h"
39 #include "include/common/utils/utils.h"
40 #include "include/common/utils/parallel_context.h"
41 #include "include/common/debug/env_config_parser.h"
42 #include "kernel/framework_utils.h"
43 
44 using mindspore::kernel::Address;
45 using mindspore::kernel::AddressPtr;
46 
47 namespace mindspore {
48 namespace device {
49 constexpr size_t kAtomicCleanInputSize = 2;
50 namespace {
GetGraphInputs(const session::KernelGraph & graph)51 std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
52   auto graph_inputs = graph.inputs();
53   std::vector<AnfNodePtr> result(graph_inputs.begin(), graph_inputs.end());
54   std::set<AnfNodePtr> inputs_set(graph_inputs.begin(), graph_inputs.end());
55   auto kernels = graph.execution_order();
56   for (auto &kernel : kernels) {
57     MS_EXCEPTION_IF_NULL(kernel);
58     auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
59     for (size_t i = 0; i < input_num; ++i) {
60       auto input_node = kernel->input(i + 1);
61       auto input_real_node = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0).first;
62       MS_EXCEPTION_IF_NULL(input_real_node);
63       if (input_real_node->isa<Parameter>() && inputs_set.find(input_real_node) == inputs_set.end()) {
64         (void)inputs_set.insert(input_real_node);
65         (void)result.emplace_back(input_real_node);
66       }
67     }
68   }
69   return result;
70 }
71 
72 // Check whether mutex exists for a stream.
CheckStreamMutexExist(const void * stream,const mindspore::HashMap<const void *,std::shared_ptr<std::mutex>> & mtxs_for_streams,std::shared_mutex * shd_mtx)73 std::pair<bool, std::mutex *> CheckStreamMutexExist(
74   const void *stream, const mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> &mtxs_for_streams,
75   std::shared_mutex *shd_mtx) {
76   MS_EXCEPTION_IF_NULL(stream);
77   MS_EXCEPTION_IF_NULL(shd_mtx);
78   std::shared_lock<std::shared_mutex> shd_lock(*shd_mtx);
79   auto iter = mtxs_for_streams.find(stream);
80   if (iter != mtxs_for_streams.end()) {
81     MS_EXCEPTION_IF_NULL(iter->second);
82     return std::make_pair(true, iter->second.get());
83   }
84   return std::make_pair(false, nullptr);
85 }
86 
87 // Create a mutex for stream.
CreateStreamMutex(const void * stream,std::shared_mutex * shd_mtx,mindspore::HashMap<const void *,std::shared_ptr<std::mutex>> * mtxs_for_streams)88 std::mutex *CreateStreamMutex(const void *stream, std::shared_mutex *shd_mtx,
89                               mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> *mtxs_for_streams) {
90   MS_EXCEPTION_IF_NULL(stream);
91   MS_EXCEPTION_IF_NULL(shd_mtx);
92   MS_EXCEPTION_IF_NULL(mtxs_for_streams);
93 
94   std::unique_lock<std::shared_mutex> unq_lock(*shd_mtx);
95   auto ret_pair = mtxs_for_streams->emplace(stream, std::make_shared<std::mutex>());
96 
97   MS_EXCEPTION_IF_NULL(ret_pair.first->second);
98   return ret_pair.first->second.get();
99 }
100 
IsNeedAllocMem(const AnfNodePtr & node,size_t index)101 bool IsNeedAllocMem(const AnfNodePtr &node, size_t index) {
102   MS_EXCEPTION_IF_NULL(node);
103   const auto &graph = node->func_graph();
104   if (graph == nullptr) {
105     return true;
106   }
107   if (!graph->has_flag(kFlagEnableZeroCopyInGraph)) {
108     return true;
109   }
110   if (node->isa<Parameter>() && graph->output() == nullptr) {
111     return false;
112   }
113   const auto &outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
114   return std::find_if(outputs.begin(), outputs.end(), [&node, &index](const KernelWithIndex &output) {
115            const auto &real_output = common::AnfAlgo::FetchRealNodeSkipMonadControl(output);
116            return ((real_output.first == node) && (real_output.second == index));
117          }) == outputs.end();
118 }
119 }  // namespace
120 constexpr size_t kMinInputSize = 2;
121 KernelRuntime::TbeLaunchKernelModCallBack KernelRuntime::tbe_call_ = nullptr;
~KernelRuntime()122 KernelRuntime::~KernelRuntime() {
123   stream_ = nullptr;
124   copy_data_stream_ = nullptr;
125   communication_stream_ = nullptr;
126 }
127 
LockRuntime(const void * stream)128 std::lock_guard<std::mutex> KernelRuntime::LockRuntime(const void *stream) {
129   MS_EXCEPTION_IF_NULL(stream);
130   // Read-write lock for accessing mtxs_for_streams map.
131   // When the lock of each stream is created, mtxs_for_streams can be accessed concurrently to improve performance.
132   static std::shared_mutex shd_mtx;
133   static mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> mtxs_for_streams;
134 
135   std::mutex *stream_mtx = nullptr;
136   // Check whether mutex exists for a stream.
137   std::pair<bool, std::mutex *> ret_pair = CheckStreamMutexExist(stream, mtxs_for_streams, &shd_mtx);
138   if (ret_pair.first) {
139     stream_mtx = ret_pair.second;
140   } else {
141     // Create a mutex for stream.
142     stream_mtx = CreateStreamMutex(stream, &shd_mtx, &mtxs_for_streams);
143   }
144 
145   MS_EXCEPTION_IF_NULL(stream_mtx);
146   return std::lock_guard<std::mutex>(*stream_mtx);
147 }
148 
Load(const session::KernelGraph &,bool)149 bool KernelRuntime::Load(const session::KernelGraph &, bool) {
150   MS_LOG(INFO) << "Call default load.";
151   return true;
152 }
153 
LoadData(const session::KernelGraph &)154 bool KernelRuntime::LoadData(const session::KernelGraph &) {
155   MS_LOG(INFO) << "Call default load data.";
156   return false;
157 }
158 
NodeOutputDeviceAddressExist(const AnfNodePtr & kernel,size_t index)159 bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
160   MS_EXCEPTION_IF_NULL(kernel);
161   if (AnfAlgo::OutputAddrExist(kernel, index)) {
162     // In subgraph sink mode, if the kernel does not need allocate memory, it cannot be skipped.
163     const auto &address = AnfAlgo::GetOutputAddr(kernel, index, IsNeedAllocMem(kernel, index));
164     MS_EXCEPTION_IF_NULL(address);
165     return address->GetDeviceType() == GetTargetDeviceType() &&
166            address->format() == AnfAlgo::GetOutputFormat(kernel, 0);
167   }
168   return false;
169 }
170 
AssignMemory(const session::KernelGraph & graph)171 void KernelRuntime::AssignMemory(const session::KernelGraph &graph) {
172   auto context_ptr = MsContext::GetInstance();
173   MS_EXCEPTION_IF_NULL(context_ptr);
174   if (UseMemScheduler()) {
175     AssignStaticMemoryValueNode(graph);
176     ResetNodeAddress(graph);
177     AddCommunicationMemInfo(graph);
178   } else {
179     MS_EXCEPTION_IF_NULL(mem_manager_);
180     mem_manager_->ResetDynamicMemory();
181     AssignStaticMemory(graph);
182     AssignDynamicMemory(graph);
183   }
184   UpdateRefNodeOutputMem(graph);
185 }
186 
GetCommunicationInputInfo(const AnfNodePtr & node,size_t * total_size,DeviceAddressPtrList * address_list,std::vector<size_t> * align_size_list) const187 void KernelRuntime::GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size,
188                                               DeviceAddressPtrList *address_list,
189                                               std::vector<size_t> *align_size_list) const {
190   MS_EXCEPTION_IF_NULL(node);
191   MS_EXCEPTION_IF_NULL(total_size);
192   MS_EXCEPTION_IF_NULL(address_list);
193   MS_EXCEPTION_IF_NULL(align_size_list);
194   size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
195   for (size_t i = 0; i < input_num; ++i) {
196     auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, true);
197     auto input_node = input_node_with_index.first;
198     MS_EXCEPTION_IF_NULL(input_node);
199     DeviceAddressPtr address = nullptr;
200     if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
201       address = AnfAlgo::GetMutableOutputAddr(input_node, input_node_with_index.second);
202     } else {
203       address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
204     }
205     MS_EXCEPTION_IF_NULL(address);
206     auto align_size = MemoryManager::GetCommonAlignSize(address->size());
207     *total_size += align_size;
208     (void)address_list->emplace_back(address);
209     (void)align_size_list->emplace_back(align_size);
210   }
211 }
212 
AssignCommunicationInputFromMemoryPool(const AnfNodePtr & node) const213 void KernelRuntime::AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const {
214   if (!common::AnfAlgo::IsCommunicationOp(node)) {
215     return;
216   }
217   MS_EXCEPTION_IF_NULL(node);
218   MS_EXCEPTION_IF_NULL(mem_manager_);
219 
220   size_t total_size = 0;
221   DeviceAddressPtrList address_list;
222   std::vector<size_t> align_size_list;
223   GetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list);
224   if (align_size_list.empty()) {
225     MS_LOG(WARNING) << "No inputs for " << node->fullname_with_scope();
226     return;
227   }
228 
229   if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list,
230                                                     AnfAlgo::GetStreamId(node))) {
231     MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
232   }
233 }
234 
GetCommunicationOutputInfo(const AnfNodePtr & node,size_t * total_size,DeviceAddressPtrList * address_list,std::vector<size_t> * align_size_list) const235 void KernelRuntime::GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size,
236                                                DeviceAddressPtrList *address_list,
237                                                std::vector<size_t> *align_size_list) const {
238   MS_EXCEPTION_IF_NULL(node);
239   MS_EXCEPTION_IF_NULL(total_size);
240   MS_EXCEPTION_IF_NULL(align_size_list);
241   MS_EXCEPTION_IF_NULL(address_list);
242 
243   const auto kernel_mod = AnfAlgo::GetKernelMod(node);
244   MS_EXCEPTION_IF_NULL(kernel_mod);
245   const auto output_size_list = kernel_mod->GetOutputSizeList();
246   for (size_t i = 0; i < output_size_list.size(); ++i) {
247     DeviceAddressPtr address = nullptr;
248     if (AnfAlgo::OutputAddrExist(node, i)) {
249       address = AnfAlgo::GetMutableOutputAddr(node, i);
250     } else {
251       const std::string output_format = AnfAlgo::GetOutputFormat(node, i);
252       const auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
253       const auto tensor_size = AnfAlgo::GetOutputTensorMemSize(node, i);
254       address = CreateDeviceAddress(nullptr, tensor_size, output_format, output_type, {node, i});
255       AnfAlgo::SetOutputAddr(address, i, node.get());
256     }
257     MS_EXCEPTION_IF_NULL(address);
258     auto align_size = MemoryManager::GetCommonAlignSize(address->size());
259     *total_size += align_size;
260     (void)align_size_list->emplace_back(align_size);
261     (void)address_list->emplace_back(address);
262   }
263 }
264 
AssignCommunicationOutputFromMemoryPool(const AnfNodePtr & node) const265 void KernelRuntime::AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const {
266   if (!common::AnfAlgo::IsCommunicationOp(node)) {
267     return;
268   }
269   MS_EXCEPTION_IF_NULL(node);
270   MS_EXCEPTION_IF_NULL(mem_manager_);
271 
272   size_t total_size = 0;
273   std::vector<size_t> align_size_list;
274   std::vector<DeviceAddressPtr> address_list;
275   GetCommunicationOutputInfo(node, &total_size, &address_list, &align_size_list);
276   if (align_size_list.empty()) {
277     MS_LOG(WARNING) << "No output for " << node->fullname_with_scope();
278     return;
279   }
280 
281   if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list,
282                                                     AnfAlgo::GetStreamId(node))) {
283     MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
284   }
285 }
286 
ResetNodeAddress(const session::KernelGraph & kernel_graph)287 void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
288   auto kernels = kernel_graph.execution_order();
289   for (auto &kernel : kernels) {
290     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
291     MS_EXCEPTION_IF_NULL(kernel_mod);
292     size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
293     for (size_t j = 0; j < input_num; ++j) {
294       auto input_index = AnfAlgo::GetInputGraphIdxByKernelIdx(kernel, j);
295       KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, input_index, true);
296       auto index = kernel_with_index.second;
297       auto &input_node = kernel_with_index.first;
298       if (NodeOutputDeviceAddressExist(input_node, index)) {
299         continue;
300       }
301       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, index);
302       if (output_type_id == kTypeUnknown) {
303         MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
304         continue;
305       }
306       auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
307       auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
308                                                 output_type_id, {input_node, index});
309       AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
310     }
311 
312     auto output_sizes = kernel_mod->GetOutputSizeList();
313     for (size_t i = 0; i < output_sizes.size(); ++i) {
314       auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
315       auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
316       AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i,
317                              kernel.get());
318     }
319     auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
320     for (size_t i = 0; i < workspace_sizes.size(); ++i) {
321       AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32),
322                                 i, kernel.get());
323     }
324   }
325 }
326 
RunOpAssignMemory(const std::vector<tensor::TensorPtr> & input_tensors,const session::KernelGraph & graph,bool is_gradient_out,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node)327 void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
328                                       const session::KernelGraph &graph, bool is_gradient_out,
329                                       const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
330   MS_EXCEPTION_IF_NULL(mem_manager_);
331   mem_manager_->ResetDynamicMemory();
332 
333   for (const auto &node : graph.execution_order()) {
334     AssignCommunicationOutputFromMemoryPool(node);
335     AssignCommunicationInputFromMemoryPool(node);
336   }
337 
338   RunOpAssignInputMemory(input_tensors, graph);
339   AssignStaticMemoryValueNode(graph);
340   for (const auto &node : graph.execution_order()) {
341     RunOpAssignOutputMemory(node, tensor_to_node, is_gradient_out);
342     RunOpAssignWorkSpaceMemory(node);
343   }
344   UpdateRefNodeOutputMem(graph);
345 }
346 
RunOpClearMemory(const session::KernelGraph & graph) const347 void KernelRuntime::RunOpClearMemory(const session::KernelGraph &graph) const {
348   // clear input parameter memory resource
349   for (const auto &input_node : graph.inputs()) {
350     MS_EXCEPTION_IF_NULL(input_node);
351     AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
352   }
353   // clear input value node memory resource
354   for (const auto &value_node : graph.graph_value_nodes()) {
355     MS_EXCEPTION_IF_NULL(value_node);
356     AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
357   }
358   for (const auto &cnode : graph.execution_order()) {
359     MS_EXCEPTION_IF_NULL(cnode);
360     // clear output memory resource
361     size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
362     for (size_t index = 0; index < output_num; ++index) {
363       AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
364     }
365     // clear workspace memory resource
366     auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
367     MS_EXCEPTION_IF_NULL(kernel_mod);
368     auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
369     for (size_t index = 0; index < workspace_lists.size(); ++index) {
370       AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
371     }
372   }
373 }
374 
375 #ifdef ENABLE_DEBUGGER
DumpDataEnabled()376 bool KernelRuntime::DumpDataEnabled() {
377   // Returns true if e2e dump is enabled.
378   auto &dump_json_parser = DumpJsonParser::GetInstance();
379   return dump_json_parser.e2e_dump_enabled();
380 }
381 
DumpDataEnabledIteration()382 bool KernelRuntime::DumpDataEnabledIteration() {
383   // Returns true if e2e dump is enabled and current iteration must be dumped.
384   auto &dump_json_parser = DumpJsonParser::GetInstance();
385   if (!dump_json_parser.e2e_dump_enabled()) {
386     return false;
387   }
388 
389   auto cur_iter = dump_json_parser.cur_dump_iter();
390   if (dump_json_parser.IsDumpIter(cur_iter)) {
391     return true;
392   }
393   return false;
394 }
395 #endif
396 
AssignStaticMemory(const session::KernelGraph & graph)397 void KernelRuntime::AssignStaticMemory(const session::KernelGraph &graph) {
398   AssignStaticMemoryInput(graph);
399   AssignStaticMemoryValueNode(graph);
400   AssignStaticMemoryOutput(graph);
401 }
402 
RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> & input_tensors,const session::KernelGraph & graph)403 void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
404                                            const session::KernelGraph &graph) {
405   MS_EXCEPTION_IF_NULL(mem_manager_);
406   if (input_tensors.size() != graph.inputs().size()) {
407     MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
408                       << " should be equal to graph input parameter size " << graph.inputs().size();
409   }
410 
411   for (size_t input_index = 0; input_index < graph.inputs().size(); ++input_index) {
412     auto item = graph.inputs()[input_index];
413     MS_EXCEPTION_IF_NULL(item);
414     if (!item->isa<Parameter>()) {
415       continue;
416     }
417     auto output_size = AnfAlgo::GetOutputTensorNum(item);
418     for (size_t index = 0; index < output_size; index++) {
419       auto current_tensor = input_tensors[input_index];
420       MS_EXCEPTION_IF_NULL(current_tensor);
421       auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
422       // Device address have already create
423       if (output_address != nullptr && output_address->GetDeviceType() == GetTargetDeviceType()) {
424         if (output_address->GetDevicePtr() == nullptr) {
425           if (!mem_manager_->MallocMemFromMemPool(output_address, output_address->size())) {
426             MS_LOG(EXCEPTION) << "Allocate memory failed, size:" << output_address->size();
427           }
428         }
429 
430         AnfAlgo::SetOutputAddr(output_address, index, item.get());
431         continue;
432       }
433       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
434       if (output_type_id == kTypeUnknown) {
435         output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
436       }
437       auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
438       // Device address new create
439       auto device_address =
440         CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
441       MS_EXCEPTION_IF_NULL(device_address);
442       MS_EXCEPTION_IF_NULL(mem_manager_);
443       device_address->set_from_persistent_mem(true);
444       auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
445       if (!ret) {
446         MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
447       }
448       AnfAlgo::SetOutputAddr(device_address, index, item.get());
449     }
450   }
451 }
452 
RunOpAssignOutputMemory(const AnfNodePtr & kernel,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node,bool is_gradient_out)453 void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel,
454                                             const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
455                                             bool is_gradient_out) {
456   MS_EXCEPTION_IF_NULL(kernel);
457   MS_EXCEPTION_IF_NULL(mem_manager_);
458   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
459   MS_EXCEPTION_IF_NULL(kernel_mod);
460   auto output_sizes = kernel_mod->GetOutputSizeList();
461   if (output_sizes.empty()) {
462     return;
463   }
464 
465   // Use device_address Allocated in RunOpMallocPre.
466   for (auto &iter : tensor_to_node) {
467     auto device_address = iter.first->device_address();
468     AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(device_address), iter.second.second,
469                            iter.second.first.get());
470   }
471 
472   for (size_t i = 0; i < output_sizes.size(); ++i) {
473     if (AnfAlgo::OutputAddrExist(kernel, i, false)) {
474       auto address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
475       MS_EXCEPTION_IF_NULL(address);
476       if (address->GetDevicePtr() == nullptr) {
477         MS_EXCEPTION_IF_NULL(mem_manager_);
478         if (!mem_manager_->MallocMemFromMemPool(address, address->size())) {
479           MS_LOG(EXCEPTION) << "Allocate memory failed, size:" << address->size();
480         }
481       }
482       continue;
483     }
484     if (common::AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName ||
485         common::AnfAlgo::GetCNodeName(kernel) == kApplyMomentumDOpName) {
486       auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
487       AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
488       continue;
489     }
490     std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
491     auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
492     auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
493     MS_EXCEPTION_IF_NULL(device_address);
494     device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
495     if (is_gradient_out) {
496       device_address->set_from_persistent_mem(true);
497     }
498     auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
499     if (!ret) {
500       MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i];
501     }
502     AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
503   }
504 }
505 
RunOpAssignWorkSpaceMemory(const AnfNodePtr & kernel)506 void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
507   MS_EXCEPTION_IF_NULL(kernel);
508   MS_EXCEPTION_IF_NULL(mem_manager_);
509   if (kernel->isa<CNode>()) {
510     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
511     MS_EXCEPTION_IF_NULL(kernel_mod);
512     auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
513     for (size_t i = 0; i < workspace_lists.size(); ++i) {
514       auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
515       MS_EXCEPTION_IF_NULL(device_address);
516       auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
517       if (!ret) {
518         MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << workspace_lists[i];
519       }
520       AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
521     }
522   }
523 }
524 
RunOpAssignOutputNodeMemory(const ValuePtr & pre_output_value,const session::KernelGraph & graph) const525 void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value,
526                                                 const session::KernelGraph &graph) const {
527   if (pre_output_value == nullptr) {
528     return;
529   }
530   std::vector<tensor::BaseTensorPtr> pre_output_tensors;
531   TensorValueToTensor(pre_output_value, &pre_output_tensors);
532   auto output_nodes = graph.outputs();
533   if (pre_output_tensors.size() != output_nodes.size()) {
534     MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
535                       << "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
536   }
537   // share output address with pre output tensors
538   for (size_t i = 0; i < output_nodes.size(); ++i) {
539     auto output_node_with_index = common::AnfAlgo::VisitKernel(output_nodes[i], 0);
540     auto output_node = output_node_with_index.first;
541     MS_EXCEPTION_IF_NULL(output_node);
542     if (!output_node->isa<CNode>()) {
543       if (output_node->isa<Parameter>()) {
544         auto param = output_node->cast<ParameterPtr>();
545         if (param != nullptr && !param->has_default()) {
546           MS_LOG(EXCEPTION) << "The output parameter should be real parameter!";
547         }
548       }
549       continue;
550     }
551     auto real_output_cnode = output_node->cast<CNodePtr>();
552     MS_EXCEPTION_IF_NULL(real_output_cnode);
553     MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
554     if (pre_output_tensors[i]->device_address() == nullptr) {
555       MS_LOG(INFO) << "The address of pre output tensor [" << i << "] is a nullptr!";
556       continue;
557     }
558     if (common::AnfAlgo::IsNopNode(real_output_cnode)) {
559       if (real_output_cnode->size() < kMinInputSize) {
560         MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
561                           << " should large than one!";
562       }
563       AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
564                              output_node_with_index.second, real_output_cnode->input(1).get());
565     } else {
566       AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
567                              output_node_with_index.second, output_node_with_index.first.get());
568     }
569   }
570 }
571 
AssignStaticMemoryInput(const session::KernelGraph & graph)572 void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
573   MS_EXCEPTION_IF_NULL(mem_manager_);
574   if (graph.need_inline()) {
575     return;
576   }
577   MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph.graph_id();
578   auto graph_inputs = GetGraphInputs(graph);
579   auto graph_valid_input = graph.valid_inputs();
580   (void)graph_inputs.insert(graph_inputs.end(), graph.child_graph_result().begin(), graph.child_graph_result().end());
581   std::vector<AnfNodePtr> need_alloc_nodes;
582   auto add_need_alloc_nodes = [&need_alloc_nodes, this](const AnfNodePtr &node) {
583     MS_EXCEPTION_IF_NULL(node);
584     if (!node->isa<Parameter>()) {
585       return;
586     }
587     if (NodeOutputDeviceAddressExist(node, 0)) {
588       const auto &address = AnfAlgo::GetOutputAddr(node, 0);
589       MS_EXCEPTION_IF_NULL(address);
590       if (address->GetPtr() != nullptr) {
591         return;
592       }
593     }
594     need_alloc_nodes.push_back(node);
595   };
596 
597   for (size_t i = 0; i < graph_inputs.size(); ++i) {
598     auto input_node = graph_inputs[i];
599     MS_EXCEPTION_IF_NULL(input_node);
600     if (i < graph_valid_input.size() && !graph_valid_input[i]) {
601       continue;
602     }
603     if (common::AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
604       auto outs = common::AnfAlgo::GetAllOutput(input_node);
605       for (auto &out : outs) {
606         MS_EXCEPTION_IF_NULL(out);
607         add_need_alloc_nodes(out);
608       }
609     }
610     add_need_alloc_nodes(input_node);
611   }
612   std::map<AnfNodePtr, AnfNodePtr> shadow_backend_node_map;
613   GetShadowBackendNodeMap(graph, &shadow_backend_node_map);
614   for (auto &item : need_alloc_nodes) {
615     MS_EXCEPTION_IF_NULL(item);
616     if (item->has_user_data(kForwardOutput)) {
617       MS_LOG(DEBUG) << "Skip allocate memory for forward output parameter " << item->DebugString();
618       continue;
619     }
620     auto output_size = AnfAlgo::GetOutputTensorNum(item);
621     for (size_t index = 0; index < output_size; index++) {
622       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
623       // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
624       if (output_type_id == kTypeUnknown) {
625         MS_LOG(INFO) << "It is not suggested to use a lonely weight parameter as the output of graph";
626         continue;
627       }
628       // If kernel has flag kFlagEnableZeroCopyInGraph, the internal parameter and the corresponding
629       // cnode cannot use the same device address.
630       DeviceAddressPtr device_address =
631         (graph.has_flag(kFlagEnableZeroCopyInGraph) ? nullptr : GetInternalDeviceAddress(graph, item));
632       GetDeviceAddress(item, shadow_backend_node_map, index, graph, &device_address);
633       AnfAlgo::SetOutputAddr(device_address, index, item.get());
634     }
635   }
636   MS_LOG(INFO) << "AssignStaticMemoryInput end";
637 }
638 
GetDeviceAddress(const AnfNodePtr & item,const std::map<AnfNodePtr,AnfNodePtr> shadow_backend_node_map,size_t index,const session::KernelGraph & graph,DeviceAddressPtr * device_address)639 void KernelRuntime::GetDeviceAddress(const AnfNodePtr &item,
640                                      const std::map<AnfNodePtr, AnfNodePtr> shadow_backend_node_map, size_t index,
641                                      const session::KernelGraph &graph, DeviceAddressPtr *device_address) {
642   AnfNodePtr shadow_node = nullptr;
643   auto iter = shadow_backend_node_map.find(item);
644   if (iter != shadow_backend_node_map.end()) {
645     shadow_node = iter->second;
646   }
647   if (*device_address == nullptr && shadow_node != nullptr) {
648     auto conj_device_address = AnfAlgo::GetMutableOutputAddr(shadow_node, index);
649     if (conj_device_address != nullptr && conj_device_address->GetDeviceType() == DeviceType::kAscend) {
650       *device_address = conj_device_address;
651     }
652   } else if (*device_address == nullptr) {
653     auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
654     TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
655     *device_address =
656       CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
657   }
658 
659   // Set the flag of no user parameter and not malloc memory.
660   if ((*device_address != nullptr) && item->isa<Parameter>()) {
661     auto input_param = item->cast<ParameterPtr>();
662     MS_EXCEPTION_IF_NULL(input_param);
663     // Unused address will not alloc memory, which is easy to cause problems for weight node, so skip weight node.
664     if (!common::AnfAlgo::IsParameterWeight(input_param) && !input_param->IsUsedByRealKernelInGraph(graph.graph_id())) {
665       MS_LOG(INFO) << "Node:" << item->fullname_with_scope() << " debug name:" << item->DebugString()
666                    << " is not used in the graph " << graph.graph_id();
667       (*device_address)->UpdateFlag(kDeviceAddressFlagNotUsed);
668       return;
669     }
670   }
671 
672   if (*device_address != nullptr && (*device_address)->GetPtr() == nullptr) {
673     auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
674     (*device_address)->set_host_shape(trans::GetRuntimePaddingShape(item, index));
675     MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
676                  << " node:" << item->fullname_with_scope() << " debug:" << item->DebugString() << " index: " << index;
677     if (!graph.has_flag(kFlagEnableZeroCopyInGraph)) {
678       auto ret_ptr = mem_manager_->MallocMem(kStaticMem, tensor_size, *device_address, graph.graph_id());
679       if (ret_ptr == nullptr) {
680         MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
681       }
682       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "AllocStaticMemory", item->fullname_with_scope(),
683                                                      graph.ToString());
684       device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddCompileTimeMemInfo, "AllocStaticMemory", tensor_size, ret_ptr,
685                                                      device::tracker::MemType::kWeight);
686     }
687   }
688 }
689 
AssignStaticMemoryOutput(const session::KernelGraph & graph)690 void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph) {
691   if (graph.need_inline()) {
692     return;
693   }
694   MS_LOG(INFO) << "AssignStaticMemoryOutput start for graph " << graph.graph_id();
695   auto nodes = common::AnfAlgo::GetAllOutput(graph.output(), {prim::kPrimTupleGetItem});
696   std::vector<session::KernelWithIndex> non_communication_op;
697   // Assign Communicate Op Memory firstly.
698   for (const auto &node : nodes) {
699     // Assign output address to nop node that the attribute of "skip_nop_op_addr" is false;
700     auto is_skip = !common::AnfAlgo::IsNopNode(node) || common::AnfAlgo::IsNeedSkipNopOpAddr(node);
701     auto kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(node, 0, is_skip);
702     MS_EXCEPTION_IF_NULL(kernel_with_index.first);
703     if (!kernel_with_index.first->isa<CNode>() || !AnfUtils::IsRealKernel(kernel_with_index.first)) {
704       continue;
705     }
706     if (common::AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {
707       AssignCommunicationNodeMem(kStaticMem, kernel_with_index.first);
708     } else {
709       (void)non_communication_op.emplace_back(kernel_with_index);
710     }
711   }
712 
713   for (const auto &item_with_index : non_communication_op) {
714     MS_EXCEPTION_IF_NULL(item_with_index.first);
715     MS_LOG(DEBUG) << "AssignNodeOutputMem for " << item_with_index.first->fullname_with_scope();
716     AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
717   }
718   MS_LOG(INFO) << "AssignStaticMemoryOutput end";
719 }
720 
UpdateSingleRefNodeMem(const CNodePtr & kernel,const session::KernelGraph & graph,bool reverse) const721 void KernelRuntime::UpdateSingleRefNodeMem(const CNodePtr &kernel, const session::KernelGraph &graph,
722                                            bool reverse) const {
723   MS_EXCEPTION_IF_NULL(kernel);
724   auto output_num = AnfAlgo::GetOutputTensorNum(kernel);
725   if (output_num == 0) {
726     MS_LOG(DEBUG) << "This kernel has no output size.";
727     return;
728   }
729   for (size_t i = 0; i < output_num; ++i) {
730     session::AnfWithOutIndex out_pair(kernel, i);
731     if (graph.IsInRefOutputMap(out_pair)) {
732       auto origin_pair = graph.GetRefCorrespondOutput(out_pair);
733       MS_EXCEPTION_IF_NULL(origin_pair.first);
734       auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
735       MS_EXCEPTION_IF_NULL(origin_node_output_addr);
736       auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
737       if (!reverse && origin_node_output_addr->GetPtr() == nullptr) {
738         continue;
739       }
740       if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
741         MS_LOG(DEBUG) << "REF address is not same, ref node output need address update";
742         MS_LOG(DEBUG) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
743                       << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
744         if (reverse) {
745           AnfAlgo::SetOutputAddr(cur_node_output_addr, origin_pair.second, origin_pair.first.get());
746         } else {
747           if (!cur_node_output_addr->host_shape().empty()) {
748             origin_node_output_addr->set_host_shape(cur_node_output_addr->host_shape());
749           }
750           AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
751         }
752       }
753     }
754   }
755 }
756 
UpdateRefNodeOutputMem(const session::KernelGraph & graph) const757 void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph &graph) const {
758   auto &kernels = graph.execution_order();
759   for (auto &kernel : kernels) {
760     UpdateSingleRefNodeMem(kernel, graph, false);
761   }
762   for (auto it = kernels.rbegin(); it != kernels.rend(); ++it) {
763     auto &kernel = *it;
764     UpdateSingleRefNodeMem(kernel, graph, true);
765   }
766 }
767 
AssignCommunicationNodeMem(MemType type,const AnfNodePtr & node)768 void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
769   if (!reuse_communication_address_.empty()) {
770     type = kDynamicMem;
771   }
772   AssignCommunicationNodeInputMem(type, node);
773   AssignCommunicationNodeOutputMem(type, node);
774   AssignWorkSpaceMem(type, node);
775 }
776 
AssignCommunicationNodeOutputMem(MemType type,const AnfNodePtr & node)777 void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
778   MS_EXCEPTION_IF_NULL(node);
779   MS_EXCEPTION_IF_NULL(mem_manager_);
780   auto kernel_mod = AnfAlgo::GetKernelMod(node);
781   MS_EXCEPTION_IF_NULL(kernel_mod);
782   auto output_sizes = kernel_mod->GetOutputSizeList();
783   if (output_sizes.empty()) {
784     MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
785     return;
786   }
787   auto context_ptr = MsContext::GetInstance();
788   MS_EXCEPTION_IF_NULL(context_ptr);
789   size_t total_size = 0;
790   size_t output_index = 0;
791   std::vector<size_t> align_size_list;
792   for (uint64_t mem_size : output_sizes) {
793     if (AnfAlgo::OutputAddrExist(node, output_index++)) {
794       MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
795       return;
796     }
797     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
798       mem_size = MemoryManager::GetCommonAlignSize(mem_size);
799     }
800     total_size += mem_size;
801     (void)align_size_list.emplace_back(mem_size);
802   }
803 
804   if (align_size_list.empty()) {
805     return;
806   }
807 
808   if (type == kSomasReuseDynamicMem) {
809     bool not_reuse = KernelMemNotReuse(node);
810     if (not_reuse) {
811       type = kDynamicMem;
812       MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
813     }
814   }
815 
816   uint8_t *output_ptr = nullptr;
817   int64_t valid_reuse_index = -1;
818   auto cnode = node->cast<CNodePtr>();
819   MS_EXCEPTION_IF_NULL(cnode);
820   if (common::AnfAlgo::HasNodeAttr(kAttrReuseCommunication, cnode)) {
821     auto reuse_index = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrReuseCommunication);
822     auto it = reuse_communication_address_.find(reuse_index);
823     if (it != reuse_communication_address_.end()) {
824       valid_reuse_index = reuse_index;
825       output_ptr = it->second.second;
826     }
827   }
828 
829   for (size_t j = 0; j < align_size_list.size(); ++j) {
830     std::string output_format = AnfAlgo::GetOutputFormat(node, j);
831     auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
832     auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, {node, j});
833     MS_EXCEPTION_IF_NULL(address);
834     if (output_ptr == nullptr) {
835       output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
836       MS_EXCEPTION_IF_NULL(output_ptr);
837       if (valid_reuse_index != -1) {
838         auto &it = reuse_communication_address_[valid_reuse_index];
839         it.second = output_ptr;
840       }
841     } else {
842       address->set_ptr(output_ptr);
843     }
844     address->set_host_shape(trans::GetRuntimePaddingShape(node, j));
845     AnfAlgo::SetOutputAddr(address, j, node.get());
846     output_ptr += align_size_list[j];
847   }
848 }
KernelMemNotReuse(const AnfNodePtr & node)849 bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
850   MS_EXCEPTION_IF_NULL(node);
851   return false;
852 }
853 
PreAssignCNodeMemory(const AnfNodePtr & anf_node,size_t index) const854 DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) const {
855   MS_EXCEPTION_IF_NULL(anf_node);
856   if (common::AnfAlgo::IsNopNode(anf_node)) {
857     auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, index);
858     return PreAssignCNodeMemory(input_node_with_index.first, input_node_with_index.second);
859   }
860 
861   auto output_size = AnfAlgo::GetOutputTensorMemSize(anf_node, index);
862   std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
863   auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
864   auto address = CreateDeviceAddress(nullptr, output_size, output_format, output_type, {anf_node, index});
865   AnfAlgo::SetOutputAddr(address, index, anf_node.get());
866   return address;
867 }
868 
AssignCommunicationNodeInputMem(MemType type,const AnfNodePtr & node)869 void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) {
870   auto context_ptr = MsContext::GetInstance();
871   MS_EXCEPTION_IF_NULL(context_ptr);
872   MS_EXCEPTION_IF_NULL(node);
873   MS_EXCEPTION_IF_NULL(mem_manager_);
874   size_t total_size = 0;
875   std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size;
876   size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
877   for (size_t i = 0; i < input_num; ++i) {
878     auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, true);
879     auto input_node = input_node_with_index.first;
880     MS_EXCEPTION_IF_NULL(input_node);
881     if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
882       MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
883       return;
884     }
885     DeviceAddressPtr address = nullptr;
886 
887     address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
888 
889     MS_EXCEPTION_IF_NULL(address);
890     auto mem_size = MemoryManager::GetCommonAlignSize(address->size());
891     total_size += mem_size;
892     (void)addr_size.emplace_back(address, mem_size);
893   }
894   if (addr_size.empty()) {
895     return;
896   }
897   if (type == kSomasReuseDynamicMem) {
898     bool not_reuse = KernelMemNotReuse(node);
899     if (not_reuse) {
900       type = kDynamicMem;
901       MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
902     }
903   }
904   auto cnode = node->cast<CNodePtr>();
905   MS_EXCEPTION_IF_NULL(cnode);
906   if (cnode->size() < kMinInputSize) {
907     // communication node's input should contain itself and at least on input
908     MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
909     return;
910   }
911 
912   int64_t valid_reuse_index = -1;
913   uint8_t *input_ptr = nullptr;
914   if (common::AnfAlgo::HasNodeAttr(kAttrReuseCommunication, cnode)) {
915     auto reuse_index = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrReuseCommunication);
916     auto it = reuse_communication_address_.find(reuse_index);
917     if (it != reuse_communication_address_.end()) {
918       valid_reuse_index = reuse_index;
919       input_ptr = it->second.first;
920     }
921   }
922 
923   if (input_ptr == nullptr) {
924     auto first_input_node = cnode->input(1);
925     auto prenode_index = common::AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true);
926     input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size,
927                                               addr_size[0].first, true);
928     if (valid_reuse_index != -1) {
929       auto &it = reuse_communication_address_[valid_reuse_index];
930       it.first = input_ptr;
931     }
932   }
933 
934   for (const auto &iter : addr_size) {
935     MS_EXCEPTION_IF_NULL(iter.first);
936     iter.first->set_ptr(input_ptr);
937     input_ptr += iter.second;
938   }
939 }
940 
AssignNodeOutputMem(MemType type,const AnfNodePtr & node,int index)941 void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) {
942   MS_EXCEPTION_IF_NULL(node);
943   MS_EXCEPTION_IF_NULL(mem_manager_);
944 
945   if (type == kSomasReuseDynamicMem) {
946     bool not_reuse = KernelMemNotReuse(node);
947     if (not_reuse) {
948       type = kDynamicMem;
949       MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
950     }
951   }
952 
953   auto kernel_mod = AnfAlgo::GetKernelMod(node);
954   MS_EXCEPTION_IF_NULL(kernel_mod);
955   auto output_sizes = kernel_mod->GetOutputSizeList();
956   if (output_sizes.empty()) {
957     return;
958   }
959   for (size_t i = 0; i < output_sizes.size(); ++i) {
960     if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
961       continue;
962     }
963     if (NodeOutputDeviceAddressExist(node, i)) {
964       MS_LOG(DEBUG) << "Already malloc index:" << i;
965       continue;
966     }
967     MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memory size:" << output_sizes[i];
968     if (type == kStaticMem) {
969       MS_LOG(INFO) << "Assign Static Memory for Output node, size:" << output_sizes[i]
970                    << " node:" << node->fullname_with_scope();
971     }
972     std::string output_format = AnfAlgo::GetOutputFormat(node, i);
973     auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
974     auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {node, i});
975     MS_EXCEPTION_IF_NULL(device_address);
976 
977     // In subgraph sink mode, graph output should not allocate memory.
978     if (IsNeedAllocMem(node, i)) {
979       uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
980       if (ptr == nullptr && type == kSomasReuseDynamicMem) {
981         MS_LOG(INFO) << "node: " << node->fullname_with_scope() << " could be a RefNode, please check it"
982                      << " output index: " << i << " memory type: " << type;
983       } else {
984         MS_EXCEPTION_IF_NULL(ptr);
985       }
986     } else {
987       MS_LOG(DEBUG) << "Skip mem alloc for device address:" << device_address << " node:" << node->DebugString();
988     }
989     device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
990     AnfAlgo::SetOutputAddr(device_address, i, node.get());
991   }
992 }
993 
AssignExtraStaticMem(const TensorPtr & tensor,const AnfNodePtr & node,size_t index)994 DeviceAddressPtr KernelRuntime::AssignExtraStaticMem(const TensorPtr &tensor, const AnfNodePtr &node, size_t index) {
995   MS_EXCEPTION_IF_NULL(node);
996   MS_EXCEPTION_IF_NULL(mem_manager_);
997   auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
998   MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope()
999                 << "Assign Static Memory for Output node, size:" << tensor_address->size();
1000   auto device_address = CreateDeviceAddress(nullptr, tensor_address->size(), tensor_address->format(),
1001                                             tensor_address->type_id(), {node, index});
1002   MS_EXCEPTION_IF_NULL(device_address);
1003   uint8_t *ptr = mem_manager_->MallocOutputMem(node, index, kStaticMem, tensor_address->size(), device_address, false);
1004   MS_EXCEPTION_IF_NULL(ptr);
1005   return device_address;
1006 }
1007 
AssignValueNodeTensor(const ValueNodePtr & value_node,const ValuePtr & node_value,size_t output_idx)1008 void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
1009                                           size_t output_idx) {
1010   MS_EXCEPTION_IF_NULL(value_node);
1011   MS_EXCEPTION_IF_NULL(node_value);
1012   MS_EXCEPTION_IF_NULL(mem_manager_);
1013   auto ms_context = MsContext::GetInstance();
1014   MS_EXCEPTION_IF_NULL(ms_context);
1015   std::vector<tensor::BaseTensorPtr> tensors;
1016   TensorValueToTensor(node_value, &tensors);
1017   // Graph id should be passed to record static memory if profiling is enabled.
1018   auto kernel_info = dynamic_cast<device::KernelInfo *>(value_node->kernel_info());
1019   MS_EXCEPTION_IF_NULL(kernel_info);
1020   uint32_t graph_id = kernel_info->graph_id();
1021   for (const auto &tensor : tensors) {
1022     if (tensor == nullptr) {
1023       MS_LOG(WARNING) << "Tensor is null";
1024       return;
1025     }
1026     auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1027     if (output_address != nullptr && output_address->GetDeviceType() == GetTargetDeviceType()) {
1028       AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
1029                              value_node.get());
1030       continue;
1031     }
1032 
1033     DeviceAddressPtr address = nullptr;
1034     size_t node_size = 0;
1035     if (node_value->isa<Scalar>()) {
1036       auto scalar_value = node_value->cast<ScalarPtr>();
1037       MS_EXCEPTION_IF_NULL(scalar_value);
1038       TypePtr data_type = scalar_value->type();
1039       MS_EXCEPTION_IF_NULL(data_type);
1040       TypeId type_id = data_type->type_id();
1041       node_size = GetTypeByte(TypeIdToType(type_id));
1042       address = CreateDeviceAddress(nullptr, node_size, kOpFormat_DEFAULT, type_id, {value_node, output_idx});
1043     } else {
1044       node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
1045       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
1046       if (output_type_id == kTypeUnknown) {
1047         output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
1048       }
1049       auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
1050       address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
1051     }
1052     MS_EXCEPTION_IF_NULL(address);
1053     address->set_host_shape(trans::GetRuntimePaddingShape(value_node, output_idx));
1054     address->set_from_persistent_mem(true);
1055     if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
1056         !mem_manager_->MallocMemFromMemPool(address, node_size)) {
1057       MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << node_size;
1058     } else {
1059       MS_LOG(INFO) << "Assign Static Memory for Value node, size:" << node_size
1060                    << " node:" << value_node->fullname_with_scope();
1061       if (mem_manager_->MallocMem(kStaticMem, node_size, address, graph_id) == nullptr) {
1062         MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
1063       }
1064     }
1065     AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
1066     size_t tensor_size = LongToSize(tensor->data().nbytes());
1067     std::string format = "DefaultFormat";
1068     if (tensor->isa<tensor::Tensor>()) {
1069       format = std::dynamic_pointer_cast<tensor::Tensor>(tensor)->device_info().host_format_;
1070     }
1071     if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
1072                                    format, tensor->data_ptr())) {
1073       MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
1074                                    << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
1075                                    << "node dtype is "
1076                                    << common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
1077     }
1078   }
1079 }
1080 
AssignStaticMemoryValueNode(const session::KernelGraph & graph)1081 void KernelRuntime::AssignStaticMemoryValueNode(const session::KernelGraph &graph) {
1082   MS_EXCEPTION_IF_NULL(mem_manager_);
1083   MS_LOG(DEBUG) << "AssignStaticMemoryValueNode start for graph " << graph.graph_id();
1084   auto ms_context = MsContext::GetInstance();
1085   MS_EXCEPTION_IF_NULL(ms_context);
1086   // order the value nodes
1087   std::map<std::string, ValueNodePtr> value_nodes_map;
1088   for (auto &node : graph.graph_value_nodes()) {
1089     MS_EXCEPTION_IF_NULL(node);
1090     value_nodes_map[node->fullname_with_scope()] = node;
1091   }
1092 
1093   for (auto &item : value_nodes_map) {
1094     auto value_node = item.second;
1095     MS_EXCEPTION_IF_NULL(value_node);
1096     if (NodeOutputDeviceAddressExist(value_node, 0)) {
1097       MS_LOG(DEBUG) << "value_node[" << value_node->DebugString() << "] address already exist";
1098       auto device_address = AnfAlgo::GetMutableOutputAddr(value_node, 0);
1099       if (device_address->GetDevicePtr() == nullptr) {
1100         if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1101           if (!mem_manager_->MallocMemFromMemPool(device_address, device_address->GetSize())) {
1102             MS_LOG(EXCEPTION) << "MallocMemFromMemPool failed";
1103           }
1104         } else {
1105           if (mem_manager_->MallocMem(kStaticMem, device_address->GetSize(), device_address, graph.graph_id())) {
1106             MS_LOG(EXCEPTION) << "MallocStaticMem failed";
1107           }
1108         }
1109       }
1110       continue;
1111     }
1112     auto &node_value = value_node->value();
1113     MS_EXCEPTION_IF_NULL(node_value);
1114     MS_LOG(DEBUG) << "Malloc memory for " << value_node->fullname_with_scope();
1115     if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>() || node_value->isa<Scalar>()) {
1116       AssignValueNodeTensor(value_node, node_value, 0);
1117     } else if (node_value->isa<StringImm>()) {
1118       const bool use_mem_from_memory_pool = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) ||
1119                                             ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
1120       auto address = CreateDeviceAddressForStringValue(node_value, use_mem_from_memory_pool, graph.graph_id());
1121       MS_EXCEPTION_IF_NULL(address);
1122       address->set_from_persistent_mem(true);
1123       AnfAlgo::SetOutputAddr(address, 0, value_node.get());
1124     }
1125   }
1126   MS_LOG(DEBUG) << "AssignStaticMemoryValueNode end";
1127 }
1128 
CreateDeviceAddressForStringValue(const ValuePtr & value,bool use_mem_pool,uint32_t graph_id)1129 DeviceAddressPtr KernelRuntime::CreateDeviceAddressForStringValue(const ValuePtr &value, bool use_mem_pool,
1130                                                                   uint32_t graph_id) {
1131   auto value_string = GetValue<std::string>(value);
1132   size_t tensor_size = value_string.size();
1133   DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
1134   MS_EXCEPTION_IF_NULL(address);
1135   address->set_from_persistent_mem(true);
1136   auto ms_context = MsContext::GetInstance();
1137   MS_EXCEPTION_IF_NULL(ms_context);
1138   if (use_mem_pool && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
1139     MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
1140   } else {
1141     MS_LOG(INFO) << "Assign Static Memory for string Value node, size:" << tensor_size;
1142     if (mem_manager_->MallocMem(kStaticMem, tensor_size, address, graph_id) == nullptr) {
1143       MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
1144     }
1145   }
1146   ShapeVector shape = {1, SizeToLong(tensor_size)};
1147   if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value_string.data(), "DefaultFormat")) {
1148     MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
1149   }
1150   return address;
1151 }
1152 
MemSchedulerPreCompute(const AnfNodePtr & kernel,const std::shared_ptr<MemScheduler> & mem_scheduler,void * stream,bool mock,KernelLaunchInfo * kernel_launch_info)1153 bool KernelRuntime::MemSchedulerPreCompute(const AnfNodePtr &kernel, const std::shared_ptr<MemScheduler> &mem_scheduler,
1154                                            void *stream, bool mock, KernelLaunchInfo *kernel_launch_info) {
1155   MS_EXCEPTION_IF_NULL(kernel);
1156   MS_EXCEPTION_IF_NULL(mem_scheduler);
1157   MS_EXCEPTION_IF_NULL(stream);
1158   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1159   MS_EXCEPTION_IF_NULL(kernel_mod);
1160   if (!mock && common::AnfAlgo::IsCommunicationOp(kernel) && !SyncStream()) {
1161     MS_LOG(ERROR) << "SyncStream failed";
1162     return false;
1163   }
1164   bool ret = mem_scheduler->PreCompute(stream);
1165   if (!ret) {
1166     return ret;
1167   }
1168   AssignKernelAddress(mem_scheduler, kernel, kernel_launch_info);
1169   auto cnode = kernel->cast<CNodePtr>();
1170   MS_EXCEPTION_IF_NULL(cnode);
1171   if (mock && common::AnfAlgo::HasNodeAttr(kAttrOffload, cnode) &&
1172       common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
1173     for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
1174       auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
1175       mem_scheduler->SetOffload(device_address);
1176     }
1177   }
1178   return true;
1179 }
1180 
MemSchedulerPostCompute(const session::KernelGraph & graph,const AnfNodePtr & kernel,const std::shared_ptr<MemScheduler> & mem_scheduler,void * stream,bool mock)1181 bool KernelRuntime::MemSchedulerPostCompute(const session::KernelGraph &graph, const AnfNodePtr &kernel,
1182                                             const std::shared_ptr<MemScheduler> &mem_scheduler, void *stream,
1183                                             bool mock) {
1184   MS_EXCEPTION_IF_NULL(kernel);
1185   MS_EXCEPTION_IF_NULL(mem_scheduler);
1186   MS_EXCEPTION_IF_NULL(stream);
1187   if (!mock) {
1188     SyncNodeOutputTensors(mem_scheduler, graph, kernel);
1189   }
1190   bool ret = mem_scheduler->PostCompute(stream);
1191   if (!ret) {
1192     return ret;
1193   }
1194   if (!mock && common::AnfAlgo::IsCommunicationOp(kernel) && !SyncStream()) {
1195     MS_LOG(ERROR) << "SyncStream failed";
1196     return false;
1197   }
1198   return true;
1199 }
1200 
AssignDynamicMemory(const session::KernelGraph & graph)1201 void KernelRuntime::AssignDynamicMemory(const session::KernelGraph &graph) {
1202   MS_EXCEPTION_IF_NULL(mem_manager_);
1203   auto context_ptr = MsContext::GetInstance();
1204   MS_EXCEPTION_IF_NULL(context_ptr);
1205   bool is_enable_mem_reuse = EnvConfigParser::GetInstance().GetSysMemreuse();
1206   auto mem_type = kDynamicMem;
1207   auto &dump_json_parser = DumpJsonParser::GetInstance();
1208   if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 0) {
1209     mindspore::EnvConfigParser::GetInstance().SetSysMemreuse(false);
1210     is_enable_mem_reuse = false;
1211     MS_LOG(INFO) << "Disable Memory Reuse when e2e dump is enable and dump mode is set to dump all kernels";
1212   }
1213 
1214   if (is_enable_mem_reuse) {
1215     MS_LOG(INFO) << "Memory Reuse is enable...";
1216     mem_manager_->MallocSomasDynamicMem(graph);
1217     mem_type = kSomasReuseDynamicMem;
1218   } else {
1219     MS_LOG(INFO) << "Memory Reuse is disable...";
1220   }
1221   auto &execution_nodes = graph.execution_order();
1222   std::vector<CNodePtr> compute_nodes;
1223   // communication nodes first
1224   for (auto &node : execution_nodes) {
1225     if (common::AnfAlgo::IsCommunicationOp(node)) {
1226       // skip if the memory is already allocated
1227       AssignCommunicationNodeMem(mem_type, node);
1228     } else {
1229       (void)compute_nodes.emplace_back(node);
1230     }
1231   }
1232 
1233   // then compute nodes
1234   for (auto &node : compute_nodes) {
1235     AssignNodeOutputMem(mem_type, node, kGetAllOuts);
1236     AssignWorkSpaceMem(mem_type, node);
1237   }
1238 }
1239 
AssignWorkSpaceMem(MemType type,const AnfNodePtr & node)1240 void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
1241   MS_EXCEPTION_IF_NULL(node);
1242   MS_EXCEPTION_IF_NULL(mem_manager_);
1243   auto kernel_mod = AnfAlgo::GetKernelMod(node);
1244   MS_EXCEPTION_IF_NULL(kernel_mod);
1245   size_t index = 0;
1246   for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
1247     if (AnfAlgo::WorkspaceAddrExist(node, index)) {
1248       MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
1249       return;
1250     }
1251     auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
1252     AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
1253     index++;
1254   }
1255 }
1256 
GenLaunchArgs(const mindspore::kernel::KernelMod & kernel_mod,const mindspore::AnfNodePtr & kernel,KernelLaunchInfo * kernel_launch_info)1257 void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
1258                                   KernelLaunchInfo *kernel_launch_info) {
1259   MS_EXCEPTION_IF_NULL(kernel);
1260   MS_EXCEPTION_IF_NULL(kernel_launch_info);
1261   auto cnode = kernel->cast<CNodePtr>();
1262   MS_EXCEPTION_IF_NULL(cnode);
1263   auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
1264   if (cnode_name == kMemSetOpName) {
1265     return GenKernelTensorLaunchArgs(cnode, &(kernel_launch_info->inputs_));
1266   }
1267   auto ms_context = MsContext::GetInstance();
1268   MS_EXCEPTION_IF_NULL(ms_context);
1269   auto skip_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
1270   size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
1271   for (size_t i = 0; i < input_num; ++i) {
1272     auto real_input = AnfAlgo::GetInputGraphIdxByKernelIdx(kernel, i);
1273     auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input, skip_nop_node);
1274     MS_EXCEPTION_IF_NULL(device_address);
1275     const auto &input = device_address->kernel_tensor();
1276     MS_EXCEPTION_IF_NULL(input);
1277     (void)kernel_launch_info->inputs_.emplace_back(input.get());
1278   }
1279 
1280   for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
1281     auto device_address = AnfAlgo::GetOutputAddr(kernel, i, skip_nop_node);
1282     const auto &output = device_address->kernel_tensor();
1283     MS_EXCEPTION_IF_NULL(output);
1284     (void)kernel_launch_info->outputs_.emplace_back(output.get());
1285   }
1286 
1287   for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
1288     auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
1289     const auto &workspace = device_address->kernel_tensor();
1290     MS_EXCEPTION_IF_NULL(workspace);
1291     (void)kernel_launch_info->workspaces_.emplace_back(workspace.get());
1292   }
1293 }
1294 
UseMemScheduler()1295 bool KernelRuntime::UseMemScheduler() {
1296   auto context_ptr = MsContext::GetInstance();
1297   MS_EXCEPTION_IF_NULL(context_ptr);
1298   if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD)) {
1299     return false;
1300   }
1301   // Not use MemScheduler when running single op
1302   return (!context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
1303           (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode));
1304 }
1305 
GenKernelEvents(const session::KernelGraph & graph)1306 void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
1307   auto &kernels = graph.execution_order();
1308   if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
1309     return;
1310   }
1311   auto kernel_events = std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>,
1312                                  std::map<AnfNodePtr, std::vector<std::function<void()>>>>();
1313   auto &kernel_pre_run_events = kernel_events.first;
1314   auto &kernel_post_run_events = kernel_events.second;
1315   for (size_t i = 0; i < kernels.size(); ++i) {
1316     auto &kernel = kernels[i];
1317     if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
1318       continue;
1319     }
1320     auto pre_event = CreateDeviceEvent();
1321     auto post_event = CreateDeviceEvent();
1322     MS_EXCEPTION_IF_NULL(pre_event);
1323     MS_EXCEPTION_IF_NULL(post_event);
1324     pre_event->set_wait_stream(communication_stream_);
1325     pre_event->set_record_stream(stream_);
1326     post_event->set_wait_stream(stream_);
1327     post_event->set_record_stream(communication_stream_);
1328     (void)kernel_pre_run_events[kernel].emplace_back([pre_event]() {
1329       pre_event->RecordEvent();
1330       pre_event->WaitEvent();
1331     });
1332     (void)kernel_post_run_events[kernel].emplace_back([post_event]() { post_event->RecordEvent(); });
1333     bool found_nearest_child = false;
1334     for (size_t j = i + 1; j < kernels.size(); ++j) {
1335       auto &child = kernels[j];
1336       MS_EXCEPTION_IF_NULL(child);
1337       if (common::AnfAlgo::IsCommunicationOp(child)) {
1338         continue;
1339       }
1340       auto input_size = child->size() - 1;
1341       for (size_t k = 0; k < input_size; ++k) {
1342         auto kernel_index =
1343           common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(child, k), 0, true);
1344         if (kernel_index.first == kernel) {
1345           found_nearest_child = true;
1346           break;
1347         }
1348       }
1349       if (found_nearest_child) {
1350         (void)kernel_pre_run_events[child].emplace_back([post_event]() { post_event->WaitEvent(); });
1351         break;
1352       }
1353     }
1354     if (!found_nearest_child) {
1355       (void)kernel_post_run_events[kernel].emplace_back([post_event]() { post_event->WaitEvent(); });
1356     }
1357   }
1358   graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
1359 }
1360 
GenKernelTensorLaunchArgs(const CNodePtr & cnode,std::vector<kernel::KernelTensor * > * kernel_inputs,const std::shared_ptr<MemScheduler> & mem_scheduler)1361 void KernelRuntime::GenKernelTensorLaunchArgs(const CNodePtr &cnode, std::vector<kernel::KernelTensor *> *kernel_inputs,
1362                                               const std::shared_ptr<MemScheduler> &mem_scheduler) {
1363   MS_EXCEPTION_IF_NULL(cnode);
1364   MS_EXCEPTION_IF_NULL(kernel_inputs);
1365   if (cnode->size() != kAtomicCleanInputSize) {
1366     MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
1367   }
1368   MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
1369   auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
1370   // set clean output address
1371   if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
1372 #if defined(__APPLE__)
1373     auto clean_output_indexes = common::AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicOutputIndexs);
1374 #else
1375     auto clean_output_indexes = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
1376 #endif
1377     for (auto index : clean_output_indexes) {
1378       auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
1379       MS_EXCEPTION_IF_NULL(device_address);
1380       const auto &input = device_address->kernel_tensor();
1381       MS_EXCEPTION_IF_NULL(input);
1382       if (mem_scheduler != nullptr) {
1383         GetOrMallocAddress(mem_scheduler, device_address, input);
1384       }
1385       auto real_output_size = AnfAlgo::GetOutputTensorMemSize(pre_node, index);
1386       if (device_address->GetSize() != real_output_size) {
1387         MS_LOG(DEBUG) << "The node:" << pre_node->fullname_with_scope() << " real output size is " << real_output_size;
1388         input->set_size(real_output_size);
1389       }
1390       (void)kernel_inputs->emplace_back(input.get());
1391     }
1392     MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
1393   }
1394   // set clean workspace address
1395   if (common::AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
1396 #if defined(__APPLE__)
1397     auto clean_workspaces_indexes =
1398       common::AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicWorkspaceIndexs);
1399 #else
1400     auto clean_workspaces_indexes =
1401       common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
1402 #endif
1403     for (const auto &index : clean_workspaces_indexes) {
1404       auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
1405       const auto &workspace = device_address->kernel_tensor();
1406       MS_EXCEPTION_IF_NULL(workspace);
1407       if (mem_scheduler != nullptr) {
1408         GetOrMallocAddress(mem_scheduler, device_address, workspace);
1409       }
1410       (void)kernel_inputs->emplace_back(workspace.get());
1411     }
1412   }
1413 }
1414 
LaunchKernelEvent(const std::map<AnfNodePtr,std::vector<std::function<void ()>>> & kernel_events,const AnfNodePtr & node) const1415 void KernelRuntime::LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &kernel_events,
1416                                       const AnfNodePtr &node) const {
1417   if (kernel_events.find(node) == kernel_events.end()) {
1418     return;
1419   }
1420 
1421   for (auto &event : kernel_events.at(node)) {
1422     event();
1423   }
1424 }
1425 
LaunchKernelWithPynativeProfiling(kernel::KernelMod * kernel_mod,const std::string & op_name,const KernelLaunchInfo & kernel_launch_info,void * stream)1426 bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
1427                                                       const KernelLaunchInfo &kernel_launch_info, void *stream) {
1428   MS_EXCEPTION_IF_NULL(kernel_mod);
1429   MS_EXCEPTION_IF_NULL(stream);
1430   float cost_time = 0;
1431   auto start = CreateDeviceTimeEvent();
1432   auto end = CreateDeviceTimeEvent();
1433   MS_EXCEPTION_IF_NULL(start);
1434   MS_EXCEPTION_IF_NULL(end);
1435   start->set_record_stream(stream);
1436   end->set_record_stream(stream);
1437   start->RecordEvent();
1438   bool ret = kernel_mod->Launch(kernel_launch_info.inputs_, kernel_launch_info.workspaces_, kernel_launch_info.outputs_,
1439                                 nullptr);
1440   if (!ret) {
1441     MS_LOG(EXCEPTION) << "Launch kernel failed, kernel name is : " << op_name;
1442   }
1443   end->RecordEvent();
1444   start->SyncEvent();
1445   end->SyncEvent();
1446   start->ElapsedTime(&cost_time, end.get());
1447   MS_LOG(DEBUG) << "Launch kernel:" << op_name << " cost:" << cost_time / kBasicTimeTransferUnit;
1448   return ret;
1449 }
1450 
DebugStreamSync(const CNodePtr & kernel)1451 void KernelRuntime::DebugStreamSync(const CNodePtr &kernel) {
1452   auto ms_context = MsContext::GetInstance();
1453   MS_EXCEPTION_IF_NULL(ms_context);
1454   auto enable_sync_run = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
1455   if (enable_sync_run) {
1456     if (!SyncStream()) {
1457       MS_LOG(EXCEPTION) << "Op " << kernel->fullname_with_scope() << " run failed!";
1458     }
1459   }
1460 }
1461 
GetOrMallocAddress(const std::shared_ptr<MemScheduler> & mem_scheduler,const DeviceAddress * device_address,const kernel::KernelTensorPtr & kernel_tensor)1462 void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
1463                                        const DeviceAddress *device_address,
1464                                        const kernel::KernelTensorPtr &kernel_tensor) {
1465   MS_EXCEPTION_IF_NULL(device_address);
1466   if (device_address->GetDevicePtr() != nullptr) {
1467     kernel_tensor->set_device_ptr(device_address->GetDevicePtr());
1468   } else {
1469     kernel_tensor->set_device_ptr(mem_scheduler->GetOrMalloc(device_address, device_address->GetSize()));
1470   }
1471 }
1472 
AssignKernelAddress(const std::shared_ptr<MemScheduler> & mem_scheduler,const AnfNodePtr & kernel,KernelLaunchInfo * kernel_launch_info) const1473 void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
1474                                         KernelLaunchInfo *kernel_launch_info) const {
1475   MS_EXCEPTION_IF_NULL(kernel);
1476   MS_EXCEPTION_IF_NULL(kernel_launch_info);
1477   auto cnode = kernel->cast<CNodePtr>();
1478   MS_EXCEPTION_IF_NULL(cnode);
1479   auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
1480   if (cnode_name == kMemSetOpName) {
1481     return GenKernelTensorLaunchArgs(cnode, &(kernel_launch_info->inputs_), mem_scheduler);
1482   }
1483   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1484   MS_EXCEPTION_IF_NULL(kernel_mod);
1485   size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
1486   const auto update_parameter = common::AnfAlgo::IsUpdateParameterKernel(cnode);
1487   for (size_t j = 0; j < input_num; ++j) {
1488     auto real_input = AnfAlgo::GetInputGraphIdxByKernelIdx(kernel, j);
1489     auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
1490     auto index = kernel_with_index.second;
1491     auto &input_node = kernel_with_index.first;
1492     auto device_address = AnfAlgo::GetOutputAddr(input_node, index, true);
1493     MS_EXCEPTION_IF_NULL(device_address);
1494     const auto &input = device_address->kernel_tensor();
1495     GetOrMallocAddress(mem_scheduler, device_address, input);
1496     (void)kernel_launch_info->inputs_.emplace_back(input.get());
1497     if (update_parameter && input_node->isa<Parameter>()) {
1498       auto param = input_node->cast<ParameterPtr>();
1499       auto abstract = param->abstract();
1500       MS_EXCEPTION_IF_NULL(abstract);
1501       if (abstract->isa<abstract::AbstractRefTensor>()) {
1502         mem_scheduler->UpdateHighPriorityMem(device_address);
1503       }
1504     }
1505   }
1506 
1507   for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
1508     auto device_address = AnfAlgo::GetOutputAddr(kernel, j, true);
1509     const auto &output = device_address->kernel_tensor();
1510     GetOrMallocAddress(mem_scheduler, device_address, output);
1511     (void)kernel_launch_info->outputs_.emplace_back(output.get());
1512   }
1513 
1514   for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
1515     auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
1516     const auto &workspace = device_address->kernel_tensor();
1517     GetOrMallocAddress(mem_scheduler, device_address, workspace);
1518     (void)kernel_launch_info->workspaces_.emplace_back(workspace.get());
1519   }
1520 }
1521 
SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> & mem_scheduler,const session::KernelGraph & graph,const AnfNodePtr & kernel)1522 void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
1523                                           const session::KernelGraph &graph, const AnfNodePtr &kernel) {
1524   MS_EXCEPTION_IF_NULL(mem_scheduler);
1525   MS_EXCEPTION_IF_NULL(kernel);
1526   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1527   MS_EXCEPTION_IF_NULL(kernel_mod);
1528   auto inputs = AnfAlgo::GetOrCreateAllInputKernelTensors(kernel);
1529   for (size_t input_idx = 0; input_idx < inputs.size(); ++input_idx) {
1530     const auto input_node_index = common::AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true);
1531     if (input_node_index.first != nullptr && input_node_index.first->isa<Parameter>()) {
1532       SyncNodeOutputTensor(mem_scheduler, input_node_index, graph);
1533     }
1534   }
1535   for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) {
1536     SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph);
1537   }
1538 }
1539 
SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> & mem_scheduler,const KernelWithIndex & node_output_index,const session::KernelGraph & graph)1540 void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler,
1541                                          const KernelWithIndex &node_output_index, const session::KernelGraph &graph) {
1542   MS_EXCEPTION_IF_NULL(mem_scheduler);
1543   if (node_output_index.first == nullptr) {
1544     return;
1545   }
1546   auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true);
1547   auto tensor = graph.GetNodeOutputTensor(node_output_index);
1548   if (tensor == nullptr) {
1549     return;
1550   }
1551   if (device_address == nullptr) {
1552     tensor->data_sync(false);
1553     tensor->set_device_address(nullptr);
1554     tensor->set_sync_status(kNeedSyncHostToDevice);
1555     return;
1556   }
1557   if (!SyncStream()) {
1558     MS_LOG(EXCEPTION) << "SyncStream failed";
1559   }
1560   auto origin_ptr = device_address->GetDevicePtr();
1561   if (device_address->GetDevicePtr() == nullptr) {
1562     device_address->SetDevicePtr(mem_scheduler->GetOrMalloc(device_address.get(), device_address->GetSize()));
1563   }
1564   tensor->set_device_address(device_address);
1565   tensor->data_sync(false);
1566   tensor->set_device_address(nullptr);
1567   device_address->SetDevicePtr(origin_ptr);
1568   tensor->set_sync_status(kNeedSyncHostToDevice);
1569 }
1570 
InitGraphInputTensors(const std::shared_ptr<MemScheduler> & mem_scheduler,const session::KernelGraph & graph) const1571 void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
1572                                           const session::KernelGraph &graph) const {
1573   MS_EXCEPTION_IF_NULL(mem_scheduler);
1574   auto &input_nodes = graph.input_nodes();
1575   auto &input_tensors = graph.input_tensors();
1576   if (input_tensors.size() != input_nodes.size()) {
1577     MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
1578   }
1579   mem_scheduler->ClearMemNeedInit();
1580   for (size_t i = 0; i < input_tensors.size(); ++i) {
1581     auto input_node = input_nodes[i];
1582     if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
1583       continue;
1584     }
1585     auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
1586     auto tensor = input_tensors[i];
1587     MS_EXCEPTION_IF_NULL(tensor);
1588     auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1589     const auto tensor_size = LongToSize(tensor->data().nbytes());
1590     bool need_sync = false;
1591     if (tensor->NeedSyncHostToDevice()) {
1592       need_sync = true;
1593     } else if (tensor_address != device_address) {
1594       tensor->data_sync(false);
1595       need_sync = true;
1596     }
1597     if (mem_scheduler->HasDeviceMem(device_address.get())) {
1598       device_address->set_ptr(nullptr);
1599     }
1600     if (need_sync) {
1601       const auto &shape = trans::GetRuntimePaddingShape(input_node, 0);
1602       if (device_address->GetPtr() != nullptr) {
1603         (void)device_address->SyncHostToDevice(shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
1604                                                tensor->device_info().host_format_, tensor->data_ptr());
1605       } else {
1606         mem_scheduler->AddMemNeedInit(device_address.get());
1607       }
1608     }
1609     MemPriority priority = kMemPriorityLow;
1610     const auto &parameter = input_node->cast<ParameterPtr>();
1611     if (common::AnfAlgo::IsParameterWeight(parameter) || graph.IsUpdatedParameter(parameter)) {
1612       priority = kMemPriorityHigh;
1613     }
1614     mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority);
1615     tensor->set_sync_status(kNoNeedSync);
1616   }
1617 }
1618 
AddCommunicationMemInfo(const session::KernelGraph & graph)1619 void KernelRuntime::AddCommunicationMemInfo(const session::KernelGraph &graph) {
1620   const auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1621   for (size_t compute_index = 0; compute_index < graph.execution_order().size(); ++compute_index) {
1622     const auto &kernel = graph.execution_order()[compute_index];
1623     MS_EXCEPTION_IF_NULL(kernel);
1624     if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
1625       continue;
1626     }
1627     auto device_address_to_key = [](const DeviceAddressPtr &device_address) -> void * { return device_address.get(); };
1628     size_t input_total_size = 0;
1629     DeviceAddressPtrList input_address_list;
1630     std::vector<size_t> input_align_size_list;
1631     GetCommunicationInputInfo(kernel, &input_total_size, &input_address_list, &input_align_size_list);
1632     if (input_address_list.size() > 1) {
1633       std::vector<const void *> input_address_key_list;
1634       (void)std::transform(input_address_list.begin(), input_address_list.end(),
1635                            std::back_inserter(input_address_key_list), device_address_to_key);
1636       mem_scheduler->AddContinuousMemInfo(true, compute_index, input_total_size, input_align_size_list,
1637                                           input_address_key_list);
1638     }
1639     size_t output_total_size = 0;
1640     DeviceAddressPtrList output_address_list;
1641     std::vector<size_t> output_align_size_list;
1642     GetCommunicationOutputInfo(kernel, &output_total_size, &output_address_list, &output_align_size_list);
1643     if (output_address_list.size() > 1) {
1644       std::vector<const void *> output_address_key_list;
1645       (void)std::transform(output_address_list.begin(), output_address_list.end(),
1646                            std::back_inserter(output_address_key_list), device_address_to_key);
1647       mem_scheduler->AddContinuousMemInfo(false, compute_index, output_total_size, output_align_size_list,
1648                                           output_address_key_list);
1649     }
1650   }
1651 }
1652 
LaunchKernel(const session::KernelGraph & graph,const AnfNodePtr & kernel,const std::shared_ptr<MemScheduler> & mem_scheduler,bool mock)1653 bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel,
1654                                  const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock) {
1655   MS_EXCEPTION_IF_NULL(kernel);
1656   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1657   MS_EXCEPTION_IF_NULL(kernel_mod);
1658   KernelLaunchInfo kernel_launch_info;
1659   auto stream = GetKernelStream(kernel);
1660   MS_EXCEPTION_IF_NULL(stream);
1661   bool ret = true;
1662   if (mem_scheduler != nullptr) {
1663     ret = MemSchedulerPreCompute(kernel, mem_scheduler, stream, mock, &kernel_launch_info);
1664     if (!ret) {
1665       return ret;
1666     }
1667   } else {
1668     GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
1669   }
1670   if (!mock) {
1671     if (pynative_mode_profiling_flag_) {
1672       ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_launch_info, stream);
1673     } else {
1674       ret = kernel_mod->Launch(kernel_launch_info.inputs_, kernel_launch_info.workspaces_, kernel_launch_info.outputs_,
1675                                stream);
1676     }
1677     if (!ret) {
1678       return ret;
1679     }
1680   }
1681   if (mem_scheduler != nullptr) {
1682     ret = MemSchedulerPostCompute(graph, kernel, mem_scheduler, stream, mock);
1683   }
1684   return ret;
1685 }
1686 
LaunchKernelMod(const session::KernelGraph & graph,bool mock)1687 bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock) {
1688   auto context_ptr = MsContext::GetInstance();
1689   MS_EXCEPTION_IF_NULL(context_ptr);
1690   std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
1691 
1692   if (UseMemScheduler()) {
1693     mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1694     MS_EXCEPTION_IF_NULL(mem_scheduler);
1695     mem_scheduler->Reset();
1696     mem_scheduler->Update();
1697     InitGraphInputTensors(mem_scheduler, graph);
1698   }
1699 
1700   const auto &kernels = graph.execution_order();
1701   std::map<AnfNodePtr, std::vector<std::function<void()>>> kernel_pre_run_events;
1702   std::map<AnfNodePtr, std::vector<std::function<void()>>> kernel_post_run_events;
1703   auto events_iter = graph_kernel_events_map_.find(graph.graph_id());
1704   if (events_iter != graph_kernel_events_map_.end()) {
1705     kernel_pre_run_events = events_iter->second.first;
1706     kernel_post_run_events = events_iter->second.second;
1707   }
1708   for (size_t i = 0; i < kernels.size(); ++i) {
1709     LaunchKernelEvent(kernel_pre_run_events, kernels[i]);
1710     auto &kernel = kernels[i];
1711     MS_EXCEPTION_IF_NULL(kernel);
1712     if (common::AnfAlgo::IsDynamicShape(kernel)) {
1713       auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1714       MS_EXCEPTION_IF_NULL(kernel_mod);
1715       opt::InferOp(kernel);
1716       auto inputs = AnfAlgo::GetOrCreateAllInputKernelTensors(kernel);
1717       auto outputs = AnfAlgo::GetOrCreateAllOutputKernelTensors(kernel);
1718       if (kernel_mod->Resize(inputs, outputs) == static_cast<int>(kernel::KRET_RESIZE_FAILED)) {
1719         MS_LOG(EXCEPTION) << "Node " << kernel->fullname_with_scope() << " Resize  failed.";
1720       }
1721       KernelLaunchInfo kernel_launch_info;
1722       device::KernelRuntime::GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
1723       // allocate workspace size
1724       std::vector<KernelTensor *> workspaces;
1725       if (AnfAlgo::GetKernelType(kernel) == KernelType::TBE_KERNEL) {
1726         MS_EXCEPTION_IF_NULL(tbe_call_);
1727         tbe_call_(kernel, kernel_mod, &workspaces);
1728       } else {
1729         workspaces = kernel_launch_info.workspaces_;
1730       }
1731 
1732       auto ret = kernel_mod->Launch(kernel_launch_info.inputs_, workspaces, kernel_launch_info.outputs_, stream_);
1733       if (!ret) {
1734         MS_LOG(ERROR) << "Launch kernel failed, kernel full name: " << kernel->fullname_with_scope();
1735         return false;
1736       }
1737 
1738       if (!SyncStream()) {
1739         MS_LOG(ERROR) << "SyncStream failed";
1740         return false;
1741       }
1742       kernel::UpdateNodeShape(kernel);
1743     } else {
1744       // Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
1745       // kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata.
1746       // And hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
1747       if (common::AnfAlgo::HasNodeAttr(kAttrNopOp, kernel)) {
1748         for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(kernel); idx += 1) {
1749           auto real_input = AnfAlgo::GetInputGraphIdxByKernelIdx(kernel, idx);
1750           auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input);
1751           AnfAlgo::SetOutputAddr(device_address, idx, kernel.get());
1752         }
1753         continue;
1754       }
1755       auto ret = LaunchKernel(graph, kernel, mem_scheduler, mock);
1756       if (!ret) {
1757         MS_LOG(ERROR) << "Launch kernel failed.";
1758         return false;
1759       }
1760       KernelLaunchProfiling(kernel->fullname_with_scope());
1761       DebugStreamSync(kernel);
1762     }
1763     LaunchKernelEvent(kernel_post_run_events, kernels[i]);
1764   }
1765   if (UseMemScheduler() && !mock) {
1766     SyncParameter(graph, mem_scheduler);
1767   }
1768   return true;
1769 }
1770 
SyncParameter(const session::KernelGraph & graph,const std::shared_ptr<MemScheduler> & mem_scheduler) const1771 void KernelRuntime::SyncParameter(const session::KernelGraph &graph,
1772                                   const std::shared_ptr<MemScheduler> &mem_scheduler) const {
1773   MS_EXCEPTION_IF_NULL(mem_scheduler);
1774   auto &input_nodes = graph.input_nodes();
1775   auto &input_tensors = graph.input_tensors();
1776   if (input_tensors.size() != input_nodes.size()) {
1777     MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
1778   }
1779 
1780   for (size_t i = 0; i < input_tensors.size(); ++i) {
1781     auto input_node = input_nodes[i];
1782     if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
1783       continue;
1784     }
1785     auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
1786     MS_EXCEPTION_IF_NULL(device_address);
1787     auto parameter = input_node->cast<ParameterPtr>();
1788     MS_EXCEPTION_IF_NULL(parameter);
1789     if (!common::AnfAlgo::IsParameterWeight(parameter) && !graph.IsUpdatedParameter(parameter)) {
1790       continue;
1791     }
1792     auto tensor = input_tensors[i];
1793     MS_EXCEPTION_IF_NULL(tensor);
1794     if (mem_scheduler->HasDeviceMem(device_address.get())) {
1795       auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh);
1796       device_address->set_ptr(device_ptr);
1797       tensor->set_device_address(device_address);
1798       tensor->set_sync_status(kNeedSyncDeviceToHost);
1799     }
1800     if (graph.IsUpdatedParameter(parameter)) {
1801       tensor->SetIsUpdateByDevice();
1802     }
1803   }
1804 }
1805 
UseMemSchedulerIfNeeded(const session::KernelGraph & graph)1806 void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
1807   auto context_ptr = MsContext::GetInstance();
1808   MS_EXCEPTION_IF_NULL(context_ptr);
1809   if (!UseMemScheduler()) {
1810     return;
1811   }
1812   auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1813   MS_EXCEPTION_IF_NULL(mem_scheduler);
1814   if (mem_scheduler->optimized()) {
1815     return;
1816   }
1817   mem_scheduler->SetMemHandler(std::make_shared<MemHandler>(mem_manager_));
1818   mem_scheduler->SetTotalStep(graph.execution_order().size());
1819 
1820   if (mem_scheduler->need_record_event()) {
1821     (void)LaunchKernelMod(graph, true);
1822     mem_scheduler->set_need_record_event(false);
1823   }
1824   auto ret = mem_scheduler->Optimize();
1825   if (!ret) {
1826     MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
1827   }
1828 }
1829 
LaunchKernels(const session::KernelGraph & graph)1830 bool KernelRuntime::LaunchKernels(const session::KernelGraph &graph) {
1831   UseMemSchedulerIfNeeded(graph);
1832   if (!LaunchKernelMod(graph)) {
1833     MS_LOG(ERROR) << "LaunchKernelMod failed!";
1834     return false;
1835   }
1836   auto ms_context = MsContext::GetInstance();
1837   MS_EXCEPTION_IF_NULL(ms_context);
1838   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
1839     if (!SyncStream()) {
1840       MS_LOG(ERROR) << "SyncStream failed";
1841       return false;
1842     }
1843   }
1844   return true;
1845 }
1846 
ClearGraphRuntimeResource(uint32_t graph_id)1847 void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
1848   MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
1849 }
1850 }  // namespace device
1851 }  // namespace mindspore
1852