1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/kernel_runtime.h"
18 #include <functional>
19 #include <utility>
20 #include <vector>
21 #include <set>
22 #include "backend/optimizer/common/helper.h"
23 #include "backend/session/anf_runtime_algorithm.h"
24 #include "backend/session/kernel_graph.h"
25 #include "common/trans.h"
26 #include "debug/data_dump/dump_json_parser.h"
27 #include "frontend/operator/ops.h"
28 #include "ir/value.h"
29 #include "utils/ms_context.h"
30 #include "utils/ms_utils.h"
31 #include "utils/shape_utils.h"
32 #include "utils/utils.h"
33 #include "frontend/parallel/context.h"
34 #include "debug/env_config_parser.h"
35 #include "pipeline/pynative/pynative_profiling.h"
36 #if ((defined ENABLE_CPU) && (!defined _WIN32))
37 #include "ps/ps_cache/ps_cache_manager.h"
38 #endif
39
40 using mindspore::kernel::Address;
41 using mindspore::kernel::AddressPtr;
42
43 namespace mindspore {
44 namespace device {
45 constexpr float kMaxMemReuseFactor = 0.8;
46 constexpr float kMinMemReuseFactor = 0.5;
47 constexpr float kRetryFactor = 0.1;
48 constexpr size_t kAtomicCleanInputSize = 2;
49 namespace {
GetGraphInputs(const session::KernelGraph & graph)50 std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
51 auto graph_inputs = graph.inputs();
52 std::vector<AnfNodePtr> result(graph_inputs.begin(), graph_inputs.end());
53 std::set<AnfNodePtr> inputs_set(graph_inputs.begin(), graph_inputs.end());
54 auto kernels = graph.execution_order();
55 for (auto &kernel : kernels) {
56 MS_EXCEPTION_IF_NULL(kernel);
57 auto input_num = AnfAlgo::GetInputTensorNum(kernel);
58 for (size_t i = 0; i < input_num; ++i) {
59 auto input_node = kernel->input(i + 1);
60 auto input_real_node = AnfAlgo::VisitKernelWithReturnType(input_node, 0).first;
61 MS_EXCEPTION_IF_NULL(input_real_node);
62 if (input_real_node->isa<Parameter>() && inputs_set.find(input_real_node) == inputs_set.end()) {
63 (void)inputs_set.insert(input_real_node);
64 (void)result.emplace_back(input_real_node);
65 }
66 }
67 }
68 return result;
69 }
70 } // namespace
71 constexpr size_t kMinInputSize = 2;
~KernelRuntime()72 KernelRuntime::~KernelRuntime() {
73 stream_ = nullptr;
74 independent_stream_ = nullptr;
75 communication_stream_ = nullptr;
76 }
77
Load(const session::KernelGraph &,bool)78 bool KernelRuntime::Load(const session::KernelGraph &, bool) {
79 MS_LOG(INFO) << "Call default load.";
80 return true;
81 }
82
LoadData(const session::KernelGraph &)83 bool KernelRuntime::LoadData(const session::KernelGraph &) {
84 MS_LOG(INFO) << "Call default load data.";
85 return false;
86 }
87
NodeOutputDeviceAddressExist(const AnfNodePtr & kernel,size_t index)88 bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
89 MS_EXCEPTION_IF_NULL(kernel);
90 if (AnfAlgo::OutputAddrExist(kernel, index)) {
91 const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
92 MS_EXCEPTION_IF_NULL(address);
93 return address->DeviceType() == GetTargetDeviceAddressType();
94 }
95 return false;
96 }
97
AssignMemory(const session::KernelGraph & graph)98 void KernelRuntime::AssignMemory(const session::KernelGraph &graph) {
99 auto context_ptr = MsContext::GetInstance();
100 MS_EXCEPTION_IF_NULL(context_ptr);
101 auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
102 if (enable_mem_scheduler) {
103 AssignStaticMemoryValueNode(graph);
104 ResetNodeAddress(graph);
105 } else {
106 MS_EXCEPTION_IF_NULL(mem_manager_);
107 mem_manager_->ResetDynamicMemory();
108 AssignStaticMemory(graph);
109 AssignDynamicMemory(graph);
110 }
111 UpdateRefNodeOutputMem(graph);
112 }
113
RunOpGetCommunicationInputInfo(const AnfNodePtr & node,size_t * total_size,std::vector<DeviceAddressPtr> * address_list,std::vector<size_t> * align_size_list) const114 void KernelRuntime::RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size,
115 std::vector<DeviceAddressPtr> *address_list,
116 std::vector<size_t> *align_size_list) const {
117 MS_EXCEPTION_IF_NULL(node);
118 MS_EXCEPTION_IF_NULL(total_size);
119 MS_EXCEPTION_IF_NULL(address_list);
120 MS_EXCEPTION_IF_NULL(align_size_list);
121 size_t input_num = AnfAlgo::GetInputTensorNum(node);
122 for (size_t i = 0; i < input_num; ++i) {
123 auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i, true);
124 auto input_node = input_node_with_index.first;
125 MS_EXCEPTION_IF_NULL(input_node);
126 DeviceAddressPtr address = nullptr;
127 if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
128 address = AnfAlgo::GetMutableOutputAddr(input_node, input_node_with_index.second);
129 } else {
130 if (input_node->isa<CNode>()) {
131 address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
132 } else {
133 MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
134 }
135 }
136 MS_EXCEPTION_IF_NULL(address);
137 auto align_size = MemoryManager::GetCommonAlignSize(address->size());
138 *total_size += align_size;
139 address_list->emplace_back(address);
140 align_size_list->emplace_back(align_size);
141 }
142 }
143
RunOpAssignCommunicationInput(const AnfNodePtr & node) const144 void KernelRuntime::RunOpAssignCommunicationInput(const AnfNodePtr &node) const {
145 if (!AnfAlgo::IsCommunicationOp(node)) {
146 return;
147 }
148 MS_EXCEPTION_IF_NULL(node);
149 MS_EXCEPTION_IF_NULL(mem_manager_);
150 size_t total_size = 0;
151 std::vector<DeviceAddressPtr> address_list;
152 std::vector<size_t> align_size_list;
153 RunOpGetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list);
154 if (address_list.empty()) {
155 return;
156 }
157
158 auto cnode = node->cast<CNodePtr>();
159 MS_EXCEPTION_IF_NULL(cnode);
160 if (cnode->inputs().size() < kMinInputSize) {
161 MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
162 return;
163 }
164
165 if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list)) {
166 MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
167 }
168 }
169
RunOpGetCommunicationOutputInfo(const AnfNodePtr & node,size_t * total_size,std::vector<size_t> * align_size_list,std::vector<DeviceAddressPtr> * device_address_list) const170 void KernelRuntime::RunOpGetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size,
171 std::vector<size_t> *align_size_list,
172 std::vector<DeviceAddressPtr> *device_address_list) const {
173 MS_EXCEPTION_IF_NULL(node);
174 MS_EXCEPTION_IF_NULL(total_size);
175 MS_EXCEPTION_IF_NULL(align_size_list);
176 MS_EXCEPTION_IF_NULL(device_address_list);
177 auto runtime_info = node->user_data<session::OpRuntimeInfo>();
178 auto output_num = AnfAlgo::GetOutputTensorNum(node);
179 for (size_t i = 0; i < output_num; ++i) {
180 MS_EXCEPTION_IF_NULL(runtime_info);
181 DeviceAddressPtr address = nullptr;
182 if (AnfAlgo::OutputAddrExist(node, i)) {
183 address = AnfAlgo::GetMutableOutputAddr(node, i);
184 } else {
185 std::string output_format = runtime_info->output_format(i);
186 auto output_type = runtime_info->output_type(i);
187 address =
188 CreateDeviceAddress(nullptr, runtime_info->output_tensor_size(i), output_format, output_type, {node, i});
189 }
190 MS_EXCEPTION_IF_NULL(address);
191 auto align_size = MemoryManager::GetCommonAlignSize(address->size());
192 *total_size += align_size;
193 align_size_list->emplace_back(align_size);
194 device_address_list->emplace_back(address);
195 }
196 }
197
RunOpAssignCommunicationOutput(const AnfNodePtr & node) const198 void KernelRuntime::RunOpAssignCommunicationOutput(const AnfNodePtr &node) const {
199 if (!AnfAlgo::IsCommunicationOp(node)) {
200 return;
201 }
202
203 MS_EXCEPTION_IF_NULL(node);
204 MS_EXCEPTION_IF_NULL(mem_manager_);
205
206 size_t total_size = 0;
207 std::vector<size_t> align_size_list;
208 std::vector<DeviceAddressPtr> device_address_list;
209 RunOpGetCommunicationOutputInfo(node, &total_size, &align_size_list, &device_address_list);
210
211 if (align_size_list.empty()) {
212 return;
213 }
214
215 if (!mem_manager_->MallocContinuousMemFromMemPool(device_address_list, total_size, align_size_list)) {
216 MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
217 }
218 }
219
RunOpMallocPre(const session::KernelGraph & graph,const std::vector<tensor::TensorPtr> & input_tensors)220 void KernelRuntime::RunOpMallocPre(const session::KernelGraph &graph,
221 const std::vector<tensor::TensorPtr> &input_tensors) {
222 const auto &nodes = graph.execution_order();
223 // Malloc for Node output
224 for (const auto &node : nodes) {
225 auto output_num = AnfAlgo::GetOutputTensorNum(node);
226 for (size_t i = 0; i < output_num; ++i) {
227 MS_EXCEPTION_IF_NULL(node);
228 auto runtime_info = node->user_data<session::OpRuntimeInfo>();
229 MS_EXCEPTION_IF_NULL(runtime_info);
230 auto const &output_format = runtime_info->output_format(i);
231 auto output_type = runtime_info->output_type(i);
232 auto tensor_size = runtime_info->output_tensor_size(i);
233 // Create DeviceAddress without ptr.
234 // Get real device ptr after KernelBuild finish.
235 auto device_address = CreateDeviceAddress(nullptr, tensor_size, output_format, output_type);
236 device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
237 AnfAlgo::SetOutputAddr(device_address, i, node.get());
238 }
239 }
240
241 // Malloc for graph input
242 if (input_tensors.size() != graph.inputs().size()) {
243 MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
244 << " should be equal to graph input parameter size " << graph.inputs().size();
245 }
246 for (size_t input_index = 0; input_index < graph.inputs().size(); ++input_index) {
247 auto item = graph.inputs()[input_index];
248 MS_EXCEPTION_IF_NULL(item);
249 if (!item->isa<Parameter>()) {
250 continue;
251 }
252 auto output_size = AnfAlgo::GetOutputTensorNum(item);
253 for (size_t index = 0; index < output_size; index++) {
254 auto current_tensor = input_tensors[input_index];
255 MS_EXCEPTION_IF_NULL(current_tensor);
256 auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
257 if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
258 AnfAlgo::SetOutputAddr(output_address, index, item.get());
259 continue;
260 }
261 auto op_runtime_info = item->user_data<session::OpRuntimeInfo>();
262 MS_EXCEPTION_IF_NULL(op_runtime_info);
263 TypeId output_type_id = op_runtime_info->output_type(index);
264 auto output_tensor_size = op_runtime_info->output_tensor_size(index);
265 auto output_format = op_runtime_info->output_format(index);
266 auto device_address =
267 CreateDeviceAddress(nullptr, output_tensor_size, output_format, output_type_id, {item, index});
268 AnfAlgo::SetOutputAddr(device_address, index, item.get());
269 current_tensor->set_device_address(device_address);
270 current_tensor->set_sync_status(kNeedSyncHostToDevice);
271 }
272 }
273 }
274
ResetNodeAddress(const session::KernelGraph & kernel_graph)275 void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
276 auto kernels = kernel_graph.execution_order();
277 for (auto &kernel : kernels) {
278 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
279 MS_EXCEPTION_IF_NULL(kernel_mod);
280 size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
281 for (size_t j = 0; j < input_num; ++j) {
282 auto input_index = AnfAlgo::GetRealInputIndex(kernel, j);
283 KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, input_index, true);
284 auto index = kernel_with_index.second;
285 auto &input_node = kernel_with_index.first;
286 if (NodeOutputDeviceAddressExist(input_node, index)) {
287 continue;
288 }
289 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, index);
290 if (output_type_id == kTypeUnknown) {
291 MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
292 continue;
293 }
294 auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
295 auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
296 output_type_id, {input_node, index});
297 AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
298 }
299
300 auto output_sizes = kernel_mod->GetOutputSizeList();
301 for (size_t i = 0; i < output_sizes.size(); ++i) {
302 auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
303 auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
304 AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i,
305 kernel.get());
306 }
307 auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
308 for (size_t i = 0; i < workspace_sizes.size(); ++i) {
309 AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32),
310 i, kernel.get());
311 }
312 }
313 }
314
RunOpAssignMemory(const std::vector<tensor::TensorPtr> & input_tensors,const session::KernelGraph & graph,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node)315 void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
316 const session::KernelGraph &graph,
317 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
318 MS_EXCEPTION_IF_NULL(mem_manager_);
319 mem_manager_->ResetDynamicMemory();
320
321 for (const auto &node : graph.execution_order()) {
322 RunOpAssignCommunicationOutput(node);
323 RunOpAssignCommunicationInput(node);
324 }
325
326 RunOpAssignInputMemory(input_tensors, graph);
327 AssignStaticMemoryValueNode(graph);
328 for (const auto &node : graph.execution_order()) {
329 RunOpAssignOutputMemory(node, tensor_to_node);
330 RunOpAssignWorkSpaceMemory(node);
331 }
332 UpdateRefNodeOutputMem(graph);
333 }
334
RunOpClearMemory(const session::KernelGraph & graph) const335 void KernelRuntime::RunOpClearMemory(const session::KernelGraph &graph) const {
336 // clear input parameter memory resource
337 for (const auto &input_node : graph.inputs()) {
338 MS_EXCEPTION_IF_NULL(input_node);
339 AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
340 }
341 // clear input value node memory resource
342 for (const auto &value_node : graph.graph_value_nodes()) {
343 MS_EXCEPTION_IF_NULL(value_node);
344 AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
345 }
346 for (const auto &cnode : graph.execution_order()) {
347 MS_EXCEPTION_IF_NULL(cnode);
348 // clear output memory resource
349 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
350 for (size_t index = 0; index < output_num; ++index) {
351 AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
352 }
353 // clear workspace memory resource
354 auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
355 MS_EXCEPTION_IF_NULL(kernel_mod);
356 auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
357 for (size_t index = 0; index < workspace_lists.size(); ++index) {
358 AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
359 }
360 }
361 }
362
363 #ifdef ENABLE_DEBUGGER
DumpDataEnabled()364 bool KernelRuntime::DumpDataEnabled() {
365 auto &dump_json_parser = DumpJsonParser::GetInstance();
366 return dump_json_parser.e2e_dump_enabled();
367 }
368
DumpDataEnabledIteration()369 bool KernelRuntime::DumpDataEnabledIteration() {
370 auto &dump_json_parser = DumpJsonParser::GetInstance();
371 if (!dump_json_parser.e2e_dump_enabled()) {
372 return false;
373 }
374
375 auto cur_iter = dump_json_parser.cur_dump_iter();
376 if (dump_json_parser.IsDumpIter(cur_iter)) {
377 return true;
378 }
379 return false;
380 }
381 #endif
382
AssignStaticMemory(const session::KernelGraph & graph)383 void KernelRuntime::AssignStaticMemory(const session::KernelGraph &graph) {
384 AssignStaticMemoryInput(graph);
385 AssignStaticMemoryValueNode(graph);
386 AssignStaticMemoryOutput(graph);
387 }
388
RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> & input_tensors,const session::KernelGraph & graph)389 void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
390 const session::KernelGraph &graph) {
391 MS_EXCEPTION_IF_NULL(mem_manager_);
392 if (input_tensors.size() != graph.inputs().size()) {
393 MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
394 << " should be equal to graph input parameter size " << graph.inputs().size();
395 }
396
397 for (size_t input_index = 0; input_index < graph.inputs().size(); ++input_index) {
398 auto item = graph.inputs()[input_index];
399 MS_EXCEPTION_IF_NULL(item);
400 if (!item->isa<Parameter>()) {
401 continue;
402 }
403 auto output_size = AnfAlgo::GetOutputTensorNum(item);
404 for (size_t index = 0; index < output_size; index++) {
405 auto current_tensor = input_tensors[input_index];
406 MS_EXCEPTION_IF_NULL(current_tensor);
407 auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
408 if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
409 if (output_address->ptr_ == nullptr) {
410 if (!mem_manager_->MallocMemFromMemPool(output_address, output_address->size())) {
411 MS_LOG(EXCEPTION) << "Allocate memory failed, size:" << output_address->size();
412 }
413 }
414
415 AnfAlgo::SetOutputAddr(output_address, index, item.get());
416 continue;
417 }
418 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
419 if (output_type_id == kTypeUnknown) {
420 output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
421 }
422 auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
423 auto device_address =
424 CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
425 MS_EXCEPTION_IF_NULL(device_address);
426 MS_EXCEPTION_IF_NULL(mem_manager_);
427 auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
428 if (!ret) {
429 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
430 }
431 AnfAlgo::SetOutputAddr(device_address, index, item.get());
432 }
433 }
434 }
435
RunOpAssignOutputMemory(const AnfNodePtr & kernel,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node)436 void KernelRuntime::RunOpAssignOutputMemory(
437 const AnfNodePtr &kernel, const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
438 MS_EXCEPTION_IF_NULL(kernel);
439 MS_EXCEPTION_IF_NULL(mem_manager_);
440 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
441 MS_EXCEPTION_IF_NULL(kernel_mod);
442 auto output_sizes = kernel_mod->GetOutputSizeList();
443 if (output_sizes.empty()) {
444 return;
445 }
446
447 // Use device_address Allocated in RunOpMallocPre.
448 for (auto &iter : tensor_to_node) {
449 auto device_address = iter.first->device_address();
450 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(device_address), iter.second.second,
451 iter.second.first.get());
452 }
453
454 for (size_t i = 0; i < output_sizes.size(); ++i) {
455 if (AnfAlgo::OutputAddrExist(kernel, i, false)) {
456 auto address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
457 MS_EXCEPTION_IF_NULL(address);
458 if (address->ptr() == nullptr) {
459 MS_EXCEPTION_IF_NULL(mem_manager_);
460 if (!mem_manager_->MallocMemFromMemPool(address, address->size())) {
461 MS_LOG(EXCEPTION) << "Allocate memory failed, size:" << address->size();
462 }
463 }
464 continue;
465 }
466 if (AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName) {
467 auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
468 AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
469 continue;
470 }
471 std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
472 auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
473 auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
474 device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
475 MS_EXCEPTION_IF_NULL(device_address);
476 auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
477 if (!ret) {
478 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i];
479 }
480 AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
481 }
482 }
483
RunOpAssignWorkSpaceMemory(const AnfNodePtr & kernel)484 void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
485 MS_EXCEPTION_IF_NULL(kernel);
486 MS_EXCEPTION_IF_NULL(mem_manager_);
487 if (kernel->isa<CNode>()) {
488 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
489 MS_EXCEPTION_IF_NULL(kernel_mod);
490 auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
491 for (size_t i = 0; i < workspace_lists.size(); ++i) {
492 auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
493 MS_EXCEPTION_IF_NULL(device_address);
494 auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
495 if (!ret) {
496 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << workspace_lists[i];
497 }
498 AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
499 }
500 }
501 }
502
RunOpAssignOutputNodeMemory(const ValuePtr & pre_output_value,const session::KernelGraph & graph)503 void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph) {
504 if (pre_output_value == nullptr) {
505 return;
506 }
507 std::vector<tensor::TensorPtr> pre_output_tensors;
508 TensorValueToTensor(pre_output_value, &pre_output_tensors);
509 auto output_nodes = graph.outputs();
510 if (pre_output_tensors.size() != output_nodes.size()) {
511 MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
512 << "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
513 }
514 // share output address with pre output tensors
515 for (size_t i = 0; i < output_nodes.size(); ++i) {
516 auto output_node_with_index = AnfAlgo::VisitKernel(output_nodes[i], 0);
517 auto output_node = output_node_with_index.first;
518 MS_EXCEPTION_IF_NULL(output_node);
519 if (!output_node->isa<CNode>()) {
520 if (output_node->isa<Parameter>()) {
521 auto param = output_node->cast<ParameterPtr>();
522 if (param != nullptr && !param->has_default()) {
523 MS_LOG(EXCEPTION) << "The output parameter should be real parameter!";
524 }
525 }
526 continue;
527 }
528 auto real_output_cnode = output_node->cast<CNodePtr>();
529 MS_EXCEPTION_IF_NULL(real_output_cnode);
530 MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
531 if (pre_output_tensors[i]->device_address() == nullptr) {
532 MS_LOG(INFO) << "The address of pre output tensor [" << i << "] is a nullptr!";
533 continue;
534 }
535 if (opt::IsNopNode(real_output_cnode)) {
536 if (real_output_cnode->inputs().size() < kMinInputSize) {
537 MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
538 << " should large than one!";
539 }
540 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
541 output_node_with_index.second, real_output_cnode->input(1).get());
542 } else {
543 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
544 output_node_with_index.second, output_node_with_index.first.get());
545 }
546 }
547 }
548
AssignStaticMemoryInput(const session::KernelGraph & graph)549 void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
550 MS_EXCEPTION_IF_NULL(mem_manager_);
551 MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph.graph_id();
552 auto graph_inputs = GetGraphInputs(graph);
553 auto graph_valid_input = graph.valid_inputs();
554 graph_inputs.insert(graph_inputs.end(), graph.child_graph_result().begin(), graph.child_graph_result().end());
555 std::vector<AnfNodePtr> need_alloc_nodes;
556 auto add_need_alloc_nodes = [&need_alloc_nodes, graph, this](const AnfNodePtr &node) {
557 MS_EXCEPTION_IF_NULL(node);
558 if (!node->isa<Parameter>()) {
559 return;
560 }
561 if (NodeOutputDeviceAddressExist(node, 0)) {
562 return;
563 }
564 auto input_param = node->cast<ParameterPtr>();
565 if (input_param != nullptr && !input_param->IsUsedByRealKernelInGraph(graph.graph_id())) {
566 return;
567 }
568 need_alloc_nodes.push_back(node);
569 };
570
571 for (size_t i = 0; i < graph_inputs.size(); ++i) {
572 auto input_node = graph_inputs[i];
573 MS_EXCEPTION_IF_NULL(input_node);
574 if (i < graph_valid_input.size() && !graph_valid_input[i]) {
575 continue;
576 }
577 if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
578 auto outs = AnfAlgo::GetAllOutput(input_node);
579 for (auto &out : outs) {
580 MS_EXCEPTION_IF_NULL(out);
581 add_need_alloc_nodes(out);
582 }
583 }
584 add_need_alloc_nodes(input_node);
585 }
586 #if ((defined ENABLE_CPU) && (!defined _WIN32))
587 bool ps_cache_check = false;
588 #endif
589 for (auto &item : need_alloc_nodes) {
590 MS_EXCEPTION_IF_NULL(item);
591 auto output_size = AnfAlgo::GetOutputTensorNum(item);
592 for (size_t index = 0; index < output_size; index++) {
593 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
594 // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
595 if (output_type_id == kTypeUnknown) {
596 MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
597 continue;
598 }
599 DeviceAddressPtr device_address = nullptr;
600 #if ((defined ENABLE_CPU) && (!defined _WIN32))
601 const std::string ¶m_name = item->fullname_with_scope();
602 if (ps::ps_cache_instance.IsHashTable(param_name)) {
603 MS_LOG(INFO) << "Parameter(" << param_name << ")"
604 << " enables the embeddingLookup cache in parameter server training mode.";
605 // PS embeddingLookup cache check.
606 if (!ps_cache_check) {
607 CheckIfSupportPSEmbeddingCache(graph);
608 ps_cache_check = true;
609 }
610 const auto &address = ps::ps_cache_instance.QueryHashTableAddr(param_name);
611 MS_EXCEPTION_IF_NULL(address.addr);
612 device_address = CreateDeviceAddress(address.addr, address.size, AnfAlgo::GetOutputFormat(item, index),
613 output_type_id, {item, index});
614 AnfAlgo::SetOutputAddr(device_address, index, item.get());
615 continue;
616 }
617 #endif
618 auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
619 device_address =
620 CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
621 MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
622 << " node:" << item->fullname_with_scope() << " index: " << index;
623 if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph.graph_id()) == nullptr) {
624 MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
625 }
626 AnfAlgo::SetOutputAddr(device_address, index, item.get());
627 }
628 }
629 MS_LOG(INFO) << "AssignStaticMemoryInput end";
630 }
631
AssignStaticMemoryOutput(const session::KernelGraph & graph)632 void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph) {
633 MS_LOG(INFO) << "AssignStaticMemoryOutput start for graph " << graph.graph_id();
634 auto nodes = AnfAlgo::GetAllOutput(graph.output(), {prim::kPrimTupleGetItem});
635 std::vector<session::KernelWithIndex> non_communication_op;
636 // Assign Communicate Op Memory firstly.
637 for (const auto &node : nodes) {
638 auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
639 MS_EXCEPTION_IF_NULL(kernel_with_index.first);
640 if (!kernel_with_index.first->isa<CNode>() || !AnfAlgo::IsRealKernel(kernel_with_index.first)) {
641 continue;
642 }
643 if (AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {
644 AssignCommunicationNodeMem(kStaticMem, kernel_with_index.first);
645 } else {
646 non_communication_op.emplace_back(kernel_with_index);
647 }
648 }
649
650 for (const auto &item_with_index : non_communication_op) {
651 MS_EXCEPTION_IF_NULL(item_with_index.first);
652 MS_LOG(DEBUG) << "AssignNodeOutputMem for " << item_with_index.first->fullname_with_scope();
653 AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
654 }
655 MS_LOG(INFO) << "AssignStaticMemoryOutput end";
656 }
657
UpdateRefNodeOutputMem(const session::KernelGraph & graph)658 void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph &graph) {
659 auto &kernels = graph.execution_order();
660 for (auto &kernel : kernels) {
661 MS_EXCEPTION_IF_NULL(kernel);
662 auto output_num = AnfAlgo::GetOutputTensorNum(kernel);
663 if (output_num == 0) {
664 MS_LOG(DEBUG) << "This kernel has no output size.";
665 continue;
666 }
667 for (size_t i = 0; i < output_num; ++i) {
668 session::AnfWithOutIndex out_pair(kernel, i);
669 if (graph.IsInRefOutputMap(out_pair)) {
670 auto origin_pair = graph.GetRefCorrespondOutput(out_pair);
671 MS_EXCEPTION_IF_NULL(origin_pair.first);
672 auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
673 MS_EXCEPTION_IF_NULL(origin_node_output_addr);
674 auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
675 if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
676 MS_LOG(DEBUG) << "REF address is not same, ref node output need address update";
677 MS_LOG(DEBUG) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
678 << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
679 AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
680 }
681 }
682 }
683 }
684 }
685
AssignCommunicationNodeMem(MemType type,const AnfNodePtr & node)686 void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
687 AssignCommunicationNodeInputMem(type, node);
688 AssignCommunicationNodeOutputMem(type, node);
689 AssignWorkSpaceMem(type, node);
690 }
691
GenKernelEvents(const session::KernelGraph & graph)692 void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
693 auto &kernels = graph.execution_order();
694 if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
695 return;
696 }
697 auto kernel_events =
698 std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>();
699 auto &kernel_pre_run_events = kernel_events.first;
700 auto &kernel_post_run_events = kernel_events.second;
701 kernel_pre_run_events.resize(kernels.size());
702 kernel_post_run_events.resize(kernels.size());
703 for (size_t i = 0; i < kernels.size(); ++i) {
704 auto &kernel = kernels[i];
705 if (!AnfAlgo::IsCommunicationOp(kernel)) {
706 continue;
707 }
708 auto pre_event = CreateDeviceEvent();
709 auto post_event = CreateDeviceEvent();
710 MS_EXCEPTION_IF_NULL(pre_event);
711 MS_EXCEPTION_IF_NULL(post_event);
712 pre_event->set_wait_stream(communication_stream_);
713 pre_event->set_record_stream(stream_);
714 post_event->set_wait_stream(stream_);
715 post_event->set_record_stream(communication_stream_);
716 kernel_pre_run_events[i].emplace_back([pre_event]() {
717 pre_event->RecordEvent();
718 pre_event->WaitEvent();
719 });
720 kernel_post_run_events[i].emplace_back([post_event]() { post_event->RecordEvent(); });
721 bool found_nearest_child = false;
722 for (size_t j = i + 1; j < kernels.size(); ++j) {
723 auto &child = kernels[j];
724 MS_EXCEPTION_IF_NULL(child);
725 if (AnfAlgo::IsCommunicationOp(child)) {
726 continue;
727 }
728 auto input_size = child->inputs().size() - 1;
729 for (size_t k = 0; k < input_size; ++k) {
730 auto kernel_index = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(child, k), 0, true);
731 if (kernel_index.first == kernel) {
732 found_nearest_child = true;
733 break;
734 }
735 }
736 if (found_nearest_child) {
737 kernel_pre_run_events[j].emplace_back([post_event]() { post_event->WaitEvent(); });
738 break;
739 }
740 }
741 if (!found_nearest_child) {
742 kernel_post_run_events[i].emplace_back([post_event]() { post_event->WaitEvent(); });
743 }
744 }
745 graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
746 }
747
AssignCommunicationNodeOutputMem(MemType type,const AnfNodePtr & node)748 void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
749 MS_EXCEPTION_IF_NULL(node);
750 MS_EXCEPTION_IF_NULL(mem_manager_);
751 auto kernel_mod = AnfAlgo::GetKernelMod(node);
752 MS_EXCEPTION_IF_NULL(kernel_mod);
753 auto output_sizes = kernel_mod->GetOutputSizeList();
754 if (output_sizes.empty()) {
755 MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
756 return;
757 }
758 auto context_ptr = MsContext::GetInstance();
759 MS_EXCEPTION_IF_NULL(context_ptr);
760 size_t total_size = 0;
761 size_t output_index = 0;
762 std::vector<size_t> align_size_list;
763 for (uint64_t mem_size : output_sizes) {
764 if (AnfAlgo::OutputAddrExist(node, output_index++)) {
765 MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
766 return;
767 }
768 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
769 mem_size = MemoryManager::GetCommonAlignSize(mem_size);
770 }
771 total_size += mem_size;
772 align_size_list.emplace_back(mem_size);
773 }
774
775 if (align_size_list.empty()) {
776 return;
777 }
778
779 if (type == kSomasReuseDynamicMem) {
780 bool not_reuse = KernelMemNotReuse(node);
781 if (not_reuse) {
782 type = kDynamicMem;
783 MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
784 }
785 }
786
787 uint8_t *output_ptr = nullptr;
788 for (size_t j = 0; j < align_size_list.size(); ++j) {
789 std::string output_format = AnfAlgo::GetOutputFormat(node, j);
790 auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
791 auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, {node, j});
792 MS_EXCEPTION_IF_NULL(address);
793 if (output_ptr == nullptr) {
794 output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
795 MS_EXCEPTION_IF_NULL(output_ptr);
796 } else {
797 address->set_ptr(output_ptr);
798 }
799 AnfAlgo::SetOutputAddr(address, j, node.get());
800 output_ptr += align_size_list[j];
801 }
802 }
KernelMemNotReuse(const AnfNodePtr & node)803 bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
804 MS_EXCEPTION_IF_NULL(node);
805 return false;
806 }
807
PreAssignCNodeMemory(const AnfNodePtr & anf_node,size_t index) const808 DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) const {
809 MS_EXCEPTION_IF_NULL(anf_node);
810 if (!anf_node->isa<CNode>()) {
811 MS_LOG(EXCEPTION) << "anf_node should be a cnode";
812 }
813 auto cnode = anf_node->cast<CNodePtr>();
814 MS_EXCEPTION_IF_NULL(cnode);
815 if (opt::IsNopNode(cnode)) {
816 const size_t kNopNodeInputSize = 2;
817 if (cnode->size() != kNopNodeInputSize) {
818 MS_LOG(EXCEPTION) << cnode->fullname_with_scope() << " has invalid input size: " << cnode->size();
819 }
820 auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
821 return PreAssignCNodeMemory(input_node_with_index.first, input_node_with_index.second);
822 }
823 auto kernel_mod = AnfAlgo::GetKernelMod(anf_node);
824 MS_EXCEPTION_IF_NULL(kernel_mod);
825 auto output_sizes = kernel_mod->GetOutputSizeList();
826 if (output_sizes.size() <= index) {
827 MS_LOG(EXCEPTION) << "Previous node output size " << output_sizes.size() << " <= node index " << index;
828 }
829 std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
830 auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
831 auto address = CreateDeviceAddress(nullptr, output_sizes[index], output_format, output_type, {anf_node, index});
832 AnfAlgo::SetOutputAddr(address, index, anf_node.get());
833 return address;
834 }
835
AssignCommunicationNodeInputMem(MemType type,const AnfNodePtr & node)836 void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) {
837 auto context_ptr = MsContext::GetInstance();
838 MS_EXCEPTION_IF_NULL(context_ptr);
839 MS_EXCEPTION_IF_NULL(node);
840 MS_EXCEPTION_IF_NULL(mem_manager_);
841 size_t total_size = 0;
842 std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size;
843 size_t input_num = AnfAlgo::GetInputTensorNum(node);
844 for (size_t i = 0; i < input_num; ++i) {
845 auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i, true);
846 auto input_node = input_node_with_index.first;
847 MS_EXCEPTION_IF_NULL(input_node);
848 if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
849 MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
850 return;
851 }
852 DeviceAddressPtr address = nullptr;
853 if (input_node->isa<CNode>()) {
854 address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
855 } else {
856 MS_LOG(EXCEPTION) << "Communication node inputs only support CNode";
857 }
858 MS_EXCEPTION_IF_NULL(address);
859 auto mem_size = MemoryManager::GetCommonAlignSize(address->size());
860 total_size += mem_size;
861 addr_size.emplace_back(address, mem_size);
862 }
863 if (addr_size.empty()) {
864 return;
865 }
866 if (type == kSomasReuseDynamicMem) {
867 bool not_reuse = KernelMemNotReuse(node);
868 if (not_reuse) {
869 type = kDynamicMem;
870 MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
871 }
872 }
873 auto cnode = node->cast<CNodePtr>();
874 MS_EXCEPTION_IF_NULL(cnode);
875 if (cnode->inputs().size() < kMinInputSize) {
876 // communication node's input should contain itself and at least on input
877 MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
878 return;
879 }
880 auto first_input_node = cnode->input(1);
881 auto prenode_index = AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true);
882 uint8_t *input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size,
883 addr_size[0].first, true);
884 for (const auto &iter : addr_size) {
885 MS_EXCEPTION_IF_NULL(iter.first);
886 iter.first->set_ptr(input_ptr);
887 input_ptr += iter.second;
888 }
889 }
890
AssignNodeOutputMem(MemType type,const AnfNodePtr & node,int index)891 void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) {
892 MS_EXCEPTION_IF_NULL(node);
893 MS_EXCEPTION_IF_NULL(mem_manager_);
894
895 if (type == kSomasReuseDynamicMem) {
896 bool not_reuse = KernelMemNotReuse(node);
897 if (not_reuse) {
898 type = kDynamicMem;
899 MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
900 }
901 }
902
903 auto kernel_mod = AnfAlgo::GetKernelMod(node);
904 MS_EXCEPTION_IF_NULL(kernel_mod);
905 auto output_sizes = kernel_mod->GetOutputSizeList();
906 if (output_sizes.empty()) {
907 return;
908 }
909 for (size_t i = 0; i < output_sizes.size(); ++i) {
910 if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
911 continue;
912 }
913 if (NodeOutputDeviceAddressExist(node, i)) {
914 MS_LOG(INFO) << "Already malloc index:" << i;
915 continue;
916 }
917 MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memory size:" << output_sizes[i];
918 if (type == kStaticMem) {
919 MS_LOG(INFO) << "Assign Static Memory for Output node, size:" << output_sizes[i]
920 << " node:" << node->fullname_with_scope();
921 }
922 std::string output_format = AnfAlgo::GetOutputFormat(node, i);
923 auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
924 auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {node, i});
925 MS_EXCEPTION_IF_NULL(device_address);
926 uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
927 MS_EXCEPTION_IF_NULL(ptr);
928 device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
929 AnfAlgo::SetOutputAddr(device_address, i, node.get());
930 }
931 }
932
AssignExtraStaticMem(const TensorPtr & tensor,const AnfNodePtr & node,size_t index)933 DeviceAddressPtr KernelRuntime::AssignExtraStaticMem(const TensorPtr &tensor, const AnfNodePtr &node, size_t index) {
934 MS_EXCEPTION_IF_NULL(node);
935 MS_EXCEPTION_IF_NULL(mem_manager_);
936 auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
937 MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope()
938 << "Assign Static Memory for Output node, size:" << tensor_address->size();
939 auto device_address = CreateDeviceAddress(nullptr, tensor_address->size(), tensor_address->format(),
940 tensor_address->type_id(), {node, index});
941 MS_EXCEPTION_IF_NULL(device_address);
942 uint8_t *ptr = mem_manager_->MallocOutputMem(node, index, kStaticMem, tensor_address->size(), device_address, false);
943 MS_EXCEPTION_IF_NULL(ptr);
944 return device_address;
945 }
946
AssignValueNodeTensor(const ValueNodePtr & value_node,const ValuePtr & node_value,size_t output_idx)947 void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
948 size_t output_idx) {
949 MS_EXCEPTION_IF_NULL(value_node);
950 MS_EXCEPTION_IF_NULL(node_value);
951 MS_EXCEPTION_IF_NULL(mem_manager_);
952 auto ms_context = MsContext::GetInstance();
953 MS_EXCEPTION_IF_NULL(ms_context);
954 std::vector<tensor::TensorPtr> tensors;
955 TensorValueToTensor(node_value, &tensors);
956 // Graph id should be passed to record static memory if profiling is enabled.
957 auto kernel_info = dynamic_cast<device::KernelInfo *>(value_node->kernel_info());
958 MS_EXCEPTION_IF_NULL(kernel_info);
959 uint32_t graph_id = kernel_info->graph_id();
960 for (const auto &tensor : tensors) {
961 if (tensor == nullptr) {
962 MS_LOG(WARNING) << "Tensor is null";
963 return;
964 }
965 auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
966 if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
967 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
968 value_node.get());
969 continue;
970 }
971 size_t tensor_size = LongToSize(tensor->data().nbytes());
972 auto node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
973 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
974 if (output_type_id == kTypeUnknown) {
975 output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
976 }
977 auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
978 DeviceAddressPtr address =
979 CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
980 MS_EXCEPTION_IF_NULL(address);
981 if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
982 !mem_manager_->MallocMemFromMemPool(address, node_size)) {
983 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << node_size;
984 } else {
985 MS_LOG(INFO) << "Assign Static Memory for Value node, size:" << node_size
986 << " node:" << value_node->fullname_with_scope();
987 if (mem_manager_->MallocMem(kStaticMem, node_size, address, graph_id) == nullptr) {
988 MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
989 }
990 }
991 AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
992 if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
993 tensor->data_c(), tensor->device_info().host_format_)) {
994 MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
995 << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
996 << "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
997 }
998 }
999 }
1000
AssignStaticMemoryValueNode(const session::KernelGraph & graph)1001 void KernelRuntime::AssignStaticMemoryValueNode(const session::KernelGraph &graph) {
1002 MS_EXCEPTION_IF_NULL(mem_manager_);
1003 MS_LOG(DEBUG) << "AssignStaticMemoryValueNode start for graph " << graph.graph_id();
1004 auto ms_context = MsContext::GetInstance();
1005 MS_EXCEPTION_IF_NULL(ms_context);
1006 // order the value nodes
1007 std::map<std::string, ValueNodePtr> value_nodes_map;
1008 for (auto &node : graph.graph_value_nodes()) {
1009 MS_EXCEPTION_IF_NULL(node);
1010 value_nodes_map[node->fullname_with_scope()] = node;
1011 }
1012
1013 for (auto &item : value_nodes_map) {
1014 auto value_node = item.second;
1015 MS_EXCEPTION_IF_NULL(value_node);
1016 if (NodeOutputDeviceAddressExist(value_node, 0)) {
1017 MS_LOG(DEBUG) << "value_node[" << value_node->DebugString() << "] address already exist";
1018 auto device_address = AnfAlgo::GetMutableOutputAddr(value_node, 0);
1019 if (device_address->ptr_ == nullptr) {
1020 if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1021 if (!mem_manager_->MallocMemFromMemPool(device_address, device_address->size_)) {
1022 MS_LOG(EXCEPTION) << "MallocMemFromMemPool failed";
1023 }
1024 } else {
1025 if (mem_manager_->MallocMem(kStaticMem, device_address->size_, device_address, graph.graph_id())) {
1026 MS_LOG(EXCEPTION) << "MallocMem kStaticMem failed";
1027 }
1028 }
1029 }
1030 continue;
1031 }
1032 auto &node_value = value_node->value();
1033 MS_EXCEPTION_IF_NULL(node_value);
1034 MS_LOG(DEBUG) << "Malloc memory for " << value_node->fullname_with_scope();
1035 if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>()) {
1036 AssignValueNodeTensor(value_node, node_value, 0);
1037 } else if (node_value->isa<StringImm>()) {
1038 auto value = GetValue<std::string>(node_value);
1039 size_t tensor_size = value.size();
1040 DeviceAddressPtr address = nullptr;
1041 address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
1042 MS_EXCEPTION_IF_NULL(address);
1043 if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
1044 !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
1045 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
1046 } else {
1047 MS_LOG(INFO) << "Assign Static Memory for Value node, size:" << tensor_size
1048 << " node:" << value_node->fullname_with_scope();
1049 if (mem_manager_->MallocMem(kStaticMem, tensor_size, address, graph.graph_id()) == nullptr) {
1050 MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem
1051 << ", tensor size is: " << tensor_size;
1052 }
1053 }
1054 AnfAlgo::SetOutputAddr(address, 0, value_node.get());
1055 ShapeVector shape = {1, SizeToLong(tensor_size)};
1056 if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
1057 MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
1058 }
1059 }
1060 }
1061 MS_LOG(DEBUG) << "AssignStaticMemoryValueNode end";
1062 }
1063
AssignDynamicMemory(const session::KernelGraph & graph)1064 void KernelRuntime::AssignDynamicMemory(const session::KernelGraph &graph) {
1065 MS_EXCEPTION_IF_NULL(mem_manager_);
1066 auto context_ptr = MsContext::GetInstance();
1067 MS_EXCEPTION_IF_NULL(context_ptr);
1068 bool is_enable_mem_reuse = EnvConfigParser::GetInstance().GetSysMemreuse();
1069 auto mem_type = kDynamicMem;
1070 auto &dump_json_parser = DumpJsonParser::GetInstance();
1071 if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 0) {
1072 mindspore::EnvConfigParser::GetInstance().SetSysMemreuse(false);
1073 is_enable_mem_reuse = false;
1074 MS_LOG(INFO) << "Disable Memory Reuse when e2e dump is enable and dump mode is set to dump all kernels";
1075 }
1076
1077 if (is_enable_mem_reuse) {
1078 MS_LOG(INFO) << "Memory Reuse is enable...";
1079 mem_manager_->MallocSomasDynamicMem(graph);
1080 mem_type = kSomasReuseDynamicMem;
1081 } else {
1082 MS_LOG(INFO) << "Memory Reuse is disable...";
1083 }
1084 auto &execution_nodes = graph.execution_order();
1085 std::vector<CNodePtr> compute_nodes;
1086 // communication nodes first
1087 for (auto &node : execution_nodes) {
1088 if (AnfAlgo::IsCommunicationOp(node)) {
1089 // skip if the memory is already allocated
1090 AssignCommunicationNodeMem(mem_type, node);
1091 } else {
1092 compute_nodes.emplace_back(node);
1093 }
1094 }
1095
1096 // then compute nodes
1097 for (auto &node : compute_nodes) {
1098 AssignNodeOutputMem(mem_type, node, kGetAllOuts);
1099 AssignWorkSpaceMem(mem_type, node);
1100 }
1101 }
1102
AssignWorkSpaceMem(MemType type,const AnfNodePtr & node)1103 void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
1104 MS_EXCEPTION_IF_NULL(node);
1105 MS_EXCEPTION_IF_NULL(mem_manager_);
1106 auto kernel_mod = AnfAlgo::GetKernelMod(node);
1107 MS_EXCEPTION_IF_NULL(kernel_mod);
1108 size_t index = 0;
1109 for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
1110 if (AnfAlgo::WorkspaceAddrExist(node, index)) {
1111 MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
1112 return;
1113 }
1114 auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
1115 AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
1116 index++;
1117 }
1118 }
1119
GenLaunchArgs(const mindspore::kernel::KernelMod & kernel_mod,const mindspore::AnfNodePtr & kernel,AddressPtrList * kernel_inputs,AddressPtrList * const kernel_workspaces,AddressPtrList * kernel_outputs)1120 void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
1121 AddressPtrList *kernel_inputs, AddressPtrList *const kernel_workspaces,
1122 AddressPtrList *kernel_outputs) {
1123 MS_EXCEPTION_IF_NULL(kernel);
1124 MS_EXCEPTION_IF_NULL(kernel_inputs);
1125 MS_EXCEPTION_IF_NULL(kernel_workspaces);
1126 MS_EXCEPTION_IF_NULL(kernel_outputs);
1127 auto cnode = kernel->cast<CNodePtr>();
1128 MS_EXCEPTION_IF_NULL(cnode);
1129 if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
1130 return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
1131 }
1132 auto ms_context = MsContext::GetInstance();
1133 MS_EXCEPTION_IF_NULL(ms_context);
1134 auto visit_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
1135 size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
1136 for (size_t i = 0; i < input_num; ++i) {
1137 auto op_name = AnfAlgo::GetCNodeName(cnode);
1138 constexpr auto none_placeholder_index = 3;
1139 if (op_name == kDynamicRNNOpName && i == none_placeholder_index) {
1140 continue;
1141 }
1142 if (op_name == kDynamicGRUV2OpName) {
1143 auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
1144 auto item = std::find(none_index.begin(), none_index.end(), i);
1145 if (item != none_index.end()) {
1146 continue;
1147 }
1148 }
1149 auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
1150 auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input, visit_nop_node);
1151 MS_EXCEPTION_IF_NULL(device_address);
1152 kernel::AddressPtr input = std::make_shared<kernel::Address>();
1153 MS_EXCEPTION_IF_NULL(input);
1154 input->addr = device_address->ptr_;
1155 MS_EXCEPTION_IF_NULL(input->addr);
1156 input->size = device_address->size_;
1157 kernel_inputs->emplace_back(input);
1158 }
1159
1160 for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
1161 auto device_address = AnfAlgo::GetOutputAddr(kernel, i, visit_nop_node);
1162 kernel::AddressPtr output = std::make_shared<kernel::Address>();
1163 MS_EXCEPTION_IF_NULL(output);
1164 output->addr = device_address->ptr_;
1165 MS_EXCEPTION_IF_NULL(output->addr);
1166 output->size = device_address->size_;
1167 kernel_outputs->emplace_back(output);
1168 }
1169
1170 for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
1171 auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
1172 kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
1173 MS_EXCEPTION_IF_NULL(workspace);
1174 workspace->addr = device_address->ptr_;
1175 MS_EXCEPTION_IF_NULL(workspace->addr);
1176 workspace->size = device_address->size_;
1177 kernel_workspaces->emplace_back(workspace);
1178 }
1179 }
1180
GenAddrCleanLaunchArgs(const CNodePtr & cnode,AddressPtrList * kernel_inputs,const std::shared_ptr<MemScheduler> & mem_scheduler)1181 void KernelRuntime::GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs,
1182 const std::shared_ptr<MemScheduler> &mem_scheduler) {
1183 MS_EXCEPTION_IF_NULL(cnode);
1184 MS_EXCEPTION_IF_NULL(kernel_inputs);
1185 if (cnode->inputs().size() != kAtomicCleanInputSize) {
1186 MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
1187 }
1188 MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
1189 auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
1190 // set clean output address
1191 if (AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
1192 #if defined(__APPLE__)
1193 auto clean_output_indexes = AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicOutputIndexs);
1194 #else
1195 auto clean_output_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
1196 #endif
1197 for (auto index : clean_output_indexes) {
1198 auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
1199 kernel::AddressPtr input = std::make_shared<kernel::Address>();
1200 MS_EXCEPTION_IF_NULL(input);
1201 if (mem_scheduler != nullptr) {
1202 GetOrMallocAddress(mem_scheduler, device_address, input);
1203 } else {
1204 input->addr = device_address->ptr_;
1205 MS_EXCEPTION_IF_NULL(input->addr);
1206 }
1207 input->size = device_address->size_;
1208 kernel_inputs->emplace_back(input);
1209 }
1210 MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
1211 }
1212 // set clean workspace address
1213 if (AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
1214 #if defined(__APPLE__)
1215 auto clean_workspaces_indexes = AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicWorkspaceIndexs);
1216 #else
1217 auto clean_workspaces_indexes = AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
1218 #endif
1219 for (const auto &index : clean_workspaces_indexes) {
1220 auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
1221 kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
1222 MS_EXCEPTION_IF_NULL(workspace);
1223 if (mem_scheduler != nullptr) {
1224 GetOrMallocAddress(mem_scheduler, device_address, workspace);
1225 } else {
1226 workspace->addr = device_address->ptr_;
1227 MS_EXCEPTION_IF_NULL(workspace->addr);
1228 }
1229 workspace->size = device_address->size_;
1230 kernel_inputs->emplace_back(workspace);
1231 }
1232 }
1233 }
1234
LaunchKernelEvent(const std::vector<std::vector<std::function<void ()>>> & kernel_events,size_t index) const1235 void KernelRuntime::LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &kernel_events,
1236 size_t index) const {
1237 if (index >= kernel_events.size()) {
1238 return;
1239 }
1240 for (auto &event : kernel_events[index]) {
1241 event();
1242 }
1243 }
1244
LaunchKernelWithPynativeProfiling(kernel::KernelMod * kernel_mod,const std::string & op_name,const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspace,const std::vector<AddressPtr> & outputs,void * stream)1245 bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
1246 const std::vector<AddressPtr> &inputs,
1247 const std::vector<AddressPtr> &workspace,
1248 const std::vector<AddressPtr> &outputs, void *stream) {
1249 MS_EXCEPTION_IF_NULL(kernel_mod);
1250 MS_EXCEPTION_IF_NULL(stream);
1251 float cost_time = 0;
1252 auto start = CreateDeviceTimeEvent();
1253 auto end = CreateDeviceTimeEvent();
1254 MS_EXCEPTION_IF_NULL(start);
1255 MS_EXCEPTION_IF_NULL(end);
1256 start->set_record_stream(stream);
1257 end->set_record_stream(stream);
1258 start->RecordEvent();
1259 bool ret = kernel_mod->Launch(inputs, workspace, outputs, stream);
1260 end->RecordEvent();
1261 start->SyncEvent();
1262 end->SyncEvent();
1263 start->ElapsedTime(&cost_time, end.get());
1264 auto launch_end_time = GetTime();
1265 double launch_start_time = launch_end_time - cost_time / kBasicTimeTransferUnit;
1266 auto op_launch_start_time_end_time = std::make_pair(launch_start_time, launch_end_time);
1267 PynativeProfiler::SetDeviceOpNameAndLaunchTimePoint(std::make_pair(op_name, op_launch_start_time_end_time));
1268 PynativeProfiler::SetDeviceOpNameAndLaunchCostTime(std::make_pair(op_name, cost_time / kBasicTimeTransferUnit));
1269 if (!ret) {
1270 MS_LOG(EXCEPTION) << "Launch kernel failed, kernel name is : " << op_name;
1271 }
1272 return ret;
1273 }
1274
DebugStreamSync(const CNodePtr & kernel)1275 void KernelRuntime::DebugStreamSync(const CNodePtr &kernel) {
1276 auto ms_context = MsContext::GetInstance();
1277 MS_EXCEPTION_IF_NULL(ms_context);
1278 auto enable_sync_run = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
1279 if (enable_sync_run) {
1280 if (!SyncStream()) {
1281 MS_LOG(EXCEPTION) << "Op " << kernel->fullname_with_scope() << " run failed!";
1282 }
1283 }
1284 }
1285
GetOrMallocAddress(const std::shared_ptr<MemScheduler> & mem_scheduler,const DeviceAddress * device_address,const kernel::AddressPtr & kernel_addr)1286 void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
1287 const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr) {
1288 if (device_address->ptr_ != nullptr) {
1289 kernel_addr->addr = device_address->ptr_;
1290 } else {
1291 kernel_addr->addr = mem_scheduler->GetOrMalloc(device_address, device_address->size_);
1292 if (mem_scheduler->IsHighPriorityMem(device_address)) {
1293 device_address->ptr_ = kernel_addr->addr;
1294 }
1295 }
1296 }
1297
AssignKernelAddress(const std::shared_ptr<MemScheduler> & mem_scheduler,const AnfNodePtr & kernel,AddressPtrList * kernel_inputs,AddressPtrList * kernel_workspaces,AddressPtrList * kernel_outputs)1298 void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
1299 AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces,
1300 AddressPtrList *kernel_outputs) {
1301 MS_EXCEPTION_IF_NULL(kernel);
1302 MS_EXCEPTION_IF_NULL(kernel_inputs);
1303 MS_EXCEPTION_IF_NULL(kernel_workspaces);
1304 MS_EXCEPTION_IF_NULL(kernel_outputs);
1305 auto cnode = kernel->cast<CNodePtr>();
1306 MS_EXCEPTION_IF_NULL(cnode);
1307 if (AnfAlgo::GetCNodeName(cnode) == kAtomicAddrCleanOpName) {
1308 return GenAddrCleanLaunchArgs(cnode, kernel_inputs, mem_scheduler);
1309 }
1310 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1311 MS_EXCEPTION_IF_NULL(kernel_mod);
1312 size_t input_num = AnfAlgo::GetInputTensorNum(kernel);
1313 for (size_t j = 0; j < input_num; ++j) {
1314 auto real_input = AnfAlgo::GetRealInputIndex(kernel, j);
1315 auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
1316 auto index = kernel_with_index.second;
1317 auto &input_node = kernel_with_index.first;
1318 auto device_address = AnfAlgo::GetOutputAddr(input_node, index, true);
1319 MS_EXCEPTION_IF_NULL(device_address);
1320 kernel::AddressPtr input = std::make_shared<kernel::Address>();
1321 GetOrMallocAddress(mem_scheduler, device_address, input);
1322 input->size = device_address->size_;
1323 kernel_inputs->emplace_back(input);
1324 }
1325
1326 for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
1327 auto device_address = AnfAlgo::GetOutputAddr(kernel, j, true);
1328 kernel::AddressPtr output = std::make_shared<kernel::Address>();
1329 GetOrMallocAddress(mem_scheduler, device_address, output);
1330 output->size = device_address->size_;
1331 kernel_outputs->emplace_back(output);
1332 }
1333
1334 for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
1335 auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
1336 kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
1337 GetOrMallocAddress(mem_scheduler, device_address, workspace);
1338 workspace->size = device_address->size_;
1339 kernel_workspaces->emplace_back(workspace);
1340 }
1341 }
1342
SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> & mem_scheduler,const session::KernelGraph & graph,const AnfNodePtr & kernel,bool mock)1343 void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
1344 const session::KernelGraph &graph, const AnfNodePtr &kernel, bool mock) {
1345 MS_EXCEPTION_IF_NULL(mem_scheduler);
1346 MS_EXCEPTION_IF_NULL(kernel);
1347 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1348 MS_EXCEPTION_IF_NULL(kernel_mod);
1349 for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
1350 auto tensor = graph.GetNodeOutputTensor(std::make_pair(kernel, j));
1351 auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, j, true);
1352 if (mock) {
1353 if (graph.IsInternalOutput(kernel, j) && device_address != nullptr) {
1354 mem_scheduler->SetMemPriority(device_address.get(), kMemPriorityHigh);
1355 }
1356 continue;
1357 }
1358 if (tensor != nullptr) {
1359 if (device_address == nullptr) {
1360 tensor->data_sync(false);
1361 tensor->set_device_address(nullptr);
1362 tensor->set_sync_status(kNeedSyncHostToDevice);
1363 continue;
1364 }
1365 if (!SyncStream()) {
1366 MS_LOG(ERROR) << "SyncStream failed";
1367 }
1368 auto origin_ptr = device_address->ptr_;
1369 if (origin_ptr == nullptr) {
1370 device_address->ptr_ = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size_);
1371 }
1372 tensor->set_device_address(device_address);
1373 tensor->data_sync(false);
1374 tensor->set_device_address(nullptr);
1375 if (origin_ptr == nullptr) {
1376 device_address->ptr_ = nullptr;
1377 }
1378 tensor->set_sync_status(kNeedSyncHostToDevice);
1379 }
1380 }
1381 }
1382
InitGraphInputTensors(const std::shared_ptr<MemScheduler> & mem_scheduler,const session::KernelGraph & graph)1383 void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
1384 const session::KernelGraph &graph) {
1385 MS_EXCEPTION_IF_NULL(mem_scheduler);
1386 auto &input_nodes = graph.input_nodes();
1387 auto &input_tensors = graph.input_tensors();
1388 if (input_tensors.size() != input_nodes.size()) {
1389 MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
1390 }
1391 for (size_t i = 0; i < input_tensors.size(); ++i) {
1392 auto tensor = input_tensors[i];
1393 MS_EXCEPTION_IF_NULL(tensor);
1394 auto input_node = input_nodes[i];
1395 if (!input_node->isa<Parameter>()) {
1396 continue;
1397 }
1398 if (AnfAlgo::OutputAddrExist(input_node, 0)) {
1399 auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
1400 MS_EXCEPTION_IF_NULL(tensor);
1401 MemPriority priority = kMemPriorityHigh;
1402 auto tensor_address = tensor->device_address();
1403 if (tensor_address != nullptr && tensor_address != device_address) {
1404 tensor->data_sync(false);
1405 priority = kMemPriorityLow;
1406 }
1407 auto tensor_size = LongToSize(tensor->data().nbytes());
1408 mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority);
1409 }
1410 }
1411 }
1412
LaunchKernel(const session::KernelGraph & graph,const AnfNodePtr & kernel,const std::shared_ptr<MemScheduler> & mem_scheduler,bool mock)1413 bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel,
1414 const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock) {
1415 MS_EXCEPTION_IF_NULL(kernel);
1416 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1417 MS_EXCEPTION_IF_NULL(kernel_mod);
1418 AddressPtrList kernel_inputs;
1419 AddressPtrList kernel_workspaces;
1420 AddressPtrList kernel_outputs;
1421 auto stream = kernel_mod->GetStream();
1422 if (stream == nullptr) {
1423 if (AnfAlgo::IsCommunicationOp(kernel)) {
1424 stream = communication_stream_;
1425 } else {
1426 stream = stream_;
1427 }
1428 }
1429 bool ret = true;
1430 if (mem_scheduler != nullptr) {
1431 ret = mem_scheduler->PreCompute(stream);
1432 if (!ret) {
1433 return ret;
1434 }
1435 AssignKernelAddress(mem_scheduler, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
1436 } else if (!kernel_mod->GetInputsAddr().empty() || !kernel_mod->GetOutputsAddr().empty()) {
1437 kernel_inputs = kernel_mod->GetInputsAddr();
1438 kernel_outputs = kernel_mod->GetOutputsAddr();
1439 kernel_workspaces = kernel_mod->GetWorkSpacesAddr();
1440 } else {
1441 GenLaunchArgs(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
1442 }
1443 if (!mock) {
1444 if (pynative_mode_profiling_flag_) {
1445 ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_inputs,
1446 kernel_workspaces, kernel_outputs, stream);
1447 } else {
1448 ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream);
1449 }
1450 }
1451 if (mem_scheduler != nullptr) {
1452 SyncNodeOutputTensors(mem_scheduler, graph, kernel, mock);
1453 ret = mem_scheduler->PostCompute(stream);
1454 if (!ret) {
1455 return ret;
1456 }
1457 }
1458 return ret;
1459 }
1460
LaunchKernelMod(const session::KernelGraph & graph,bool mock)1461 bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock) {
1462 auto context_ptr = MsContext::GetInstance();
1463 MS_EXCEPTION_IF_NULL(context_ptr);
1464 std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
1465 auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
1466 if (enable_mem_scheduler) {
1467 mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1468 MS_EXCEPTION_IF_NULL(mem_scheduler);
1469 mem_scheduler->SetMemHandler(mem_manager_);
1470 mem_scheduler->RecordMemUsage();
1471 InitGraphInputTensors(mem_scheduler, graph);
1472 }
1473 const auto &kernels = graph.execution_order();
1474 std::vector<DynamicKernelPtr> dynamic_kernel_list;
1475 auto iter = graph_dynamic_kernel_map_.find(graph.graph_id());
1476 if (iter != graph_dynamic_kernel_map_.end()) {
1477 dynamic_kernel_list = iter->second;
1478 }
1479 if (!dynamic_kernel_list.empty() && dynamic_kernel_list.size() != kernels.size()) {
1480 MS_LOG(EXCEPTION) << "The size of dynamic kernels " << dynamic_kernel_list.size()
1481 << " should be equal to the size of kernels " << kernels.size();
1482 }
1483 std::vector<std::vector<std::function<void()>>> kernel_pre_run_events;
1484 std::vector<std::vector<std::function<void()>>> kernel_post_run_events;
1485 auto events_iter = graph_kernel_events_map_.find(graph.graph_id());
1486 if (events_iter != graph_kernel_events_map_.end()) {
1487 kernel_pre_run_events = events_iter->second.first;
1488 kernel_post_run_events = events_iter->second.second;
1489 }
1490 for (size_t i = 0; i < kernels.size(); ++i) {
1491 LaunchKernelEvent(kernel_pre_run_events, i);
1492 if (!dynamic_kernel_list.empty() && dynamic_kernel_list[i] != nullptr &&
1493 dynamic_kernel_list[i]->is_dynamic_shape()) {
1494 dynamic_kernel_list[i]->InferShape();
1495 dynamic_kernel_list[i]->UpdateArgs();
1496 dynamic_kernel_list[i]->Execute();
1497 if (!SyncStream()) {
1498 MS_LOG(ERROR) << "SyncStream failed";
1499 return false;
1500 }
1501 dynamic_kernel_list[i]->PostExecute();
1502 } else {
1503 auto &kernel = kernels[i];
1504 MS_EXCEPTION_IF_NULL(kernel);
1505
1506 // Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
1507 // kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata.
1508 // And hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
1509 if (AnfAlgo::HasNodeAttr("nop_op", kernel)) {
1510 for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(kernel); idx += 1) {
1511 auto real_input = AnfAlgo::GetRealInputIndex(kernel, idx);
1512 auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input);
1513 AnfAlgo::SetOutputAddr(device_address, idx, kernel.get());
1514 }
1515 continue;
1516 }
1517 auto ret = LaunchKernel(graph, kernel, mem_scheduler, mock);
1518 if (!ret) {
1519 MS_LOG(ERROR) << "Launch kernel failed.";
1520 return false;
1521 }
1522 KernelLaunchProfiling(kernel->fullname_with_scope());
1523 DebugStreamSync(kernel);
1524 }
1525 LaunchKernelEvent(kernel_post_run_events, i);
1526 }
1527 if (mem_scheduler != nullptr) {
1528 mem_scheduler->OptMemUsage();
1529 }
1530 return true;
1531 }
1532
UseMemSchedulerIfNeeded(const session::KernelGraph & graph)1533 void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
1534 auto context_ptr = MsContext::GetInstance();
1535 MS_EXCEPTION_IF_NULL(context_ptr);
1536 auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
1537 if (enable_mem_scheduler) {
1538 auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1539 if (mem_scheduler->need_record_event()) {
1540 (void)LaunchKernelMod(graph, true);
1541 }
1542 float mem_used_factor = kMaxMemReuseFactor;
1543 while (!mem_scheduler->optimized() && mem_used_factor >= kMinMemReuseFactor) {
1544 mem_scheduler->SetMemUsedFactor(mem_used_factor);
1545 bool ret = LaunchKernelMod(graph, true);
1546 if (ret) {
1547 mem_scheduler->SetOptimized(true);
1548 } else {
1549 mem_used_factor -= kRetryFactor;
1550 }
1551 }
1552 }
1553 }
1554
LaunchKernels(const session::KernelGraph & graph)1555 bool KernelRuntime::LaunchKernels(const session::KernelGraph &graph) {
1556 UseMemSchedulerIfNeeded(graph);
1557 if (!LaunchKernelMod(graph)) {
1558 MS_LOG(ERROR) << "LaunchKernelMod failed!";
1559 return false;
1560 }
1561 auto ms_context = MsContext::GetInstance();
1562 MS_EXCEPTION_IF_NULL(ms_context);
1563 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
1564 if (!SyncStream()) {
1565 MS_LOG(ERROR) << "SyncStream failed";
1566 return false;
1567 }
1568 }
1569 return true;
1570 }
1571
ClearGraphRuntimeResource(uint32_t graph_id)1572 void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
1573 MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
1574 }
1575
1576 #if ((defined ENABLE_CPU) && (!defined _WIN32))
GetFirstPSEmbeddingCache(const session::KernelGraph & graph,AnfNodePtr * const first_cache_input_index,size_t * const first_cache_size)1577 void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph &graph,
1578 AnfNodePtr *const first_cache_input_index,
1579 size_t *const first_cache_size) {
1580 for (const auto &kernel : graph.execution_order()) {
1581 MS_EXCEPTION_IF_NULL(kernel);
1582 auto kernel_name = AnfAlgo::GetCNodeName(kernel);
1583 if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
1584 continue;
1585 }
1586 auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
1587 auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
1588 MS_EXCEPTION_IF_NULL(input_param.first);
1589 MS_EXCEPTION_IF_NULL(input_index.first);
1590 auto param_name = input_param.first->fullname_with_scope();
1591 if (!ps::ps_cache_instance.IsHashTable(param_name)) {
1592 continue;
1593 }
1594 auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
1595 while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
1596 input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
1597 MS_EXCEPTION_IF_NULL(input_index.first);
1598 }
1599 auto cnode =
1600 AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
1601 MS_EXCEPTION_IF_NULL(cnode);
1602 if (!cnode->isa<CNode>()) {
1603 MS_LOG(EXCEPTION) << "The embeddingLookup whose input index should be a CNode but got "
1604 << cnode->fullname_with_scope();
1605 }
1606 auto input_index_node_name = AnfAlgo::GetCNodeName(cnode);
1607 if (input_index_node_name != kGetNextOpName) {
1608 bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
1609 if ((!full_batch && (input_index_node_name != kUniqueOpName)) ||
1610 (full_batch && (input_index_node_name != kMinimumOpName))) {
1611 MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
1612 << ") cache is from " << cnode->fullname_with_scope();
1613 MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
1614 "parameter server training mode.";
1615 }
1616 }
1617 *first_cache_input_index = cnode;
1618 *first_cache_size = size;
1619 MS_LOG(INFO) << "The input index of the first embeddingLookup cache is from " << cnode->fullname_with_scope()
1620 << ", the cache size is " << size;
1621 return;
1622 }
1623 }
1624
CheckSparsePSEmbeddingCache(const CNodePtr & node)1625 void KernelRuntime::CheckSparsePSEmbeddingCache(const CNodePtr &node) {
1626 MS_EXCEPTION_IF_NULL(node);
1627 auto pre_node = AnfAlgo::GetPrevNodeOutput(node, 1, true);
1628 MS_EXCEPTION_IF_NULL(pre_node.first);
1629 while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
1630 pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
1631 MS_EXCEPTION_IF_NULL(pre_node.first);
1632 }
1633 if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kUniqueOpName)) {
1634 MS_LOG(EXCEPTION) << "The input_indices of kernel[SparseGatherV2] must be unique in parameter server cache mode";
1635 }
1636
1637 pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
1638 MS_EXCEPTION_IF_NULL(pre_node.first);
1639 while (pre_node.first->isa<CNode>() && (AnfAlgo::GetCNodeName(pre_node.first) == kCastOpName)) {
1640 pre_node = AnfAlgo::GetPrevNodeOutput(pre_node.first, 0, true);
1641 MS_EXCEPTION_IF_NULL(pre_node.first);
1642 }
1643 if (!(pre_node.first->isa<CNode>()) || (AnfAlgo::GetCNodeName(pre_node.first) != kGetNextOpName)) {
1644 MS_LOG(EXCEPTION) << "The input indices of kernel[Unique] must be produced from dataset directly and the indices "
1645 "value can not be changed before delivering to kernel[Unique] in parameter server cache mode.";
1646 }
1647 }
1648
CheckIfSupportPSEmbeddingCache(const session::KernelGraph & graph)1649 void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph &graph) {
1650 AnfNodePtr first_cache_input_index = nullptr;
1651 size_t first_cache_size = 0;
1652 GetFirstPSEmbeddingCache(graph, &first_cache_input_index, &first_cache_size);
1653 MS_EXCEPTION_IF_NULL(first_cache_input_index);
1654 for (const auto &kernel : graph.execution_order()) {
1655 MS_EXCEPTION_IF_NULL(kernel);
1656 auto kernel_name = AnfAlgo::GetCNodeName(kernel);
1657 if (kernel_name != kGatherV2OpName && kernel_name != kSparseGatherV2OpName) {
1658 continue;
1659 }
1660 auto input_param = AnfAlgo::GetPrevNodeOutput(kernel, 0, true);
1661 auto input_index = AnfAlgo::GetPrevNodeOutput(kernel, 1, true);
1662 MS_EXCEPTION_IF_NULL(input_param.first);
1663 MS_EXCEPTION_IF_NULL(input_index.first);
1664 if (!input_param.first->isa<Parameter>()) {
1665 continue;
1666 }
1667 auto param_name = input_param.first->fullname_with_scope();
1668 if (ps::ps_cache_instance.IsHashTable(param_name) && (kernel_name == kSparseGatherV2OpName)) {
1669 CheckSparsePSEmbeddingCache(kernel);
1670 }
1671 while (input_index.first->isa<CNode>() && (AnfAlgo::GetCNodeName(input_index.first) == kCastOpName)) {
1672 input_index = AnfAlgo::GetPrevNodeOutput(input_index.first, 0, true);
1673 MS_EXCEPTION_IF_NULL(input_index.first);
1674 }
1675 auto cnode =
1676 AnfAlgo::IsGraphKernel(input_index.first) ? AnfAlgo::GetOutputOfGraphkernel(input_index) : input_index.first;
1677 MS_EXCEPTION_IF_NULL(cnode);
1678 if (cnode == first_cache_input_index) {
1679 if (!ps::ps_cache_instance.IsHashTable(param_name)) {
1680 MS_LOG(ERROR) << "The embeddingLookup(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
1681 MS_LOG(EXCEPTION) << "All the embeddingLookups whose input indices are from dataset must enable cache at the "
1682 "same time when one of them enables cache in parameter server training mode.";
1683 }
1684 auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
1685 if (size != first_cache_size) {
1686 MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope()
1687 << ") is not the same as other embeddingLookup cache size(" << first_cache_size << ").";
1688 MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode.";
1689 }
1690 } else if (ps::ps_cache_instance.IsHashTable(param_name)) {
1691 MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
1692 << cnode->fullname_with_scope();
1693 MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
1694 "parameter server training mode.";
1695 } else if (cnode->isa<CNode>() && (AnfAlgo::GetCNodeName(cnode) == kGetNextOpName)) {
1696 MS_LOG(ERROR) << "The EmbeddingLookup kernel(" << kernel->fullname_with_scope() << ") doesn't enable cache.";
1697 MS_LOG(EXCEPTION) << "All EmbeddingLookup kernels whose input indices are from dataset must enable cache at "
1698 "the same time and parameter 'sparse' must be equal to the value of 'enable_sparse' in "
1699 "context setting in parameter server training mode.";
1700 }
1701 }
1702 }
1703 #endif
1704 } // namespace device
1705 } // namespace mindspore
1706