1 /**
2 * Copyright 2021-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"){}
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/graph_scheduler/graph_compiler.h"
18 #include <numeric>
19 #include <map>
20 #include <utility>
21 #include <algorithm>
22 #include <functional>
23 #include <list>
24 #include "runtime/graph_scheduler/graph_scheduler.h"
25 #include "runtime/device/device_address_utils.h"
26 #include "runtime/pynative/op_executor.h"
27 #include "include/backend/device_address.h"
28 #include "runtime/device/ms_device_shape_transfer.h"
29 #include "runtime/pynative/op_runtime_info.h"
30 #include "runtime/pynative/op_compiler.h"
31 #include "include/common/utils/convert_utils.h"
32 #include "backend/common/graph_kernel/graph_kernel_flags.h"
33 #include "backend/common/optimizer/common_backend_optimization.h"
34 #include "utils/ms_context.h"
35 #include "ir/tensor.h"
36 #include "kernel/framework_utils.h"
37 #include "include/backend/debug/profiler/profiling.h"
38 #include "include/backend/optimizer/helper.h"
39 #include "base/base_ref_utils.h"
40 #include "include/common/debug/dump_proto.h"
41 #include "include/common/utils/parallel_context.h"
42 #include "plugin/device/cpu/hal/hardware/cpu_device_context.h"
43 #ifdef ENABLE_DEBUGGER
44 #include "include/backend/debug/debugger/debugger.h"
45 #endif
46 #ifdef ENABLE_DUMP_IR
47 #include "include/common/debug/anf_ir_dump.h"
48 #endif
49 #ifndef ENABLE_SECURITY
50 #include "include/backend/debug/data_dump/dump_json_parser.h"
51 #include "include/backend/optimizer/graph_optimizer.h"
52 #endif
53 #if defined(__linux__) && defined(WITH_BACKEND)
54 #include "include/backend/distributed/ps/ps_context.h"
55 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
56 #endif
57 #include "include/common/profiler.h"
58 #include "include/common/utils/compile_cache_context.h"
59 #include "utils/phase.h"
60 #include "pipeline/jit/ps/base.h"
61 #include "ops/framework_ops.h"
62
63 namespace mindspore {
64 namespace runtime {
65 namespace {
SetSummaryNodesRefCount(const KernelGraph * graph)66 void SetSummaryNodesRefCount(const KernelGraph *graph) {
67 MS_EXCEPTION_IF_NULL(graph);
68 if (!graph->summary_node_exist()) {
69 return;
70 }
71
72 const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes();
73 if (summary_nodes.empty()) {
74 return;
75 }
76
77 for (const auto &item : summary_nodes) {
78 const AnfNodePtr &node = item.second.first;
79 size_t index = IntToSize(item.second.second);
80 auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
81 MS_EXCEPTION_IF_NULL(device_address);
82 device_address->set_original_ref_count(SIZE_MAX);
83 device_address->ResetRefCount();
84 }
85 }
86
EnableBackendCompileCache(const FuncGraphPtr & func_graph,const device::DeviceType & device_type)87 bool EnableBackendCompileCache(const FuncGraphPtr &func_graph, const device::DeviceType &device_type) {
88 if (!CompileCacheEnable()) {
89 return false;
90 }
91 auto &context = CompileCacheContext::GetInstance();
92 if (context.FrontGraph() != func_graph) {
93 return false;
94 }
95 if (context.RestrictedScenarios()) {
96 return false;
97 }
98 if (MsContext::GetInstance()->backend_policy() == "ge") {
99 return false;
100 }
101 if (device_type != device::DeviceType::kAscend) {
102 return false;
103 }
104 auto ms_context = MsContext::GetInstance();
105 MS_EXCEPTION_IF_NULL(ms_context);
106 if (ms_context->CellReuseLevel() != CellReuseLevel::kNoCellReuse) {
107 return false;
108 }
109 return true;
110 }
111
UseCacheToCompileGraph(const FuncGraphPtr & func_graph,const device::DeviceType & device_type)112 bool UseCacheToCompileGraph(const FuncGraphPtr &func_graph, const device::DeviceType &device_type) {
113 if (!EnableBackendCompileCache(func_graph, device_type)) {
114 return false;
115 }
116 auto &context = CompileCacheContext::GetInstance();
117 if (!context.UseCompileCache()) {
118 return false;
119 }
120 return true;
121 }
122
ExportCompileCache(const FuncGraphPtr & func_graph,const device::DeviceType & device_type)123 bool ExportCompileCache(const FuncGraphPtr &func_graph, const device::DeviceType &device_type) {
124 if (!EnableBackendCompileCache(func_graph, device_type)) {
125 return false;
126 }
127 auto &context = CompileCacheContext::GetInstance();
128 if (context.UseCompileCache()) {
129 return false;
130 }
131 return true;
132 }
133
134 // Fetch the real input of the nop node recursively.
FetchRealNodeByNopNode(const AnfNodePtr & node)135 AnfNodePtr FetchRealNodeByNopNode(const AnfNodePtr &node) {
136 MS_EXCEPTION_IF_NULL(node);
137 if ((!node->isa<CNode>()) || (!common::AnfAlgo::IsNopNode(node))) {
138 return node;
139 }
140
141 const auto &cnode = node->cast<CNodePtr>();
142 MS_EXCEPTION_IF_NULL(cnode);
143
144 const auto &inputs = cnode->inputs();
145 if (inputs.size() <= 1) {
146 MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, cnode)
147 << "#dmsg#Runtime error info:#dmsg#Invalid cnode:" << cnode->DebugString();
148 }
149 return FetchRealNodeByNopNode(inputs[1]);
150 }
151
OptimizeNopNode(KernelGraph * graph)152 void OptimizeNopNode(KernelGraph *graph) {
153 MS_EXCEPTION_IF_NULL(graph);
154 std::vector<CNodePtr> nop_nodes_need_set_ref;
155
156 // Skip the graph mode.
157 if (graph->is_graph_run_mode()) {
158 return;
159 }
160
161 const auto &output_node = graph->output();
162 const auto &ref_map = graph->GetRefMap();
163 std::set<std::pair<AnfNodePtr, size_t>> ref_out_value;
164 for (const auto &iter : ref_map) {
165 ref_out_value.insert(iter.second);
166 }
167 MS_EXCEPTION_IF_NULL(output_node);
168 const auto &graph_outputs = common::AnfAlgo::GetAllOutputWithIndex(output_node);
169 // Collect all the nopnodes that can be eliminated.
170 for (const auto &cnode : graph->execution_order()) {
171 MS_EXCEPTION_IF_NULL(cnode);
172 if ((!common::AnfAlgo::IsNopNode(cnode)) || ref_map.count({cnode, 0}) != 0 ||
173 ref_out_value.count({cnode, 0}) != 0 ||
174 std::find_if(graph_outputs.begin(), graph_outputs.end(),
175 [&cnode](const KernelWithIndex &output) {
176 const auto &real_output = common::AnfAlgo::FetchRealNodeSkipMonadControl(output);
177 return real_output == KernelWithIndex(cnode, 0);
178 }) != graph_outputs.end() ||
179 std::find_if(cnode->inputs().begin(), cnode->inputs().end(), [](const auto &input) {
180 return common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimConditionGather) ||
181 common::AnfAlgo::CheckPrimitiveType(common::AnfAlgo::VisitKernelWithReturnType(input, 0).first,
182 prim::kPrimConditionGather);
183 }) != cnode->inputs().end()) {
184 continue;
185 }
186 // NopNode that does not meet the above conditions is set to Ref Node and is not deleted from the graph to avoid
187 // incorrect shape information of KernelTensor obtained in KernelMod::Launch.
188 (void)nop_nodes_need_set_ref.emplace_back(cnode);
189 }
190
191 // Add the ref node pairs, which must be after elimination to avoid using elimination nodes.
192 for (auto &ref_node : nop_nodes_need_set_ref) {
193 MS_EXCEPTION_IF_NULL(ref_node);
194 auto input_node = common::AnfAlgo::GetInputNode(ref_node, 0);
195 MS_EXCEPTION_IF_NULL(input_node);
196 // Record the original information of ref node.
197 auto origin_pair = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
198 MS_EXCEPTION_IF_NULL(origin_pair.first);
199 // The device address of parameter as input may be not the running used in the heterogeneous or control flow
200 // scenarios, and not set the ref node.
201 if (origin_pair.first->isa<Parameter>() || origin_pair.first->isa<ValueNode>()) {
202 continue;
203 }
204 // The ref node cannot be set for node pairs from different device target(appears in the kernel backoff scene).
205 if (AnfAlgo::FetchDeviceTarget(origin_pair.first, graph) != AnfAlgo::FetchDeviceTarget(ref_node, graph)) {
206 continue;
207 }
208 MS_LOG(INFO) << "The reference relation of nopnode " << ref_node->fullname_with_scope() << ", index: " << 0
209 << " to input " << origin_pair.first->fullname_with_scope() << ", index: " << origin_pair.second;
210 graph->AddRefCorrespondPairs(std::make_pair(ref_node, 0), origin_pair);
211 }
212 }
213
IsEnableZeroCopy(bool run_in_pynative)214 bool IsEnableZeroCopy(bool run_in_pynative) {
215 if (run_in_pynative) {
216 return false;
217 }
218
219 auto ms_context = MsContext::GetInstance();
220 MS_EXCEPTION_IF_NULL(ms_context);
221 bool task_sink = ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
222 bool is_multi_graph_sink = ms_context->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK);
223 // If the run mode is not subgraph sink, the flag should not be set.
224 if (!task_sink || is_multi_graph_sink) {
225 return false;
226 }
227
228 // In ps cache mode, the whole graph sink has set multi_graph_sink to false, the zero copy cannot be enabled.
229 #if defined(__linux__) && defined(WITH_BACKEND)
230 if (ps::PSContext::instance()->cache_enable()) {
231 return false;
232 }
233 #endif
234
235 auto parallel_context = parallel::ParallelContext::GetInstance();
236 MS_EXCEPTION_IF_NULL(parallel_context);
237 auto parallel_mode = parallel_context->parallel_mode();
238 bool is_parallel_mode = parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel ||
239 parallel_mode == parallel::kHybridParallel || parallel_mode == parallel::kDataParallel;
240 // If there are auto parallel in graph, the flag should not be set. In parallel, the continue memory in communication
241 // ops not support addr change.
242 // force zero copy when use ge
243 bool is_enable_ge = ms_context->backend_policy() == "ge";
244 if (is_parallel_mode && !is_enable_ge) {
245 return false;
246 }
247 return true;
248 }
249
SetRunGraphBySingleOpFlag(const KernelGraphPtr & graph)250 void SetRunGraphBySingleOpFlag(const KernelGraphPtr &graph) {
251 MS_EXCEPTION_IF_NULL(graph);
252 for (auto &node : graph->execution_order()) {
253 MS_EXCEPTION_IF_NULL(node);
254 MS_EXCEPTION_IF_NULL(node->input(0));
255 bool enable = false;
256 if (!AnfAlgo::NodeValueIsFuncGraph(node->input(0))) {
257 if (!kernel::CheckResizeCondition(node) && graph->has_flag(kFlagPyNativeRunInGraph)) {
258 MS_LOG(INFO) << "Enable Run Graph By Single Op";
259 enable = true;
260 }
261 }
262 // BpGraph contain bprop_cut node.
263 auto contain_bprop_cut = common::AnfAlgo::IsBpropCutOpExecInBackend(node);
264 if (enable || contain_bprop_cut) {
265 MS_LOG(INFO) << "Set kFlagEnableRunGraphBySingleOp: NeedSkipResize:" << enable
266 << ", BpGraph contain bprop_cut node:" << contain_bprop_cut;
267 graph->set_flag(kFlagEnableRunGraphBySingleOp, true);
268 break;
269 }
270 }
271 }
272
UseCacheToCompileGraphImpl(const KernelGraphPtr & graph,const DeviceContext * device_context)273 void UseCacheToCompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) {
274 MS_EXCEPTION_IF_NULL(graph);
275 MS_EXCEPTION_IF_NULL(device_context);
276
277 auto &compile_cache_context = CompileCacheContext::GetInstance();
278 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCreateKernel, 1, 0, 0);
279 compile_cache_context.SetFusionOpBuildInfoFlag(true);
280 device_context->GetKernelExecutor(false)->CreateKernel(graph->execution_order());
281 compile_cache_context.SetFusionOpBuildInfoFlag(false);
282 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCreateKernel, 1, 0, 1);
283 // Kernels that are not supported by other device can be backed off and rebuilt on the CPU.
284 #ifdef WITH_BACKEND
285 if (!graph->is_from_single_op()) {
286 auto cpu_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
287 {kCPUDevice, device_context->device_context_key().device_id_});
288 MS_EXCEPTION_IF_NULL(cpu_context);
289 auto cpu_executor = dynamic_cast<device::cpu::CPUKernelExecutor *>(cpu_context->GetKernelExecutor(false).get());
290 MS_EXCEPTION_IF_NULL(cpu_executor);
291 cpu_executor->RebuildKernelSelectBackoffOp(graph->execution_order());
292 }
293 #endif
294 #ifndef ENABLE_SECURITY
295 // Update needed dump kernels for mindRT.
296 DumpJsonParser::GetInstance().UpdateNeedDumpKernels(*graph.get());
297 #endif
298 if (graph->is_dynamic_shape()) {
299 auto profiler_manage_inst = profiler::ProfilerManager::GetInstance();
300 MS_EXCEPTION_IF_NULL(profiler_manage_inst);
301 profiler_manage_inst->SetNetDynamicShapeStatus();
302 }
303 }
304
IsValidSequence(const ValueSequencePtr & sequence_value)305 bool IsValidSequence(const ValueSequencePtr &sequence_value) {
306 MS_EXCEPTION_IF_NULL(sequence_value);
307 const auto &values = sequence_value->value();
308 if (values.empty()) {
309 return true;
310 }
311 MS_EXCEPTION_IF_NULL(values[0]);
312 if (values[0]->isa<ValueSequence>()) {
313 return false;
314 }
315 if (values[0]->type() == nullptr) {
316 MS_LOG(DEBUG) << "Failed to get type from value tuple:" << sequence_value->ToString();
317 return false;
318 }
319 TypeId base_type = values[0]->type()->type_id();
320 for (size_t i = 1; i < values.size(); ++i) {
321 MS_EXCEPTION_IF_NULL(values[i]);
322 MS_EXCEPTION_IF_NULL(values[i]->type());
323 TypeId type = values[i]->type()->type_id();
324 if (type != base_type) {
325 MS_LOG(DEBUG) << "Invalid value type for value:" << sequence_value->ToString();
326 return false;
327 }
328 }
329 return true;
330 }
331
CollectValueNodeForKernelGraph(const KernelGraphPtr & graph)332 void CollectValueNodeForKernelGraph(const KernelGraphPtr &graph) {
333 MS_EXCEPTION_IF_NULL(graph);
334 graph->ClearAllValueNode();
335 const auto &nodes = TopoSort(graph->get_return());
336 for (const auto &node : nodes) {
337 MS_EXCEPTION_IF_NULL(node);
338 if (!node->isa<ValueNode>() || node->kernel_info() == nullptr) {
339 continue;
340 }
341 const auto &value_node = node->cast<ValueNodePtr>();
342 MS_EXCEPTION_IF_NULL(value_node);
343 const auto &value = value_node->value();
344 MS_EXCEPTION_IF_NULL(value);
345 if (value->isa<Primitive>() ||
346 (value->isa<ValueSequence>() && (!IsValidSequence(value->cast<ValueSequencePtr>())))) {
347 continue;
348 }
349 MS_LOG(DEBUG) << "Add value node:" << node->DebugString() << " for kernel graph:" << graph->ToString();
350 graph->AddValueNodeToGraph(value_node);
351 }
352 }
353
CompileAnyTypeInputGraph(const KernelGraphPtr & graph,const AnfNodePtrList & outputs,const DeviceContext * device_context)354 GraphId CompileAnyTypeInputGraph(const KernelGraphPtr &graph, const AnfNodePtrList &outputs,
355 const DeviceContext *device_context) {
356 MS_EXCEPTION_IF_NULL(graph);
357 for (const auto &input : graph->inputs()) {
358 MS_EXCEPTION_IF_NULL(input);
359 MS_LOG(DEBUG) << "input node:" << input->DebugString()
360 << " abstract:" << (input->abstract() == nullptr ? "null" : input->abstract()->ToString());
361 }
362 MS_LOG(DEBUG) << "Pre construct any type input kernel graph:" << graph->ToString();
363 graph->set_is_any_type_input(true);
364 opt::OptimizationForAnyTypeKernelGraph(graph);
365 graph->SetInputNodes();
366 for (const auto &input : graph->input_nodes()) {
367 MS_EXCEPTION_IF_NULL(input);
368 MS_LOG(DEBUG) << "input node:" << input->DebugString()
369 << " abstract:" << (input->abstract() == nullptr ? "null" : input->abstract()->ToString());
370 if (!input->isa<Parameter>()) {
371 continue;
372 }
373 const auto ¶meter = input->cast<ParameterPtr>();
374 MS_EXCEPTION_IF_NULL(parameter);
375 const auto &shape = parameter->Shape();
376 if (shape != nullptr &&
377 ((shape->isa<abstract::Shape>() && shape->IsDynamic()) || shape->isa<abstract::DynamicSequenceShape>())) {
378 parameter->set_has_dynamic_shape(true);
379 }
380 }
381 auto backend_output = graph->output();
382 MS_EXCEPTION_IF_NULL(backend_output);
383 graph->CacheGraphOutputToFrontNodeWithIndex({backend_output}, outputs);
384 graph->UpdateInternalParameter();
385 DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
386
387 auto output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
388 for (const auto &output_with_index : output_with_indexs) {
389 const auto &output = output_with_index.first;
390 MS_EXCEPTION_IF_NULL(output);
391 if (common::AnfAlgo::IsBpropCutOpExecInBackend(output) || HasAbstractMonad(output)) {
392 continue;
393 }
394 if (output->kernel_info() == nullptr) {
395 output->set_kernel_info(std::make_shared<device::KernelInfo>());
396 }
397 auto kernel_info = dynamic_cast<device::KernelInfo *>(output->kernel_info());
398 MS_EXCEPTION_IF_NULL(kernel_info);
399 // select_kernel_build_info() has checked whether return pointer is null
400 auto build_info = kernel_info->select_kernel_build_info();
401 if (build_info != nullptr) {
402 continue;
403 }
404 size_t output_num = 1;
405 if (output->abstract() != nullptr) {
406 output_num = common::AnfAlgo::GetOutputNumByAbstract(output->abstract());
407 }
408 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
409 builder.SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
410 builder.SetOutputsDeviceType(std::vector<TypeId>(output_num, kTypeUnknown));
411 builder.SetOutputsKernelObjectType(
412 std::vector<kernel::KernelObjectType>(output_num, kernel::KernelObjectType::TENSOR));
413 AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), output.get());
414 MS_LOG(DEBUG) << "Set kernel build info for node:" << output->DebugString() << " output num:" << output_num;
415 }
416 CollectValueNodeForKernelGraph(graph);
417 DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
418 DeviceAddressUtils::CreateGraphOutputDeviceAddress(device_context, graph);
419 return graph->graph_id();
420 }
421 } // namespace
422
CompileGraph(const GraphSegmentPtr & segment,const std::pair<AnfNodePtrList,AnfNodePtrList> & io_nodes,const DeviceContext * device_context,device::RunMode run_mode,bool run_in_pynative)423 GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment,
424 const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes,
425 const DeviceContext *device_context, device::RunMode run_mode,
426 bool run_in_pynative) {
427 MS_EXCEPTION_IF_NULL(segment);
428 MS_EXCEPTION_IF_NULL(device_context);
429 MS_LOG(INFO) << "Status record: start compile graph.";
430 auto nodes = segment->nodes_;
431 auto device_target = device_context->GetDeviceType();
432 // Generate kernel graph.
433 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageConstructKernelGraph, 1, 0, 0);
434 auto kernel_graph =
435 session_->ConstructKernelGraph(nodes, io_nodes.second, device_target, true, IsEnableZeroCopy(run_in_pynative));
436 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageConstructKernelGraph, 1, 0, 1);
437 SetGraphDependency(kernel_graph, segment);
438 return CompileGraph(kernel_graph, io_nodes, device_context, run_mode, run_in_pynative);
439 }
440
CompileGraph(const KernelGraphPtr & kernel_graph,const std::pair<AnfNodePtrList,AnfNodePtrList> & io_nodes,const DeviceContext * device_context,device::RunMode run_mode,bool run_in_pynative)441 GraphId GraphCompiler::CompileGraph(const KernelGraphPtr &kernel_graph,
442 const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes,
443 const DeviceContext *device_context, device::RunMode run_mode,
444 bool run_in_pynative) {
445 MS_EXCEPTION_IF_NULL(session_);
446 MS_EXCEPTION_IF_NULL(device_context);
447 MS_EXCEPTION_IF_NULL(kernel_graph);
448
449 const auto &outputs = io_nodes.second;
450 if (common::AnfAlgo::IsAnyTypeInput(io_nodes.first)) {
451 return CompileAnyTypeInputGraph(kernel_graph, outputs, device_context);
452 }
453 kernel_graph->erase_flag(kFlagPyNativeRunInGraph);
454 SetRunGraphBySingleOpFlag(kernel_graph);
455 kernel_graph->UpdateGraphAquireGilAttr();
456 if (run_mode == device::RunMode::kUnknown) {
457 kernel_graph->set_run_mode(device_context->GetRunMode(kernel_graph));
458 } else {
459 kernel_graph->set_run_mode(run_mode);
460 }
461 auto manager = MakeManager({kernel_graph});
462 if (manager) {
463 manager->AddFuncGraph(kernel_graph);
464 kernel_graph->set_manager(manager);
465 }
466
467 opt::OptimizationWithoutBackend(kernel_graph);
468 // Unify the MindIR, must be before of the kernel_graph optimization.
469 auto kernel_executor = device_context->GetKernelExecutor(false);
470 if (kernel_executor != nullptr) {
471 kernel_executor->AddMindIRPass(kernel_graph);
472 }
473 kernel_graph->SetInputNodes();
474 auto context_ptr = MsContext::GetInstance();
475 session_->SetInputNodeUsage(kernel_graph, manager);
476 MS_EXCEPTION_IF_NULL(context_ptr);
477 if (context_ptr->backend_policy() == "ge" && device_context->GetDeviceType() == device::DeviceType::kAscend &&
478 !IsEnableRefMode()) {
479 MS_EXCEPTION_IF_NULL(device_context->graph_executor_);
480 if (!device_context->graph_executor_->CompileGraph(kernel_graph, {})) {
481 MS_LOG(EXCEPTION) << "Compile kernel_graph failed: " << kernel_graph->graph_id();
482 }
483 kernel_graph->UpdateInternalParameter();
484 kernel_graph->CacheGraphOutputToFrontNodeWithIndex({kernel_graph->output()}, outputs);
485 kernel_graph->set_front_outputs(outputs);
486 return kernel_graph->graph_id();
487 }
488 kernel_graph->SetOptimizerFlag();
489
490 GraphId graph_id = 0;
491 if (run_in_pynative) {
492 MS_EXCEPTION_IF_NULL(session_);
493 // kernel_graph kernel does not support pynative mode now, print a warning here.
494 graphkernel::GraphKernelFlags::GetInstance().CheckSupport();
495 graph_id = kernel_graph->graph_id();
496 } else {
497 graph_id = CompileGraphImpl(kernel_graph, device_context, run_in_pynative);
498 }
499
500 kernel_graph->set_front_outputs(outputs);
501
502 kernel_graph->set_root_graph_id(graph_id);
503 session_->DumpGraphs({kernel_graph});
504
505 // The kernel_graph is not compiled yet in PyNative Mode.
506 // Need to cache output latter when the kernel_graph is compiled.
507 if (!run_in_pynative) {
508 // Cache the backend kernel_graph output nodes to front nodes with output index.
509 auto backend_node = kernel_graph->output();
510 MS_EXCEPTION_IF_NULL(backend_node);
511 kernel_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
512 }
513 AnfAlgo::UpdateGraphValidRefPair(kernel_graph);
514
515 MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
516 return graph_id;
517 }
518
~GraphCompilerInfo()519 GraphCompilerInfo::~GraphCompilerInfo() {
520 GraphScheduler::GetInstance().Clear(name_, graphs_, origin_parameters_order_, control_node_parser_);
521 }
522
CompileDynamicGraph(const GraphSegmentPtr & segment,const AnfNodePtrList & outputs,const DeviceContext * device_context)523 GraphId GraphCompiler::CompileDynamicGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs,
524 const DeviceContext *device_context) {
525 MS_EXCEPTION_IF_NULL(segment);
526 MS_EXCEPTION_IF_NULL(device_context);
527 MS_LOG(INFO) << "Status record: start compile graph.";
528
529 auto nodes = segment->nodes_;
530 auto device_target = device_context->GetDeviceType();
531 // Generate kernel graph.
532 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageConstructKernelGraph, 1, 0, 0);
533 const auto &kernel_graph = session_->ConstructKernelGraph(nodes, outputs, device_target, true, false);
534 return CompileDynamicGraph(kernel_graph, device_context);
535 }
536
CompileDynamicGraph(const KernelGraphPtr & kernel_graph,const DeviceContext * device_context)537 GraphId GraphCompiler::CompileDynamicGraph(const KernelGraphPtr &kernel_graph, const DeviceContext *device_context) {
538 MS_EXCEPTION_IF_NULL(kernel_graph);
539 MS_EXCEPTION_IF_NULL(device_context);
540 // Dynamic shape or dynamic graph structure flag.
541 kernel_graph->set_flag(kAttrMutableKernel, true);
542 MS_LOG(INFO) << "Set kFlagEnableRunGraphBySingleOp: Dynamic shape or dynamic graph structure flag";
543 kernel_graph->set_flag(kFlagEnableRunGraphBySingleOp, true);
544
545 kernel_graph->UpdateGraphAquireGilAttr();
546 kernel_graph->SetInputNodes();
547 auto manager = Manage(kernel_graph);
548 if (manager) {
549 manager->AddFuncGraph(kernel_graph);
550 kernel_graph->set_manager(manager);
551 }
552 session_->SetInputNodeUsage(kernel_graph, manager);
553 kernel_graph->SetOptimizerFlag();
554 kernel_graph->set_run_mode(device::RunMode::kKernelMode);
555
556 // kernel_graph kernel does not support pynative mode now, print a warning here.
557 graphkernel::GraphKernelFlags::GetInstance().CheckSupport();
558
559 GraphId graph_id = kernel_graph->graph_id();
560 kernel_graph->set_root_graph_id(graph_id);
561 session_->DumpGraphs({kernel_graph});
562
563 MS_LOG(INFO) << "Status record: end compile kernel_graph. kernel_graph id: " << graph_id;
564 return graph_id;
565 }
566
ConstructKernelGraphForGraphRunMode(const FuncGraphPtr & func_graph,const DeviceContext * device_context,std::vector<KernelGraphPtr> * const all_graphs,bool * const need_return_ahead)567 KernelGraphPtr GraphCompiler::ConstructKernelGraphForGraphRunMode(const FuncGraphPtr &func_graph,
568 const DeviceContext *device_context,
569 std::vector<KernelGraphPtr> *const all_graphs,
570 bool *const need_return_ahead) {
571 MS_EXCEPTION_IF_NULL(func_graph);
572 MS_EXCEPTION_IF_NULL(device_context);
573 MS_EXCEPTION_IF_NULL(all_graphs);
574 auto device_target = device_context->GetDeviceType();
575 KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, all_graphs, device_target);
576 MS_EXCEPTION_IF_NULL(root_graph);
577 for (const auto &graph : *all_graphs) {
578 MS_EXCEPTION_IF_NULL(graph);
579 MS_LOG(INFO) << "Set root graph for graph: " << graph->graph_id() << " to: " << root_graph->graph_id() << ".";
580 graph->set_root_graph_id(root_graph->graph_id());
581 graph->set_run_mode(device::RunMode::kGraphMode);
582 graph->set_is_loop_count_sink(true);
583 graph->set_attrs(func_graph->attrs());
584 opt::OptimizationWithoutBackend(graph);
585 }
586
587 // Unify the MindIR, must be before of the graph optimization.
588 auto kernel_executor = device_context->GetKernelExecutor(false);
589 if (kernel_executor != nullptr) {
590 kernel_executor->AddMindIRPass(root_graph);
591 }
592
593 // todo: waiting for GraphExecutor
594 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
595 if (MsContext::GetInstance()->backend_policy() == "ge") {
596 auto manager = MakeManager();
597 MS_EXCEPTION_IF_NULL(manager);
598 for (const auto &graph : *all_graphs) {
599 MS_EXCEPTION_IF_NULL(graph);
600 graph->set_flag(kFlagEnableZeroCopyInGraph, true);
601 manager->AddFuncGraph(graph);
602 graph->set_manager(manager);
603 graph->SetInputNodes();
604 }
605 root_graph->SetInputNodes();
606 MS_EXCEPTION_IF_NULL(device_context->graph_executor_);
607 if (!device_context->graph_executor_->CompileGraph(root_graph, {})) {
608 MS_LOG(EXCEPTION) << "Compile graph failed: " << root_graph->graph_id();
609 }
610 root_graph->CacheGraphOutputToFrontNodeWithIndex({root_graph->output()}, {func_graph->output()});
611 *need_return_ahead = true;
612 }
613 if (*need_return_ahead) {
614 return root_graph;
615 }
616 // set executing sink true in graph mode
617 root_graph->set_run_mode(device::RunMode::kGraphMode);
618 root_graph->set_is_loop_count_sink(true);
619 #if defined(__linux__) && defined(WITH_BACKEND)
620 // Embedding cache need global step of compute graph, can not enable loop sink, move loop control to loop count actor.
621 if (ps::PSContext::instance()->cache_enable()) {
622 root_graph->set_is_loop_count_sink(false);
623 for (const auto &graph : *all_graphs) {
624 MS_EXCEPTION_IF_NULL(graph);
625 graph->set_is_loop_count_sink(false);
626 }
627 }
628 #endif
629 root_graph->SetInputNodes();
630 return root_graph;
631 }
632
CompileWholeGraphForGraphRunMode(const FuncGraphPtr & func_graph,const DeviceContext * device_context)633 GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func_graph,
634 const DeviceContext *device_context) {
635 MS_EXCEPTION_IF_NULL(session_);
636 MS_EXCEPTION_IF_NULL(func_graph);
637 MS_EXCEPTION_IF_NULL(device_context);
638 MS_LOG(INFO) << "Status record: start compile graph.";
639 // Generate kernel graph.
640 std::vector<KernelGraphPtr> all_graphs;
641 auto device_target = device_context->GetDeviceType();
642 KernelGraphPtr root_graph;
643 bool need_return_ahead = false;
644 if (UseCacheToCompileGraph(func_graph, device_target)) {
645 root_graph = session_->ConstructKernelGraph(&all_graphs);
646 use_cache_to_compile_graph_ = true;
647 } else {
648 root_graph = ConstructKernelGraphForGraphRunMode(func_graph, device_context, &all_graphs, &need_return_ahead);
649 }
650 GraphId graph_id = root_graph->graph_id();
651 if (need_return_ahead) {
652 return graph_id;
653 }
654 if (ExportCompileCache(func_graph, device_target)) {
655 export_compile_cache_ = true;
656 }
657 if (!func_graph->has_flag(kFlagPyNativeRunInGraph)) {
658 graph_id = CompileGraphImpl(root_graph, device_context);
659 }
660 if (CompileCacheEnable()) {
661 CompileCacheContext::GetInstance().Clear();
662 }
663
664 // dump all graphs.
665 // for ascend mindRT.
666 session_->DumpGraphs(all_graphs);
667
668 if (!func_graph->has_flag(kFlagPyNativeRunInGraph)) {
669 // Cache the backend graph output nodes to front nodes with output index.
670 auto output = func_graph->output();
671 MS_EXCEPTION_IF_NULL(output);
672 auto backend_node = root_graph->output();
673 MS_EXCEPTION_IF_NULL(backend_node);
674 root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
675 AnfAlgo::UpdateGraphValidRefPair(root_graph);
676 } else {
677 for (auto &node : root_graph->execution_order()) {
678 if (common::AnfAlgo::IsBpropCutOpExecInBackend(node)) {
679 MS_LOG(INFO) << "Set kFlagEnableRunGraphBySingleOp: IsBpropCutOpExecInBackend";
680 root_graph->set_flag(kFlagEnableRunGraphBySingleOp, true);
681 }
682 }
683 root_graph->set_front_outputs({func_graph->output()});
684 }
685 MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
686 return graph_id;
687 }
688
CompileGraphImpl(const KernelGraphPtr & graph,const DeviceContext * device_context,bool run_in_pynative) const689 GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context,
690 bool run_in_pynative) const {
691 MS_EXCEPTION_IF_NULL(graph);
692 MS_EXCEPTION_IF_NULL(device_context);
693 MS_EXCEPTION_IF_NULL(session_);
694 const auto &context = MsContext::GetInstance();
695 MS_EXCEPTION_IF_NULL(context);
696 if (use_cache_to_compile_graph_) {
697 UseCacheToCompileGraphImpl(graph, device_context);
698 } else {
699 #ifdef ENABLE_DUMP_IR
700 if (context->CanDump(kIntroductory)) {
701 // Dump .pb graph before graph optimization.
702 DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
703 }
704 #endif
705 MS_EXCEPTION_IF_NULL(device_context->GetKernelExecutor(false));
706 // Execute optimization pass.
707 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageOptimizeGraph, 1, 0, 0);
708 PROF_START(OptimizeGraph);
709 device_context->GetKernelExecutor(false)->OptimizeGraph(graph);
710 PROF_END(OptimizeGraph);
711 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageOptimizeGraph, 1, 0, 1);
712 // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
713 // 'KernelMod' is real executive object of kernel.
714 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCreateKernel, 1, 0, 0);
715 PROF_START(CreateKernel);
716 device_context->GetKernelExecutor(false)->CreateKernel(graph->execution_order());
717 PROF_END(CreateKernel);
718 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCreateKernel, 1, 0, 1);
719
720 // Kernels that are not supported by other device can be backed off and rebuilt on the CPU.
721 #ifdef WITH_BACKEND
722 if (!graph->is_from_single_op()) {
723 auto cpu_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
724 {kCPUDevice, device_context->device_context_key().device_id_});
725 MS_EXCEPTION_IF_NULL(cpu_device_context);
726 auto cpu_executor =
727 dynamic_cast<device::cpu::CPUKernelExecutor *>(cpu_device_context->GetKernelExecutor(false).get());
728 MS_EXCEPTION_IF_NULL(cpu_executor);
729 cpu_executor->RebuildKernelSelectBackoffOp(graph->execution_order());
730 }
731 #endif
732
733 // Read the output and input ref map and set to the kernel graph.
734 AnfAlgo::AddOutInRefToGraph(graph);
735
736 // Optimize the nop node.
737 if (!run_in_pynative) {
738 OptimizeNopNode(graph.get());
739 #ifdef ENABLE_DUMP_IR
740 if (context->CanDump(kIntroductory)) {
741 DumpIR("hwopt_comm_after_eliminate_nopnode_" + graph->ToString() + ".ir", graph, true);
742 }
743 #endif
744 }
745
746 #ifndef ENABLE_SECURITY
747 session_->RecurseSetSummaryNodesForAllGraphs(graph.get());
748 // Update needed dump kernels for mindRT.
749 DumpJsonParser::GetInstance().UpdateNeedDumpKernels(*graph.get());
750 #endif
751
752 // dynamic shape pass of graphmode
753 if (graph->is_dynamic_shape()) {
754 if (!graph->is_graph_run_mode()) {
755 // Temporarily disable CustomActor for asynchronous InferShape and Resize for the dynamic shape scenario,
756 // and implement the corresponding capability through direct InferShape and Resize in KernelActor.
757 // opt::DynamicShapeConvertPass(graph);
758 }
759 auto profiler_manage_inst = profiler::ProfilerManager::GetInstance();
760 MS_EXCEPTION_IF_NULL(profiler_manage_inst);
761 profiler_manage_inst->SetNetDynamicShapeStatus();
762 }
763 }
764
765 if (export_compile_cache_) {
766 session_->CacheKernelGraph(graph);
767 }
768 // Adjust kernel graph before run graph.
769 PROF_START(PreprocessBeforeRun);
770 device_context->GetKernelExecutor(false)->PreprocessBeforeRun(graph);
771 PROF_END(PreprocessBeforeRun);
772 graph->UpdateInternalParameter();
773 // Set device target for parameter affinity.
774 AnfAlgo::SetParameterDeviceTarget(graph);
775
776 // Create device address for all anf nodes of graph.
777 CreateDeviceAddress(graph, device_context);
778
779 #if defined(__linux__) && defined(WITH_BACKEND)
780 // Set device address for embedding cache parameter, only enable when enable embedding cache mode.
781 // `CreateDeviceAddress` should execute before this step.
782 EmbeddingCacheScheduler::GetInstance().SetEmbedCachedParamAddress(device_context, graph);
783 #endif
784
785 SetSummaryNodesRefCount(graph.get());
786 #ifdef ENABLE_DUMP_IR
787 // Dump .pb graph after graph optimization.
788 if (context->CanDump(kIntroductory)) {
789 DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
790 }
791 #endif
792
793 #ifdef ENABLE_DEBUGGER
794 auto debugger = Debugger::GetInstance();
795 MS_EXCEPTION_IF_NULL(debugger);
796 // Dump graph for GPU mindRT if dump is enabled.
797 debugger->DumpInGraphCompiler(graph);
798 if (debugger && debugger->DebuggerBackendEnabled()) {
799 // Load graphs for GPU and Ascend mindRT.
800 debugger->LoadGraphs(graph);
801 }
802 #endif
803
804 graph->EnableRuntimeCache();
805 return graph->graph_id();
806 }
807
Fetch(GraphId graph_id) const808 KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
809 MS_EXCEPTION_IF_NULL(session_);
810 return session_->GetGraph(graph_id);
811 }
812
CreateDeviceAddress(const KernelGraphPtr & graph,const DeviceContext * device_context) const813 void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
814 MS_EXCEPTION_IF_NULL(graph);
815 MS_LOG(INFO) << "Status record: start create device address. graph id: " << graph->graph_id();
816 DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
817 DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
818 DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, false);
819 DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
820 DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
821 DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
822 MS_LOG(INFO) << "Status record: end create device address. graph id: " << graph->graph_id();
823 }
824
GetParamAndOutputIndex(const KernelGraphPtr & graph,const std::vector<TensorPtr> & inputs,VectorRef * const outputs,std::map<AnfNodePtr,size_t> * parameter_index,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)825 void GraphCompiler::GetParamAndOutputIndex(
826 const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
827 std::map<AnfNodePtr, size_t> *parameter_index,
828 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
829 MS_EXCEPTION_IF_NULL(session_);
830 session_->GetParameterIndex(graph.get(), inputs, parameter_index);
831 session_->CreateOutputPlaceholder(graph, inputs, outputs, output_indexes);
832 }
833
GetSingleOpInputTensors(const CNodePtr & kernel,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,bool is_run_pyboost,InputInfo * const input_info)834 void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
835 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
836 const std::map<AnfNodePtr, size_t> ¶meter_index,
837 const std::vector<TensorPtr> &graph_inputs, bool is_run_pyboost,
838 InputInfo *const input_info) {
839 MS_EXCEPTION_IF_NULL(session_);
840 if (is_run_pyboost) {
841 session_->GetOpInputTensorsFromCNode(kernel, op_output, parameter_index, graph_inputs, input_info);
842 } else {
843 session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_info);
844 }
845 }
846
GetSingleOpInputTensorByIndex(const CNodePtr & kernel,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,InputInfo * const input_info,size_t input_index)847 tensor::BaseTensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(
848 const CNodePtr &kernel, const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
849 const std::map<AnfNodePtr, size_t> ¶meter_index, const std::vector<TensorPtr> &graph_inputs,
850 InputInfo *const input_info, size_t input_index) {
851 MS_EXCEPTION_IF_NULL(session_);
852 return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_info, input_index);
853 }
854
GetSingleOpRunInfoAndGraphInfo(const CNodePtr & kernel,const InputInfo & input_info,bool use_dynamic_shape_process,session::BackendOpRunInfoPtr * op_run_info,const GraphOutputInfo * const graph_output_info)855 void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputInfo &input_info,
856 bool use_dynamic_shape_process,
857 session::BackendOpRunInfoPtr *op_run_info,
858 const GraphOutputInfo *const graph_output_info) {
859 MS_EXCEPTION_IF_NULL(session_);
860 *op_run_info = session_->GetSingleOpRunInfo(kernel, input_info, graph_output_info);
861 (*op_run_info)->base_op_run_info.use_dynamic_shape_process = use_dynamic_shape_process;
862 }
863
CalculateRefCount(const KernelGraphPtr & graph,std::map<KernelWithIndex,size_t> * ref_count) const864 void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
865 MS_EXCEPTION_IF_NULL(session_);
866 session_->GetRefCount(graph.get(), ref_count);
867 }
868
CalculateForwardOpOutputCount(const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & inputs,std::map<std::string,size_t> * forward_op_output_tensor_id,const std::map<AnfNodePtr,size_t> & parameter_index) const869 void GraphCompiler::CalculateForwardOpOutputCount(const KernelGraphPtr &graph,
870 const std::vector<tensor::TensorPtr> &inputs,
871 std::map<std::string, size_t> *forward_op_output_tensor_id,
872 const std::map<AnfNodePtr, size_t> ¶meter_index) const {
873 MS_EXCEPTION_IF_NULL(session_);
874 MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
875 forward_op_output_tensor_id->clear();
876 session_->GetForwardOpOutputRefCount(graph.get(), inputs, forward_op_output_tensor_id, parameter_index);
877 }
878
UpdateRefCount(const std::set<KernelWithIndex> & input_kernels_with_index,std::map<KernelWithIndex,size_t> * ref_count,std::map<KernelWithIndex,tensor::BaseTensorPtr> * op_output_map) const879 void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
880 std::map<KernelWithIndex, size_t> *ref_count,
881 std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map) const {
882 MS_EXCEPTION_IF_NULL(session_);
883 session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
884 }
885
UpdateForwardOpOutputRefCount(const std::vector<ValuePtr> & input_values,std::map<std::string,size_t> * forward_op_output_tensor_id) const886 void GraphCompiler::UpdateForwardOpOutputRefCount(const std::vector<ValuePtr> &input_values,
887 std::map<std::string, size_t> *forward_op_output_tensor_id) const {
888 MS_EXCEPTION_IF_NULL(session_);
889 MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
890 session_->ReleaseForwardOpOutput(input_values, forward_op_output_tensor_id);
891 }
892
RecoverGraphOutput(const AnfNodePtr & kernel,const VectorRef & op_outputs,const std::map<KernelWithIndex,size_t> & ref_count,std::map<KernelWithIndex,tensor::BaseTensorPtr> * op_output_map,GraphOutputInfo * const graph_output_info) const893 void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
894 const std::map<KernelWithIndex, size_t> &ref_count,
895 std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map,
896 GraphOutputInfo *const graph_output_info) const {
897 MS_EXCEPTION_IF_NULL(session_);
898 session_->HandleOpOutputs(kernel, op_outputs, ref_count, op_output_map, graph_output_info);
899 }
900
RegisterSummaryCallBackFunc(const CallBackFunc & callback) const901 void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
902 MS_EXCEPTION_IF_NULL(session_);
903 #ifndef ENABLE_SECURITY
904 session_->RegisterSummaryCallBackFunc(callback);
905 #endif
906 }
907
Summary(const std::vector<KernelGraphPtr> & graphs) const908 void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
909 MS_EXCEPTION_IF_NULL(session_);
910 for (const auto &graph : graphs) {
911 #ifndef ENABLE_SECURITY
912 session_->Summary(graph.get());
913 #endif
914 }
915 }
916
SetGraphDependency(const KernelGraphPtr & graph,const GraphSegmentPtr & segment) const917 void GraphCompiler::SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const {
918 MS_EXCEPTION_IF_NULL(graph);
919 MS_EXCEPTION_IF_NULL(segment);
920 segment->graph_id_ = graph->graph_id();
921 for (auto &pre_segment : segment->pre_segments_) {
922 MS_EXCEPTION_IF_NULL(pre_segment);
923 auto pre_graph = Fetch(pre_segment->graph_id_);
924 MS_EXCEPTION_IF_NULL(pre_graph);
925 pre_graph->AddPostGraph(graph);
926 graph->AddPreGraph(pre_graph);
927 MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph->graph_id();
928 }
929 }
930 } // namespace runtime
931 } // namespace mindspore
932