1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/kernel_runtime.h"
18 #include <algorithm>
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 #include <set>
23 #include <shared_mutex>
24 #include "ops/ascend_op_name.h"
25 #include "ops/nn_optimizer_op_name.h"
26 #include "ops/sequence_ops.h"
27 #include "include/backend/optimizer/helper.h"
28 #include "include/backend/anf_runtime_algorithm.h"
29 #include "include/common/utils/anfalgo.h"
30 #include "include/backend/kernel_graph.h"
31 #include "runtime/device/ms_device_shape_transfer.h"
32 #include "runtime/pynative/op_runtime_info.h"
33 #include "runtime/device/kernel_runtime_manager.h"
34 #include "include/backend/debug/data_dump/dump_json_parser.h"
35 #include "include/backend/mem_reuse/mem_tracker.h"
36 #include "frontend/operator/ops.h"
37 #include "ir/value.h"
38 #include "utils/ms_context.h"
39 #include "include/common/utils/utils.h"
40 #include "include/common/utils/parallel_context.h"
41 #include "include/common/debug/env_config_parser.h"
42 #include "kernel/framework_utils.h"
43
44 using mindspore::kernel::Address;
45 using mindspore::kernel::AddressPtr;
46
47 namespace mindspore {
48 namespace device {
49 constexpr size_t kAtomicCleanInputSize = 2;
50 namespace {
GetGraphInputs(const session::KernelGraph & graph)51 std::vector<AnfNodePtr> GetGraphInputs(const session::KernelGraph &graph) {
52 auto graph_inputs = graph.inputs();
53 std::vector<AnfNodePtr> result(graph_inputs.begin(), graph_inputs.end());
54 std::set<AnfNodePtr> inputs_set(graph_inputs.begin(), graph_inputs.end());
55 auto kernels = graph.execution_order();
56 for (auto &kernel : kernels) {
57 MS_EXCEPTION_IF_NULL(kernel);
58 auto input_num = common::AnfAlgo::GetInputTensorNum(kernel);
59 for (size_t i = 0; i < input_num; ++i) {
60 auto input_node = kernel->input(i + 1);
61 auto input_real_node = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0).first;
62 MS_EXCEPTION_IF_NULL(input_real_node);
63 if (input_real_node->isa<Parameter>() && inputs_set.find(input_real_node) == inputs_set.end()) {
64 (void)inputs_set.insert(input_real_node);
65 (void)result.emplace_back(input_real_node);
66 }
67 }
68 }
69 return result;
70 }
71
72 // Check whether mutex exists for a stream.
CheckStreamMutexExist(const void * stream,const mindspore::HashMap<const void *,std::shared_ptr<std::mutex>> & mtxs_for_streams,std::shared_mutex * shd_mtx)73 std::pair<bool, std::mutex *> CheckStreamMutexExist(
74 const void *stream, const mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> &mtxs_for_streams,
75 std::shared_mutex *shd_mtx) {
76 MS_EXCEPTION_IF_NULL(stream);
77 MS_EXCEPTION_IF_NULL(shd_mtx);
78 std::shared_lock<std::shared_mutex> shd_lock(*shd_mtx);
79 auto iter = mtxs_for_streams.find(stream);
80 if (iter != mtxs_for_streams.end()) {
81 MS_EXCEPTION_IF_NULL(iter->second);
82 return std::make_pair(true, iter->second.get());
83 }
84 return std::make_pair(false, nullptr);
85 }
86
87 // Create a mutex for stream.
CreateStreamMutex(const void * stream,std::shared_mutex * shd_mtx,mindspore::HashMap<const void *,std::shared_ptr<std::mutex>> * mtxs_for_streams)88 std::mutex *CreateStreamMutex(const void *stream, std::shared_mutex *shd_mtx,
89 mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> *mtxs_for_streams) {
90 MS_EXCEPTION_IF_NULL(stream);
91 MS_EXCEPTION_IF_NULL(shd_mtx);
92 MS_EXCEPTION_IF_NULL(mtxs_for_streams);
93
94 std::unique_lock<std::shared_mutex> unq_lock(*shd_mtx);
95 auto ret_pair = mtxs_for_streams->emplace(stream, std::make_shared<std::mutex>());
96
97 MS_EXCEPTION_IF_NULL(ret_pair.first->second);
98 return ret_pair.first->second.get();
99 }
100
IsNeedAllocMem(const AnfNodePtr & node,size_t index)101 bool IsNeedAllocMem(const AnfNodePtr &node, size_t index) {
102 MS_EXCEPTION_IF_NULL(node);
103 const auto &graph = node->func_graph();
104 if (graph == nullptr) {
105 return true;
106 }
107 if (!graph->has_flag(kFlagEnableZeroCopyInGraph)) {
108 return true;
109 }
110 if (node->isa<Parameter>() && graph->output() == nullptr) {
111 return false;
112 }
113 const auto &outputs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
114 return std::find_if(outputs.begin(), outputs.end(), [&node, &index](const KernelWithIndex &output) {
115 const auto &real_output = common::AnfAlgo::FetchRealNodeSkipMonadControl(output);
116 return ((real_output.first == node) && (real_output.second == index));
117 }) == outputs.end();
118 }
119 } // namespace
120 constexpr size_t kMinInputSize = 2;
121 KernelRuntime::TbeLaunchKernelModCallBack KernelRuntime::tbe_call_ = nullptr;
~KernelRuntime()122 KernelRuntime::~KernelRuntime() {
123 stream_ = nullptr;
124 copy_data_stream_ = nullptr;
125 communication_stream_ = nullptr;
126 }
127
LockRuntime(const void * stream)128 std::lock_guard<std::mutex> KernelRuntime::LockRuntime(const void *stream) {
129 MS_EXCEPTION_IF_NULL(stream);
130 // Read-write lock for accessing mtxs_for_streams map.
131 // When the lock of each stream is created, mtxs_for_streams can be accessed concurrently to improve performance.
132 static std::shared_mutex shd_mtx;
133 static mindspore::HashMap<const void *, std::shared_ptr<std::mutex>> mtxs_for_streams;
134
135 std::mutex *stream_mtx = nullptr;
136 // Check whether mutex exists for a stream.
137 std::pair<bool, std::mutex *> ret_pair = CheckStreamMutexExist(stream, mtxs_for_streams, &shd_mtx);
138 if (ret_pair.first) {
139 stream_mtx = ret_pair.second;
140 } else {
141 // Create a mutex for stream.
142 stream_mtx = CreateStreamMutex(stream, &shd_mtx, &mtxs_for_streams);
143 }
144
145 MS_EXCEPTION_IF_NULL(stream_mtx);
146 return std::lock_guard<std::mutex>(*stream_mtx);
147 }
148
Load(const session::KernelGraph &,bool)149 bool KernelRuntime::Load(const session::KernelGraph &, bool) {
150 MS_LOG(INFO) << "Call default load.";
151 return true;
152 }
153
LoadData(const session::KernelGraph &)154 bool KernelRuntime::LoadData(const session::KernelGraph &) {
155 MS_LOG(INFO) << "Call default load data.";
156 return false;
157 }
158
NodeOutputDeviceAddressExist(const AnfNodePtr & kernel,size_t index)159 bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
160 MS_EXCEPTION_IF_NULL(kernel);
161 if (AnfAlgo::OutputAddrExist(kernel, index)) {
162 // In subgraph sink mode, if the kernel does not need allocate memory, it cannot be skipped.
163 const auto &address = AnfAlgo::GetOutputAddr(kernel, index, IsNeedAllocMem(kernel, index));
164 MS_EXCEPTION_IF_NULL(address);
165 return address->GetDeviceType() == GetTargetDeviceType() &&
166 address->format() == AnfAlgo::GetOutputFormat(kernel, 0);
167 }
168 return false;
169 }
170
AssignMemory(const session::KernelGraph & graph)171 void KernelRuntime::AssignMemory(const session::KernelGraph &graph) {
172 auto context_ptr = MsContext::GetInstance();
173 MS_EXCEPTION_IF_NULL(context_ptr);
174 if (UseMemScheduler()) {
175 AssignStaticMemoryValueNode(graph);
176 ResetNodeAddress(graph);
177 AddCommunicationMemInfo(graph);
178 } else {
179 MS_EXCEPTION_IF_NULL(mem_manager_);
180 mem_manager_->ResetDynamicMemory();
181 AssignStaticMemory(graph);
182 AssignDynamicMemory(graph);
183 }
184 UpdateRefNodeOutputMem(graph);
185 }
186
GetCommunicationInputInfo(const AnfNodePtr & node,size_t * total_size,DeviceAddressPtrList * address_list,std::vector<size_t> * align_size_list) const187 void KernelRuntime::GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size,
188 DeviceAddressPtrList *address_list,
189 std::vector<size_t> *align_size_list) const {
190 MS_EXCEPTION_IF_NULL(node);
191 MS_EXCEPTION_IF_NULL(total_size);
192 MS_EXCEPTION_IF_NULL(address_list);
193 MS_EXCEPTION_IF_NULL(align_size_list);
194 size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
195 for (size_t i = 0; i < input_num; ++i) {
196 auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, true);
197 auto input_node = input_node_with_index.first;
198 MS_EXCEPTION_IF_NULL(input_node);
199 DeviceAddressPtr address = nullptr;
200 if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
201 address = AnfAlgo::GetMutableOutputAddr(input_node, input_node_with_index.second);
202 } else {
203 address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
204 }
205 MS_EXCEPTION_IF_NULL(address);
206 auto align_size = MemoryManager::GetCommonAlignSize(address->size());
207 *total_size += align_size;
208 (void)address_list->emplace_back(address);
209 (void)align_size_list->emplace_back(align_size);
210 }
211 }
212
AssignCommunicationInputFromMemoryPool(const AnfNodePtr & node) const213 void KernelRuntime::AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const {
214 if (!common::AnfAlgo::IsCommunicationOp(node)) {
215 return;
216 }
217 MS_EXCEPTION_IF_NULL(node);
218 MS_EXCEPTION_IF_NULL(mem_manager_);
219
220 size_t total_size = 0;
221 DeviceAddressPtrList address_list;
222 std::vector<size_t> align_size_list;
223 GetCommunicationInputInfo(node, &total_size, &address_list, &align_size_list);
224 if (align_size_list.empty()) {
225 MS_LOG(WARNING) << "No inputs for " << node->fullname_with_scope();
226 return;
227 }
228
229 if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list,
230 AnfAlgo::GetStreamId(node))) {
231 MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
232 }
233 }
234
GetCommunicationOutputInfo(const AnfNodePtr & node,size_t * total_size,DeviceAddressPtrList * address_list,std::vector<size_t> * align_size_list) const235 void KernelRuntime::GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size,
236 DeviceAddressPtrList *address_list,
237 std::vector<size_t> *align_size_list) const {
238 MS_EXCEPTION_IF_NULL(node);
239 MS_EXCEPTION_IF_NULL(total_size);
240 MS_EXCEPTION_IF_NULL(align_size_list);
241 MS_EXCEPTION_IF_NULL(address_list);
242
243 const auto kernel_mod = AnfAlgo::GetKernelMod(node);
244 MS_EXCEPTION_IF_NULL(kernel_mod);
245 const auto output_size_list = kernel_mod->GetOutputSizeList();
246 for (size_t i = 0; i < output_size_list.size(); ++i) {
247 DeviceAddressPtr address = nullptr;
248 if (AnfAlgo::OutputAddrExist(node, i)) {
249 address = AnfAlgo::GetMutableOutputAddr(node, i);
250 } else {
251 const std::string output_format = AnfAlgo::GetOutputFormat(node, i);
252 const auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
253 const auto tensor_size = AnfAlgo::GetOutputTensorMemSize(node, i);
254 address = CreateDeviceAddress(nullptr, tensor_size, output_format, output_type, {node, i});
255 AnfAlgo::SetOutputAddr(address, i, node.get());
256 }
257 MS_EXCEPTION_IF_NULL(address);
258 auto align_size = MemoryManager::GetCommonAlignSize(address->size());
259 *total_size += align_size;
260 (void)align_size_list->emplace_back(align_size);
261 (void)address_list->emplace_back(address);
262 }
263 }
264
AssignCommunicationOutputFromMemoryPool(const AnfNodePtr & node) const265 void KernelRuntime::AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const {
266 if (!common::AnfAlgo::IsCommunicationOp(node)) {
267 return;
268 }
269 MS_EXCEPTION_IF_NULL(node);
270 MS_EXCEPTION_IF_NULL(mem_manager_);
271
272 size_t total_size = 0;
273 std::vector<size_t> align_size_list;
274 std::vector<DeviceAddressPtr> address_list;
275 GetCommunicationOutputInfo(node, &total_size, &address_list, &align_size_list);
276 if (align_size_list.empty()) {
277 MS_LOG(WARNING) << "No output for " << node->fullname_with_scope();
278 return;
279 }
280
281 if (!mem_manager_->MallocContinuousMemFromMemPool(address_list, total_size, align_size_list,
282 AnfAlgo::GetStreamId(node))) {
283 MS_LOG(EXCEPTION) << "Allocate continuous memory failed, totol_size:" << total_size;
284 }
285 }
286
ResetNodeAddress(const session::KernelGraph & kernel_graph)287 void KernelRuntime::ResetNodeAddress(const session::KernelGraph &kernel_graph) {
288 auto kernels = kernel_graph.execution_order();
289 for (auto &kernel : kernels) {
290 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
291 MS_EXCEPTION_IF_NULL(kernel_mod);
292 size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
293 for (size_t j = 0; j < input_num; ++j) {
294 auto input_index = AnfAlgo::GetInputGraphIdxByKernelIdx(kernel, j);
295 KernelWithIndex kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, input_index, true);
296 auto index = kernel_with_index.second;
297 auto &input_node = kernel_with_index.first;
298 if (NodeOutputDeviceAddressExist(input_node, index)) {
299 continue;
300 }
301 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(input_node, index);
302 if (output_type_id == kTypeUnknown) {
303 MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
304 continue;
305 }
306 auto tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
307 auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(input_node, index),
308 output_type_id, {input_node, index});
309 AnfAlgo::SetOutputAddr(device_address, index, input_node.get());
310 }
311
312 auto output_sizes = kernel_mod->GetOutputSizeList();
313 for (size_t i = 0; i < output_sizes.size(); ++i) {
314 auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
315 auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
316 AnfAlgo::SetOutputAddr(CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type), i,
317 kernel.get());
318 }
319 auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
320 for (size_t i = 0; i < workspace_sizes.size(); ++i) {
321 AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(nullptr, workspace_sizes[i], kOpFormat_DEFAULT, kNumberTypeFloat32),
322 i, kernel.get());
323 }
324 }
325 }
326
RunOpAssignMemory(const std::vector<tensor::TensorPtr> & input_tensors,const session::KernelGraph & graph,bool is_gradient_out,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node)327 void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
328 const session::KernelGraph &graph, bool is_gradient_out,
329 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
330 MS_EXCEPTION_IF_NULL(mem_manager_);
331 mem_manager_->ResetDynamicMemory();
332
333 for (const auto &node : graph.execution_order()) {
334 AssignCommunicationOutputFromMemoryPool(node);
335 AssignCommunicationInputFromMemoryPool(node);
336 }
337
338 RunOpAssignInputMemory(input_tensors, graph);
339 AssignStaticMemoryValueNode(graph);
340 for (const auto &node : graph.execution_order()) {
341 RunOpAssignOutputMemory(node, tensor_to_node, is_gradient_out);
342 RunOpAssignWorkSpaceMemory(node);
343 }
344 UpdateRefNodeOutputMem(graph);
345 }
346
RunOpClearMemory(const session::KernelGraph & graph) const347 void KernelRuntime::RunOpClearMemory(const session::KernelGraph &graph) const {
348 // clear input parameter memory resource
349 for (const auto &input_node : graph.inputs()) {
350 MS_EXCEPTION_IF_NULL(input_node);
351 AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get());
352 }
353 // clear input value node memory resource
354 for (const auto &value_node : graph.graph_value_nodes()) {
355 MS_EXCEPTION_IF_NULL(value_node);
356 AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
357 }
358 for (const auto &cnode : graph.execution_order()) {
359 MS_EXCEPTION_IF_NULL(cnode);
360 // clear output memory resource
361 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
362 for (size_t index = 0; index < output_num; ++index) {
363 AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
364 }
365 // clear workspace memory resource
366 auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
367 MS_EXCEPTION_IF_NULL(kernel_mod);
368 auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
369 for (size_t index = 0; index < workspace_lists.size(); ++index) {
370 AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get());
371 }
372 }
373 }
374
375 #ifdef ENABLE_DEBUGGER
DumpDataEnabled()376 bool KernelRuntime::DumpDataEnabled() {
377 // Returns true if e2e dump is enabled.
378 auto &dump_json_parser = DumpJsonParser::GetInstance();
379 return dump_json_parser.e2e_dump_enabled();
380 }
381
DumpDataEnabledIteration()382 bool KernelRuntime::DumpDataEnabledIteration() {
383 // Returns true if e2e dump is enabled and current iteration must be dumped.
384 auto &dump_json_parser = DumpJsonParser::GetInstance();
385 if (!dump_json_parser.e2e_dump_enabled()) {
386 return false;
387 }
388
389 auto cur_iter = dump_json_parser.cur_dump_iter();
390 if (dump_json_parser.IsDumpIter(cur_iter)) {
391 return true;
392 }
393 return false;
394 }
395 #endif
396
AssignStaticMemory(const session::KernelGraph & graph)397 void KernelRuntime::AssignStaticMemory(const session::KernelGraph &graph) {
398 AssignStaticMemoryInput(graph);
399 AssignStaticMemoryValueNode(graph);
400 AssignStaticMemoryOutput(graph);
401 }
402
RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> & input_tensors,const session::KernelGraph & graph)403 void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors,
404 const session::KernelGraph &graph) {
405 MS_EXCEPTION_IF_NULL(mem_manager_);
406 if (input_tensors.size() != graph.inputs().size()) {
407 MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size()
408 << " should be equal to graph input parameter size " << graph.inputs().size();
409 }
410
411 for (size_t input_index = 0; input_index < graph.inputs().size(); ++input_index) {
412 auto item = graph.inputs()[input_index];
413 MS_EXCEPTION_IF_NULL(item);
414 if (!item->isa<Parameter>()) {
415 continue;
416 }
417 auto output_size = AnfAlgo::GetOutputTensorNum(item);
418 for (size_t index = 0; index < output_size; index++) {
419 auto current_tensor = input_tensors[input_index];
420 MS_EXCEPTION_IF_NULL(current_tensor);
421 auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(current_tensor->device_address());
422 // Device address have already create
423 if (output_address != nullptr && output_address->GetDeviceType() == GetTargetDeviceType()) {
424 if (output_address->GetDevicePtr() == nullptr) {
425 if (!mem_manager_->MallocMemFromMemPool(output_address, output_address->size())) {
426 MS_LOG(EXCEPTION) << "Allocate memory failed, size:" << output_address->size();
427 }
428 }
429
430 AnfAlgo::SetOutputAddr(output_address, index, item.get());
431 continue;
432 }
433 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
434 if (output_type_id == kTypeUnknown) {
435 output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
436 }
437 auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
438 // Device address new create
439 auto device_address =
440 CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
441 MS_EXCEPTION_IF_NULL(device_address);
442 MS_EXCEPTION_IF_NULL(mem_manager_);
443 device_address->set_from_persistent_mem(true);
444 auto ret = mem_manager_->MallocMemFromMemPool(device_address, tensor_size);
445 if (!ret) {
446 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
447 }
448 AnfAlgo::SetOutputAddr(device_address, index, item.get());
449 }
450 }
451 }
452
RunOpAssignOutputMemory(const AnfNodePtr & kernel,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node,bool is_gradient_out)453 void KernelRuntime::RunOpAssignOutputMemory(const AnfNodePtr &kernel,
454 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
455 bool is_gradient_out) {
456 MS_EXCEPTION_IF_NULL(kernel);
457 MS_EXCEPTION_IF_NULL(mem_manager_);
458 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
459 MS_EXCEPTION_IF_NULL(kernel_mod);
460 auto output_sizes = kernel_mod->GetOutputSizeList();
461 if (output_sizes.empty()) {
462 return;
463 }
464
465 // Use device_address Allocated in RunOpMallocPre.
466 for (auto &iter : tensor_to_node) {
467 auto device_address = iter.first->device_address();
468 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(device_address), iter.second.second,
469 iter.second.first.get());
470 }
471
472 for (size_t i = 0; i < output_sizes.size(); ++i) {
473 if (AnfAlgo::OutputAddrExist(kernel, i, false)) {
474 auto address = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
475 MS_EXCEPTION_IF_NULL(address);
476 if (address->GetDevicePtr() == nullptr) {
477 MS_EXCEPTION_IF_NULL(mem_manager_);
478 if (!mem_manager_->MallocMemFromMemPool(address, address->size())) {
479 MS_LOG(EXCEPTION) << "Allocate memory failed, size:" << address->size();
480 }
481 }
482 continue;
483 }
484 if (common::AnfAlgo::GetCNodeName(kernel) == kApplyMomentumOpName ||
485 common::AnfAlgo::GetCNodeName(kernel) == kApplyMomentumDOpName) {
486 auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
487 AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
488 continue;
489 }
490 std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
491 auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
492 auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {kernel, i});
493 MS_EXCEPTION_IF_NULL(device_address);
494 device_address->set_host_shape(trans::GetRuntimePaddingShape(kernel, i));
495 if (is_gradient_out) {
496 device_address->set_from_persistent_mem(true);
497 }
498 auto ret = mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]);
499 if (!ret) {
500 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << output_sizes[i];
501 }
502 AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
503 }
504 }
505
RunOpAssignWorkSpaceMemory(const AnfNodePtr & kernel)506 void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
507 MS_EXCEPTION_IF_NULL(kernel);
508 MS_EXCEPTION_IF_NULL(mem_manager_);
509 if (kernel->isa<CNode>()) {
510 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
511 MS_EXCEPTION_IF_NULL(kernel_mod);
512 auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
513 for (size_t i = 0; i < workspace_lists.size(); ++i) {
514 auto device_address = CreateDeviceAddress(nullptr, workspace_lists[i], "", kTypeUnknown);
515 MS_EXCEPTION_IF_NULL(device_address);
516 auto ret = mem_manager_->MallocMemFromMemPool(device_address, workspace_lists[i]);
517 if (!ret) {
518 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << workspace_lists[i];
519 }
520 AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
521 }
522 }
523 }
524
RunOpAssignOutputNodeMemory(const ValuePtr & pre_output_value,const session::KernelGraph & graph) const525 void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value,
526 const session::KernelGraph &graph) const {
527 if (pre_output_value == nullptr) {
528 return;
529 }
530 std::vector<tensor::BaseTensorPtr> pre_output_tensors;
531 TensorValueToTensor(pre_output_value, &pre_output_tensors);
532 auto output_nodes = graph.outputs();
533 if (pre_output_tensors.size() != output_nodes.size()) {
534 MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
535 << "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
536 }
537 // share output address with pre output tensors
538 for (size_t i = 0; i < output_nodes.size(); ++i) {
539 auto output_node_with_index = common::AnfAlgo::VisitKernel(output_nodes[i], 0);
540 auto output_node = output_node_with_index.first;
541 MS_EXCEPTION_IF_NULL(output_node);
542 if (!output_node->isa<CNode>()) {
543 if (output_node->isa<Parameter>()) {
544 auto param = output_node->cast<ParameterPtr>();
545 if (param != nullptr && !param->has_default()) {
546 MS_LOG(EXCEPTION) << "The output parameter should be real parameter!";
547 }
548 }
549 continue;
550 }
551 auto real_output_cnode = output_node->cast<CNodePtr>();
552 MS_EXCEPTION_IF_NULL(real_output_cnode);
553 MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
554 if (pre_output_tensors[i]->device_address() == nullptr) {
555 MS_LOG(INFO) << "The address of pre output tensor [" << i << "] is a nullptr!";
556 continue;
557 }
558 if (common::AnfAlgo::IsNopNode(real_output_cnode)) {
559 if (real_output_cnode->size() < kMinInputSize) {
560 MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
561 << " should large than one!";
562 }
563 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
564 output_node_with_index.second, real_output_cnode->input(1).get());
565 } else {
566 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
567 output_node_with_index.second, output_node_with_index.first.get());
568 }
569 }
570 }
571
AssignStaticMemoryInput(const session::KernelGraph & graph)572 void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph &graph) {
573 MS_EXCEPTION_IF_NULL(mem_manager_);
574 if (graph.need_inline()) {
575 return;
576 }
577 MS_LOG(INFO) << "AssignStaticMemoryInput start for graph " << graph.graph_id();
578 auto graph_inputs = GetGraphInputs(graph);
579 auto graph_valid_input = graph.valid_inputs();
580 (void)graph_inputs.insert(graph_inputs.end(), graph.child_graph_result().begin(), graph.child_graph_result().end());
581 std::vector<AnfNodePtr> need_alloc_nodes;
582 auto add_need_alloc_nodes = [&need_alloc_nodes, this](const AnfNodePtr &node) {
583 MS_EXCEPTION_IF_NULL(node);
584 if (!node->isa<Parameter>()) {
585 return;
586 }
587 if (NodeOutputDeviceAddressExist(node, 0)) {
588 const auto &address = AnfAlgo::GetOutputAddr(node, 0);
589 MS_EXCEPTION_IF_NULL(address);
590 if (address->GetPtr() != nullptr) {
591 return;
592 }
593 }
594 need_alloc_nodes.push_back(node);
595 };
596
597 for (size_t i = 0; i < graph_inputs.size(); ++i) {
598 auto input_node = graph_inputs[i];
599 MS_EXCEPTION_IF_NULL(input_node);
600 if (i < graph_valid_input.size() && !graph_valid_input[i]) {
601 continue;
602 }
603 if (common::AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
604 auto outs = common::AnfAlgo::GetAllOutput(input_node);
605 for (auto &out : outs) {
606 MS_EXCEPTION_IF_NULL(out);
607 add_need_alloc_nodes(out);
608 }
609 }
610 add_need_alloc_nodes(input_node);
611 }
612 std::map<AnfNodePtr, AnfNodePtr> shadow_backend_node_map;
613 GetShadowBackendNodeMap(graph, &shadow_backend_node_map);
614 for (auto &item : need_alloc_nodes) {
615 MS_EXCEPTION_IF_NULL(item);
616 if (item->has_user_data(kForwardOutput)) {
617 MS_LOG(DEBUG) << "Skip allocate memory for forward output parameter " << item->DebugString();
618 continue;
619 }
620 auto output_size = AnfAlgo::GetOutputTensorNum(item);
621 for (size_t index = 0; index < output_size; index++) {
622 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
623 // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
624 if (output_type_id == kTypeUnknown) {
625 MS_LOG(INFO) << "It is not suggested to use a lonely weight parameter as the output of graph";
626 continue;
627 }
628 // If kernel has flag kFlagEnableZeroCopyInGraph, the internal parameter and the corresponding
629 // cnode cannot use the same device address.
630 DeviceAddressPtr device_address =
631 (graph.has_flag(kFlagEnableZeroCopyInGraph) ? nullptr : GetInternalDeviceAddress(graph, item));
632 GetDeviceAddress(item, shadow_backend_node_map, index, graph, &device_address);
633 AnfAlgo::SetOutputAddr(device_address, index, item.get());
634 }
635 }
636 MS_LOG(INFO) << "AssignStaticMemoryInput end";
637 }
638
GetDeviceAddress(const AnfNodePtr & item,const std::map<AnfNodePtr,AnfNodePtr> shadow_backend_node_map,size_t index,const session::KernelGraph & graph,DeviceAddressPtr * device_address)639 void KernelRuntime::GetDeviceAddress(const AnfNodePtr &item,
640 const std::map<AnfNodePtr, AnfNodePtr> shadow_backend_node_map, size_t index,
641 const session::KernelGraph &graph, DeviceAddressPtr *device_address) {
642 AnfNodePtr shadow_node = nullptr;
643 auto iter = shadow_backend_node_map.find(item);
644 if (iter != shadow_backend_node_map.end()) {
645 shadow_node = iter->second;
646 }
647 if (*device_address == nullptr && shadow_node != nullptr) {
648 auto conj_device_address = AnfAlgo::GetMutableOutputAddr(shadow_node, index);
649 if (conj_device_address != nullptr && conj_device_address->GetDeviceType() == DeviceType::kAscend) {
650 *device_address = conj_device_address;
651 }
652 } else if (*device_address == nullptr) {
653 auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
654 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
655 *device_address =
656 CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id, {item, index});
657 }
658
659 // Set the flag of no user parameter and not malloc memory.
660 if ((*device_address != nullptr) && item->isa<Parameter>()) {
661 auto input_param = item->cast<ParameterPtr>();
662 MS_EXCEPTION_IF_NULL(input_param);
663 // Unused address will not alloc memory, which is easy to cause problems for weight node, so skip weight node.
664 if (!common::AnfAlgo::IsParameterWeight(input_param) && !input_param->IsUsedByRealKernelInGraph(graph.graph_id())) {
665 MS_LOG(INFO) << "Node:" << item->fullname_with_scope() << " debug name:" << item->DebugString()
666 << " is not used in the graph " << graph.graph_id();
667 (*device_address)->UpdateFlag(kDeviceAddressFlagNotUsed);
668 return;
669 }
670 }
671
672 if (*device_address != nullptr && (*device_address)->GetPtr() == nullptr) {
673 auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
674 (*device_address)->set_host_shape(trans::GetRuntimePaddingShape(item, index));
675 MS_LOG(INFO) << "Assign Static Memory for Input node, size:" << tensor_size
676 << " node:" << item->fullname_with_scope() << " debug:" << item->DebugString() << " index: " << index;
677 if (!graph.has_flag(kFlagEnableZeroCopyInGraph)) {
678 auto ret_ptr = mem_manager_->MallocMem(kStaticMem, tensor_size, *device_address, graph.graph_id());
679 if (ret_ptr == nullptr) {
680 MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
681 }
682 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "AllocStaticMemory", item->fullname_with_scope(),
683 graph.ToString());
684 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddCompileTimeMemInfo, "AllocStaticMemory", tensor_size, ret_ptr,
685 device::tracker::MemType::kWeight);
686 }
687 }
688 }
689
AssignStaticMemoryOutput(const session::KernelGraph & graph)690 void KernelRuntime::AssignStaticMemoryOutput(const session::KernelGraph &graph) {
691 if (graph.need_inline()) {
692 return;
693 }
694 MS_LOG(INFO) << "AssignStaticMemoryOutput start for graph " << graph.graph_id();
695 auto nodes = common::AnfAlgo::GetAllOutput(graph.output(), {prim::kPrimTupleGetItem});
696 std::vector<session::KernelWithIndex> non_communication_op;
697 // Assign Communicate Op Memory firstly.
698 for (const auto &node : nodes) {
699 // Assign output address to nop node that the attribute of "skip_nop_op_addr" is false;
700 auto is_skip = !common::AnfAlgo::IsNopNode(node) || common::AnfAlgo::IsNeedSkipNopOpAddr(node);
701 auto kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(node, 0, is_skip);
702 MS_EXCEPTION_IF_NULL(kernel_with_index.first);
703 if (!kernel_with_index.first->isa<CNode>() || !AnfUtils::IsRealKernel(kernel_with_index.first)) {
704 continue;
705 }
706 if (common::AnfAlgo::IsCommunicationOp(kernel_with_index.first)) {
707 AssignCommunicationNodeMem(kStaticMem, kernel_with_index.first);
708 } else {
709 (void)non_communication_op.emplace_back(kernel_with_index);
710 }
711 }
712
713 for (const auto &item_with_index : non_communication_op) {
714 MS_EXCEPTION_IF_NULL(item_with_index.first);
715 MS_LOG(DEBUG) << "AssignNodeOutputMem for " << item_with_index.first->fullname_with_scope();
716 AssignNodeOutputMem(kStaticMem, item_with_index.first, SizeToInt(item_with_index.second));
717 }
718 MS_LOG(INFO) << "AssignStaticMemoryOutput end";
719 }
720
UpdateSingleRefNodeMem(const CNodePtr & kernel,const session::KernelGraph & graph,bool reverse) const721 void KernelRuntime::UpdateSingleRefNodeMem(const CNodePtr &kernel, const session::KernelGraph &graph,
722 bool reverse) const {
723 MS_EXCEPTION_IF_NULL(kernel);
724 auto output_num = AnfAlgo::GetOutputTensorNum(kernel);
725 if (output_num == 0) {
726 MS_LOG(DEBUG) << "This kernel has no output size.";
727 return;
728 }
729 for (size_t i = 0; i < output_num; ++i) {
730 session::AnfWithOutIndex out_pair(kernel, i);
731 if (graph.IsInRefOutputMap(out_pair)) {
732 auto origin_pair = graph.GetRefCorrespondOutput(out_pair);
733 MS_EXCEPTION_IF_NULL(origin_pair.first);
734 auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second);
735 MS_EXCEPTION_IF_NULL(origin_node_output_addr);
736 auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(kernel, i);
737 if (!reverse && origin_node_output_addr->GetPtr() == nullptr) {
738 continue;
739 }
740 if (origin_node_output_addr.get() != cur_node_output_addr.get()) {
741 MS_LOG(DEBUG) << "REF address is not same, ref node output need address update";
742 MS_LOG(DEBUG) << "REF origin op is " << origin_pair.first->DebugString() << ", output index is "
743 << origin_pair.second << ", cur op is " << kernel->DebugString() << ", out index is " << i;
744 if (reverse) {
745 AnfAlgo::SetOutputAddr(cur_node_output_addr, origin_pair.second, origin_pair.first.get());
746 } else {
747 if (!cur_node_output_addr->host_shape().empty()) {
748 origin_node_output_addr->set_host_shape(cur_node_output_addr->host_shape());
749 }
750 AnfAlgo::SetOutputAddr(origin_node_output_addr, i, kernel.get());
751 }
752 }
753 }
754 }
755 }
756
UpdateRefNodeOutputMem(const session::KernelGraph & graph) const757 void KernelRuntime::UpdateRefNodeOutputMem(const session::KernelGraph &graph) const {
758 auto &kernels = graph.execution_order();
759 for (auto &kernel : kernels) {
760 UpdateSingleRefNodeMem(kernel, graph, false);
761 }
762 for (auto it = kernels.rbegin(); it != kernels.rend(); ++it) {
763 auto &kernel = *it;
764 UpdateSingleRefNodeMem(kernel, graph, true);
765 }
766 }
767
AssignCommunicationNodeMem(MemType type,const AnfNodePtr & node)768 void KernelRuntime::AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node) {
769 if (!reuse_communication_address_.empty()) {
770 type = kDynamicMem;
771 }
772 AssignCommunicationNodeInputMem(type, node);
773 AssignCommunicationNodeOutputMem(type, node);
774 AssignWorkSpaceMem(type, node);
775 }
776
AssignCommunicationNodeOutputMem(MemType type,const AnfNodePtr & node)777 void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node) {
778 MS_EXCEPTION_IF_NULL(node);
779 MS_EXCEPTION_IF_NULL(mem_manager_);
780 auto kernel_mod = AnfAlgo::GetKernelMod(node);
781 MS_EXCEPTION_IF_NULL(kernel_mod);
782 auto output_sizes = kernel_mod->GetOutputSizeList();
783 if (output_sizes.empty()) {
784 MS_LOG(INFO) << "This kernel[" << node->DebugString() << "] has no output size.";
785 return;
786 }
787 auto context_ptr = MsContext::GetInstance();
788 MS_EXCEPTION_IF_NULL(context_ptr);
789 size_t total_size = 0;
790 size_t output_index = 0;
791 std::vector<size_t> align_size_list;
792 for (uint64_t mem_size : output_sizes) {
793 if (AnfAlgo::OutputAddrExist(node, output_index++)) {
794 MS_LOG(INFO) << "Communication op " << node->fullname_with_scope() << " has output device address";
795 return;
796 }
797 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
798 mem_size = MemoryManager::GetCommonAlignSize(mem_size);
799 }
800 total_size += mem_size;
801 (void)align_size_list.emplace_back(mem_size);
802 }
803
804 if (align_size_list.empty()) {
805 return;
806 }
807
808 if (type == kSomasReuseDynamicMem) {
809 bool not_reuse = KernelMemNotReuse(node);
810 if (not_reuse) {
811 type = kDynamicMem;
812 MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
813 }
814 }
815
816 uint8_t *output_ptr = nullptr;
817 int64_t valid_reuse_index = -1;
818 auto cnode = node->cast<CNodePtr>();
819 MS_EXCEPTION_IF_NULL(cnode);
820 if (common::AnfAlgo::HasNodeAttr(kAttrReuseCommunication, cnode)) {
821 auto reuse_index = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrReuseCommunication);
822 auto it = reuse_communication_address_.find(reuse_index);
823 if (it != reuse_communication_address_.end()) {
824 valid_reuse_index = reuse_index;
825 output_ptr = it->second.second;
826 }
827 }
828
829 for (size_t j = 0; j < align_size_list.size(); ++j) {
830 std::string output_format = AnfAlgo::GetOutputFormat(node, j);
831 auto output_type = AnfAlgo::GetOutputDeviceDataType(node, j);
832 auto address = CreateDeviceAddress(nullptr, output_sizes[j], output_format, output_type, {node, j});
833 MS_EXCEPTION_IF_NULL(address);
834 if (output_ptr == nullptr) {
835 output_ptr = mem_manager_->MallocOutputMem(node, 0, type, total_size, address, true);
836 MS_EXCEPTION_IF_NULL(output_ptr);
837 if (valid_reuse_index != -1) {
838 auto &it = reuse_communication_address_[valid_reuse_index];
839 it.second = output_ptr;
840 }
841 } else {
842 address->set_ptr(output_ptr);
843 }
844 address->set_host_shape(trans::GetRuntimePaddingShape(node, j));
845 AnfAlgo::SetOutputAddr(address, j, node.get());
846 output_ptr += align_size_list[j];
847 }
848 }
KernelMemNotReuse(const AnfNodePtr & node)849 bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
850 MS_EXCEPTION_IF_NULL(node);
851 return false;
852 }
853
PreAssignCNodeMemory(const AnfNodePtr & anf_node,size_t index) const854 DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) const {
855 MS_EXCEPTION_IF_NULL(anf_node);
856 if (common::AnfAlgo::IsNopNode(anf_node)) {
857 auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(anf_node, index);
858 return PreAssignCNodeMemory(input_node_with_index.first, input_node_with_index.second);
859 }
860
861 auto output_size = AnfAlgo::GetOutputTensorMemSize(anf_node, index);
862 std::string output_format = AnfAlgo::GetOutputFormat(anf_node, index);
863 auto output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, index);
864 auto address = CreateDeviceAddress(nullptr, output_size, output_format, output_type, {anf_node, index});
865 AnfAlgo::SetOutputAddr(address, index, anf_node.get());
866 return address;
867 }
868
AssignCommunicationNodeInputMem(MemType type,const AnfNodePtr & node)869 void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node) {
870 auto context_ptr = MsContext::GetInstance();
871 MS_EXCEPTION_IF_NULL(context_ptr);
872 MS_EXCEPTION_IF_NULL(node);
873 MS_EXCEPTION_IF_NULL(mem_manager_);
874 size_t total_size = 0;
875 std::vector<std::pair<DeviceAddressPtr, size_t>> addr_size;
876 size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
877 for (size_t i = 0; i < input_num; ++i) {
878 auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(node, i, true);
879 auto input_node = input_node_with_index.first;
880 MS_EXCEPTION_IF_NULL(input_node);
881 if (AnfAlgo::OutputAddrExist(input_node, input_node_with_index.second)) {
882 MS_LOG(INFO) << "Communication op " << input_node->fullname_with_scope() << " has input device address";
883 return;
884 }
885 DeviceAddressPtr address = nullptr;
886
887 address = PreAssignCNodeMemory(input_node, input_node_with_index.second);
888
889 MS_EXCEPTION_IF_NULL(address);
890 auto mem_size = MemoryManager::GetCommonAlignSize(address->size());
891 total_size += mem_size;
892 (void)addr_size.emplace_back(address, mem_size);
893 }
894 if (addr_size.empty()) {
895 return;
896 }
897 if (type == kSomasReuseDynamicMem) {
898 bool not_reuse = KernelMemNotReuse(node);
899 if (not_reuse) {
900 type = kDynamicMem;
901 MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s input.";
902 }
903 }
904 auto cnode = node->cast<CNodePtr>();
905 MS_EXCEPTION_IF_NULL(cnode);
906 if (cnode->size() < kMinInputSize) {
907 // communication node's input should contain itself and at least on input
908 MS_LOG(ERROR) << "No inputs for " << cnode->fullname_with_scope();
909 return;
910 }
911
912 int64_t valid_reuse_index = -1;
913 uint8_t *input_ptr = nullptr;
914 if (common::AnfAlgo::HasNodeAttr(kAttrReuseCommunication, cnode)) {
915 auto reuse_index = common::AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrReuseCommunication);
916 auto it = reuse_communication_address_.find(reuse_index);
917 if (it != reuse_communication_address_.end()) {
918 valid_reuse_index = reuse_index;
919 input_ptr = it->second.first;
920 }
921 }
922
923 if (input_ptr == nullptr) {
924 auto first_input_node = cnode->input(1);
925 auto prenode_index = common::AnfAlgo::VisitKernelWithReturnType(first_input_node, 0, true);
926 input_ptr = mem_manager_->MallocOutputMem(prenode_index.first, prenode_index.second, type, total_size,
927 addr_size[0].first, true);
928 if (valid_reuse_index != -1) {
929 auto &it = reuse_communication_address_[valid_reuse_index];
930 it.first = input_ptr;
931 }
932 }
933
934 for (const auto &iter : addr_size) {
935 MS_EXCEPTION_IF_NULL(iter.first);
936 iter.first->set_ptr(input_ptr);
937 input_ptr += iter.second;
938 }
939 }
940
AssignNodeOutputMem(MemType type,const AnfNodePtr & node,int index)941 void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index) {
942 MS_EXCEPTION_IF_NULL(node);
943 MS_EXCEPTION_IF_NULL(mem_manager_);
944
945 if (type == kSomasReuseDynamicMem) {
946 bool not_reuse = KernelMemNotReuse(node);
947 if (not_reuse) {
948 type = kDynamicMem;
949 MS_LOG(INFO) << "Disable Memory Reuse for " << node->fullname_with_scope() << "'s output.";
950 }
951 }
952
953 auto kernel_mod = AnfAlgo::GetKernelMod(node);
954 MS_EXCEPTION_IF_NULL(kernel_mod);
955 auto output_sizes = kernel_mod->GetOutputSizeList();
956 if (output_sizes.empty()) {
957 return;
958 }
959 for (size_t i = 0; i < output_sizes.size(); ++i) {
960 if ((kGetAllOuts != index) && (SizeToInt(i) != index)) {
961 continue;
962 }
963 if (NodeOutputDeviceAddressExist(node, i)) {
964 MS_LOG(DEBUG) << "Already malloc index:" << i;
965 continue;
966 }
967 MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope() << " output memory size:" << output_sizes[i];
968 if (type == kStaticMem) {
969 MS_LOG(INFO) << "Assign Static Memory for Output node, size:" << output_sizes[i]
970 << " node:" << node->fullname_with_scope();
971 }
972 std::string output_format = AnfAlgo::GetOutputFormat(node, i);
973 auto output_type = AnfAlgo::GetOutputDeviceDataType(node, i);
974 auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type, {node, i});
975 MS_EXCEPTION_IF_NULL(device_address);
976
977 // In subgraph sink mode, graph output should not allocate memory.
978 if (IsNeedAllocMem(node, i)) {
979 uint8_t *ptr = mem_manager_->MallocOutputMem(node, i, type, output_sizes[i], device_address, false);
980 if (ptr == nullptr && type == kSomasReuseDynamicMem) {
981 MS_LOG(INFO) << "node: " << node->fullname_with_scope() << " could be a RefNode, please check it"
982 << " output index: " << i << " memory type: " << type;
983 } else {
984 MS_EXCEPTION_IF_NULL(ptr);
985 }
986 } else {
987 MS_LOG(DEBUG) << "Skip mem alloc for device address:" << device_address << " node:" << node->DebugString();
988 }
989 device_address->set_host_shape(trans::GetRuntimePaddingShape(node, i));
990 AnfAlgo::SetOutputAddr(device_address, i, node.get());
991 }
992 }
993
AssignExtraStaticMem(const TensorPtr & tensor,const AnfNodePtr & node,size_t index)994 DeviceAddressPtr KernelRuntime::AssignExtraStaticMem(const TensorPtr &tensor, const AnfNodePtr &node, size_t index) {
995 MS_EXCEPTION_IF_NULL(node);
996 MS_EXCEPTION_IF_NULL(mem_manager_);
997 auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
998 MS_LOG(DEBUG) << "Assign Node:" << node->fullname_with_scope()
999 << "Assign Static Memory for Output node, size:" << tensor_address->size();
1000 auto device_address = CreateDeviceAddress(nullptr, tensor_address->size(), tensor_address->format(),
1001 tensor_address->type_id(), {node, index});
1002 MS_EXCEPTION_IF_NULL(device_address);
1003 uint8_t *ptr = mem_manager_->MallocOutputMem(node, index, kStaticMem, tensor_address->size(), device_address, false);
1004 MS_EXCEPTION_IF_NULL(ptr);
1005 return device_address;
1006 }
1007
AssignValueNodeTensor(const ValueNodePtr & value_node,const ValuePtr & node_value,size_t output_idx)1008 void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value,
1009 size_t output_idx) {
1010 MS_EXCEPTION_IF_NULL(value_node);
1011 MS_EXCEPTION_IF_NULL(node_value);
1012 MS_EXCEPTION_IF_NULL(mem_manager_);
1013 auto ms_context = MsContext::GetInstance();
1014 MS_EXCEPTION_IF_NULL(ms_context);
1015 std::vector<tensor::BaseTensorPtr> tensors;
1016 TensorValueToTensor(node_value, &tensors);
1017 // Graph id should be passed to record static memory if profiling is enabled.
1018 auto kernel_info = dynamic_cast<device::KernelInfo *>(value_node->kernel_info());
1019 MS_EXCEPTION_IF_NULL(kernel_info);
1020 uint32_t graph_id = kernel_info->graph_id();
1021 for (const auto &tensor : tensors) {
1022 if (tensor == nullptr) {
1023 MS_LOG(WARNING) << "Tensor is null";
1024 return;
1025 }
1026 auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1027 if (output_address != nullptr && output_address->GetDeviceType() == GetTargetDeviceType()) {
1028 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
1029 value_node.get());
1030 continue;
1031 }
1032
1033 DeviceAddressPtr address = nullptr;
1034 size_t node_size = 0;
1035 if (node_value->isa<Scalar>()) {
1036 auto scalar_value = node_value->cast<ScalarPtr>();
1037 MS_EXCEPTION_IF_NULL(scalar_value);
1038 TypePtr data_type = scalar_value->type();
1039 MS_EXCEPTION_IF_NULL(data_type);
1040 TypeId type_id = data_type->type_id();
1041 node_size = GetTypeByte(TypeIdToType(type_id));
1042 address = CreateDeviceAddress(nullptr, node_size, kOpFormat_DEFAULT, type_id, {value_node, output_idx});
1043 } else {
1044 node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
1045 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
1046 if (output_type_id == kTypeUnknown) {
1047 output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
1048 }
1049 auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
1050 address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id, {value_node, output_idx});
1051 }
1052 MS_EXCEPTION_IF_NULL(address);
1053 address->set_host_shape(trans::GetRuntimePaddingShape(value_node, output_idx));
1054 address->set_from_persistent_mem(true);
1055 if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
1056 !mem_manager_->MallocMemFromMemPool(address, node_size)) {
1057 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << node_size;
1058 } else {
1059 MS_LOG(INFO) << "Assign Static Memory for Value node, size:" << node_size
1060 << " node:" << value_node->fullname_with_scope();
1061 if (mem_manager_->MallocMem(kStaticMem, node_size, address, graph_id) == nullptr) {
1062 MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << node_size;
1063 }
1064 }
1065 AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
1066 size_t tensor_size = LongToSize(tensor->data().nbytes());
1067 std::string format = "DefaultFormat";
1068 if (tensor->isa<tensor::Tensor>()) {
1069 format = std::dynamic_pointer_cast<tensor::Tensor>(tensor)->device_info().host_format_;
1070 }
1071 if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
1072 format, tensor->data_ptr())) {
1073 MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
1074 << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
1075 << "node dtype is "
1076 << common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
1077 }
1078 }
1079 }
1080
AssignStaticMemoryValueNode(const session::KernelGraph & graph)1081 void KernelRuntime::AssignStaticMemoryValueNode(const session::KernelGraph &graph) {
1082 MS_EXCEPTION_IF_NULL(mem_manager_);
1083 MS_LOG(DEBUG) << "AssignStaticMemoryValueNode start for graph " << graph.graph_id();
1084 auto ms_context = MsContext::GetInstance();
1085 MS_EXCEPTION_IF_NULL(ms_context);
1086 // order the value nodes
1087 std::map<std::string, ValueNodePtr> value_nodes_map;
1088 for (auto &node : graph.graph_value_nodes()) {
1089 MS_EXCEPTION_IF_NULL(node);
1090 value_nodes_map[node->fullname_with_scope()] = node;
1091 }
1092
1093 for (auto &item : value_nodes_map) {
1094 auto value_node = item.second;
1095 MS_EXCEPTION_IF_NULL(value_node);
1096 if (NodeOutputDeviceAddressExist(value_node, 0)) {
1097 MS_LOG(DEBUG) << "value_node[" << value_node->DebugString() << "] address already exist";
1098 auto device_address = AnfAlgo::GetMutableOutputAddr(value_node, 0);
1099 if (device_address->GetDevicePtr() == nullptr) {
1100 if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1101 if (!mem_manager_->MallocMemFromMemPool(device_address, device_address->GetSize())) {
1102 MS_LOG(EXCEPTION) << "MallocMemFromMemPool failed";
1103 }
1104 } else {
1105 if (mem_manager_->MallocMem(kStaticMem, device_address->GetSize(), device_address, graph.graph_id())) {
1106 MS_LOG(EXCEPTION) << "MallocStaticMem failed";
1107 }
1108 }
1109 }
1110 continue;
1111 }
1112 auto &node_value = value_node->value();
1113 MS_EXCEPTION_IF_NULL(node_value);
1114 MS_LOG(DEBUG) << "Malloc memory for " << value_node->fullname_with_scope();
1115 if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>() || node_value->isa<Scalar>()) {
1116 AssignValueNodeTensor(value_node, node_value, 0);
1117 } else if (node_value->isa<StringImm>()) {
1118 const bool use_mem_from_memory_pool = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) ||
1119 ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
1120 auto address = CreateDeviceAddressForStringValue(node_value, use_mem_from_memory_pool, graph.graph_id());
1121 MS_EXCEPTION_IF_NULL(address);
1122 address->set_from_persistent_mem(true);
1123 AnfAlgo::SetOutputAddr(address, 0, value_node.get());
1124 }
1125 }
1126 MS_LOG(DEBUG) << "AssignStaticMemoryValueNode end";
1127 }
1128
CreateDeviceAddressForStringValue(const ValuePtr & value,bool use_mem_pool,uint32_t graph_id)1129 DeviceAddressPtr KernelRuntime::CreateDeviceAddressForStringValue(const ValuePtr &value, bool use_mem_pool,
1130 uint32_t graph_id) {
1131 auto value_string = GetValue<std::string>(value);
1132 size_t tensor_size = value_string.size();
1133 DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
1134 MS_EXCEPTION_IF_NULL(address);
1135 address->set_from_persistent_mem(true);
1136 auto ms_context = MsContext::GetInstance();
1137 MS_EXCEPTION_IF_NULL(ms_context);
1138 if (use_mem_pool && !mem_manager_->MallocMemFromMemPool(address, tensor_size)) {
1139 MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size:" << tensor_size;
1140 } else {
1141 MS_LOG(INFO) << "Assign Static Memory for string Value node, size:" << tensor_size;
1142 if (mem_manager_->MallocMem(kStaticMem, tensor_size, address, graph_id) == nullptr) {
1143 MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size;
1144 }
1145 }
1146 ShapeVector shape = {1, SizeToLong(tensor_size)};
1147 if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value_string.data(), "DefaultFormat")) {
1148 MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!";
1149 }
1150 return address;
1151 }
1152
MemSchedulerPreCompute(const AnfNodePtr & kernel,const std::shared_ptr<MemScheduler> & mem_scheduler,void * stream,bool mock,KernelLaunchInfo * kernel_launch_info)1153 bool KernelRuntime::MemSchedulerPreCompute(const AnfNodePtr &kernel, const std::shared_ptr<MemScheduler> &mem_scheduler,
1154 void *stream, bool mock, KernelLaunchInfo *kernel_launch_info) {
1155 MS_EXCEPTION_IF_NULL(kernel);
1156 MS_EXCEPTION_IF_NULL(mem_scheduler);
1157 MS_EXCEPTION_IF_NULL(stream);
1158 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1159 MS_EXCEPTION_IF_NULL(kernel_mod);
1160 if (!mock && common::AnfAlgo::IsCommunicationOp(kernel) && !SyncStream()) {
1161 MS_LOG(ERROR) << "SyncStream failed";
1162 return false;
1163 }
1164 bool ret = mem_scheduler->PreCompute(stream);
1165 if (!ret) {
1166 return ret;
1167 }
1168 AssignKernelAddress(mem_scheduler, kernel, kernel_launch_info);
1169 auto cnode = kernel->cast<CNodePtr>();
1170 MS_EXCEPTION_IF_NULL(cnode);
1171 if (mock && common::AnfAlgo::HasNodeAttr(kAttrOffload, cnode) &&
1172 common::AnfAlgo::GetNodeAttr<bool>(cnode, kAttrOffload)) {
1173 for (size_t i = 0; i < kernel_mod->GetOutputSizeList().size(); ++i) {
1174 auto device_address = AnfAlgo::GetOutputAddr(kernel, i, true);
1175 mem_scheduler->SetOffload(device_address);
1176 }
1177 }
1178 return true;
1179 }
1180
MemSchedulerPostCompute(const session::KernelGraph & graph,const AnfNodePtr & kernel,const std::shared_ptr<MemScheduler> & mem_scheduler,void * stream,bool mock)1181 bool KernelRuntime::MemSchedulerPostCompute(const session::KernelGraph &graph, const AnfNodePtr &kernel,
1182 const std::shared_ptr<MemScheduler> &mem_scheduler, void *stream,
1183 bool mock) {
1184 MS_EXCEPTION_IF_NULL(kernel);
1185 MS_EXCEPTION_IF_NULL(mem_scheduler);
1186 MS_EXCEPTION_IF_NULL(stream);
1187 if (!mock) {
1188 SyncNodeOutputTensors(mem_scheduler, graph, kernel);
1189 }
1190 bool ret = mem_scheduler->PostCompute(stream);
1191 if (!ret) {
1192 return ret;
1193 }
1194 if (!mock && common::AnfAlgo::IsCommunicationOp(kernel) && !SyncStream()) {
1195 MS_LOG(ERROR) << "SyncStream failed";
1196 return false;
1197 }
1198 return true;
1199 }
1200
AssignDynamicMemory(const session::KernelGraph & graph)1201 void KernelRuntime::AssignDynamicMemory(const session::KernelGraph &graph) {
1202 MS_EXCEPTION_IF_NULL(mem_manager_);
1203 auto context_ptr = MsContext::GetInstance();
1204 MS_EXCEPTION_IF_NULL(context_ptr);
1205 bool is_enable_mem_reuse = EnvConfigParser::GetInstance().GetSysMemreuse();
1206 auto mem_type = kDynamicMem;
1207 auto &dump_json_parser = DumpJsonParser::GetInstance();
1208 if (dump_json_parser.e2e_dump_enabled() && dump_json_parser.dump_mode() == 0) {
1209 mindspore::EnvConfigParser::GetInstance().SetSysMemreuse(false);
1210 is_enable_mem_reuse = false;
1211 MS_LOG(INFO) << "Disable Memory Reuse when e2e dump is enable and dump mode is set to dump all kernels";
1212 }
1213
1214 if (is_enable_mem_reuse) {
1215 MS_LOG(INFO) << "Memory Reuse is enable...";
1216 mem_manager_->MallocSomasDynamicMem(graph);
1217 mem_type = kSomasReuseDynamicMem;
1218 } else {
1219 MS_LOG(INFO) << "Memory Reuse is disable...";
1220 }
1221 auto &execution_nodes = graph.execution_order();
1222 std::vector<CNodePtr> compute_nodes;
1223 // communication nodes first
1224 for (auto &node : execution_nodes) {
1225 if (common::AnfAlgo::IsCommunicationOp(node)) {
1226 // skip if the memory is already allocated
1227 AssignCommunicationNodeMem(mem_type, node);
1228 } else {
1229 (void)compute_nodes.emplace_back(node);
1230 }
1231 }
1232
1233 // then compute nodes
1234 for (auto &node : compute_nodes) {
1235 AssignNodeOutputMem(mem_type, node, kGetAllOuts);
1236 AssignWorkSpaceMem(mem_type, node);
1237 }
1238 }
1239
AssignWorkSpaceMem(MemType type,const AnfNodePtr & node)1240 void KernelRuntime::AssignWorkSpaceMem(MemType type, const AnfNodePtr &node) {
1241 MS_EXCEPTION_IF_NULL(node);
1242 MS_EXCEPTION_IF_NULL(mem_manager_);
1243 auto kernel_mod = AnfAlgo::GetKernelMod(node);
1244 MS_EXCEPTION_IF_NULL(kernel_mod);
1245 size_t index = 0;
1246 for (auto &size : kernel_mod->GetWorkspaceSizeList()) {
1247 if (AnfAlgo::WorkspaceAddrExist(node, index)) {
1248 MS_LOG(INFO) << "Op " << node->fullname_with_scope() << " has workspace device address";
1249 return;
1250 }
1251 auto ptr = mem_manager_->MallocWorkSpaceMem(node, index, type, size);
1252 AnfAlgo::SetWorkspaceAddr(CreateDeviceAddress(ptr, size, "", kTypeUnknown), index, node.get());
1253 index++;
1254 }
1255 }
1256
GenLaunchArgs(const mindspore::kernel::KernelMod & kernel_mod,const mindspore::AnfNodePtr & kernel,KernelLaunchInfo * kernel_launch_info)1257 void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const mindspore::AnfNodePtr &kernel,
1258 KernelLaunchInfo *kernel_launch_info) {
1259 MS_EXCEPTION_IF_NULL(kernel);
1260 MS_EXCEPTION_IF_NULL(kernel_launch_info);
1261 auto cnode = kernel->cast<CNodePtr>();
1262 MS_EXCEPTION_IF_NULL(cnode);
1263 auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
1264 if (cnode_name == kMemSetOpName) {
1265 return GenKernelTensorLaunchArgs(cnode, &(kernel_launch_info->inputs_));
1266 }
1267 auto ms_context = MsContext::GetInstance();
1268 MS_EXCEPTION_IF_NULL(ms_context);
1269 auto skip_nop_node = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode);
1270 size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
1271 for (size_t i = 0; i < input_num; ++i) {
1272 auto real_input = AnfAlgo::GetInputGraphIdxByKernelIdx(kernel, i);
1273 auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input, skip_nop_node);
1274 MS_EXCEPTION_IF_NULL(device_address);
1275 const auto &input = device_address->kernel_tensor();
1276 MS_EXCEPTION_IF_NULL(input);
1277 (void)kernel_launch_info->inputs_.emplace_back(input.get());
1278 }
1279
1280 for (size_t i = 0; i < kernel_mod.GetOutputSizeList().size(); ++i) {
1281 auto device_address = AnfAlgo::GetOutputAddr(kernel, i, skip_nop_node);
1282 const auto &output = device_address->kernel_tensor();
1283 MS_EXCEPTION_IF_NULL(output);
1284 (void)kernel_launch_info->outputs_.emplace_back(output.get());
1285 }
1286
1287 for (size_t i = 0; i < kernel_mod.GetWorkspaceSizeList().size(); ++i) {
1288 auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
1289 const auto &workspace = device_address->kernel_tensor();
1290 MS_EXCEPTION_IF_NULL(workspace);
1291 (void)kernel_launch_info->workspaces_.emplace_back(workspace.get());
1292 }
1293 }
1294
UseMemScheduler()1295 bool KernelRuntime::UseMemScheduler() {
1296 auto context_ptr = MsContext::GetInstance();
1297 MS_EXCEPTION_IF_NULL(context_ptr);
1298 if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_OFFLOAD)) {
1299 return false;
1300 }
1301 // Not use MemScheduler when running single op
1302 return (!context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER) &&
1303 (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode));
1304 }
1305
GenKernelEvents(const session::KernelGraph & graph)1306 void KernelRuntime::GenKernelEvents(const session::KernelGraph &graph) {
1307 auto &kernels = graph.execution_order();
1308 if (kernels.empty() || graph_kernel_events_map_.find(graph.graph_id()) != graph_kernel_events_map_.end()) {
1309 return;
1310 }
1311 auto kernel_events = std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>,
1312 std::map<AnfNodePtr, std::vector<std::function<void()>>>>();
1313 auto &kernel_pre_run_events = kernel_events.first;
1314 auto &kernel_post_run_events = kernel_events.second;
1315 for (size_t i = 0; i < kernels.size(); ++i) {
1316 auto &kernel = kernels[i];
1317 if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
1318 continue;
1319 }
1320 auto pre_event = CreateDeviceEvent();
1321 auto post_event = CreateDeviceEvent();
1322 MS_EXCEPTION_IF_NULL(pre_event);
1323 MS_EXCEPTION_IF_NULL(post_event);
1324 pre_event->set_wait_stream(communication_stream_);
1325 pre_event->set_record_stream(stream_);
1326 post_event->set_wait_stream(stream_);
1327 post_event->set_record_stream(communication_stream_);
1328 (void)kernel_pre_run_events[kernel].emplace_back([pre_event]() {
1329 pre_event->RecordEvent();
1330 pre_event->WaitEvent();
1331 });
1332 (void)kernel_post_run_events[kernel].emplace_back([post_event]() { post_event->RecordEvent(); });
1333 bool found_nearest_child = false;
1334 for (size_t j = i + 1; j < kernels.size(); ++j) {
1335 auto &child = kernels[j];
1336 MS_EXCEPTION_IF_NULL(child);
1337 if (common::AnfAlgo::IsCommunicationOp(child)) {
1338 continue;
1339 }
1340 auto input_size = child->size() - 1;
1341 for (size_t k = 0; k < input_size; ++k) {
1342 auto kernel_index =
1343 common::AnfAlgo::VisitKernelWithReturnType(common::AnfAlgo::GetInputNode(child, k), 0, true);
1344 if (kernel_index.first == kernel) {
1345 found_nearest_child = true;
1346 break;
1347 }
1348 }
1349 if (found_nearest_child) {
1350 (void)kernel_pre_run_events[child].emplace_back([post_event]() { post_event->WaitEvent(); });
1351 break;
1352 }
1353 }
1354 if (!found_nearest_child) {
1355 (void)kernel_post_run_events[kernel].emplace_back([post_event]() { post_event->WaitEvent(); });
1356 }
1357 }
1358 graph_kernel_events_map_[graph.graph_id()] = std::move(kernel_events);
1359 }
1360
GenKernelTensorLaunchArgs(const CNodePtr & cnode,std::vector<kernel::KernelTensor * > * kernel_inputs,const std::shared_ptr<MemScheduler> & mem_scheduler)1361 void KernelRuntime::GenKernelTensorLaunchArgs(const CNodePtr &cnode, std::vector<kernel::KernelTensor *> *kernel_inputs,
1362 const std::shared_ptr<MemScheduler> &mem_scheduler) {
1363 MS_EXCEPTION_IF_NULL(cnode);
1364 MS_EXCEPTION_IF_NULL(kernel_inputs);
1365 if (cnode->size() != kAtomicCleanInputSize) {
1366 MS_LOG(EXCEPTION) << "Atomic Addr clean Node Input nodes not equal 2.";
1367 }
1368 MS_EXCEPTION_IF_NULL(cnode->inputs()[1]);
1369 auto pre_node = (cnode->inputs()[1])->cast<CNodePtr>();
1370 // set clean output address
1371 if (common::AnfAlgo::HasNodeAttr(kAttrAtomicOutputIndexs, pre_node)) {
1372 #if defined(__APPLE__)
1373 auto clean_output_indexes = common::AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicOutputIndexs);
1374 #else
1375 auto clean_output_indexes = common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicOutputIndexs);
1376 #endif
1377 for (auto index : clean_output_indexes) {
1378 auto device_address = AnfAlgo::GetOutputAddr(pre_node, index);
1379 MS_EXCEPTION_IF_NULL(device_address);
1380 const auto &input = device_address->kernel_tensor();
1381 MS_EXCEPTION_IF_NULL(input);
1382 if (mem_scheduler != nullptr) {
1383 GetOrMallocAddress(mem_scheduler, device_address, input);
1384 }
1385 auto real_output_size = AnfAlgo::GetOutputTensorMemSize(pre_node, index);
1386 if (device_address->GetSize() != real_output_size) {
1387 MS_LOG(DEBUG) << "The node:" << pre_node->fullname_with_scope() << " real output size is " << real_output_size;
1388 input->set_size(real_output_size);
1389 }
1390 (void)kernel_inputs->emplace_back(input.get());
1391 }
1392 MS_LOG(DEBUG) << "AtomicAddClean clean output size:" << clean_output_indexes.size();
1393 }
1394 // set clean workspace address
1395 if (common::AnfAlgo::HasNodeAttr(kAttrAtomicWorkspaceIndexs, pre_node)) {
1396 #if defined(__APPLE__)
1397 auto clean_workspaces_indexes =
1398 common::AnfAlgo::GetNodeAttr<std::vector<int>>(pre_node, kAttrAtomicWorkspaceIndexs);
1399 #else
1400 auto clean_workspaces_indexes =
1401 common::AnfAlgo::GetNodeAttr<std::vector<size_t>>(pre_node, kAttrAtomicWorkspaceIndexs);
1402 #endif
1403 for (const auto &index : clean_workspaces_indexes) {
1404 auto device_address = AnfAlgo::GetWorkspaceAddr(pre_node, index);
1405 const auto &workspace = device_address->kernel_tensor();
1406 MS_EXCEPTION_IF_NULL(workspace);
1407 if (mem_scheduler != nullptr) {
1408 GetOrMallocAddress(mem_scheduler, device_address, workspace);
1409 }
1410 (void)kernel_inputs->emplace_back(workspace.get());
1411 }
1412 }
1413 }
1414
LaunchKernelEvent(const std::map<AnfNodePtr,std::vector<std::function<void ()>>> & kernel_events,const AnfNodePtr & node) const1415 void KernelRuntime::LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &kernel_events,
1416 const AnfNodePtr &node) const {
1417 if (kernel_events.find(node) == kernel_events.end()) {
1418 return;
1419 }
1420
1421 for (auto &event : kernel_events.at(node)) {
1422 event();
1423 }
1424 }
1425
LaunchKernelWithPynativeProfiling(kernel::KernelMod * kernel_mod,const std::string & op_name,const KernelLaunchInfo & kernel_launch_info,void * stream)1426 bool KernelRuntime::LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
1427 const KernelLaunchInfo &kernel_launch_info, void *stream) {
1428 MS_EXCEPTION_IF_NULL(kernel_mod);
1429 MS_EXCEPTION_IF_NULL(stream);
1430 float cost_time = 0;
1431 auto start = CreateDeviceTimeEvent();
1432 auto end = CreateDeviceTimeEvent();
1433 MS_EXCEPTION_IF_NULL(start);
1434 MS_EXCEPTION_IF_NULL(end);
1435 start->set_record_stream(stream);
1436 end->set_record_stream(stream);
1437 start->RecordEvent();
1438 bool ret = kernel_mod->Launch(kernel_launch_info.inputs_, kernel_launch_info.workspaces_, kernel_launch_info.outputs_,
1439 nullptr);
1440 if (!ret) {
1441 MS_LOG(EXCEPTION) << "Launch kernel failed, kernel name is : " << op_name;
1442 }
1443 end->RecordEvent();
1444 start->SyncEvent();
1445 end->SyncEvent();
1446 start->ElapsedTime(&cost_time, end.get());
1447 MS_LOG(DEBUG) << "Launch kernel:" << op_name << " cost:" << cost_time / kBasicTimeTransferUnit;
1448 return ret;
1449 }
1450
DebugStreamSync(const CNodePtr & kernel)1451 void KernelRuntime::DebugStreamSync(const CNodePtr &kernel) {
1452 auto ms_context = MsContext::GetInstance();
1453 MS_EXCEPTION_IF_NULL(ms_context);
1454 auto enable_sync_run = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
1455 if (enable_sync_run) {
1456 if (!SyncStream()) {
1457 MS_LOG(EXCEPTION) << "Op " << kernel->fullname_with_scope() << " run failed!";
1458 }
1459 }
1460 }
1461
GetOrMallocAddress(const std::shared_ptr<MemScheduler> & mem_scheduler,const DeviceAddress * device_address,const kernel::KernelTensorPtr & kernel_tensor)1462 void KernelRuntime::GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
1463 const DeviceAddress *device_address,
1464 const kernel::KernelTensorPtr &kernel_tensor) {
1465 MS_EXCEPTION_IF_NULL(device_address);
1466 if (device_address->GetDevicePtr() != nullptr) {
1467 kernel_tensor->set_device_ptr(device_address->GetDevicePtr());
1468 } else {
1469 kernel_tensor->set_device_ptr(mem_scheduler->GetOrMalloc(device_address, device_address->GetSize()));
1470 }
1471 }
1472
AssignKernelAddress(const std::shared_ptr<MemScheduler> & mem_scheduler,const AnfNodePtr & kernel,KernelLaunchInfo * kernel_launch_info) const1473 void KernelRuntime::AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
1474 KernelLaunchInfo *kernel_launch_info) const {
1475 MS_EXCEPTION_IF_NULL(kernel);
1476 MS_EXCEPTION_IF_NULL(kernel_launch_info);
1477 auto cnode = kernel->cast<CNodePtr>();
1478 MS_EXCEPTION_IF_NULL(cnode);
1479 auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
1480 if (cnode_name == kMemSetOpName) {
1481 return GenKernelTensorLaunchArgs(cnode, &(kernel_launch_info->inputs_), mem_scheduler);
1482 }
1483 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1484 MS_EXCEPTION_IF_NULL(kernel_mod);
1485 size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
1486 const auto update_parameter = common::AnfAlgo::IsUpdateParameterKernel(cnode);
1487 for (size_t j = 0; j < input_num; ++j) {
1488 auto real_input = AnfAlgo::GetInputGraphIdxByKernelIdx(kernel, j);
1489 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, real_input, true);
1490 auto index = kernel_with_index.second;
1491 auto &input_node = kernel_with_index.first;
1492 auto device_address = AnfAlgo::GetOutputAddr(input_node, index, true);
1493 MS_EXCEPTION_IF_NULL(device_address);
1494 const auto &input = device_address->kernel_tensor();
1495 GetOrMallocAddress(mem_scheduler, device_address, input);
1496 (void)kernel_launch_info->inputs_.emplace_back(input.get());
1497 if (update_parameter && input_node->isa<Parameter>()) {
1498 auto param = input_node->cast<ParameterPtr>();
1499 auto abstract = param->abstract();
1500 MS_EXCEPTION_IF_NULL(abstract);
1501 if (abstract->isa<abstract::AbstractRefTensor>()) {
1502 mem_scheduler->UpdateHighPriorityMem(device_address);
1503 }
1504 }
1505 }
1506
1507 for (size_t j = 0; j < kernel_mod->GetOutputSizeList().size(); ++j) {
1508 auto device_address = AnfAlgo::GetOutputAddr(kernel, j, true);
1509 const auto &output = device_address->kernel_tensor();
1510 GetOrMallocAddress(mem_scheduler, device_address, output);
1511 (void)kernel_launch_info->outputs_.emplace_back(output.get());
1512 }
1513
1514 for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
1515 auto device_address = AnfAlgo::GetWorkspaceAddr(kernel, i);
1516 const auto &workspace = device_address->kernel_tensor();
1517 GetOrMallocAddress(mem_scheduler, device_address, workspace);
1518 (void)kernel_launch_info->workspaces_.emplace_back(workspace.get());
1519 }
1520 }
1521
SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> & mem_scheduler,const session::KernelGraph & graph,const AnfNodePtr & kernel)1522 void KernelRuntime::SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
1523 const session::KernelGraph &graph, const AnfNodePtr &kernel) {
1524 MS_EXCEPTION_IF_NULL(mem_scheduler);
1525 MS_EXCEPTION_IF_NULL(kernel);
1526 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1527 MS_EXCEPTION_IF_NULL(kernel_mod);
1528 auto inputs = AnfAlgo::GetOrCreateAllInputKernelTensors(kernel);
1529 for (size_t input_idx = 0; input_idx < inputs.size(); ++input_idx) {
1530 const auto input_node_index = common::AnfAlgo::GetPrevNodeOutput(kernel, input_idx, true);
1531 if (input_node_index.first != nullptr && input_node_index.first->isa<Parameter>()) {
1532 SyncNodeOutputTensor(mem_scheduler, input_node_index, graph);
1533 }
1534 }
1535 for (size_t output_idx = 0; output_idx < kernel_mod->GetOutputSizeList().size(); ++output_idx) {
1536 SyncNodeOutputTensor(mem_scheduler, std::make_pair(kernel, output_idx), graph);
1537 }
1538 }
1539
SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> & mem_scheduler,const KernelWithIndex & node_output_index,const session::KernelGraph & graph)1540 void KernelRuntime::SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler,
1541 const KernelWithIndex &node_output_index, const session::KernelGraph &graph) {
1542 MS_EXCEPTION_IF_NULL(mem_scheduler);
1543 if (node_output_index.first == nullptr) {
1544 return;
1545 }
1546 auto device_address = AnfAlgo::GetMutableOutputAddr(node_output_index, true);
1547 auto tensor = graph.GetNodeOutputTensor(node_output_index);
1548 if (tensor == nullptr) {
1549 return;
1550 }
1551 if (device_address == nullptr) {
1552 tensor->data_sync(false);
1553 tensor->set_device_address(nullptr);
1554 tensor->set_sync_status(kNeedSyncHostToDevice);
1555 return;
1556 }
1557 if (!SyncStream()) {
1558 MS_LOG(EXCEPTION) << "SyncStream failed";
1559 }
1560 auto origin_ptr = device_address->GetDevicePtr();
1561 if (device_address->GetDevicePtr() == nullptr) {
1562 device_address->SetDevicePtr(mem_scheduler->GetOrMalloc(device_address.get(), device_address->GetSize()));
1563 }
1564 tensor->set_device_address(device_address);
1565 tensor->data_sync(false);
1566 tensor->set_device_address(nullptr);
1567 device_address->SetDevicePtr(origin_ptr);
1568 tensor->set_sync_status(kNeedSyncHostToDevice);
1569 }
1570
InitGraphInputTensors(const std::shared_ptr<MemScheduler> & mem_scheduler,const session::KernelGraph & graph) const1571 void KernelRuntime::InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
1572 const session::KernelGraph &graph) const {
1573 MS_EXCEPTION_IF_NULL(mem_scheduler);
1574 auto &input_nodes = graph.input_nodes();
1575 auto &input_tensors = graph.input_tensors();
1576 if (input_tensors.size() != input_nodes.size()) {
1577 MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
1578 }
1579 mem_scheduler->ClearMemNeedInit();
1580 for (size_t i = 0; i < input_tensors.size(); ++i) {
1581 auto input_node = input_nodes[i];
1582 if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
1583 continue;
1584 }
1585 auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
1586 auto tensor = input_tensors[i];
1587 MS_EXCEPTION_IF_NULL(tensor);
1588 auto tensor_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1589 const auto tensor_size = LongToSize(tensor->data().nbytes());
1590 bool need_sync = false;
1591 if (tensor->NeedSyncHostToDevice()) {
1592 need_sync = true;
1593 } else if (tensor_address != device_address) {
1594 tensor->data_sync(false);
1595 need_sync = true;
1596 }
1597 if (mem_scheduler->HasDeviceMem(device_address.get())) {
1598 device_address->set_ptr(nullptr);
1599 }
1600 if (need_sync) {
1601 const auto &shape = trans::GetRuntimePaddingShape(input_node, 0);
1602 if (device_address->GetPtr() != nullptr) {
1603 (void)device_address->SyncHostToDevice(shape, LongToSize(tensor->data().nbytes()), tensor->data_type(),
1604 tensor->device_info().host_format_, tensor->data_ptr());
1605 } else {
1606 mem_scheduler->AddMemNeedInit(device_address.get());
1607 }
1608 }
1609 MemPriority priority = kMemPriorityLow;
1610 const auto ¶meter = input_node->cast<ParameterPtr>();
1611 if (common::AnfAlgo::IsParameterWeight(parameter) || graph.IsUpdatedParameter(parameter)) {
1612 priority = kMemPriorityHigh;
1613 }
1614 mem_scheduler->Init(device_address.get(), tensor->data_c(), tensor_size, priority);
1615 tensor->set_sync_status(kNoNeedSync);
1616 }
1617 }
1618
AddCommunicationMemInfo(const session::KernelGraph & graph)1619 void KernelRuntime::AddCommunicationMemInfo(const session::KernelGraph &graph) {
1620 const auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1621 for (size_t compute_index = 0; compute_index < graph.execution_order().size(); ++compute_index) {
1622 const auto &kernel = graph.execution_order()[compute_index];
1623 MS_EXCEPTION_IF_NULL(kernel);
1624 if (!common::AnfAlgo::IsCommunicationOp(kernel)) {
1625 continue;
1626 }
1627 auto device_address_to_key = [](const DeviceAddressPtr &device_address) -> void * { return device_address.get(); };
1628 size_t input_total_size = 0;
1629 DeviceAddressPtrList input_address_list;
1630 std::vector<size_t> input_align_size_list;
1631 GetCommunicationInputInfo(kernel, &input_total_size, &input_address_list, &input_align_size_list);
1632 if (input_address_list.size() > 1) {
1633 std::vector<const void *> input_address_key_list;
1634 (void)std::transform(input_address_list.begin(), input_address_list.end(),
1635 std::back_inserter(input_address_key_list), device_address_to_key);
1636 mem_scheduler->AddContinuousMemInfo(true, compute_index, input_total_size, input_align_size_list,
1637 input_address_key_list);
1638 }
1639 size_t output_total_size = 0;
1640 DeviceAddressPtrList output_address_list;
1641 std::vector<size_t> output_align_size_list;
1642 GetCommunicationOutputInfo(kernel, &output_total_size, &output_address_list, &output_align_size_list);
1643 if (output_address_list.size() > 1) {
1644 std::vector<const void *> output_address_key_list;
1645 (void)std::transform(output_address_list.begin(), output_address_list.end(),
1646 std::back_inserter(output_address_key_list), device_address_to_key);
1647 mem_scheduler->AddContinuousMemInfo(false, compute_index, output_total_size, output_align_size_list,
1648 output_address_key_list);
1649 }
1650 }
1651 }
1652
LaunchKernel(const session::KernelGraph & graph,const AnfNodePtr & kernel,const std::shared_ptr<MemScheduler> & mem_scheduler,bool mock)1653 bool KernelRuntime::LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel,
1654 const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock) {
1655 MS_EXCEPTION_IF_NULL(kernel);
1656 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1657 MS_EXCEPTION_IF_NULL(kernel_mod);
1658 KernelLaunchInfo kernel_launch_info;
1659 auto stream = GetKernelStream(kernel);
1660 MS_EXCEPTION_IF_NULL(stream);
1661 bool ret = true;
1662 if (mem_scheduler != nullptr) {
1663 ret = MemSchedulerPreCompute(kernel, mem_scheduler, stream, mock, &kernel_launch_info);
1664 if (!ret) {
1665 return ret;
1666 }
1667 } else {
1668 GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
1669 }
1670 if (!mock) {
1671 if (pynative_mode_profiling_flag_) {
1672 ret = LaunchKernelWithPynativeProfiling(kernel_mod, kernel->fullname_with_scope(), kernel_launch_info, stream);
1673 } else {
1674 ret = kernel_mod->Launch(kernel_launch_info.inputs_, kernel_launch_info.workspaces_, kernel_launch_info.outputs_,
1675 stream);
1676 }
1677 if (!ret) {
1678 return ret;
1679 }
1680 }
1681 if (mem_scheduler != nullptr) {
1682 ret = MemSchedulerPostCompute(graph, kernel, mem_scheduler, stream, mock);
1683 }
1684 return ret;
1685 }
1686
LaunchKernelMod(const session::KernelGraph & graph,bool mock)1687 bool KernelRuntime::LaunchKernelMod(const session::KernelGraph &graph, bool mock) {
1688 auto context_ptr = MsContext::GetInstance();
1689 MS_EXCEPTION_IF_NULL(context_ptr);
1690 std::shared_ptr<MemScheduler> mem_scheduler = nullptr;
1691
1692 if (UseMemScheduler()) {
1693 mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1694 MS_EXCEPTION_IF_NULL(mem_scheduler);
1695 mem_scheduler->Reset();
1696 mem_scheduler->Update();
1697 InitGraphInputTensors(mem_scheduler, graph);
1698 }
1699
1700 const auto &kernels = graph.execution_order();
1701 std::map<AnfNodePtr, std::vector<std::function<void()>>> kernel_pre_run_events;
1702 std::map<AnfNodePtr, std::vector<std::function<void()>>> kernel_post_run_events;
1703 auto events_iter = graph_kernel_events_map_.find(graph.graph_id());
1704 if (events_iter != graph_kernel_events_map_.end()) {
1705 kernel_pre_run_events = events_iter->second.first;
1706 kernel_post_run_events = events_iter->second.second;
1707 }
1708 for (size_t i = 0; i < kernels.size(); ++i) {
1709 LaunchKernelEvent(kernel_pre_run_events, kernels[i]);
1710 auto &kernel = kernels[i];
1711 MS_EXCEPTION_IF_NULL(kernel);
1712 if (common::AnfAlgo::IsDynamicShape(kernel)) {
1713 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1714 MS_EXCEPTION_IF_NULL(kernel_mod);
1715 opt::InferOp(kernel);
1716 auto inputs = AnfAlgo::GetOrCreateAllInputKernelTensors(kernel);
1717 auto outputs = AnfAlgo::GetOrCreateAllOutputKernelTensors(kernel);
1718 if (kernel_mod->Resize(inputs, outputs) == static_cast<int>(kernel::KRET_RESIZE_FAILED)) {
1719 MS_LOG(EXCEPTION) << "Node " << kernel->fullname_with_scope() << " Resize failed.";
1720 }
1721 KernelLaunchInfo kernel_launch_info;
1722 device::KernelRuntime::GenLaunchArgs(*kernel_mod, kernel, &kernel_launch_info);
1723 // allocate workspace size
1724 std::vector<KernelTensor *> workspaces;
1725 if (AnfAlgo::GetKernelType(kernel) == KernelType::TBE_KERNEL) {
1726 MS_EXCEPTION_IF_NULL(tbe_call_);
1727 tbe_call_(kernel, kernel_mod, &workspaces);
1728 } else {
1729 workspaces = kernel_launch_info.workspaces_;
1730 }
1731
1732 auto ret = kernel_mod->Launch(kernel_launch_info.inputs_, workspaces, kernel_launch_info.outputs_, stream_);
1733 if (!ret) {
1734 MS_LOG(ERROR) << "Launch kernel failed, kernel full name: " << kernel->fullname_with_scope();
1735 return false;
1736 }
1737
1738 if (!SyncStream()) {
1739 MS_LOG(ERROR) << "SyncStream failed";
1740 return false;
1741 }
1742 kernel::UpdateNodeShape(kernel);
1743 } else {
1744 // Skip transpose kernel with "nop_op" attr which is not hidden or removed in PyNative infer scenario. Transpose
1745 // kernel, which is not supposed to be executed, is generated in TransDataSplit to support specific Transdata.
1746 // And hard code here should be removed after new Transdata programme is implemented in the foreseeable future.
1747 if (common::AnfAlgo::HasNodeAttr(kAttrNopOp, kernel)) {
1748 for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(kernel); idx += 1) {
1749 auto real_input = AnfAlgo::GetInputGraphIdxByKernelIdx(kernel, idx);
1750 auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, real_input);
1751 AnfAlgo::SetOutputAddr(device_address, idx, kernel.get());
1752 }
1753 continue;
1754 }
1755 auto ret = LaunchKernel(graph, kernel, mem_scheduler, mock);
1756 if (!ret) {
1757 MS_LOG(ERROR) << "Launch kernel failed.";
1758 return false;
1759 }
1760 KernelLaunchProfiling(kernel->fullname_with_scope());
1761 DebugStreamSync(kernel);
1762 }
1763 LaunchKernelEvent(kernel_post_run_events, kernels[i]);
1764 }
1765 if (UseMemScheduler() && !mock) {
1766 SyncParameter(graph, mem_scheduler);
1767 }
1768 return true;
1769 }
1770
SyncParameter(const session::KernelGraph & graph,const std::shared_ptr<MemScheduler> & mem_scheduler) const1771 void KernelRuntime::SyncParameter(const session::KernelGraph &graph,
1772 const std::shared_ptr<MemScheduler> &mem_scheduler) const {
1773 MS_EXCEPTION_IF_NULL(mem_scheduler);
1774 auto &input_nodes = graph.input_nodes();
1775 auto &input_tensors = graph.input_tensors();
1776 if (input_tensors.size() != input_nodes.size()) {
1777 MS_LOG_EXCEPTION << "Invalid input tensor size:" << input_tensors.size() << " vs node size:" << input_nodes.size();
1778 }
1779
1780 for (size_t i = 0; i < input_tensors.size(); ++i) {
1781 auto input_node = input_nodes[i];
1782 if (!input_node->isa<Parameter>() || !AnfAlgo::OutputAddrExist(input_node, 0)) {
1783 continue;
1784 }
1785 auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
1786 MS_EXCEPTION_IF_NULL(device_address);
1787 auto parameter = input_node->cast<ParameterPtr>();
1788 MS_EXCEPTION_IF_NULL(parameter);
1789 if (!common::AnfAlgo::IsParameterWeight(parameter) && !graph.IsUpdatedParameter(parameter)) {
1790 continue;
1791 }
1792 auto tensor = input_tensors[i];
1793 MS_EXCEPTION_IF_NULL(tensor);
1794 if (mem_scheduler->HasDeviceMem(device_address.get())) {
1795 auto device_ptr = mem_scheduler->GetOrMalloc(device_address.get(), device_address->size(), kMemPriorityHigh);
1796 device_address->set_ptr(device_ptr);
1797 tensor->set_device_address(device_address);
1798 tensor->set_sync_status(kNeedSyncDeviceToHost);
1799 }
1800 if (graph.IsUpdatedParameter(parameter)) {
1801 tensor->SetIsUpdateByDevice();
1802 }
1803 }
1804 }
1805
UseMemSchedulerIfNeeded(const session::KernelGraph & graph)1806 void KernelRuntime::UseMemSchedulerIfNeeded(const session::KernelGraph &graph) {
1807 auto context_ptr = MsContext::GetInstance();
1808 MS_EXCEPTION_IF_NULL(context_ptr);
1809 if (!UseMemScheduler()) {
1810 return;
1811 }
1812 auto mem_scheduler = mem_scheduler_manager_.GetOrCreateMemScheduler(graph.graph_id());
1813 MS_EXCEPTION_IF_NULL(mem_scheduler);
1814 if (mem_scheduler->optimized()) {
1815 return;
1816 }
1817 mem_scheduler->SetMemHandler(std::make_shared<MemHandler>(mem_manager_));
1818 mem_scheduler->SetTotalStep(graph.execution_order().size());
1819
1820 if (mem_scheduler->need_record_event()) {
1821 (void)LaunchKernelMod(graph, true);
1822 mem_scheduler->set_need_record_event(false);
1823 }
1824 auto ret = mem_scheduler->Optimize();
1825 if (!ret) {
1826 MS_LOG_EXCEPTION << "Can't run graph " << graph.graph_id() << " for memory limit.";
1827 }
1828 }
1829
LaunchKernels(const session::KernelGraph & graph)1830 bool KernelRuntime::LaunchKernels(const session::KernelGraph &graph) {
1831 UseMemSchedulerIfNeeded(graph);
1832 if (!LaunchKernelMod(graph)) {
1833 MS_LOG(ERROR) << "LaunchKernelMod failed!";
1834 return false;
1835 }
1836 auto ms_context = MsContext::GetInstance();
1837 MS_EXCEPTION_IF_NULL(ms_context);
1838 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
1839 if (!SyncStream()) {
1840 MS_LOG(ERROR) << "SyncStream failed";
1841 return false;
1842 }
1843 }
1844 return true;
1845 }
1846
ClearGraphRuntimeResource(uint32_t graph_id)1847 void KernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
1848 MS_LOG(INFO) << "Clear graph:" << graph_id << " runtime resource";
1849 }
1850 } // namespace device
1851 } // namespace mindspore
1852