• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License"){}
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/framework/graph_compiler.h"
18 #include <numeric>
19 #include <map>
20 #include <utility>
21 #include "runtime/framework/graph_scheduler.h"
22 #include "runtime/device/device_address.h"
23 #include "common/trans.h"
24 #include "utils/convert_utils.h"
25 #include "ir/tensor.h"
26 #include "backend/optimizer/common/helper.h"
27 #include "base/base_ref_utils.h"
28 #include "debug/dump_proto.h"
29 #ifdef ENABLE_DEBUGGER
30 #include "debug/debugger/debugger.h"
31 #endif
32 #ifdef ENABLE_DUMP_IR
33 #include "debug/anf_ir_dump.h"
34 #include "debug/rdr/running_data_recorder.h"
35 #endif
36 #ifndef ENABLE_SECURITY
37 #include "debug/data_dump/dump_json_parser.h"
38 #endif
39 
40 namespace mindspore {
41 namespace runtime {
42 namespace {
43 // Whether device address of anf node is valid and device address type
44 // is consistent with device type, for example, device address type
45 // DeviceAddressType::kGPU should be used on GPU device
NodeDeviceAddressExist(const DeviceContext * device_context,const AnfNodePtr & kernel,size_t index)46 bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
47   MS_EXCEPTION_IF_NULL(kernel);
48   MS_EXCEPTION_IF_NULL(device_context);
49   if (AnfAlgo::OutputAddrExist(kernel, index)) {
50     const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
51     MS_EXCEPTION_IF_NULL(address);
52     return address->DeviceType() == device_context->GetDeviceAddressType();
53   }
54   return false;
55 }
56 
CreateParameterDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)57 void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
58   MS_EXCEPTION_IF_NULL(device_context);
59   MS_EXCEPTION_IF_NULL(graph);
60   std::vector<AnfNodePtr> graph_inputs = graph->inputs();
61   const std::vector<bool> &graph_valid_input = graph->valid_inputs();
62   (void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
63 
64   // Anf nodes which need create device address.
65   std::vector<AnfNodePtr> nodes_list;
66   for (size_t i = 0; i < graph_inputs.size(); ++i) {
67     AnfNodePtr item = graph_inputs[i];
68     MS_EXCEPTION_IF_NULL(item);
69     if (i < graph_valid_input.size() && !graph_valid_input[i]) {
70       continue;
71     }
72 
73     if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
74       std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item);
75       for (const auto &out : outs) {
76         MS_EXCEPTION_IF_NULL(out);
77         if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
78           continue;
79         }
80         nodes_list.push_back(out);
81       }
82     }
83     if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
84       continue;
85     }
86     nodes_list.push_back(item);
87   }
88 
89   // Create device address for anf node in nodes_list
90   for (const auto &item : nodes_list) {
91     auto output_size = AnfAlgo::GetOutputTensorNum(item);
92     for (size_t index = 0; index < output_size; index++) {
93       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
94       if (output_type_id == kTypeUnknown) {
95         output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
96       }
97 
98       size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
99       auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
100                                                                 AnfAlgo::GetOutputFormat(item, index), output_type_id);
101       AnfAlgo::SetOutputAddr(device_address, index, item.get());
102     }
103   }
104 }
105 
CreateDeviceAddressForTensorValue(const DeviceContext * device_context,const ValuePtr & node_value,size_t output_idx,const ValueNodePtr & value_node)106 void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
107                                        size_t output_idx, const ValueNodePtr &value_node) {
108   MS_EXCEPTION_IF_NULL(device_context);
109   MS_EXCEPTION_IF_NULL(node_value);
110   MS_EXCEPTION_IF_NULL(value_node);
111   const auto &ms_context = MsContext::GetInstance();
112   MS_EXCEPTION_IF_NULL(ms_context);
113   std::vector<TensorPtr> tensors;
114   TensorValueToTensor(node_value, &tensors);
115 
116   for (const auto &tensor : tensors) {
117     if (tensor == nullptr) {
118       MS_LOG(WARNING) << "Tensor is null";
119       return;
120     }
121     auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
122     if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
123       bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
124       bool is_graph_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
125       if (is_graph_mode || is_pynative_infer) {
126         AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
127                                value_node.get());
128       }
129       continue;
130     }
131 
132     size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
133     TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
134     if (output_type_id == kTypeUnknown) {
135       output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
136     }
137     std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
138 
139     device::DeviceAddressPtr address =
140       device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
141     MS_EXCEPTION_IF_NULL(address);
142     AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
143   }
144 }
145 
CreateValueNodeDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)146 void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
147   MS_EXCEPTION_IF_NULL(device_context);
148   MS_EXCEPTION_IF_NULL(graph);
149   for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
150     MS_EXCEPTION_IF_NULL(value_node);
151     if (NodeDeviceAddressExist(device_context, value_node, 0)) {
152       continue;
153     }
154 
155     const auto &node_value = value_node->value();
156     MS_EXCEPTION_IF_NULL(node_value);
157     if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
158       CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
159     } else if (node_value->isa<StringImm>()) {
160       auto value = GetValue<std::string>(node_value);
161       size_t tensor_size = value.size();
162       auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
163       MS_EXCEPTION_IF_NULL(address);
164 
165       AnfAlgo::SetOutputAddr(address, 0, value_node.get());
166     }
167   }
168 }
169 
CreateKernelOutputDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)170 void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
171   MS_EXCEPTION_IF_NULL(device_context);
172   MS_EXCEPTION_IF_NULL(graph);
173   const std::vector<CNodePtr> &kernels = graph->execution_order();
174   for (const auto &kernel : kernels) {
175     MS_EXCEPTION_IF_NULL(kernel);
176     if (AnfAlgo::IsControlOpExecInBackend(kernel)) {
177       continue;
178     }
179     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
180     MS_EXCEPTION_IF_NULL(kernel_mod);
181     auto output_sizes = kernel_mod->GetOutputSizeList();
182     for (size_t i = 0; i < output_sizes.size(); ++i) {
183       if (AnfAlgo::OutputAddrExist(kernel, i)) {
184         continue;
185       }
186 
187       std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
188       auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
189       auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
190       AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
191     }
192   }
193 }
194 
CreateKernelWorkspaceDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)195 void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
196   MS_EXCEPTION_IF_NULL(device_context);
197   MS_EXCEPTION_IF_NULL(graph);
198   const std::vector<CNodePtr> &kernels = graph->execution_order();
199   for (const auto &kernel : kernels) {
200     MS_EXCEPTION_IF_NULL(kernel);
201     if (AnfAlgo::IsControlOpExecInBackend(kernel)) {
202       continue;
203     }
204     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
205     MS_EXCEPTION_IF_NULL(kernel_mod);
206     auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
207     for (size_t i = 0; i < workspace_sizes.size(); ++i) {
208       auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
209       AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
210     }
211   }
212 }
213 
UpdateDeviceAddressForInplaceNode(const KernelGraphPtr & graph)214 void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
215   MS_EXCEPTION_IF_NULL(graph);
216   // Collect the inplace groups.
217   std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
218   const std::vector<CNodePtr> &kernels = graph->execution_order();
219   for (const auto &kernel : kernels) {
220     if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
221       continue;
222     }
223     auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
224     MS_EXCEPTION_IF_NULL(primitive);
225     auto inplace_group_attr = primitive->GetAttr("inplace_group");
226     MS_EXCEPTION_IF_NULL(inplace_group_attr);
227     auto group_id = GetValue<uint32_t>(inplace_group_attr);
228     (void)inplace_groups[group_id].emplace_back(kernel);
229   }
230 
231   const size_t kMinInplaceGroupSize = 2;
232   for (const auto &inplace_group : inplace_groups) {
233     auto &group_nodes = inplace_group.second;
234     if (group_nodes.size() < kMinInplaceGroupSize) {
235       continue;
236     }
237     // Get the device address of the first node in the inplace group.
238     auto node_primitive = AnfAlgo::GetCNodePrimitive(group_nodes[0]);
239     MS_EXCEPTION_IF_NULL(node_primitive);
240     auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
241     auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
242     MS_EXCEPTION_IF_NULL(device_address);
243 
244     // Update the device address of other nodes using device address of the first node in the inplace group.
245     for (size_t i = 1; i < group_nodes.size(); ++i) {
246       auto &group_node = group_nodes[i];
247       auto prim = AnfAlgo::GetCNodePrimitive(group_node);
248       MS_EXCEPTION_IF_NULL(prim);
249       auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
250       AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
251       // Update the reference count of device address.
252       device_address->IncreaseOriginalRefCount();
253       device_address->ResetRefCount();
254     }
255   }
256 }
257 
SetSummaryNodesRefCount(const KernelGraph * graph)258 void SetSummaryNodesRefCount(const KernelGraph *graph) {
259   MS_EXCEPTION_IF_NULL(graph);
260   if (!graph->summary_node_exist()) {
261     return;
262   }
263 
264   const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes();
265   if (summary_nodes.empty()) {
266     return;
267   }
268 
269   for (const auto &item : summary_nodes) {
270     const AnfNodePtr &node = item.second.first;
271     size_t index = IntToSize(item.second.second);
272     auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
273     MS_EXCEPTION_IF_NULL(device_address);
274     device_address->set_original_ref_count(SIZE_MAX);
275     device_address->ResetRefCount();
276   }
277 }
278 
UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> & output_with_index)279 void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_with_index) {
280   for (const auto &item_with_index : output_with_index) {
281     if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) {
282       continue;
283     }
284     auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false);
285     MS_EXCEPTION_IF_NULL(device_address);
286     device_address->set_original_ref_count(SIZE_MAX);
287     device_address->ResetRefCount();
288   }
289 }
290 }  // namespace
291 
~GraphCompilerInfo()292 GraphCompilerInfo::~GraphCompilerInfo() { GraphScheduler::GetInstance().Clear(name_, graphs_); }
293 
CompileGraph(const AnfNodePtrList & nodes,const AnfNodePtrList & outputs,const DeviceContext * device_context)294 GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs,
295                                     const DeviceContext *device_context) {
296   MS_EXCEPTION_IF_NULL(session_);
297   // Generate kernel graph.
298   KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs);
299   MS_EXCEPTION_IF_NULL(graph);
300 
301   // Cache the backend graph output nodes to front nodes with output index.
302   for (auto &output : outputs) {
303     auto backend_node = graph->GetBackendAnfByFrontAnf(output);
304     if (backend_node != nullptr) {
305       graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output);
306     }
307   }
308 
309   return CompileGraphImpl(graph, device_context);
310 }
311 
CompileGraphImpl(const KernelGraphPtr & graph,const DeviceContext * device_context) const312 GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
313   MS_EXCEPTION_IF_NULL(graph);
314   MS_EXCEPTION_IF_NULL(device_context);
315   const auto &ms_context = MsContext::GetInstance();
316   MS_EXCEPTION_IF_NULL(ms_context);
317 #ifdef ENABLE_DUMP_IR
318   bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
319   // Dump .pb graph before graph optimization.
320   if (save_graphs) {
321     DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
322   }
323 #endif
324 
325   MS_LOG(INFO) << "Get graph outputs before optimizer, graph id: " << graph->graph_id();
326   auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
327 
328   // Execute optimization pass.
329   device_context->OptimizeGraph(graph);
330 
331   // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
332   // 'KernelMod' is real executive object of kernel.
333   device_context->CreateKernel(graph->execution_order());
334 
335   // Adjust kernel graph before run graph.
336   device_context->PreprocessBeforeRunGraph(graph);
337 
338   MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id();
339   auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
340   // Update the output map of kernel graph by modified output nodes.
341   graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
342 
343   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
344     // Create device address for all anf nodes of graph.
345     CreateDeviceAddress(graph, device_context);
346   }
347 
348   graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
349 
350   MS_EXCEPTION_IF_NULL(session_);
351   session_->InitAllBucket(graph, device_context);
352 #ifndef ENABLE_SECURITY
353   session_->SetSummaryNodes(graph.get());
354 #endif
355   SetSummaryNodesRefCount(graph.get());
356 #ifdef ENABLE_DUMP_IR
357   // Dump .pb graph after graph optimization.
358   if (save_graphs) {
359     DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
360   }
361 #endif
362 
363 #ifdef ENABLE_DEBUGGER
364   auto debugger = Debugger::GetInstance();
365   debugger->DumpInGraphCompiler(graph);
366   if (debugger && debugger->DebuggerBackendEnabled()) {
367     debugger->LoadGraphs(graph);
368   }
369 #endif
370 
371 #ifdef ENABLE_DUMP_IR
372   std::string name = "graph_build";
373   DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
374   (void)mindspore::RDR::RecordAnfGraph(SubModuleId::SM_SESSION, name, graph, dump_params, ".ir,.pb");
375   auto &kernels = graph->execution_order();
376   std::string exec_order_name = "graph_exec_order." + std::to_string(graph->graph_id());
377   (void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels);
378 #endif
379 
380   session_->DumpGraph(graph);
381   return graph->graph_id();
382 }
383 
CompileGraph(const session::OpRunInfo & op_run_info,const GraphInfo & graph_info,const std::vector<int64_t> * tensors_mask,std::vector<TensorPtr> * const input_tensors,bool * single_op_cache_hit,const DeviceContext * device_context)384 GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, const GraphInfo &graph_info,
385                                     const std::vector<int64_t> *tensors_mask,
386                                     std::vector<TensorPtr> *const input_tensors, bool *single_op_cache_hit,
387                                     const DeviceContext *device_context) {
388   // Check if the graph cache exists.
389   auto iter = run_op_graphs_.find(graph_info);
390   if (iter != run_op_graphs_.end()) {
391     const auto &graph = iter->second;
392     MS_EXCEPTION_IF_NULL(graph);
393     *single_op_cache_hit = true;
394     return graph->graph_id();
395   }
396   *single_op_cache_hit = false;
397   // Generate kernel graph.
398   MS_EXCEPTION_IF_NULL(session_);
399   KernelGraphPtr graph = session_->ConstructSingleOpGraph(op_run_info, *input_tensors, *tensors_mask);
400   MS_EXCEPTION_IF_NULL(graph);
401 
402   MS_EXCEPTION_IF_NULL(device_context);
403   device_context->OptimizeSingleOpGraph(graph);
404 
405   // Generate 'KernelMod' for kernel in graph.
406   device_context->CreateKernel(graph->execution_order());
407 
408   device_context->PreprocessBeforeRunSingleOpGraph(graph);
409 
410   // Create device address for all anf nodes of graph.
411   CreateDeviceAddress(graph, device_context);
412 
413   graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
414   run_op_graphs_[graph_info] = graph;
415 
416   auto output_nodes = graph->outputs();
417   auto &outputs_with_index = run_op_graph_output_nodes_[graph->graph_id()];
418   for (auto &node : output_nodes) {
419     MS_EXCEPTION_IF_NULL(node);
420     (void)outputs_with_index.emplace_back(AnfAlgo::VisitKernelWithReturnType(node, 0, false));
421   }
422 
423   UpdateRefCountForGraphOutput(outputs_with_index);
424 
425   return graph->graph_id();
426 }
427 
Fetch(GraphId graph_id) const428 KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
429   MS_EXCEPTION_IF_NULL(session_);
430   return session_->GetGraph(graph_id);
431 }
432 
Fetch(const GraphInfo & graph_info) const433 KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
434   auto iter = run_op_graphs_.find(graph_info);
435   if (iter == run_op_graphs_.end()) {
436     MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
437     return nullptr;
438   }
439   return iter->second;
440 }
441 
CreateDeviceAddress(const KernelGraphPtr & graph,const DeviceContext * device_context) const442 void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
443   CreateParameterDeviceAddress(device_context, graph);
444   CreateValueNodeDeviceAddress(device_context, graph);
445   CreateKernelOutputDeviceAddress(device_context, graph);
446   CreateKernelWorkspaceDeviceAddress(device_context, graph);
447   UpdateDeviceAddressForInplaceNode(graph);
448 }
449 
GetParamAndOutputIndex(const KernelGraphPtr & graph,const std::vector<TensorPtr> & inputs,VectorRef * const outputs,std::map<AnfNodePtr,size_t> * parameter_index,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)450 void GraphCompiler::GetParamAndOutputIndex(
451   const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
452   std::map<AnfNodePtr, size_t> *parameter_index,
453   std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
454   MS_EXCEPTION_IF_NULL(session_);
455   session_->GetParameterIndex(graph.get(), inputs, parameter_index);
456   session_->CreateOutputPlaceholder(graph, inputs, outputs, output_indexes);
457 }
458 
GetSingleOpInputTensors(const CNodePtr & kernel,const std::map<KernelWithIndex,TensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,InputTensorInfo * const input_tensor_info)459 void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
460                                             const std::map<KernelWithIndex, TensorPtr> &op_output,
461                                             const std::map<AnfNodePtr, size_t> &parameter_index,
462                                             const std::vector<TensorPtr> &graph_inputs,
463                                             InputTensorInfo *const input_tensor_info) {
464   MS_EXCEPTION_IF_NULL(session_);
465   session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_tensor_info);
466 }
467 
GetSingleOpInputTensorByIndex(const CNodePtr & kernel,const std::map<KernelWithIndex,TensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,InputTensorInfo * const input_tensor_info,size_t input_index)468 TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
469                                                        const std::map<KernelWithIndex, TensorPtr> &op_output,
470                                                        const std::map<AnfNodePtr, size_t> &parameter_index,
471                                                        const std::vector<TensorPtr> &graph_inputs,
472                                                        InputTensorInfo *const input_tensor_info, size_t input_index) {
473   MS_EXCEPTION_IF_NULL(session_);
474   return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_tensor_info,
475                                            input_index);
476 }
477 
GetSingleOpRunInfoAndGraphInfo(const CNodePtr & kernel,const std::vector<TensorPtr> & input_tensors,OpRunInfo * const run_info,GraphInfo * const graph_info)478 void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,
479                                                    OpRunInfo *const run_info, GraphInfo *const graph_info) {
480   MS_EXCEPTION_IF_NULL(session_);
481   MS_EXCEPTION_IF_NULL(graph_info);
482   session_->GetSingleOpRunInfo(kernel, run_info);
483   *graph_info = session_->GetSingleOpGraphInfo(kernel, input_tensors);
484 }
485 
CalculateRefCount(const KernelGraphPtr & graph,std::map<KernelWithIndex,size_t> * ref_count) const486 void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
487   MS_EXCEPTION_IF_NULL(session_);
488   session_->GetRefCount(graph.get(), ref_count);
489 }
490 
UpdateRefCount(const std::set<KernelWithIndex> & input_kernels_with_index,std::map<KernelWithIndex,size_t> * ref_count,std::map<KernelWithIndex,tensor::TensorPtr> * op_output_map) const491 void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
492                                    std::map<KernelWithIndex, size_t> *ref_count,
493                                    std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
494   MS_EXCEPTION_IF_NULL(session_);
495   session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
496 }
497 
RecoverGraphOutput(const AnfNodePtr & kernel,const VectorRef & op_outputs,const std::map<KernelWithIndex,size_t> & ref_count,std::map<KernelWithIndex,TensorPtr> * op_output_map,GraphOutputInfo * const graph_output_info) const498 void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
499                                        const std::map<KernelWithIndex, size_t> &ref_count,
500                                        std::map<KernelWithIndex, TensorPtr> *op_output_map,
501                                        GraphOutputInfo *const graph_output_info) const {
502   MS_EXCEPTION_IF_NULL(session_);
503   session_->HandleOpOutputs(kernel, op_outputs, ref_count, op_output_map, graph_output_info);
504 }
505 
AddGradAddrToBucket(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & grad_tensor)506 void GraphCompiler::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
507   MS_EXCEPTION_IF_NULL(session_);
508   session_->AddGradAddrToBucket(graph_id, grad_tensor);
509 }
510 
ClearAllBucket(const GraphId & graph_id)511 void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
512   MS_EXCEPTION_IF_NULL(session_);
513   session_->ClearAllBucket(graph_id);
514 }
515 
GetGraphOutputNodes(GraphId graph_id) const516 const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId graph_id) const {
517   const auto &iter = run_op_graph_output_nodes_.find(graph_id);
518   if (iter == run_op_graph_output_nodes_.end()) {
519     MS_LOG(EXCEPTION) << "Can not find output nodes for graph id: " << graph_id;
520   }
521   return iter->second;
522 }
523 
RegisterSummaryCallBackFunc(const CallBackFunc & callback) const524 void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
525   MS_EXCEPTION_IF_NULL(session_);
526 #ifndef ENABLE_SECURITY
527   session_->RegisterSummaryCallBackFunc(callback);
528 #endif
529 }
530 
Summary(const std::vector<KernelGraphPtr> & graphs) const531 void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
532   MS_EXCEPTION_IF_NULL(session_);
533   for (const auto &graph : graphs) {
534 #ifndef ENABLE_SECURITY
535     session_->Summary(graph.get());
536 #endif
537   }
538 }
539 
EraseSingleOpCache(const GraphInfo & graph_info,const GraphId & graph_id)540 void GraphCompiler::EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id) {
541   (void)run_op_graphs_.erase(graph_info);
542   (void)run_op_graph_output_nodes_.erase(graph_id);
543 }
544 }  // namespace runtime
545 }  // namespace mindspore
546