• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/kernel_runtime.h"
18 #include <functional>
19 #include <utility>
20 #include <vector>
21 #include <set>
22 #include "backend/optimizer/common/helper.h"
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "backend/session/kernel_graph.h"
25 #include "common/trans.h"
26 #include "debug/data_dump/dump_json_parser.h"
27 #include "frontend/operator/ops.h"
28 #include "ir/value.h"
29 #include "utils/ms_context.h"
30 #include "utils/ms_utils.h"
31 #include "utils/shape_utils.h"
32 #include "utils/utils.h"
33 #include "frontend/parallel/context.h"
34 #include "debug/env_config_parser.h"
35 #include "pipeline/pynative/pynative_profiling.h"
36 #if ((defined ENABLE_CPU) && (!defined _WIN32))
37 #include "ps/ps_cache/ps_cache_manager.h"
38 #endif
39 
40 using mindspore::kernel::Address;
41 using mindspore::kernel::AddressPtr;
42 
43 namespace mindspore {
44 namespace device {
45 constexpr float kMaxMemReuseFactor = 0.8;
46 constexpr float kMinMemReuseFactor = 0.5;
47 constexpr float kRetryFactor = 0.1;
48 constexpr size_t kAtomicCleanInputSize = 2;
49 namespace {
GetGraphInputs(const session::KernelGraph & graph)50 std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
51   auto graph_inputs = graph.inputs();
52   std::vector<AnfNodePtr> result(graph_inputs.begin(), graph_inputs.end());
53   std::set<AnfNodePtr> inputs_set(graph_inputs.begin(), graph_inputs.end());
54   auto kernels = graph.execution_order();
55   for (auto &kernel : kernels) {
56     MS_EXCEPTION_IF_NULL(kernel);
57     auto input_num = AnfAlgo::GetInputTensorNum(kernel);
58     for (size_t i = 0; i < input_num; ++i) {
59       auto input_node = kernel->input(i + 1);
60       auto input_real_node = AnfAlgo::VisitKernelWithReturnType(input_node, 0).first;
61       MS_EXCEPTION_IF_NULL(input_real_node);
62       if (input_real_node->isa<Parameter>() && inputs_set.find(input_real_node) == inputs_set.end()) {
63         (void)inputs_set.insert(input_real_node);
64         (void)result.emplace_back(input_real_node);
65       }
66     }
67   }
68   return result;
69 }
70 }  // namespace
71 constexpr size_t kMinInputSize = 2;
~KernelRuntime()72 KernelRuntime::~KernelRuntime() {
73   stream_ = nullptr;
74   independent_stream_ = nullptr;
75   communication_stream_ = nullptr;
76 }
77 
Load(const session::KernelGraph &,bool)78 bool KernelRuntime::Load(const session::KernelGraph &, bool) {
79   MS_LOG(INFO) << "Call default load.";
80   return true;
81 }
82 
LoadData(const session::KernelGraph &)83 bool KernelRuntime::LoadData(const session::KernelGraph &) {
84   MS_LOG(INFO) << "Call default load data.";
85   return false;
86 }
87 
NodeOutputDeviceAddressExist(const AnfNodePtr & kernel,size_t index)88 bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
89   MS_EXCEPTION_IF_NULL(kernel);
90   if (AnfAlgo::OutputAddrExist(kernel, index)) {
91     const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
92     MS_EXCEPTION_IF_NULL(address);
93     return address->DeviceType() == GetTargetDeviceAddressType();
94   }
95   return false;
96 }
97 
AssignMemory(const session::KernelGraph & graph)98 void KernelRuntime::AssignMemory(const session::KernelGraph &graph) {
99   auto context_ptr = MsContext::GetInstance();
100   MS_EXCEPTION_IF_NULL(context_ptr);
101   auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
102   if (enable_mem_scheduler) {
103     AssignStaticMemoryValueNode(graph);
104     ResetNodeAddress(graph);
105   } else {
106     MS_EXCEPTION_IF_NULL(mem_manager_);
107     mem_manager_->ResetDynamicMemory();
108     AssignStaticMemory(graph);
109     AssignDynamicMemory(graph);
110   }
111   UpdateRefNodeOutputMem(graph);
112 }
113 
RunOpGetCommunicationInputInfo(const AnfNodePtr & node,size_t * total_size,std::vector<DeviceAddressPtr> * address_list,std::vector<size_t> * align_size_list) const114 void KernelRuntime::RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size,
115                                                    std::vector<DeviceAddressPtr> *address_list,
116                                                    std::vector<size_t> *align_size_list) const {
117   MS_EXCEPTION_IF_NULL(node);
118   MS_EXCEPTION_IF_NULL(total_size);
119   MS_EXCEPTION_IF_NULL(address_list);
120   MS_EXCEPTION_IF_NULL(align_size_list);
121   size_t input_num = AnfAlgo::GetInputTensorNum(node);
122   for (size_t i = 0; i < input_num; ++i) {
123     auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i, true);
124     auto input_node = input_node_with_index.first;
125     MS_EXCEPTION_IF_NULL(input_node);
126     DeviceAddressPtr address = nullptr;
127     if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
128       address = AnfAlgo::GetMutableOutputAddr(input_node, input_node_with_index.second);
129     } else {
130       if (input_node->isa<CNode>()) {
131         address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
132       } else {
133         MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
134       }
135     }
136     MS_EXCEPTION_IF_NULL(address);
137     auto align_size = MemoryManager::GetCommonAlignSize(address->size());
138     *total_size += align_size;
139     address_list->emplace_back(address);
140     align_size_list->emplace_back(align_size);
141   }
142 }
143 
RunOpAssignCommunicationInput(const AnfNodePtr & node) const144 void KernelRuntime::RunOpAssignCommunicationInput(const AnfNodePtr &node) const {
145   if (!AnfAlgo::IsCommunicationOp(node)) {
146     return;
147   }
148   MS_EXCEPTION_IF_NULL(node);
149   MS_EXCEPTION_IF_NULL(mem_manager_);
150   size_t total_size = 0;
151   std::vector<DeviceAddressPtr> address_list;
152   std::vector<size_t> align_size_list;
153   RunOpGetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list);
154   if (address_list.empty()) {
155     return;
156   }
157 
158   auto cnode = node->cast<CNodePtr>();
159   MS_EXCEPTION_IF_NULL(cnode);
160   if (cnode->inputs().size() < kMinInputSize) {
161     MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
162     return;
163   }
164 
165   if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list)) {
166     MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
167   }
168 }
169 
RunOpGetCommunicationOutputInfo(const AnfNodePtr & node,size_t * total_size,std::vector<size_t> * align_size_list,std::vector<DeviceAddressPtr> * device_address_list) const170 void KernelRuntime::RunOpGetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size,
171                                                     std::vector<size_t> *align_size_list,
172                                                     std::vector<DeviceAddressPtr> *device_address_list) const {
173   MS_EXCEPTION_IF_NULL(node);
174   MS_EXCEPTION_IF_NULL(total_size);
175   MS_EXCEPTION_IF_NULL(align_size_list);
176   MS_EXCEPTION_IF_NULL(device_address_list);
177   auto runtime_info = node->user_data<session::OpRuntimeInfo>();
178   auto output_num = AnfAlgo::GetOutputTensorNum(node);
179   for (size_t i = 0; i < output_num; ++i) {
180     MS_EXCEPTION_IF_NULL(runtime_info);
181     DeviceAddressPtr address = nullptr;
182     if (AnfAlgo::OutputAddrExist(node, i)) {
183       address = AnfAlgo::GetMutableOutputAddr(node, i);
184     } else {
185       std::string output_format = runtime_info->output_format(i);
186       auto output_type = runtime_info->output_type(i);
187       address =
188         CreateDeviceAddress(nullptr, runtime_info->output_tensor_size(i), output_format, output_type, {node, i});
189     }
190     MS_EXCEPTION_IF_NULL(address);
191     auto align_size = MemoryManager::GetCommonAlignSize(address->size());
192     *total_size += align_size;
193     align_size_list->emplace_back(align_size);
194     device_address_list->emplace_back(address);
195   }
196 }
197 
RunOpAssignCommunicationOutput(const AnfNodePtr & node) const198 void KernelRuntime::RunOpAssignCommunicationOutput(const AnfNodePtr &node) const {
199   if (!AnfAlgo::IsCommunicationOp(node)) {
200     return;
201   }
202 
203   MS_EXCEPTION_IF_NULL(node);
204   MS_EXCEPTION_IF_NULL(mem_manager_);
205 
206   size_t total_size = 0;
207   std::vector<size_t> align_size_list;
208   std::vector<DeviceAddressPtr> device_address_list;
209   RunOpGetCommunicationOutputInfo(node, &total_size, &align_size_list, &device_address_list);
210 
211   if (align_size_list.empty()) {
212     return;
213   }
214 
215   if (!mem_manager_->MallocContinuousMemFromMemPool(device_address_list, total_size, align_size_list)) {
216     MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
217   }
218 }
219 
RunOpMallocPre(const session::KernelGraph & graph,const std::vector<tensor::TensorPtr> & input_tensors)220 void KernelRuntime::RunOpMallocPre(const session::KernelGraph &graph,
221                                    const std::vector<tensor::TensorPtr> &input_tensors) {
222   const auto &nodes = graph.execution_order();
223   // Malloc for Node output
224   for (const auto &node : nodes) {
225     auto output_num = AnfAlgo::GetOutputTensorNum(node);
226     for (size_t i = 0; i < output_num; ++i) {
227       MS_EXCEPTION_IF_NULL(node);
228       auto runtime_info = node->user_data<session::OpRuntimeInfo>();
229       MS_EXCEPTION_IF_NULL(runtime_info);
230       auto const &output_format = runtime_info->output_format(i);
231       auto output_type = runtime_info->output_type(i);
232       auto tensor_size = runtime_info->output_tensor_size(i);
233       // Create DeviceAddress without ptr.
234       // Get real device ptr after KernelBuild finish.
235       auto device_address = CreateDeviceAddress(nullptr, tensor_size, output_format, output_type);
236       device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
237       AnfAlgo::SetOutputAddr(device_address, i, node.get());
238     }
239   }
240 
241   // Malloc for graph input
242   if (input_tensors.size() != graph.inputs().size()) {
243     MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
244                       << " should be equal to graph input parameter size " << graph.inputs().size();
245   }
246   for (size_t input_index = 0; input_index < graph.inputs().size(); ++input_index) {
247     auto item = graph.inputs()[input_index];
248     MS_EXCEPTION_IF_NULL(item);
249     if (!item->isa<Parameter>()) {
250       continue;
251     }
252     auto output_size = AnfAlgo::GetOutputTensorNum(item);
253     for (size_t index = 0; index < output_size; index++) {
254       auto current_tensor = input_tensors[input_index];
255       MS_EXCEPTION_IF_NULL(current_tensor);
256       auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
257       if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
258         AnfAlgo::SetOutputAddr(output_address, index, item.get());
259         continue;
260       }
261       auto op_runtime_info = item->user_data<session::OpRuntimeInfo>();
262       MS_EXCEPTION_IF_NULL(op_runtime_info);
263       TypeId output_type_id = op_runtime_info->output_type(index);
264       auto output_tensor_size = op_runtime_info->output_tensor_size(index);
265       auto output_format = op_runtime_info->output_format(index);
266       auto device_address =
267         CreateDeviceAddress(nullptr, output_tensor_size, output_format, output_type_id, {item, index});
268       AnfAlgo::SetOutputAddr(device_address, index, item.get());
269       current_tensor->set_device_address(device_address);
270       current_tensor->set_sync_status(kNeedSyncHostToDevice);
271     }
272   }
273 }
274 
ResetNodeAddress(const session::KernelGraph & kernel_graph)275 void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
276   auto kernels = kernel_graph.execution_order();
277   for (auto &kernel : kernels) {
278     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
279     MS_EXCEPTION_IF_NULL(kernel_mod);
280     size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
281     for (size_t j = 0; j < input_num; ++j) {
282       auto input_index = AnfAlgo::GetRealInputIndex(kernel, j);
283       KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, input_index, true);
284       auto index = kernel_with_index.second;
285       auto &input_node = kernel_with_index.first;
286       if (NodeOutputDeviceAddressExist(input_node, index)) {
287         continue;
288       }
289       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, index);
290       if (output_type_id == kTypeUnknown) {
291         MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
292         continue;
293       }
294       auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
295       auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
296                                                 output_type_id, {input_node, index});
297       AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
298     }
299 
300     auto output_sizes = kernel_mod->GetOutputSizeList();
301     for (size_t i = 0; i < output_sizes.size(); ++i) {
302       auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
303       auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
304       AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i,
305                              kernel.get());
306     }
307     auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
308     for (size_t i = 0; i < workspace_sizes.size(); ++i) {
309       AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32),
310                                 i, kernel.get());
311     }
312   }
313 }
314 
RunOpAssignMemory(const std::vector<tensor::TensorPtr> & input_tensors,const session::KernelGraph & graph,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node)315 void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
316                                       const session::KernelGraph &graph,
317                                       const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
318   MS_EXCEPTION_IF_NULL(mem_manager_);
319   mem_manager_->ResetDynamicMemory();
320 
321   for (const auto &node : graph.execution_order()) {
322     RunOpAssignCommunicationOutput(node);
323     RunOpAssignCommunicationInput(node);
324   }
325 
326   RunOpAssignInputMemory(input_tensors, graph);
327   AssignStaticMemoryValueNode(graph);
328   for (const auto &node : graph.execution_order()) {
329     RunOpAssignOutputMemory(node, tensor_to_node);
330     RunOpAssignWorkSpaceMemory(node);
331   }
332   UpdateRefNodeOutputMem(graph);
333 }
334 
RunOpClearMemory(const session::KernelGraph & graph) const335 void KernelRuntime::RunOpClearMemory(const session::KernelGraph &graph) const {
336   // clear input parameter memory resource
337   for (const auto &input_node : graph.inputs()) {
338     MS_EXCEPTION_IF_NULL(input_node);
339     AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
340   }
341   // clear input value node memory resource
342   for (const auto &value_node : graph.graph_value_nodes()) {
343     MS_EXCEPTION_IF_NULL(value_node);
344     AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
345   }
346   for (const auto &cnode : graph.execution_order()) {
347     MS_EXCEPTION_IF_NULL(cnode);
348     // clear output memory resource
349     size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
350     for (size_t index = 0; index < output_num; ++index) {
351       AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
352     }
353     // clear workspace memory resource
354     auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
355     MS_EXCEPTION_IF_NULL(kernel_mod);
356     auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
357     for (size_t index = 0; index < workspace_lists.size(); ++index) {
358       AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
359     }
360   }
361 }
362 
363 #ifdef ENABLE_DEBUGGER
DumpDataEnabled()364 bool KernelRuntime::DumpDataEnabled() {
365   auto &dump_json_parser = DumpJsonParser::GetInstance();
366   return dump_json_parser.e2e_dump_enabled();
367 }
368 
DumpDataEnabledIteration()369 bool KernelRuntime::DumpDataEnabledIteration() {
370   auto &dump_json_parser = DumpJsonParser::GetInstance();
371   if (!dump_json_parser.e2e_dump_enabled()) {
372     return false;
373   }
374 
375   auto cur_iter = dump_json_parser.cur_dump_iter();
376   if (dump_json_parser.IsDumpIter(cur_iter)) {
377     return true;
378   }
379   return false;
380 }
381 #endif
382 
AssignStaticMemory(const session::KernelGraph & graph)383 void KernelRuntime::AssignStaticMemory(const session::KernelGraph &graph) {
384   AssignStaticMemoryInput(graph);
385   AssignStaticMemoryValueNode(graph);
386   AssignStaticMemoryOutput(graph);
387 }
388 
RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> & input_tensors,const session::KernelGraph & graph)389 void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
390                                            const session::KernelGraph &graph) {
391   MS_EXCEPTION_IF_NULL(mem_manager_);
392   if (input_tensors.size() != graph.inputs().size()) {
393     MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
394                       << " should be equal to graph input parameter size " << graph.inputs().size();
395   }
396 
397   for (size_t input_index = 0; input_index < graph.inputs().size(); ++input_index) {
398     auto item = graph.inputs()[input_index];
399     MS_EXCEPTION_IF_NULL(item);
400     if (!item->isa<Parameter>()) {
401       continue;
402     }
403     auto output_size = AnfAlgo::GetOutputTensorNum(item);
404     for (size_t index = 0; index < output_size; index++) {
405       auto current_tensor = input_tensors[input_index];
406       MS_EXCEPTION_IF_NULL(current_tensor);
407       auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
408       if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
409         if (output_address->ptr_ == nullptr) {
410           if (!mem_manager_->MallocMemFromMemPool(output_address, output_address->size())) {
411             MS_LOG(EXCEPTION) << "Allocate memory failed, size:" << output_address->size();
412           }
413         }
414 
415         AnfAlgo::SetOutputAddr(output_address, index, item.get());
416         continue;
417       }
418       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
419       if (output_type_id == kTypeUnknown) {
420         output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
421       }
422       auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
423       auto device_address =
424         CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
425       MS_EXCEPTION_IF_NULL(device_address);
426       MS_EXCEPTION_IF_NULL(mem_manager_);
427       auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
428       if (!ret) {
429         MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
430       }
431       AnfAlgo::SetOutputAddr(device_address, index, item.get());
432     }
433   }
434 }
435 
RunOpAssignOutputMemory(const AnfNodePtr & kernel,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node)436 void KernelRuntime::RunOpAssignOutputMemory(
437   const AnfNodePtr &kernel, const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
438   MS_EXCEPTION_IF_NULL(kernel);
439   MS_EXCEPTION_IF_NULL(mem_manager_);
440   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
441   MS_EXCEPTION_IF_NULL(kernel_mod);
442   auto output_sizes = kernel_mod->GetOutputSizeList();
443   if (output_sizes.empty()) {
444     return;
445   }
446 
447   // Use device_address Allocated in RunOpMallocPre.
448   for (auto &iter : tensor_to_node) {
449     auto device_address = iter.first->device_address();
450     AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(device_address), iter.second.second,
451                            iter.second.first.get());
452   }
453 
454   for (size_t i = 0; i < output_sizes.size(); ++i) {
455     if (AnfAlgo::OutputAddrExist(kernel, i, false)) {
456       auto address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
457       MS_EXCEPTION_IF_NULL(address);
458       if (address->ptr() == nullptr) {
459         MS_EXCEPTION_IF_NULL(mem_manager_);
460         if (!mem_manager_->MallocMemFromMemPool(address, address->size())) {
461           MS_LOG(EXCEPTION) << "Allocate memory failed, size:" << address->size();
462         }
463       }
464       continue;
465     }
466     if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) {
467       auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
468       AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
469       continue;
470     }
471     std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
472     auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
473     auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
474     device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
475     MS_EXCEPTION_IF_NULL(device_address);
476     auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
477     if (!ret) {
478       MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i];
479     }
480     AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
481   }
482 }
483 
RunOpAssignWorkSpaceMemory(const AnfNodePtr & kernel)484 void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
485   MS_EXCEPTION_IF_NULL(kernel);
486   MS_EXCEPTION_IF_NULL(mem_manager_);
487   if (kernel->isa<CNode>()) {
488     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
489     MS_EXCEPTION_IF_NULL(kernel_mod);
490     auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
491     for (size_t i = 0; i < workspace_lists.size(); ++i) {
492       auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
493       MS_EXCEPTION_IF_NULL(device_address);
494       auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
495       if (!ret) {
496         MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << workspace_lists[i];
497       }
498       AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
499     }
500   }
501 }
502 
RunOpAssignOutputNodeMemory(const ValuePtr & pre_output_value,const session::KernelGraph & graph)503 void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph) {
504   if (pre_output_value == nullptr) {
505     return;
506   }
507   std::vector<tensor::TensorPtr> pre_output_tensors;
508   TensorValueToTensor(pre_output_value, &pre_output_tensors);
509   auto output_nodes = graph.outputs();
510   if (pre_output_tensors.size() != output_nodes.size()) {
511     MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
512                       << "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
513   }
514   // share output address with pre output tensors
515   for (size_t i = 0; i < output_nodes.size(); ++i) {
516     auto output_node_with_index = AnfAlgo::VisitKernel(output_nodes[i], 0);
517     auto output_node = output_node_with_index.first;
518     MS_EXCEPTION_IF_NULL(output_node);
519     if (!output_node->isa<CNode>()) {
520       if (output_node->isa<Parameter>()) {
521         auto param = output_node->cast<ParameterPtr>();
522         if (param != nullptr && !param->has_default()) {
523           MS_LOG(EXCEPTION) << "The output parameter should be real parameter!";
524         }
525       }
526       continue;
527     }
528     auto real_output_cnode = output_node->cast<CNodePtr>();
529     MS_EXCEPTION_IF_NULL(real_output_cnode);
530     MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
531     if (pre_output_tensors[i]->device_address() == nullptr) {
532       MS_LOG(INFO) << "The address of pre output tensor [" << i << "] is a nullptr!";
533       continue;
534     }
535     if (opt::IsNopNode(real_output_cnode)) {
536       if (real_output_cnode->inputs().size() < kMinInputSize) {
537         MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
538                           << " should large than one!";
539       }
540       AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
541                              output_node_with_index.second, real_output_cnode->input(1).get());
542     } else {
543       AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
544                              output_node_with_index.second, output_node_with_index.first.get());
545     }
546   }
547 }
548 
AssignStaticMemoryInput(const session::KernelGraph & graph)549 void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
550   MS_EXCEPTION_IF_NULL(mem_manager_);
551   MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph.graph_id();
552   auto graph_inputs = GetGraphInputs(graph);
553   auto graph_valid_input = graph.valid_inputs();
554   graph_inputs.insert(graph_inputs.end(), graph.child_graph_result().begin(), graph.child_graph_result().end());
555   std::vector<AnfNodePtr> need_alloc_nodes;
556   auto add_need_alloc_nodes = [&need_alloc_nodes, graph, this](const AnfNodePtr &node) {
557     MS_EXCEPTION_IF_NULL(node);
558     if (!node->isa<Parameter>()) {
559       return;
560     }
561     if (NodeOutputDeviceAddressExist(node, 0)) {
562       return;
563     }
564     auto input_param = node->cast<ParameterPtr>();
565     if (input_param != nullptr && !input_param->IsUsedByRealKernelInGraph(graph.graph_id())) {
566       return;
567     }
568     need_alloc_nodes.push_back(node);
569   };
570 
571   for (size_t i = 0; i < graph_inputs.size(); ++i) {
572     auto input_node = graph_inputs[i];
573     MS_EXCEPTION_IF_NULL(input_node);
574     if (i < graph_valid_input.size() && !graph_valid_input[i]) {
575       continue;
576     }
577     if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
578       auto outs = AnfAlgo::GetAllOutput(input_node);
579       for (auto &out : outs) {
580         MS_EXCEPTION_IF_NULL(out);
581         add_need_alloc_nodes(out);
582       }
583     }
584     add_need_alloc_nodes(input_node);
585   }
586 #if ((defined ENABLE_CPU) && (!defined _WIN32))
587   bool ps_cache_check = false;
588 #endif
589   for (auto &item : need_alloc_nodes) {
590     MS_EXCEPTION_IF_NULL(item);
591     auto output_size = AnfAlgo::GetOutputTensorNum(item);
592     for (size_t index = 0; index < output_size; index++) {
593       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
594       // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
595       if (output_type_id == kTypeUnknown) {
596         MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
597         continue;
598       }
599       DeviceAddressPtr device_address = nullptr;
600 #if ((defined ENABLE_CPU) && (!defined _WIN32))
601       const std::string &param_name = item->fullname_with_scope();
602       if (ps::ps_cache_instance.IsHashTable(param_name)) {
603         MS_LOG(INFO) << "Parameter(" << param_name << ")"
604                      << " enables the embeddingLookup cache in parameter server training mode.";
605         // PS embeddingLookup cache check.
606         if (!ps_cache_check) {
607           CheckIfSupportPSEmbeddingCache(graph);
608           ps_cache_check = true;
609         }
610         const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
611         MS_EXCEPTION_IF_NULL(address.addr);
612         device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index),
613                                              output_type_id, {item, index});
614         AnfAlgo::SetOutputAddr(device_address, index, item.get());
615         continue;
616       }
617 #endif
618       auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
619       device_address =
620         CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
621       MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
622                    << " node:" << item->fullname_with_scope() << " index: " << index;
623       if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph.graph_id()) == nullptr) {
624         MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
625       }
626       AnfAlgo::SetOutputAddr(device_address, index, item.get());
627     }
628   }
629   MS_LOG(INFO) << "AssignStaticMemoryInput end";
630 }
631 
AssignStaticMemoryOutput(const session::KernelGraph & graph)632 void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph) {
633   MS_LOG(INFO) << "AssignStaticMemoryOutput start for graph " << graph.graph_id();
634   auto nodes = AnfAlgo::GetAllOutput(graph.output(), {prim::kPrimTupleGetItem});
635   std::vector<session::KernelWithIndex> non_communication_op;
636   // Assign Communicate Op Memory firstly.
637   for (const auto &node : nodes) {
638     auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
639     MS_EXCEPTION_IF_NULL(kernel_with_index.first);
640     if (!kernel_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_with_index.first)) {
641       continue;
642     }
643     if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {
644       AssignCommunicationNodeMem(kStaticMem, kernel_with_index.first);
645     } else {
646       non_communication_op.emplace_back(kernel_with_index);
647     }
648   }
649 
650   for (const auto &item_with_index : non_communication_op) {
651     MS_EXCEPTION_IF_NULL(item_with_index.first);
652     MS_LOG(DEBUG) << "AssignNodeOutputMem for " << item_with_index.first->fullname_with_scope();
653     AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
654   }
655   MS_LOG(INFO) << "AssignStaticMemoryOutput end";
656 }
657 
UpdateRefNodeOutputMem(const session::KernelGraph & graph)658 void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph &graph) {
659   auto &kernels = graph.execution_order();
660   for (auto &kernel : kernels) {
661     MS_EXCEPTION_IF_NULL(kernel);
662     auto output_num = AnfAlgo::GetOutputTensorNum(kernel);
663     if (output_num == 0) {
664       MS_LOG(DEBUG) << "This kernel has no output size.";
665       continue;
666     }
667     for (size_t i = 0; i < output_num; ++i) {
668       session::AnfWithOutIndex out_pair(kernel, i);
669       if (graph.IsInRefOutputMap(out_pair)) {
670         auto origin_pair = graph.GetRefCorrespondOutput(out_pair);
671         MS_EXCEPTION_IF_NULL(origin_pair.first);
672         auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
673         MS_EXCEPTION_IF_NULL(origin_node_output_addr);
674         auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
675         if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
676           MS_LOG(DEBUG) << "REF address is not same, ref node output need address update";
677           MS_LOG(DEBUG) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
678                         << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
679           AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
680         }
681       }
682     }
683   }
684 }
685 
AssignCommunicationNodeMem(MemType type,const AnfNodePtr & node)686 void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
687   AssignCommunicationNodeInputMem(type, node);
688   AssignCommunicationNodeOutputMem(type, node);
689   AssignWorkSpaceMem(type, node);
690 }
691 
GenKernelEvents(const session::KernelGraph & graph)692 void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
693   auto &kernels = graph.execution_order();
694   if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
695     return;
696   }
697   auto kernel_events =
698     std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
699   auto &kernel_pre_run_events = kernel_events.first;
700   auto &kernel_post_run_events = kernel_events.second;
701   kernel_pre_run_events.resize(kernels.size());
702   kernel_post_run_events.resize(kernels.size());
703   for (size_t i = 0; i < kernels.size(); ++i) {
704     auto &kernel = kernels[i];
705     if (!AnfAlgo::IsCommunicationOp(kernel)) {
706       continue;
707     }
708     auto pre_event = CreateDeviceEvent();
709     auto post_event = CreateDeviceEvent();
710     MS_EXCEPTION_IF_NULL(pre_event);
711     MS_EXCEPTION_IF_NULL(post_event);
712     pre_event->set_wait_stream(communication_stream_);
713     pre_event->set_record_stream(stream_);
714     post_event->set_wait_stream(stream_);
715     post_event->set_record_stream(communication_stream_);
716     kernel_pre_run_events[i].emplace_back([pre_event]() {
717       pre_event->RecordEvent();
718       pre_event->WaitEvent();
719     });
720     kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); });
721     bool found_nearest_child = false;
722     for (size_t j = i + 1; j < kernels.size(); ++j) {
723       auto &child = kernels[j];
724       MS_EXCEPTION_IF_NULL(child);
725       if (AnfAlgo::IsCommunicationOp(child)) {
726         continue;
727       }
728       auto input_size = child->inputs().size() - 1;
729       for (size_t k = 0; k < input_size; ++k) {
730         auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
731         if (kernel_index.first == kernel) {
732           found_nearest_child = true;
733           break;
734         }
735       }
736       if (found_nearest_child) {
737         kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); });
738         break;
739       }
740     }
741     if (!found_nearest_child) {
742       kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); });
743     }
744   }
745   graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
746 }
747 
AssignCommunicationNodeOutputMem(MemType type,const AnfNodePtr & node)748 void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
749   MS_EXCEPTION_IF_NULL(node);
750   MS_EXCEPTION_IF_NULL(mem_manager_);
751   auto kernel_mod = AnfAlgo::GetKernelMod(node);
752   MS_EXCEPTION_IF_NULL(kernel_mod);
753   auto output_sizes = kernel_mod->GetOutputSizeList();
754   if (output_sizes.empty()) {
755     MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
756     return;
757   }
758   auto context_ptr = MsContext::GetInstance();
759   MS_EXCEPTION_IF_NULL(context_ptr);
760   size_t total_size = 0;
761   size_t output_index = 0;
762   std::vector<size_t> align_size_list;
763   for (uint64_t mem_size : output_sizes) {
764     if (AnfAlgo::OutputAddrExist(node, output_index++)) {
765       MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
766       return;
767     }
768     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
769       mem_size = MemoryManager::GetCommonAlignSize(mem_size);
770     }
771     total_size += mem_size;
772     align_size_list.emplace_back(mem_size);
773   }
774 
775   if (align_size_list.empty()) {
776     return;
777   }
778 
779   if (type == kSomasReuseDynamicMem) {
780     bool not_reuse = KernelMemNotReuse(node);
781     if (not_reuse) {
782       type = kDynamicMem;
783       MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
784     }
785   }
786 
787   uint8_t *output_ptr = nullptr;
788   for (size_t j = 0; j < align_size_list.size(); ++j) {
789     std::string output_format = AnfAlgo::GetOutputFormat(node, j);
790     auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
791     auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, {node, j});
792     MS_EXCEPTION_IF_NULL(address);
793     if (output_ptr == nullptr) {
794       output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
795       MS_EXCEPTION_IF_NULL(output_ptr);
796     } else {
797       address->set_ptr(output_ptr);
798     }
799     AnfAlgo::SetOutputAddr(address, j, node.get());
800     output_ptr += align_size_list[j];
801   }
802 }
KernelMemNotReuse(const AnfNodePtr & node)803 bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
804   MS_EXCEPTION_IF_NULL(node);
805   return false;
806 }
807 
PreAssignCNodeMemory(const AnfNodePtr & anf_node,size_t index) const808 DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) const {
809   MS_EXCEPTION_IF_NULL(anf_node);
810   if (!anf_node->isa<CNode>()) {
811     MS_LOG(EXCEPTION) << "anf_node should be a cnode";
812   }
813   auto cnode = anf_node->cast<CNodePtr>();
814   MS_EXCEPTION_IF_NULL(cnode);
815   if (opt::IsNopNode(cnode)) {
816     const size_t kNopNodeInputSize = 2;
817     if (cnode->size() != kNopNodeInputSize) {
818       MS_LOG(EXCEPTION) << cnode->fullname_with_scope() << " has invalid input size: " << cnode->size();
819     }
820     auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
821     return PreAssignCNodeMemory(input_node_with_index.first, input_node_with_index.second);
822   }
823   auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
824   MS_EXCEPTION_IF_NULL(kernel_mod);
825   auto output_sizes = kernel_mod->GetOutputSizeList();
826   if (output_sizes.size() <= index) {
827     MS_LOG(EXCEPTION) << "Previous node output size " << output_sizes.size() << " <= node index " << index;
828   }
829   std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
830   auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
831   auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, {anf_node, index});
832   AnfAlgo::SetOutputAddr(address, index, anf_node.get());
833   return address;
834 }
835 
AssignCommunicationNodeInputMem(MemType type,const AnfNodePtr & node)836 void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) {
837   auto context_ptr = MsContext::GetInstance();
838   MS_EXCEPTION_IF_NULL(context_ptr);
839   MS_EXCEPTION_IF_NULL(node);
840   MS_EXCEPTION_IF_NULL(mem_manager_);
841   size_t total_size = 0;
842   std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size;
843   size_t input_num = AnfAlgo::GetInputTensorNum(node);
844   for (size_t i = 0; i < input_num; ++i) {
845     auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i, true);
846     auto input_node = input_node_with_index.first;
847     MS_EXCEPTION_IF_NULL(input_node);
848     if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
849       MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
850       return;
851     }
852     DeviceAddressPtr address = nullptr;
853     if (input_node->isa<CNode>()) {
854       address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
855     } else {
856       MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
857     }
858     MS_EXCEPTION_IF_NULL(address);
859     auto mem_size = MemoryManager::GetCommonAlignSize(address->size());
860     total_size += mem_size;
861     addr_size.emplace_back(address, mem_size);
862   }
863   if (addr_size.empty()) {
864     return;
865   }
866   if (type == kSomasReuseDynamicMem) {
867     bool not_reuse = KernelMemNotReuse(node);
868     if (not_reuse) {
869       type = kDynamicMem;
870       MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
871     }
872   }
873   auto cnode = node->cast<CNodePtr>();
874   MS_EXCEPTION_IF_NULL(cnode);
875   if (cnode->inputs().size() < kMinInputSize) {
876     // communication node's input should contain itself and at least on input
877     MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
878     return;
879   }
880   auto first_input_node = cnode->input(1);
881   auto prenode_index = AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true);
882   uint8_t *input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size,
883                                                      addr_size[0].first, true);
884   for (const auto &iter : addr_size) {
885     MS_EXCEPTION_IF_NULL(iter.first);
886     iter.first->set_ptr(input_ptr);
887     input_ptr += iter.second;
888   }
889 }
890 
AssignNodeOutputMem(MemType type,const AnfNodePtr & node,int index)891 void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) {
892   MS_EXCEPTION_IF_NULL(node);
893   MS_EXCEPTION_IF_NULL(mem_manager_);
894 
895   if (type == kSomasReuseDynamicMem) {
896     bool not_reuse = KernelMemNotReuse(node);
897     if (not_reuse) {
898       type = kDynamicMem;
899       MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
900     }
901   }
902 
903   auto kernel_mod = AnfAlgo::GetKernelMod(node);
904   MS_EXCEPTION_IF_NULL(kernel_mod);
905   auto output_sizes = kernel_mod->GetOutputSizeList();
906   if (output_sizes.empty()) {
907     return;
908   }
909   for (size_t i = 0; i < output_sizes.size(); ++i) {
910     if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
911       continue;
912     }
913     if (NodeOutputDeviceAddressExist(node, i)) {
914       MS_LOG(INFO) << "Already malloc index:" << i;
915       continue;
916     }
917     MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memory size:" << output_sizes[i];
918     if (type == kStaticMem) {
919       MS_LOG(INFO) << "Assign Static Memory for Output node, size:" << output_sizes[i]
920                    << " node:" << node->fullname_with_scope();
921     }
922     std::string output_format = AnfAlgo::GetOutputFormat(node, i);
923     auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
924     auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {node, i});
925     MS_EXCEPTION_IF_NULL(device_address);
926     uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
927     MS_EXCEPTION_IF_NULL(ptr);
928     device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
929     AnfAlgo::SetOutputAddr(device_address, i, node.get());
930   }
931 }
932 
AssignExtraStaticMem(const TensorPtr & tensor,const AnfNodePtr & node,size_t index)933 DeviceAddressPtr KernelRuntime::AssignExtraStaticMem(const TensorPtr &tensor, const AnfNodePtr &node, size_t index) {
934   MS_EXCEPTION_IF_NULL(node);
935   MS_EXCEPTION_IF_NULL(mem_manager_);
936   auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
937   MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope()
938                 << "Assign Static Memory for Output node, size:" << tensor_address->size();
939   auto device_address = CreateDeviceAddress(nullptr, tensor_address->size(), tensor_address->format(),
940                                             tensor_address->type_id(), {node, index});
941   MS_EXCEPTION_IF_NULL(device_address);
942   uint8_t *ptr = mem_manager_->MallocOutputMem(node, index, kStaticMem, tensor_address->size(), device_address, false);
943   MS_EXCEPTION_IF_NULL(ptr);
944   return device_address;
945 }
946 
AssignValueNodeTensor(const ValueNodePtr & value_node,const ValuePtr & node_value,size_t output_idx)947 void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
948                                           size_t output_idx) {
949   MS_EXCEPTION_IF_NULL(value_node);
950   MS_EXCEPTION_IF_NULL(node_value);
951   MS_EXCEPTION_IF_NULL(mem_manager_);
952   auto ms_context = MsContext::GetInstance();
953   MS_EXCEPTION_IF_NULL(ms_context);
954   std::vector<tensor::TensorPtr> tensors;
955   TensorValueToTensor(node_value, &tensors);
956   // Graph id should be passed to record static memory if profiling is enabled.
957   auto kernel_info = dynamic_cast<device::KernelInfo *>(value_node->kernel_info());
958   MS_EXCEPTION_IF_NULL(kernel_info);
959   uint32_t graph_id = kernel_info->graph_id();
960   for (const auto &tensor : tensors) {
961     if (tensor == nullptr) {
962       MS_LOG(WARNING) << "Tensor is null";
963       return;
964     }
965     auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
966     if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
967       AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
968                              value_node.get());
969       continue;
970     }
971     size_t tensor_size = LongToSize(tensor->data().nbytes());
972     auto node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
973     TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
974     if (output_type_id == kTypeUnknown) {
975       output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
976     }
977     auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
978     DeviceAddressPtr address =
979       CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
980     MS_EXCEPTION_IF_NULL(address);
981     if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
982         !mem_manager_->MallocMemFromMemPool(address, node_size)) {
983       MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << node_size;
984     } else {
985       MS_LOG(INFO) << "Assign Static Memory for Value node, size:" << node_size
986                    << " node:" << value_node->fullname_with_scope();
987       if (mem_manager_->MallocMem(kStaticMem, node_size, address, graph_id) == nullptr) {
988         MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
989       }
990     }
991     AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
992     if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
993                                    tensor->data_c(), tensor->device_info().host_format_)) {
994       MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
995                                    << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
996                                    << "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
997     }
998   }
999 }
1000 
AssignStaticMemoryValueNode(const session::KernelGraph & graph)1001 void KernelRuntime::AssignStaticMemoryValueNode(const session::KernelGraph &graph) {
1002   MS_EXCEPTION_IF_NULL(mem_manager_);
1003   MS_LOG(DEBUG) << "AssignStaticMemoryValueNode start for graph " << graph.graph_id();
1004   auto ms_context = MsContext::GetInstance();
1005   MS_EXCEPTION_IF_NULL(ms_context);
1006   // order the value nodes
1007   std::map<std::string, ValueNodePtr> value_nodes_map;
1008   for (auto &node : graph.graph_value_nodes()) {
1009     MS_EXCEPTION_IF_NULL(node);
1010     value_nodes_map[node->fullname_with_scope()] = node;
1011   }
1012 
1013   for (auto &item : value_nodes_map) {
1014     auto value_node = item.second;
1015     MS_EXCEPTION_IF_NULL(value_node);
1016     if (NodeOutputDeviceAddressExist(value_node, 0)) {
1017       MS_LOG(DEBUG) << "value_node[" << value_node->DebugString() << "] address already exist";
1018       auto device_address = AnfAlgo::GetMutableOutputAddr(value_node, 0);
1019       if (device_address->ptr_ == nullptr) {
1020         if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1021           if (!mem_manager_->MallocMemFromMemPool(device_address, device_address->size_)) {
1022             MS_LOG(EXCEPTION) << "MallocMemFromMemPool failed";
1023           }
1024         } else {
1025           if (mem_manager_->MallocMem(kStaticMem, device_address->size_, device_address, graph.graph_id())) {
1026             MS_LOG(EXCEPTION) << "MallocMem kStaticMem failed";
1027           }
1028         }
1029       }
1030       continue;
1031     }
1032     auto &node_value = value_node->value();
1033     MS_EXCEPTION_IF_NULL(node_value);
1034     MS_LOG(DEBUG) << "Malloc memory for " << value_node->fullname_with_scope();
1035     if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>()) {
1036       AssignValueNodeTensor(value_node, node_value, 0);
1037     } else if (node_value->isa<StringImm>()) {
1038       auto value = GetValue<std::string>(node_value);
1039       size_t tensor_size = value.size();
1040       DeviceAddressPtr address = nullptr;
1041       address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
1042       MS_EXCEPTION_IF_NULL(address);
1043       if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
1044           !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
1045         MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
1046       } else {
1047         MS_LOG(INFO) << "Assign Static Memory for Value node, size:" << tensor_size
1048                      << " node:" << value_node->fullname_with_scope();
1049         if (mem_manager_->MallocMem(kStaticMem, tensor_size, address, graph.graph_id()) == nullptr) {
1050           MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem
1051                             << ", tensor size is: " << tensor_size;
1052         }
1053       }
1054       AnfAlgo::SetOutputAddr(address, 0, value_node.get());
1055       ShapeVector shape = {1, SizeToLong(tensor_size)};
1056       if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
1057         MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
1058       }
1059     }
1060   }
1061   MS_LOG(DEBUG) << "AssignStaticMemoryValueNode end";
1062 }
1063 
AssignDynamicMemory(const session::KernelGraph & graph)1064 void KernelRuntime::AssignDynamicMemory(const session::KernelGraph &graph) {
1065   MS_EXCEPTION_IF_NULL(mem_manager_);
1066   auto context_ptr = MsContext::GetInstance();
1067   MS_EXCEPTION_IF_NULL(context_ptr);
1068   bool is_enable_mem_reuse = EnvConfigParser::GetInstance().GetSysMemreuse();
1069   auto mem_type = kDynamicMem;
1070   auto &dump_json_parser = DumpJsonParser::GetInstance();
1071   if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 0) {
1072     mindspore::EnvConfigParser::GetInstance().SetSysMemreuse(false);
1073     is_enable_mem_reuse = false;
1074     MS_LOG(INFO) << "Disable Memory Reuse when e2e dump is enable and dump mode is set to dump all kernels";
1075   }
1076 
1077   if (is_enable_mem_reuse) {
1078     MS_LOG(INFO) << "Memory Reuse is enable...";
1079     mem_manager_->MallocSomasDynamicMem(graph);
1080     mem_type = kSomasReuseDynamicMem;
1081   } else {
1082     MS_LOG(INFO) << "Memory Reuse is disable...";
1083   }
1084   auto &execution_nodes = graph.execution_order();
1085   std::vector<CNodePtr> compute_nodes;
1086   // communication nodes first
1087   for (auto &node : execution_nodes) {
1088     if (AnfAlgo::IsCommunicationOp(node)) {
1089       // skip if the memory is already allocated
1090       AssignCommunicationNodeMem(mem_type, node);
1091     } else {
1092       compute_nodes.emplace_back(node);
1093     }
1094   }
1095 
1096   // then compute nodes
1097   for (auto &node : compute_nodes) {
1098     AssignNodeOutputMem(mem_type, node, kGetAllOuts);
1099     AssignWorkSpaceMem(mem_type, node);
1100   }
1101 }
1102 
AssignWorkSpaceMem(MemType type,const AnfNodePtr & node)1103 void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
1104   MS_EXCEPTION_IF_NULL(node);
1105   MS_EXCEPTION_IF_NULL(mem_manager_);
1106   auto kernel_mod = AnfAlgo::GetKernelMod(node);
1107   MS_EXCEPTION_IF_NULL(kernel_mod);
1108   size_t index = 0;
1109   for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
1110     if (AnfAlgo::WorkspaceAddrExist(node, index)) {
1111       MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
1112       return;
1113     }
1114     auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
1115     AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
1116     index++;
1117   }
1118 }
1119 
GenLaunchArgs(const mindspore::kernel::KernelMod & kernel_mod,const mindspore::AnfNodePtr & kernel,AddressPtrList * kernel_inputs,AddressPtrList * const kernel_workspaces,AddressPtrList * kernel_outputs)1120 void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
1121                                   AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
1122                                   AddressPtrList *kernel_outputs) {
1123   MS_EXCEPTION_IF_NULL(kernel);
1124   MS_EXCEPTION_IF_NULL(kernel_inputs);
1125   MS_EXCEPTION_IF_NULL(kernel_workspaces);
1126   MS_EXCEPTION_IF_NULL(kernel_outputs);
1127   auto cnode = kernel->cast<CNodePtr>();
1128   MS_EXCEPTION_IF_NULL(cnode);
1129   if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
1130     return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
1131   }
1132   auto ms_context = MsContext::GetInstance();
1133   MS_EXCEPTION_IF_NULL(ms_context);
1134   auto visit_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
1135   size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
1136   for (size_t i = 0; i < input_num; ++i) {
1137     auto op_name = AnfAlgo::GetCNodeName(cnode);
1138     constexpr auto none_placeholder_index = 3;
1139     if (op_name == kDynamicRNNOpName && i == none_placeholder_index) {
1140       continue;
1141     }
1142     if (op_name == kDynamicGRUV2OpName) {
1143       auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
1144       auto item = std::find(none_index.begin(), none_index.end(), i);
1145       if (item != none_index.end()) {
1146         continue;
1147       }
1148     }
1149     auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
1150     auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input, visit_nop_node);
1151     MS_EXCEPTION_IF_NULL(device_address);
1152     kernel::AddressPtr input = std::make_shared<kernel::Address>();
1153     MS_EXCEPTION_IF_NULL(input);
1154     input->addr = device_address->ptr_;
1155     MS_EXCEPTION_IF_NULL(input->addr);
1156     input->size = device_address->size_;
1157     kernel_inputs->emplace_back(input);
1158   }
1159 
1160   for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
1161     auto device_address = AnfAlgo::GetOutputAddr(kernel, i, visit_nop_node);
1162     kernel::AddressPtr output = std::make_shared<kernel::Address>();
1163     MS_EXCEPTION_IF_NULL(output);
1164     output->addr = device_address->ptr_;
1165     MS_EXCEPTION_IF_NULL(output->addr);
1166     output->size = device_address->size_;
1167     kernel_outputs->emplace_back(output);
1168   }
1169 
1170   for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
1171     auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
1172     kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
1173     MS_EXCEPTION_IF_NULL(workspace);
1174     workspace->addr = device_address->ptr_;
1175     MS_EXCEPTION_IF_NULL(workspace->addr);
1176     workspace->size = device_address->size_;
1177     kernel_workspaces->emplace_back(workspace);
1178   }
1179 }
1180 
GenAddrCleanLaunchArgs(const CNodePtr & cnode,AddressPtrList * kernel_inputs,const std::shared_ptr<MemScheduler> & mem_scheduler)1181 void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs,
1182                                            const std::shared_ptr<MemScheduler> &mem_scheduler) {
1183   MS_EXCEPTION_IF_NULL(cnode);
1184   MS_EXCEPTION_IF_NULL(kernel_inputs);
1185   if (cnode->inputs().size() != kAtomicCleanInputSize) {
1186     MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
1187   }
1188   MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
1189   auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
1190   // set clean output address
1191   if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
1192 #if defined(__APPLE__)
1193     auto clean_output_indexes = AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicOutputIndexs);
1194 #else
1195     auto clean_output_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
1196 #endif
1197     for (auto index : clean_output_indexes) {
1198       auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
1199       kernel::AddressPtr input = std::make_shared<kernel::Address>();
1200       MS_EXCEPTION_IF_NULL(input);
1201       if (mem_scheduler != nullptr) {
1202         GetOrMallocAddress(mem_scheduler, device_address, input);
1203       } else {
1204         input->addr = device_address->ptr_;
1205         MS_EXCEPTION_IF_NULL(input->addr);
1206       }
1207       input->size = device_address->size_;
1208       kernel_inputs->emplace_back(input);
1209     }
1210     MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
1211   }
1212   // set clean workspace address
1213   if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
1214 #if defined(__APPLE__)
1215     auto clean_workspaces_indexes = AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicWorkspaceIndexs);
1216 #else
1217     auto clean_workspaces_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
1218 #endif
1219     for (const auto &index : clean_workspaces_indexes) {
1220       auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
1221       kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
1222       MS_EXCEPTION_IF_NULL(workspace);
1223       if (mem_scheduler != nullptr) {
1224         GetOrMallocAddress(mem_scheduler, device_address, workspace);
1225       } else {
1226         workspace->addr = device_address->ptr_;
1227         MS_EXCEPTION_IF_NULL(workspace->addr);
1228       }
1229       workspace->size = device_address->size_;
1230       kernel_inputs->emplace_back(workspace);
1231     }
1232   }
1233 }
1234 
LaunchKernelEvent(const std::vector<std::vector<std::function<void ()>>> & kernel_events,size_t index) const1235 void KernelRuntime::LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &kernel_events,
1236                                       size_t index) const {
1237   if (index >= kernel_events.size()) {
1238     return;
1239   }
1240   for (auto &event : kernel_events[index]) {
1241     event();
1242   }
1243 }
1244 
LaunchKernelWithPynativeProfiling(kernel::KernelMod * kernel_mod,const std::string & op_name,const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspace,const std::vector<AddressPtr> & outputs,void * stream)1245 bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
1246                                                       const std::vector<AddressPtr> &inputs,
1247                                                       const std::vector<AddressPtr> &workspace,
1248                                                       const std::vector<AddressPtr> &outputs, void *stream) {
1249   MS_EXCEPTION_IF_NULL(kernel_mod);
1250   MS_EXCEPTION_IF_NULL(stream);
1251   float cost_time = 0;
1252   auto start = CreateDeviceTimeEvent();
1253   auto end = CreateDeviceTimeEvent();
1254   MS_EXCEPTION_IF_NULL(start);
1255   MS_EXCEPTION_IF_NULL(end);
1256   start->set_record_stream(stream);
1257   end->set_record_stream(stream);
1258   start->RecordEvent();
1259   bool ret = kernel_mod->Launch(inputs, workspace, outputs, stream);
1260   end->RecordEvent();
1261   start->SyncEvent();
1262   end->SyncEvent();
1263   start->ElapsedTime(&cost_time, end.get());
1264   auto launch_end_time = GetTime();
1265   double launch_start_time = launch_end_time - cost_time / kBasicTimeTransferUnit;
1266   auto op_launch_start_time_end_time = std::make_pair(launch_start_time, launch_end_time);
1267   PynativeProfiler::SetDeviceOpNameAndLaunchTimePoint(std::make_pair(op_name, op_launch_start_time_end_time));
1268   PynativeProfiler::SetDeviceOpNameAndLaunchCostTime(std::make_pair(op_name, cost_time / kBasicTimeTransferUnit));
1269   if (!ret) {
1270     MS_LOG(EXCEPTION) << "Launch kernel failed, kernel name is : " << op_name;
1271   }
1272   return ret;
1273 }
1274 
DebugStreamSync(const CNodePtr & kernel)1275 void KernelRuntime::DebugStreamSync(const CNodePtr &kernel) {
1276   auto ms_context = MsContext::GetInstance();
1277   MS_EXCEPTION_IF_NULL(ms_context);
1278   auto enable_sync_run = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
1279   if (enable_sync_run) {
1280     if (!SyncStream()) {
1281       MS_LOG(EXCEPTION) << "Op " << kernel->fullname_with_scope() << " run failed!";
1282     }
1283   }
1284 }
1285 
GetOrMallocAddress(const std::shared_ptr<MemScheduler> & mem_scheduler,const DeviceAddress * device_address,const kernel::AddressPtr & kernel_addr)1286 void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
1287                                        const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr) {
1288   if (device_address->ptr_ != nullptr) {
1289     kernel_addr->addr = device_address->ptr_;
1290   } else {
1291     kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_);
1292     if (mem_scheduler->IsHighPriorityMem(device_address)) {
1293       device_address->ptr_ = kernel_addr->addr;
1294     }
1295   }
1296 }
1297 
AssignKernelAddress(const std::shared_ptr<MemScheduler> & mem_scheduler,const AnfNodePtr & kernel,AddressPtrList * kernel_inputs,AddressPtrList * kernel_workspaces,AddressPtrList * kernel_outputs)1298 void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
1299                                         AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
1300                                         AddressPtrList *kernel_outputs) {
1301   MS_EXCEPTION_IF_NULL(kernel);
1302   MS_EXCEPTION_IF_NULL(kernel_inputs);
1303   MS_EXCEPTION_IF_NULL(kernel_workspaces);
1304   MS_EXCEPTION_IF_NULL(kernel_outputs);
1305   auto cnode = kernel->cast<CNodePtr>();
1306   MS_EXCEPTION_IF_NULL(cnode);
1307   if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
1308     return GenAddrCleanLaunchArgs(cnode, kernel_inputs, mem_scheduler);
1309   }
1310   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1311   MS_EXCEPTION_IF_NULL(kernel_mod);
1312   size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
1313   for (size_t j = 0; j < input_num; ++j) {
1314     auto real_input = AnfAlgo::GetRealInputIndex(kernel, j);
1315     auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
1316     auto index = kernel_with_index.second;
1317     auto &input_node = kernel_with_index.first;
1318     auto device_address = AnfAlgo::GetOutputAddr(input_node, index, true);
1319     MS_EXCEPTION_IF_NULL(device_address);
1320     kernel::AddressPtr input = std::make_shared<kernel::Address>();
1321     GetOrMallocAddress(mem_scheduler, device_address, input);
1322     input->size = device_address->size_;
1323     kernel_inputs->emplace_back(input);
1324   }
1325 
1326   for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
1327     auto device_address = AnfAlgo::GetOutputAddr(kernel, j, true);
1328     kernel::AddressPtr output = std::make_shared<kernel::Address>();
1329     GetOrMallocAddress(mem_scheduler, device_address, output);
1330     output->size = device_address->size_;
1331     kernel_outputs->emplace_back(output);
1332   }
1333 
1334   for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
1335     auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
1336     kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
1337     GetOrMallocAddress(mem_scheduler, device_address, workspace);
1338     workspace->size = device_address->size_;
1339     kernel_workspaces->emplace_back(workspace);
1340   }
1341 }
1342 
SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> & mem_scheduler,const session::KernelGraph & graph,const AnfNodePtr & kernel,bool mock)1343 void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
1344                                           const session::KernelGraph &graph, const AnfNodePtr &kernel, bool mock) {
1345   MS_EXCEPTION_IF_NULL(mem_scheduler);
1346   MS_EXCEPTION_IF_NULL(kernel);
1347   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1348   MS_EXCEPTION_IF_NULL(kernel_mod);
1349   for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
1350     auto tensor = graph.GetNodeOutputTensor(std::make_pair(kernel, j));
1351     auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, j, true);
1352     if (mock) {
1353       if (graph.IsInternalOutput(kernel, j) && device_address != nullptr) {
1354         mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
1355       }
1356       continue;
1357     }
1358     if (tensor != nullptr) {
1359       if (device_address == nullptr) {
1360         tensor->data_sync(false);
1361         tensor->set_device_address(nullptr);
1362         tensor->set_sync_status(kNeedSyncHostToDevice);
1363         continue;
1364       }
1365       if (!SyncStream()) {
1366         MS_LOG(ERROR) << "SyncStream failed";
1367       }
1368       auto origin_ptr = device_address->ptr_;
1369       if (origin_ptr == nullptr) {
1370         device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_);
1371       }
1372       tensor->set_device_address(device_address);
1373       tensor->data_sync(false);
1374       tensor->set_device_address(nullptr);
1375       if (origin_ptr == nullptr) {
1376         device_address->ptr_ = nullptr;
1377       }
1378       tensor->set_sync_status(kNeedSyncHostToDevice);
1379     }
1380   }
1381 }
1382 
InitGraphInputTensors(const std::shared_ptr<MemScheduler> & mem_scheduler,const session::KernelGraph & graph)1383 void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
1384                                           const session::KernelGraph &graph) {
1385   MS_EXCEPTION_IF_NULL(mem_scheduler);
1386   auto &input_nodes = graph.input_nodes();
1387   auto &input_tensors = graph.input_tensors();
1388   if (input_tensors.size() != input_nodes.size()) {
1389     MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
1390   }
1391   for (size_t i = 0; i < input_tensors.size(); ++i) {
1392     auto tensor = input_tensors[i];
1393     MS_EXCEPTION_IF_NULL(tensor);
1394     auto input_node = input_nodes[i];
1395     if (!input_node->isa<Parameter>()) {
1396       continue;
1397     }
1398     if (AnfAlgo::OutputAddrExist(input_node, 0)) {
1399       auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
1400       MS_EXCEPTION_IF_NULL(tensor);
1401       MemPriority priority = kMemPriorityHigh;
1402       auto tensor_address = tensor->device_address();
1403       if (tensor_address != nullptr && tensor_address != device_address) {
1404         tensor->data_sync(false);
1405         priority = kMemPriorityLow;
1406       }
1407       auto tensor_size = LongToSize(tensor->data().nbytes());
1408       mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority);
1409     }
1410   }
1411 }
1412 
LaunchKernel(const session::KernelGraph & graph,const AnfNodePtr & kernel,const std::shared_ptr<MemScheduler> & mem_scheduler,bool mock)1413 bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel,
1414                                  const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock) {
1415   MS_EXCEPTION_IF_NULL(kernel);
1416   auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1417   MS_EXCEPTION_IF_NULL(kernel_mod);
1418   AddressPtrList kernel_inputs;
1419   AddressPtrList kernel_workspaces;
1420   AddressPtrList kernel_outputs;
1421   auto stream = kernel_mod->GetStream();
1422   if (stream == nullptr) {
1423     if (AnfAlgo::IsCommunicationOp(kernel)) {
1424       stream = communication_stream_;
1425     } else {
1426       stream = stream_;
1427     }
1428   }
1429   bool ret = true;
1430   if (mem_scheduler != nullptr) {
1431     ret = mem_scheduler->PreCompute(stream);
1432     if (!ret) {
1433       return ret;
1434     }
1435     AssignKernelAddress(mem_scheduler, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
1436   } else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
1437     kernel_inputs = kernel_mod->GetInputsAddr();
1438     kernel_outputs = kernel_mod->GetOutputsAddr();
1439     kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
1440   } else {
1441     GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
1442   }
1443   if (!mock) {
1444     if (pynative_mode_profiling_flag_) {
1445       ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
1446                                               kernel_workspaces, kernel_outputs, stream);
1447     } else {
1448       ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream);
1449     }
1450   }
1451   if (mem_scheduler != nullptr) {
1452     SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock);
1453     ret = mem_scheduler->PostCompute(stream);
1454     if (!ret) {
1455       return ret;
1456     }
1457   }
1458   return ret;
1459 }
1460 
LaunchKernelMod(const session::KernelGraph & graph,bool mock)1461 bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock) {
1462   auto context_ptr = MsContext::GetInstance();
1463   MS_EXCEPTION_IF_NULL(context_ptr);
1464   std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
1465   auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
1466   if (enable_mem_scheduler) {
1467     mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1468     MS_EXCEPTION_IF_NULL(mem_scheduler);
1469     mem_scheduler->SetMemHandler(mem_manager_);
1470     mem_scheduler->RecordMemUsage();
1471     InitGraphInputTensors(mem_scheduler, graph);
1472   }
1473   const auto &kernels = graph.execution_order();
1474   std::vector<DynamicKernelPtr> dynamic_kernel_list;
1475   auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
1476   if (iter != graph_dynamic_kernel_map_.end()) {
1477     dynamic_kernel_list = iter->second;
1478   }
1479   if (!dynamic_kernel_list.empty() && dynamic_kernel_list.size() != kernels.size()) {
1480     MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size()
1481                       << " should be equal to the size of kernels " << kernels.size();
1482   }
1483   std::vector<std::vector<std::function<void()>>> kernel_pre_run_events;
1484   std::vector<std::vector<std::function<void()>>> kernel_post_run_events;
1485   auto events_iter = graph_kernel_events_map_.find(graph.graph_id());
1486   if (events_iter != graph_kernel_events_map_.end()) {
1487     kernel_pre_run_events = events_iter->second.first;
1488     kernel_post_run_events = events_iter->second.second;
1489   }
1490   for (size_t i = 0; i < kernels.size(); ++i) {
1491     LaunchKernelEvent(kernel_pre_run_events, i);
1492     if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr &&
1493         dynamic_kernel_list[i]->is_dynamic_shape()) {
1494       dynamic_kernel_list[i]->InferShape();
1495       dynamic_kernel_list[i]->UpdateArgs();
1496       dynamic_kernel_list[i]->Execute();
1497       if (!SyncStream()) {
1498         MS_LOG(ERROR) << "SyncStream failed";
1499         return false;
1500       }
1501       dynamic_kernel_list[i]->PostExecute();
1502     } else {
1503       auto &kernel = kernels[i];
1504       MS_EXCEPTION_IF_NULL(kernel);
1505 
1506       // Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
1507       // kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata.
1508       // And hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
1509       if (AnfAlgo::HasNodeAttr("nop_op", kernel)) {
1510         for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(kernel); idx += 1) {
1511           auto real_input = AnfAlgo::GetRealInputIndex(kernel, idx);
1512           auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input);
1513           AnfAlgo::SetOutputAddr(device_address, idx, kernel.get());
1514         }
1515         continue;
1516       }
1517       auto ret = LaunchKernel(graph, kernel, mem_scheduler, mock);
1518       if (!ret) {
1519         MS_LOG(ERROR) << "Launch kernel failed.";
1520         return false;
1521       }
1522       KernelLaunchProfiling(kernel->fullname_with_scope());
1523       DebugStreamSync(kernel);
1524     }
1525     LaunchKernelEvent(kernel_post_run_events, i);
1526   }
1527   if (mem_scheduler != nullptr) {
1528     mem_scheduler->OptMemUsage();
1529   }
1530   return true;
1531 }
1532 
UseMemSchedulerIfNeeded(const session::KernelGraph & graph)1533 void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
1534   auto context_ptr = MsContext::GetInstance();
1535   MS_EXCEPTION_IF_NULL(context_ptr);
1536   auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
1537   if (enable_mem_scheduler) {
1538     auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1539     if (mem_scheduler->need_record_event()) {
1540       (void)LaunchKernelMod(graph, true);
1541     }
1542     float mem_used_factor = kMaxMemReuseFactor;
1543     while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
1544       mem_scheduler->SetMemUsedFactor(mem_used_factor);
1545       bool ret = LaunchKernelMod(graph, true);
1546       if (ret) {
1547         mem_scheduler->SetOptimized(true);
1548       } else {
1549         mem_used_factor -= kRetryFactor;
1550       }
1551     }
1552   }
1553 }
1554 
LaunchKernels(const session::KernelGraph & graph)1555 bool KernelRuntime::LaunchKernels(const session::KernelGraph &graph) {
1556   UseMemSchedulerIfNeeded(graph);
1557   if (!LaunchKernelMod(graph)) {
1558     MS_LOG(ERROR) << "LaunchKernelMod failed!";
1559     return false;
1560   }
1561   auto ms_context = MsContext::GetInstance();
1562   MS_EXCEPTION_IF_NULL(ms_context);
1563   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
1564     if (!SyncStream()) {
1565       MS_LOG(ERROR) << "SyncStream failed";
1566       return false;
1567     }
1568   }
1569   return true;
1570 }
1571 
ClearGraphRuntimeResource(uint32_t graph_id)1572 void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
1573   MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
1574 }
1575 
1576 #if ((defined ENABLE_CPU) && (!defined _WIN32))
GetFirstPSEmbeddingCache(const session::KernelGraph & graph,AnfNodePtr * const first_cache_input_index,size_t * const first_cache_size)1577 void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph &graph,
1578                                              AnfNodePtr *const first_cache_input_index,
1579                                              size_t *const first_cache_size) {
1580   for (const auto &kernel : graph.execution_order()) {
1581     MS_EXCEPTION_IF_NULL(kernel);
1582     auto kernel_name = AnfAlgo::GetCNodeName(kernel);
1583     if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
1584       continue;
1585     }
1586     auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
1587     auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
1588     MS_EXCEPTION_IF_NULL(input_param.first);
1589     MS_EXCEPTION_IF_NULL(input_index.first);
1590     auto param_name = input_param.first->fullname_with_scope();
1591     if (!ps::ps_cache_instance.IsHashTable(param_name)) {
1592       continue;
1593     }
1594     auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
1595     while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
1596       input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
1597       MS_EXCEPTION_IF_NULL(input_index.first);
1598     }
1599     auto cnode =
1600       AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
1601     MS_EXCEPTION_IF_NULL(cnode);
1602     if (!cnode->isa<CNode>()) {
1603       MS_LOG(EXCEPTION) << "The embeddingLookup whose input index should be a CNode but got "
1604                         << cnode->fullname_with_scope();
1605     }
1606     auto input_index_node_name = AnfAlgo::GetCNodeName(cnode);
1607     if (input_index_node_name != kGetNextOpName) {
1608       bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
1609       if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
1610           (full_batch && (input_index_node_name != kMinimumOpName))) {
1611         MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
1612                       << ") cache is from " << cnode->fullname_with_scope();
1613         MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
1614                              "parameter server training mode.";
1615       }
1616     }
1617     *first_cache_input_index = cnode;
1618     *first_cache_size = size;
1619     MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from " << cnode->fullname_with_scope()
1620                  << ", the cache size is " << size;
1621     return;
1622   }
1623 }
1624 
CheckSparsePSEmbeddingCache(const CNodePtr & node)1625 void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) {
1626   MS_EXCEPTION_IF_NULL(node);
1627   auto pre_node = AnfAlgo::GetPrevNodeOutput(node, 1, true);
1628   MS_EXCEPTION_IF_NULL(pre_node.first);
1629   while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
1630     pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
1631     MS_EXCEPTION_IF_NULL(pre_node.first);
1632   }
1633   if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
1634     MS_LOG(EXCEPTION) << "The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode";
1635   }
1636 
1637   pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
1638   MS_EXCEPTION_IF_NULL(pre_node.first);
1639   while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) {
1640     pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
1641     MS_EXCEPTION_IF_NULL(pre_node.first);
1642   }
1643   if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kGetNextOpName)) {
1644     MS_LOG(EXCEPTION) << "The input indices of kernel[Unique] must be produced from dataset directly and the indices "
1645                          "value can not be changed before delivering to kernel[Unique] in parameter server cache mode.";
1646   }
1647 }
1648 
CheckIfSupportPSEmbeddingCache(const session::KernelGraph & graph)1649 void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph &graph) {
1650   AnfNodePtr first_cache_input_index = nullptr;
1651   size_t first_cache_size = 0;
1652   GetFirstPSEmbeddingCache(graph, &first_cache_input_index, &first_cache_size);
1653   MS_EXCEPTION_IF_NULL(first_cache_input_index);
1654   for (const auto &kernel : graph.execution_order()) {
1655     MS_EXCEPTION_IF_NULL(kernel);
1656     auto kernel_name = AnfAlgo::GetCNodeName(kernel);
1657     if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
1658       continue;
1659     }
1660     auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
1661     auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
1662     MS_EXCEPTION_IF_NULL(input_param.first);
1663     MS_EXCEPTION_IF_NULL(input_index.first);
1664     if (!input_param.first->isa<Parameter>()) {
1665       continue;
1666     }
1667     auto param_name = input_param.first->fullname_with_scope();
1668     if (ps::ps_cache_instance.IsHashTable(param_name) && (kernel_name == kSparseGatherV2OpName)) {
1669       CheckSparsePSEmbeddingCache(kernel);
1670     }
1671     while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
1672       input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
1673       MS_EXCEPTION_IF_NULL(input_index.first);
1674     }
1675     auto cnode =
1676       AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
1677     MS_EXCEPTION_IF_NULL(cnode);
1678     if (cnode == first_cache_input_index) {
1679       if (!ps::ps_cache_instance.IsHashTable(param_name)) {
1680         MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
1681         MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the "
1682                              "same time when one of them enables cache in parameter server training mode.";
1683       }
1684       auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
1685       if (size != first_cache_size) {
1686         MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope()
1687                       << ") is not the same as other embeddingLookup cache size(" << first_cache_size << ").";
1688         MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode.";
1689       }
1690     } else if (ps::ps_cache_instance.IsHashTable(param_name)) {
1691       MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
1692                     << cnode->fullname_with_scope();
1693       MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
1694                            "parameter server training mode.";
1695     } else if (cnode->isa<CNode>() && (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName)) {
1696       MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
1697       MS_LOG(EXCEPTION) << "All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
1698                            "the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "
1699                            "context setting in parameter server training mode.";
1700     }
1701   }
1702 }
1703 #endif
1704 }  // namespace device
1705 }  // namespace mindspore
1706