• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License"){}
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/graph_scheduler/graph_compiler.h"
18 #include <numeric>
19 #include <map>
20 #include <utility>
21 #include <algorithm>
22 #include <functional>
23 #include <list>
24 #include "runtime/graph_scheduler/graph_scheduler.h"
25 #include "runtime/device/device_address_utils.h"
26 #include "runtime/pynative/op_executor.h"
27 #include "include/backend/device_address.h"
28 #include "runtime/device/ms_device_shape_transfer.h"
29 #include "runtime/pynative/op_runtime_info.h"
30 #include "runtime/pynative/op_compiler.h"
31 #include "include/common/utils/convert_utils.h"
32 #include "backend/common/graph_kernel/graph_kernel_flags.h"
33 #include "backend/common/optimizer/common_backend_optimization.h"
34 #include "utils/ms_context.h"
35 #include "ir/tensor.h"
36 #include "kernel/framework_utils.h"
37 #include "include/backend/debug/profiler/profiling.h"
38 #include "include/backend/optimizer/helper.h"
39 #include "base/base_ref_utils.h"
40 #include "include/common/debug/dump_proto.h"
41 #include "include/common/utils/parallel_context.h"
42 #include "plugin/device/cpu/hal/hardware/cpu_device_context.h"
43 #ifdef ENABLE_DEBUGGER
44 #include "include/backend/debug/debugger/debugger.h"
45 #endif
46 #ifdef ENABLE_DUMP_IR
47 #include "include/common/debug/anf_ir_dump.h"
48 #endif
49 #ifndef ENABLE_SECURITY
50 #include "include/backend/debug/data_dump/dump_json_parser.h"
51 #include "include/backend/optimizer/graph_optimizer.h"
52 #endif
53 #if defined(__linux__) && defined(WITH_BACKEND)
54 #include "include/backend/distributed/ps/ps_context.h"
55 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
56 #endif
57 #include "include/common/profiler.h"
58 #include "include/common/utils/compile_cache_context.h"
59 #include "utils/phase.h"
60 #include "pipeline/jit/ps/base.h"
61 #include "ops/framework_ops.h"
62 
63 namespace mindspore {
64 namespace runtime {
65 namespace {
SetSummaryNodesRefCount(const KernelGraph * graph)66 void SetSummaryNodesRefCount(const KernelGraph *graph) {
67   MS_EXCEPTION_IF_NULL(graph);
68   if (!graph->summary_node_exist()) {
69     return;
70   }
71 
72   const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes();
73   if (summary_nodes.empty()) {
74     return;
75   }
76 
77   for (const auto &item : summary_nodes) {
78     const AnfNodePtr &node = item.second.first;
79     size_t index = IntToSize(item.second.second);
80     auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
81     MS_EXCEPTION_IF_NULL(device_address);
82     device_address->set_original_ref_count(SIZE_MAX);
83     device_address->ResetRefCount();
84   }
85 }
86 
EnableBackendCompileCache(const FuncGraphPtr & func_graph,const device::DeviceType & device_type)87 bool EnableBackendCompileCache(const FuncGraphPtr &func_graph, const device::DeviceType &device_type) {
88   if (!CompileCacheEnable()) {
89     return false;
90   }
91   auto &context = CompileCacheContext::GetInstance();
92   if (context.FrontGraph() != func_graph) {
93     return false;
94   }
95   if (context.RestrictedScenarios()) {
96     return false;
97   }
98   if (MsContext::GetInstance()->backend_policy() == "ge") {
99     return false;
100   }
101   if (device_type != device::DeviceType::kAscend) {
102     return false;
103   }
104   auto ms_context = MsContext::GetInstance();
105   MS_EXCEPTION_IF_NULL(ms_context);
106   if (ms_context->CellReuseLevel() != CellReuseLevel::kNoCellReuse) {
107     return false;
108   }
109   return true;
110 }
111 
UseCacheToCompileGraph(const FuncGraphPtr & func_graph,const device::DeviceType & device_type)112 bool UseCacheToCompileGraph(const FuncGraphPtr &func_graph, const device::DeviceType &device_type) {
113   if (!EnableBackendCompileCache(func_graph, device_type)) {
114     return false;
115   }
116   auto &context = CompileCacheContext::GetInstance();
117   if (!context.UseCompileCache()) {
118     return false;
119   }
120   return true;
121 }
122 
ExportCompileCache(const FuncGraphPtr & func_graph,const device::DeviceType & device_type)123 bool ExportCompileCache(const FuncGraphPtr &func_graph, const device::DeviceType &device_type) {
124   if (!EnableBackendCompileCache(func_graph, device_type)) {
125     return false;
126   }
127   auto &context = CompileCacheContext::GetInstance();
128   if (context.UseCompileCache()) {
129     return false;
130   }
131   return true;
132 }
133 
134 // Fetch the real input of the nop node recursively.
FetchRealNodeByNopNode(const AnfNodePtr & node)135 AnfNodePtr FetchRealNodeByNopNode(const AnfNodePtr &node) {
136   MS_EXCEPTION_IF_NULL(node);
137   if ((!node->isa<CNode>()) || (!common::AnfAlgo::IsNopNode(node))) {
138     return node;
139   }
140 
141   const auto &cnode = node->cast<CNodePtr>();
142   MS_EXCEPTION_IF_NULL(cnode);
143 
144   const auto &inputs = cnode->inputs();
145   if (inputs.size() <= 1) {
146     MS_LOG_WITH_NODE(INTERNAL_EXCEPTION, cnode)
147       << "#dmsg#Runtime error info:#dmsg#Invalid cnode:" << cnode->DebugString();
148   }
149   return FetchRealNodeByNopNode(inputs[1]);
150 }
151 
OptimizeNopNode(KernelGraph * graph)152 void OptimizeNopNode(KernelGraph *graph) {
153   MS_EXCEPTION_IF_NULL(graph);
154   std::vector<CNodePtr> nop_nodes_need_set_ref;
155 
156   // Skip the graph mode.
157   if (graph->is_graph_run_mode()) {
158     return;
159   }
160 
161   const auto &output_node = graph->output();
162   const auto &ref_map = graph->GetRefMap();
163   std::set<std::pair<AnfNodePtr, size_t>> ref_out_value;
164   for (const auto &iter : ref_map) {
165     ref_out_value.insert(iter.second);
166   }
167   MS_EXCEPTION_IF_NULL(output_node);
168   const auto &graph_outputs = common::AnfAlgo::GetAllOutputWithIndex(output_node);
169   // Collect all the nopnodes that can be eliminated.
170   for (const auto &cnode : graph->execution_order()) {
171     MS_EXCEPTION_IF_NULL(cnode);
172     if ((!common::AnfAlgo::IsNopNode(cnode)) || ref_map.count({cnode, 0}) != 0 ||
173         ref_out_value.count({cnode, 0}) != 0 ||
174         std::find_if(graph_outputs.begin(), graph_outputs.end(),
175                      [&cnode](const KernelWithIndex &output) {
176                        const auto &real_output = common::AnfAlgo::FetchRealNodeSkipMonadControl(output);
177                        return real_output == KernelWithIndex(cnode, 0);
178                      }) != graph_outputs.end() ||
179         std::find_if(cnode->inputs().begin(), cnode->inputs().end(), [](const auto &input) {
180           return common::AnfAlgo::CheckPrimitiveType(input, prim::kPrimConditionGather) ||
181                  common::AnfAlgo::CheckPrimitiveType(common::AnfAlgo::VisitKernelWithReturnType(input, 0).first,
182                                                      prim::kPrimConditionGather);
183         }) != cnode->inputs().end()) {
184       continue;
185     }
186     // NopNode that does not meet the above conditions is set to Ref Node and is not deleted from the graph to avoid
187     // incorrect shape information of KernelTensor obtained in KernelMod::Launch.
188     (void)nop_nodes_need_set_ref.emplace_back(cnode);
189   }
190 
191   // Add the ref node pairs, which must be after elimination to avoid using elimination nodes.
192   for (auto &ref_node : nop_nodes_need_set_ref) {
193     MS_EXCEPTION_IF_NULL(ref_node);
194     auto input_node = common::AnfAlgo::GetInputNode(ref_node, 0);
195     MS_EXCEPTION_IF_NULL(input_node);
196     // Record the original information of ref node.
197     auto origin_pair = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
198     MS_EXCEPTION_IF_NULL(origin_pair.first);
199     // The device address of parameter as input may be not the running used in the heterogeneous or control flow
200     // scenarios, and not set the ref node.
201     if (origin_pair.first->isa<Parameter>() || origin_pair.first->isa<ValueNode>()) {
202       continue;
203     }
204     // The ref node cannot be set for node pairs from different device target(appears in the kernel backoff scene).
205     if (AnfAlgo::FetchDeviceTarget(origin_pair.first, graph) != AnfAlgo::FetchDeviceTarget(ref_node, graph)) {
206       continue;
207     }
208     MS_LOG(INFO) << "The reference relation of nopnode " << ref_node->fullname_with_scope() << ", index: " << 0
209                  << " to input " << origin_pair.first->fullname_with_scope() << ", index: " << origin_pair.second;
210     graph->AddRefCorrespondPairs(std::make_pair(ref_node, 0), origin_pair);
211   }
212 }
213 
IsEnableZeroCopy(bool run_in_pynative)214 bool IsEnableZeroCopy(bool run_in_pynative) {
215   if (run_in_pynative) {
216     return false;
217   }
218 
219   auto ms_context = MsContext::GetInstance();
220   MS_EXCEPTION_IF_NULL(ms_context);
221   bool task_sink = ms_context->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
222   bool is_multi_graph_sink = ms_context->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK);
223   // If the run mode is not subgraph sink, the flag should not be set.
224   if (!task_sink || is_multi_graph_sink) {
225     return false;
226   }
227 
228 // In ps cache mode, the whole graph sink has set multi_graph_sink to false, the zero copy cannot be enabled.
229 #if defined(__linux__) && defined(WITH_BACKEND)
230   if (ps::PSContext::instance()->cache_enable()) {
231     return false;
232   }
233 #endif
234 
235   auto parallel_context = parallel::ParallelContext::GetInstance();
236   MS_EXCEPTION_IF_NULL(parallel_context);
237   auto parallel_mode = parallel_context->parallel_mode();
238   bool is_parallel_mode = parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel ||
239                           parallel_mode == parallel::kHybridParallel || parallel_mode == parallel::kDataParallel;
240   // If there are auto parallel in graph, the flag should not be set. In parallel, the continue memory in communication
241   // ops not support addr change.
242   // force zero copy when use ge
243   bool is_enable_ge = ms_context->backend_policy() == "ge";
244   if (is_parallel_mode && !is_enable_ge) {
245     return false;
246   }
247   return true;
248 }
249 
SetRunGraphBySingleOpFlag(const KernelGraphPtr & graph)250 void SetRunGraphBySingleOpFlag(const KernelGraphPtr &graph) {
251   MS_EXCEPTION_IF_NULL(graph);
252   for (auto &node : graph->execution_order()) {
253     MS_EXCEPTION_IF_NULL(node);
254     MS_EXCEPTION_IF_NULL(node->input(0));
255     bool enable = false;
256     if (!AnfAlgo::NodeValueIsFuncGraph(node->input(0))) {
257       if (!kernel::CheckResizeCondition(node) && graph->has_flag(kFlagPyNativeRunInGraph)) {
258         MS_LOG(INFO) << "Enable Run Graph By Single Op";
259         enable = true;
260       }
261     }
262     // BpGraph contain bprop_cut node.
263     auto contain_bprop_cut = common::AnfAlgo::IsBpropCutOpExecInBackend(node);
264     if (enable || contain_bprop_cut) {
265       MS_LOG(INFO) << "Set kFlagEnableRunGraphBySingleOp: NeedSkipResize:" << enable
266                    << ", BpGraph contain bprop_cut node:" << contain_bprop_cut;
267       graph->set_flag(kFlagEnableRunGraphBySingleOp, true);
268       break;
269     }
270   }
271 }
272 
UseCacheToCompileGraphImpl(const KernelGraphPtr & graph,const DeviceContext * device_context)273 void UseCacheToCompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) {
274   MS_EXCEPTION_IF_NULL(graph);
275   MS_EXCEPTION_IF_NULL(device_context);
276 
277   auto &compile_cache_context = CompileCacheContext::GetInstance();
278   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCreateKernel, 1, 0, 0);
279   compile_cache_context.SetFusionOpBuildInfoFlag(true);
280   device_context->GetKernelExecutor(false)->CreateKernel(graph->execution_order());
281   compile_cache_context.SetFusionOpBuildInfoFlag(false);
282   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCreateKernel, 1, 0, 1);
283   // Kernels that are not supported by other device can be backed off and rebuilt on the CPU.
284 #ifdef WITH_BACKEND
285   if (!graph->is_from_single_op()) {
286     auto cpu_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
287       {kCPUDevice, device_context->device_context_key().device_id_});
288     MS_EXCEPTION_IF_NULL(cpu_context);
289     auto cpu_executor = dynamic_cast<device::cpu::CPUKernelExecutor *>(cpu_context->GetKernelExecutor(false).get());
290     MS_EXCEPTION_IF_NULL(cpu_executor);
291     cpu_executor->RebuildKernelSelectBackoffOp(graph->execution_order());
292   }
293 #endif
294 #ifndef ENABLE_SECURITY
295   // Update needed dump kernels for mindRT.
296   DumpJsonParser::GetInstance().UpdateNeedDumpKernels(*graph.get());
297 #endif
298   if (graph->is_dynamic_shape()) {
299     auto profiler_manage_inst = profiler::ProfilerManager::GetInstance();
300     MS_EXCEPTION_IF_NULL(profiler_manage_inst);
301     profiler_manage_inst->SetNetDynamicShapeStatus();
302   }
303 }
304 
IsValidSequence(const ValueSequencePtr & sequence_value)305 bool IsValidSequence(const ValueSequencePtr &sequence_value) {
306   MS_EXCEPTION_IF_NULL(sequence_value);
307   const auto &values = sequence_value->value();
308   if (values.empty()) {
309     return true;
310   }
311   MS_EXCEPTION_IF_NULL(values[0]);
312   if (values[0]->isa<ValueSequence>()) {
313     return false;
314   }
315   if (values[0]->type() == nullptr) {
316     MS_LOG(DEBUG) << "Failed to get type from value tuple:" << sequence_value->ToString();
317     return false;
318   }
319   TypeId base_type = values[0]->type()->type_id();
320   for (size_t i = 1; i < values.size(); ++i) {
321     MS_EXCEPTION_IF_NULL(values[i]);
322     MS_EXCEPTION_IF_NULL(values[i]->type());
323     TypeId type = values[i]->type()->type_id();
324     if (type != base_type) {
325       MS_LOG(DEBUG) << "Invalid value type for value:" << sequence_value->ToString();
326       return false;
327     }
328   }
329   return true;
330 }
331 
CollectValueNodeForKernelGraph(const KernelGraphPtr & graph)332 void CollectValueNodeForKernelGraph(const KernelGraphPtr &graph) {
333   MS_EXCEPTION_IF_NULL(graph);
334   graph->ClearAllValueNode();
335   const auto &nodes = TopoSort(graph->get_return());
336   for (const auto &node : nodes) {
337     MS_EXCEPTION_IF_NULL(node);
338     if (!node->isa<ValueNode>() || node->kernel_info() == nullptr) {
339       continue;
340     }
341     const auto &value_node = node->cast<ValueNodePtr>();
342     MS_EXCEPTION_IF_NULL(value_node);
343     const auto &value = value_node->value();
344     MS_EXCEPTION_IF_NULL(value);
345     if (value->isa<Primitive>() ||
346         (value->isa<ValueSequence>() && (!IsValidSequence(value->cast<ValueSequencePtr>())))) {
347       continue;
348     }
349     MS_LOG(DEBUG) << "Add value node:" << node->DebugString() << " for kernel graph:" << graph->ToString();
350     graph->AddValueNodeToGraph(value_node);
351   }
352 }
353 
CompileAnyTypeInputGraph(const KernelGraphPtr & graph,const AnfNodePtrList & outputs,const DeviceContext * device_context)354 GraphId CompileAnyTypeInputGraph(const KernelGraphPtr &graph, const AnfNodePtrList &outputs,
355                                  const DeviceContext *device_context) {
356   MS_EXCEPTION_IF_NULL(graph);
357   for (const auto &input : graph->inputs()) {
358     MS_EXCEPTION_IF_NULL(input);
359     MS_LOG(DEBUG) << "input node:" << input->DebugString()
360                   << " abstract:" << (input->abstract() == nullptr ? "null" : input->abstract()->ToString());
361   }
362   MS_LOG(DEBUG) << "Pre construct any type input kernel graph:" << graph->ToString();
363   graph->set_is_any_type_input(true);
364   opt::OptimizationForAnyTypeKernelGraph(graph);
365   graph->SetInputNodes();
366   for (const auto &input : graph->input_nodes()) {
367     MS_EXCEPTION_IF_NULL(input);
368     MS_LOG(DEBUG) << "input node:" << input->DebugString()
369                   << " abstract:" << (input->abstract() == nullptr ? "null" : input->abstract()->ToString());
370     if (!input->isa<Parameter>()) {
371       continue;
372     }
373     const auto &parameter = input->cast<ParameterPtr>();
374     MS_EXCEPTION_IF_NULL(parameter);
375     const auto &shape = parameter->Shape();
376     if (shape != nullptr &&
377         ((shape->isa<abstract::Shape>() && shape->IsDynamic()) || shape->isa<abstract::DynamicSequenceShape>())) {
378       parameter->set_has_dynamic_shape(true);
379     }
380   }
381   auto backend_output = graph->output();
382   MS_EXCEPTION_IF_NULL(backend_output);
383   graph->CacheGraphOutputToFrontNodeWithIndex({backend_output}, outputs);
384   graph->UpdateInternalParameter();
385   DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
386 
387   auto output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
388   for (const auto &output_with_index : output_with_indexs) {
389     const auto &output = output_with_index.first;
390     MS_EXCEPTION_IF_NULL(output);
391     if (common::AnfAlgo::IsBpropCutOpExecInBackend(output) || HasAbstractMonad(output)) {
392       continue;
393     }
394     if (output->kernel_info() == nullptr) {
395       output->set_kernel_info(std::make_shared<device::KernelInfo>());
396     }
397     auto kernel_info = dynamic_cast<device::KernelInfo *>(output->kernel_info());
398     MS_EXCEPTION_IF_NULL(kernel_info);
399     // select_kernel_build_info() has checked whether return pointer is null
400     auto build_info = kernel_info->select_kernel_build_info();
401     if (build_info != nullptr) {
402       continue;
403     }
404     size_t output_num = 1;
405     if (output->abstract() != nullptr) {
406       output_num = common::AnfAlgo::GetOutputNumByAbstract(output->abstract());
407     }
408     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
409     builder.SetOutputsFormat(std::vector<std::string>(output_num, kOpFormat_DEFAULT));
410     builder.SetOutputsDeviceType(std::vector<TypeId>(output_num, kTypeUnknown));
411     builder.SetOutputsKernelObjectType(
412       std::vector<kernel::KernelObjectType>(output_num, kernel::KernelObjectType::TENSOR));
413     AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), output.get());
414     MS_LOG(DEBUG) << "Set kernel build info for node:" << output->DebugString() << " output num:" << output_num;
415   }
416   CollectValueNodeForKernelGraph(graph);
417   DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
418   DeviceAddressUtils::CreateGraphOutputDeviceAddress(device_context, graph);
419   return graph->graph_id();
420 }
421 }  // namespace
422 
CompileGraph(const GraphSegmentPtr & segment,const std::pair<AnfNodePtrList,AnfNodePtrList> & io_nodes,const DeviceContext * device_context,device::RunMode run_mode,bool run_in_pynative)423 GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment,
424                                     const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes,
425                                     const DeviceContext *device_context, device::RunMode run_mode,
426                                     bool run_in_pynative) {
427   MS_EXCEPTION_IF_NULL(segment);
428   MS_EXCEPTION_IF_NULL(device_context);
429   MS_LOG(INFO) << "Status record: start compile graph.";
430   auto nodes = segment->nodes_;
431   auto device_target = device_context->GetDeviceType();
432   // Generate kernel graph.
433   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageConstructKernelGraph, 1, 0, 0);
434   auto kernel_graph =
435     session_->ConstructKernelGraph(nodes, io_nodes.second, device_target, true, IsEnableZeroCopy(run_in_pynative));
436   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageConstructKernelGraph, 1, 0, 1);
437   SetGraphDependency(kernel_graph, segment);
438   return CompileGraph(kernel_graph, io_nodes, device_context, run_mode, run_in_pynative);
439 }
440 
CompileGraph(const KernelGraphPtr & kernel_graph,const std::pair<AnfNodePtrList,AnfNodePtrList> & io_nodes,const DeviceContext * device_context,device::RunMode run_mode,bool run_in_pynative)441 GraphId GraphCompiler::CompileGraph(const KernelGraphPtr &kernel_graph,
442                                     const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes,
443                                     const DeviceContext *device_context, device::RunMode run_mode,
444                                     bool run_in_pynative) {
445   MS_EXCEPTION_IF_NULL(session_);
446   MS_EXCEPTION_IF_NULL(device_context);
447   MS_EXCEPTION_IF_NULL(kernel_graph);
448 
449   const auto &outputs = io_nodes.second;
450   if (common::AnfAlgo::IsAnyTypeInput(io_nodes.first)) {
451     return CompileAnyTypeInputGraph(kernel_graph, outputs, device_context);
452   }
453   kernel_graph->erase_flag(kFlagPyNativeRunInGraph);
454   SetRunGraphBySingleOpFlag(kernel_graph);
455   kernel_graph->UpdateGraphAquireGilAttr();
456   if (run_mode == device::RunMode::kUnknown) {
457     kernel_graph->set_run_mode(device_context->GetRunMode(kernel_graph));
458   } else {
459     kernel_graph->set_run_mode(run_mode);
460   }
461   auto manager = MakeManager({kernel_graph});
462   if (manager) {
463     manager->AddFuncGraph(kernel_graph);
464     kernel_graph->set_manager(manager);
465   }
466 
467   opt::OptimizationWithoutBackend(kernel_graph);
468   // Unify the MindIR, must be before of the kernel_graph optimization.
469   auto kernel_executor = device_context->GetKernelExecutor(false);
470   if (kernel_executor != nullptr) {
471     kernel_executor->AddMindIRPass(kernel_graph);
472   }
473   kernel_graph->SetInputNodes();
474   auto context_ptr = MsContext::GetInstance();
475   session_->SetInputNodeUsage(kernel_graph, manager);
476   MS_EXCEPTION_IF_NULL(context_ptr);
477   if (context_ptr->backend_policy() == "ge" && device_context->GetDeviceType() == device::DeviceType::kAscend &&
478       !IsEnableRefMode()) {
479     MS_EXCEPTION_IF_NULL(device_context->graph_executor_);
480     if (!device_context->graph_executor_->CompileGraph(kernel_graph, {})) {
481       MS_LOG(EXCEPTION) << "Compile kernel_graph failed: " << kernel_graph->graph_id();
482     }
483     kernel_graph->UpdateInternalParameter();
484     kernel_graph->CacheGraphOutputToFrontNodeWithIndex({kernel_graph->output()}, outputs);
485     kernel_graph->set_front_outputs(outputs);
486     return kernel_graph->graph_id();
487   }
488   kernel_graph->SetOptimizerFlag();
489 
490   GraphId graph_id = 0;
491   if (run_in_pynative) {
492     MS_EXCEPTION_IF_NULL(session_);
493     // kernel_graph kernel does not support pynative mode now, print a warning here.
494     graphkernel::GraphKernelFlags::GetInstance().CheckSupport();
495     graph_id = kernel_graph->graph_id();
496   } else {
497     graph_id = CompileGraphImpl(kernel_graph, device_context, run_in_pynative);
498   }
499 
500   kernel_graph->set_front_outputs(outputs);
501 
502   kernel_graph->set_root_graph_id(graph_id);
503   session_->DumpGraphs({kernel_graph});
504 
505   // The kernel_graph is not compiled yet in PyNative Mode.
506   // Need to cache output latter when the kernel_graph is compiled.
507   if (!run_in_pynative) {
508     // Cache the backend kernel_graph output nodes to front nodes with output index.
509     auto backend_node = kernel_graph->output();
510     MS_EXCEPTION_IF_NULL(backend_node);
511     kernel_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs);
512   }
513   AnfAlgo::UpdateGraphValidRefPair(kernel_graph);
514 
515   MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
516   return graph_id;
517 }
518 
~GraphCompilerInfo()519 GraphCompilerInfo::~GraphCompilerInfo() {
520   GraphScheduler::GetInstance().Clear(name_, graphs_, origin_parameters_order_, control_node_parser_);
521 }
522 
CompileDynamicGraph(const GraphSegmentPtr & segment,const AnfNodePtrList & outputs,const DeviceContext * device_context)523 GraphId GraphCompiler::CompileDynamicGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs,
524                                            const DeviceContext *device_context) {
525   MS_EXCEPTION_IF_NULL(segment);
526   MS_EXCEPTION_IF_NULL(device_context);
527   MS_LOG(INFO) << "Status record: start compile graph.";
528 
529   auto nodes = segment->nodes_;
530   auto device_target = device_context->GetDeviceType();
531   // Generate kernel graph.
532   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageConstructKernelGraph, 1, 0, 0);
533   const auto &kernel_graph = session_->ConstructKernelGraph(nodes, outputs, device_target, true, false);
534   return CompileDynamicGraph(kernel_graph, device_context);
535 }
536 
CompileDynamicGraph(const KernelGraphPtr & kernel_graph,const DeviceContext * device_context)537 GraphId GraphCompiler::CompileDynamicGraph(const KernelGraphPtr &kernel_graph, const DeviceContext *device_context) {
538   MS_EXCEPTION_IF_NULL(kernel_graph);
539   MS_EXCEPTION_IF_NULL(device_context);
540   // Dynamic shape or dynamic graph structure flag.
541   kernel_graph->set_flag(kAttrMutableKernel, true);
542   MS_LOG(INFO) << "Set kFlagEnableRunGraphBySingleOp: Dynamic shape or dynamic graph structure flag";
543   kernel_graph->set_flag(kFlagEnableRunGraphBySingleOp, true);
544 
545   kernel_graph->UpdateGraphAquireGilAttr();
546   kernel_graph->SetInputNodes();
547   auto manager = Manage(kernel_graph);
548   if (manager) {
549     manager->AddFuncGraph(kernel_graph);
550     kernel_graph->set_manager(manager);
551   }
552   session_->SetInputNodeUsage(kernel_graph, manager);
553   kernel_graph->SetOptimizerFlag();
554   kernel_graph->set_run_mode(device::RunMode::kKernelMode);
555 
556   // kernel_graph kernel does not support pynative mode now, print a warning here.
557   graphkernel::GraphKernelFlags::GetInstance().CheckSupport();
558 
559   GraphId graph_id = kernel_graph->graph_id();
560   kernel_graph->set_root_graph_id(graph_id);
561   session_->DumpGraphs({kernel_graph});
562 
563   MS_LOG(INFO) << "Status record: end compile kernel_graph. kernel_graph id: " << graph_id;
564   return graph_id;
565 }
566 
ConstructKernelGraphForGraphRunMode(const FuncGraphPtr & func_graph,const DeviceContext * device_context,std::vector<KernelGraphPtr> * const all_graphs,bool * const need_return_ahead)567 KernelGraphPtr GraphCompiler::ConstructKernelGraphForGraphRunMode(const FuncGraphPtr &func_graph,
568                                                                   const DeviceContext *device_context,
569                                                                   std::vector<KernelGraphPtr> *const all_graphs,
570                                                                   bool *const need_return_ahead) {
571   MS_EXCEPTION_IF_NULL(func_graph);
572   MS_EXCEPTION_IF_NULL(device_context);
573   MS_EXCEPTION_IF_NULL(all_graphs);
574   auto device_target = device_context->GetDeviceType();
575   KernelGraphPtr root_graph = session_->ConstructKernelGraph(func_graph, all_graphs, device_target);
576   MS_EXCEPTION_IF_NULL(root_graph);
577   for (const auto &graph : *all_graphs) {
578     MS_EXCEPTION_IF_NULL(graph);
579     MS_LOG(INFO) << "Set root graph for graph: " << graph->graph_id() << " to: " << root_graph->graph_id() << ".";
580     graph->set_root_graph_id(root_graph->graph_id());
581     graph->set_run_mode(device::RunMode::kGraphMode);
582     graph->set_is_loop_count_sink(true);
583     graph->set_attrs(func_graph->attrs());
584     opt::OptimizationWithoutBackend(graph);
585   }
586 
587   // Unify the MindIR, must be before of the graph optimization.
588   auto kernel_executor = device_context->GetKernelExecutor(false);
589   if (kernel_executor != nullptr) {
590     kernel_executor->AddMindIRPass(root_graph);
591   }
592 
593   // todo: waiting for GraphExecutor
594   MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
595   if (MsContext::GetInstance()->backend_policy() == "ge") {
596     auto manager = MakeManager();
597     MS_EXCEPTION_IF_NULL(manager);
598     for (const auto &graph : *all_graphs) {
599       MS_EXCEPTION_IF_NULL(graph);
600       graph->set_flag(kFlagEnableZeroCopyInGraph, true);
601       manager->AddFuncGraph(graph);
602       graph->set_manager(manager);
603       graph->SetInputNodes();
604     }
605     root_graph->SetInputNodes();
606     MS_EXCEPTION_IF_NULL(device_context->graph_executor_);
607     if (!device_context->graph_executor_->CompileGraph(root_graph, {})) {
608       MS_LOG(EXCEPTION) << "Compile graph failed: " << root_graph->graph_id();
609     }
610     root_graph->CacheGraphOutputToFrontNodeWithIndex({root_graph->output()}, {func_graph->output()});
611     *need_return_ahead = true;
612   }
613   if (*need_return_ahead) {
614     return root_graph;
615   }
616   // set executing sink true in graph mode
617   root_graph->set_run_mode(device::RunMode::kGraphMode);
618   root_graph->set_is_loop_count_sink(true);
619 #if defined(__linux__) && defined(WITH_BACKEND)
620   // Embedding cache need global step of compute graph, can not enable loop sink, move loop control to loop count actor.
621   if (ps::PSContext::instance()->cache_enable()) {
622     root_graph->set_is_loop_count_sink(false);
623     for (const auto &graph : *all_graphs) {
624       MS_EXCEPTION_IF_NULL(graph);
625       graph->set_is_loop_count_sink(false);
626     }
627   }
628 #endif
629   root_graph->SetInputNodes();
630   return root_graph;
631 }
632 
CompileWholeGraphForGraphRunMode(const FuncGraphPtr & func_graph,const DeviceContext * device_context)633 GraphId GraphCompiler::CompileWholeGraphForGraphRunMode(const FuncGraphPtr &func_graph,
634                                                         const DeviceContext *device_context) {
635   MS_EXCEPTION_IF_NULL(session_);
636   MS_EXCEPTION_IF_NULL(func_graph);
637   MS_EXCEPTION_IF_NULL(device_context);
638   MS_LOG(INFO) << "Status record: start compile graph.";
639   // Generate kernel graph.
640   std::vector<KernelGraphPtr> all_graphs;
641   auto device_target = device_context->GetDeviceType();
642   KernelGraphPtr root_graph;
643   bool need_return_ahead = false;
644   if (UseCacheToCompileGraph(func_graph, device_target)) {
645     root_graph = session_->ConstructKernelGraph(&all_graphs);
646     use_cache_to_compile_graph_ = true;
647   } else {
648     root_graph = ConstructKernelGraphForGraphRunMode(func_graph, device_context, &all_graphs, &need_return_ahead);
649   }
650   GraphId graph_id = root_graph->graph_id();
651   if (need_return_ahead) {
652     return graph_id;
653   }
654   if (ExportCompileCache(func_graph, device_target)) {
655     export_compile_cache_ = true;
656   }
657   if (!func_graph->has_flag(kFlagPyNativeRunInGraph)) {
658     graph_id = CompileGraphImpl(root_graph, device_context);
659   }
660   if (CompileCacheEnable()) {
661     CompileCacheContext::GetInstance().Clear();
662   }
663 
664   // dump all graphs.
665   // for ascend mindRT.
666   session_->DumpGraphs(all_graphs);
667 
668   if (!func_graph->has_flag(kFlagPyNativeRunInGraph)) {
669     // Cache the backend graph output nodes to front nodes with output index.
670     auto output = func_graph->output();
671     MS_EXCEPTION_IF_NULL(output);
672     auto backend_node = root_graph->output();
673     MS_EXCEPTION_IF_NULL(backend_node);
674     root_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {output});
675     AnfAlgo::UpdateGraphValidRefPair(root_graph);
676   } else {
677     for (auto &node : root_graph->execution_order()) {
678       if (common::AnfAlgo::IsBpropCutOpExecInBackend(node)) {
679         MS_LOG(INFO) << "Set kFlagEnableRunGraphBySingleOp: IsBpropCutOpExecInBackend";
680         root_graph->set_flag(kFlagEnableRunGraphBySingleOp, true);
681       }
682     }
683     root_graph->set_front_outputs({func_graph->output()});
684   }
685   MS_LOG(INFO) << "Status record: end compile graph. graph id: " << graph_id;
686   return graph_id;
687 }
688 
CompileGraphImpl(const KernelGraphPtr & graph,const DeviceContext * device_context,bool run_in_pynative) const689 GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context,
690                                         bool run_in_pynative) const {
691   MS_EXCEPTION_IF_NULL(graph);
692   MS_EXCEPTION_IF_NULL(device_context);
693   MS_EXCEPTION_IF_NULL(session_);
694   const auto &context = MsContext::GetInstance();
695   MS_EXCEPTION_IF_NULL(context);
696   if (use_cache_to_compile_graph_) {
697     UseCacheToCompileGraphImpl(graph, device_context);
698   } else {
699 #ifdef ENABLE_DUMP_IR
700     if (context->CanDump(kIntroductory)) {
701       // Dump .pb graph before graph optimization.
702       DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
703     }
704 #endif
705     MS_EXCEPTION_IF_NULL(device_context->GetKernelExecutor(false));
706     // Execute optimization pass.
707     (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageOptimizeGraph, 1, 0, 0);
708     PROF_START(OptimizeGraph);
709     device_context->GetKernelExecutor(false)->OptimizeGraph(graph);
710     PROF_END(OptimizeGraph);
711     (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageOptimizeGraph, 1, 0, 1);
712     // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
713     // 'KernelMod' is real executive object of kernel.
714     (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCreateKernel, 1, 0, 0);
715     PROF_START(CreateKernel);
716     device_context->GetKernelExecutor(false)->CreateKernel(graph->execution_order());
717     PROF_END(CreateKernel);
718     (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCreateKernel, 1, 0, 1);
719 
720     // Kernels that are not supported by other device can be backed off and rebuilt on the CPU.
721 #ifdef WITH_BACKEND
722     if (!graph->is_from_single_op()) {
723       auto cpu_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
724         {kCPUDevice, device_context->device_context_key().device_id_});
725       MS_EXCEPTION_IF_NULL(cpu_device_context);
726       auto cpu_executor =
727         dynamic_cast<device::cpu::CPUKernelExecutor *>(cpu_device_context->GetKernelExecutor(false).get());
728       MS_EXCEPTION_IF_NULL(cpu_executor);
729       cpu_executor->RebuildKernelSelectBackoffOp(graph->execution_order());
730     }
731 #endif
732 
733     // Read the output and input ref map and set to the kernel graph.
734     AnfAlgo::AddOutInRefToGraph(graph);
735 
736     // Optimize the nop node.
737     if (!run_in_pynative) {
738       OptimizeNopNode(graph.get());
739 #ifdef ENABLE_DUMP_IR
740       if (context->CanDump(kIntroductory)) {
741         DumpIR("hwopt_comm_after_eliminate_nopnode_" + graph->ToString() + ".ir", graph, true);
742       }
743 #endif
744     }
745 
746 #ifndef ENABLE_SECURITY
747     session_->RecurseSetSummaryNodesForAllGraphs(graph.get());
748     // Update needed dump kernels for mindRT.
749     DumpJsonParser::GetInstance().UpdateNeedDumpKernels(*graph.get());
750 #endif
751 
752     // dynamic shape pass of graphmode
753     if (graph->is_dynamic_shape()) {
754       if (!graph->is_graph_run_mode()) {
755         // Temporarily disable CustomActor for asynchronous InferShape and Resize for the dynamic shape scenario,
756         // and implement the corresponding capability through direct InferShape and Resize in KernelActor.
757         // opt::DynamicShapeConvertPass(graph);
758       }
759       auto profiler_manage_inst = profiler::ProfilerManager::GetInstance();
760       MS_EXCEPTION_IF_NULL(profiler_manage_inst);
761       profiler_manage_inst->SetNetDynamicShapeStatus();
762     }
763   }
764 
765   if (export_compile_cache_) {
766     session_->CacheKernelGraph(graph);
767   }
768   // Adjust kernel graph before run graph.
769   PROF_START(PreprocessBeforeRun);
770   device_context->GetKernelExecutor(false)->PreprocessBeforeRun(graph);
771   PROF_END(PreprocessBeforeRun);
772   graph->UpdateInternalParameter();
773   // Set device target for parameter affinity.
774   AnfAlgo::SetParameterDeviceTarget(graph);
775 
776   // Create device address for all anf nodes of graph.
777   CreateDeviceAddress(graph, device_context);
778 
779 #if defined(__linux__) && defined(WITH_BACKEND)
780   // Set device address for embedding cache parameter, only enable when enable embedding cache mode.
781   // `CreateDeviceAddress` should execute before this step.
782   EmbeddingCacheScheduler::GetInstance().SetEmbedCachedParamAddress(device_context, graph);
783 #endif
784 
785   SetSummaryNodesRefCount(graph.get());
786 #ifdef ENABLE_DUMP_IR
787   // Dump .pb graph after graph optimization.
788   if (context->CanDump(kIntroductory)) {
789     DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
790   }
791 #endif
792 
793 #ifdef ENABLE_DEBUGGER
794   auto debugger = Debugger::GetInstance();
795   MS_EXCEPTION_IF_NULL(debugger);
796   // Dump graph for GPU mindRT if dump is enabled.
797   debugger->DumpInGraphCompiler(graph);
798   if (debugger && debugger->DebuggerBackendEnabled()) {
799     // Load graphs for GPU and Ascend mindRT.
800     debugger->LoadGraphs(graph);
801   }
802 #endif
803 
804   graph->EnableRuntimeCache();
805   return graph->graph_id();
806 }
807 
Fetch(GraphId graph_id) const808 KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
809   MS_EXCEPTION_IF_NULL(session_);
810   return session_->GetGraph(graph_id);
811 }
812 
CreateDeviceAddress(const KernelGraphPtr & graph,const DeviceContext * device_context) const813 void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
814   MS_EXCEPTION_IF_NULL(graph);
815   MS_LOG(INFO) << "Status record: start create device address. graph id: " << graph->graph_id();
816   DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
817   DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
818   DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, false);
819   DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
820   DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
821   DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
822   MS_LOG(INFO) << "Status record: end create device address. graph id: " << graph->graph_id();
823 }
824 
GetParamAndOutputIndex(const KernelGraphPtr & graph,const std::vector<TensorPtr> & inputs,VectorRef * const outputs,std::map<AnfNodePtr,size_t> * parameter_index,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)825 void GraphCompiler::GetParamAndOutputIndex(
826   const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
827   std::map<AnfNodePtr, size_t> *parameter_index,
828   std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
829   MS_EXCEPTION_IF_NULL(session_);
830   session_->GetParameterIndex(graph.get(), inputs, parameter_index);
831   session_->CreateOutputPlaceholder(graph, inputs, outputs, output_indexes);
832 }
833 
GetSingleOpInputTensors(const CNodePtr & kernel,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,bool is_run_pyboost,InputInfo * const input_info)834 void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
835                                             const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
836                                             const std::map<AnfNodePtr, size_t> &parameter_index,
837                                             const std::vector<TensorPtr> &graph_inputs, bool is_run_pyboost,
838                                             InputInfo *const input_info) {
839   MS_EXCEPTION_IF_NULL(session_);
840   if (is_run_pyboost) {
841     session_->GetOpInputTensorsFromCNode(kernel, op_output, parameter_index, graph_inputs, input_info);
842   } else {
843     session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_info);
844   }
845 }
846 
GetSingleOpInputTensorByIndex(const CNodePtr & kernel,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,InputInfo * const input_info,size_t input_index)847 tensor::BaseTensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(
848   const CNodePtr &kernel, const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
849   const std::map<AnfNodePtr, size_t> &parameter_index, const std::vector<TensorPtr> &graph_inputs,
850   InputInfo *const input_info, size_t input_index) {
851   MS_EXCEPTION_IF_NULL(session_);
852   return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_info, input_index);
853 }
854 
GetSingleOpRunInfoAndGraphInfo(const CNodePtr & kernel,const InputInfo & input_info,bool use_dynamic_shape_process,session::BackendOpRunInfoPtr * op_run_info,const GraphOutputInfo * const graph_output_info)855 void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const InputInfo &input_info,
856                                                    bool use_dynamic_shape_process,
857                                                    session::BackendOpRunInfoPtr *op_run_info,
858                                                    const GraphOutputInfo *const graph_output_info) {
859   MS_EXCEPTION_IF_NULL(session_);
860   *op_run_info = session_->GetSingleOpRunInfo(kernel, input_info, graph_output_info);
861   (*op_run_info)->base_op_run_info.use_dynamic_shape_process = use_dynamic_shape_process;
862 }
863 
CalculateRefCount(const KernelGraphPtr & graph,std::map<KernelWithIndex,size_t> * ref_count) const864 void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
865   MS_EXCEPTION_IF_NULL(session_);
866   session_->GetRefCount(graph.get(), ref_count);
867 }
868 
CalculateForwardOpOutputCount(const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & inputs,std::map<std::string,size_t> * forward_op_output_tensor_id,const std::map<AnfNodePtr,size_t> & parameter_index) const869 void GraphCompiler::CalculateForwardOpOutputCount(const KernelGraphPtr &graph,
870                                                   const std::vector<tensor::TensorPtr> &inputs,
871                                                   std::map<std::string, size_t> *forward_op_output_tensor_id,
872                                                   const std::map<AnfNodePtr, size_t> &parameter_index) const {
873   MS_EXCEPTION_IF_NULL(session_);
874   MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
875   forward_op_output_tensor_id->clear();
876   session_->GetForwardOpOutputRefCount(graph.get(), inputs, forward_op_output_tensor_id, parameter_index);
877 }
878 
UpdateRefCount(const std::set<KernelWithIndex> & input_kernels_with_index,std::map<KernelWithIndex,size_t> * ref_count,std::map<KernelWithIndex,tensor::BaseTensorPtr> * op_output_map) const879 void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
880                                    std::map<KernelWithIndex, size_t> *ref_count,
881                                    std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map) const {
882   MS_EXCEPTION_IF_NULL(session_);
883   session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
884 }
885 
UpdateForwardOpOutputRefCount(const std::vector<ValuePtr> & input_values,std::map<std::string,size_t> * forward_op_output_tensor_id) const886 void GraphCompiler::UpdateForwardOpOutputRefCount(const std::vector<ValuePtr> &input_values,
887                                                   std::map<std::string, size_t> *forward_op_output_tensor_id) const {
888   MS_EXCEPTION_IF_NULL(session_);
889   MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
890   session_->ReleaseForwardOpOutput(input_values, forward_op_output_tensor_id);
891 }
892 
RecoverGraphOutput(const AnfNodePtr & kernel,const VectorRef & op_outputs,const std::map<KernelWithIndex,size_t> & ref_count,std::map<KernelWithIndex,tensor::BaseTensorPtr> * op_output_map,GraphOutputInfo * const graph_output_info) const893 void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
894                                        const std::map<KernelWithIndex, size_t> &ref_count,
895                                        std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map,
896                                        GraphOutputInfo *const graph_output_info) const {
897   MS_EXCEPTION_IF_NULL(session_);
898   session_->HandleOpOutputs(kernel, op_outputs, ref_count, op_output_map, graph_output_info);
899 }
900 
RegisterSummaryCallBackFunc(const CallBackFunc & callback) const901 void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
902   MS_EXCEPTION_IF_NULL(session_);
903 #ifndef ENABLE_SECURITY
904   session_->RegisterSummaryCallBackFunc(callback);
905 #endif
906 }
907 
Summary(const std::vector<KernelGraphPtr> & graphs) const908 void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
909   MS_EXCEPTION_IF_NULL(session_);
910   for (const auto &graph : graphs) {
911 #ifndef ENABLE_SECURITY
912     session_->Summary(graph.get());
913 #endif
914   }
915 }
916 
SetGraphDependency(const KernelGraphPtr & graph,const GraphSegmentPtr & segment) const917 void GraphCompiler::SetGraphDependency(const KernelGraphPtr &graph, const GraphSegmentPtr &segment) const {
918   MS_EXCEPTION_IF_NULL(graph);
919   MS_EXCEPTION_IF_NULL(segment);
920   segment->graph_id_ = graph->graph_id();
921   for (auto &pre_segment : segment->pre_segments_) {
922     MS_EXCEPTION_IF_NULL(pre_segment);
923     auto pre_graph = Fetch(pre_segment->graph_id_);
924     MS_EXCEPTION_IF_NULL(pre_graph);
925     pre_graph->AddPostGraph(graph);
926     graph->AddPreGraph(pre_graph);
927     MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph->graph_id();
928   }
929 }
930 }  // namespace runtime
931 }  // namespace mindspore
932