1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License"){}
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/framework/graph_compiler.h"
18 #include <numeric>
19 #include <map>
20 #include <utility>
21 #include "runtime/framework/graph_scheduler.h"
22 #include "runtime/device/device_address.h"
23 #include "common/trans.h"
24 #include "utils/convert_utils.h"
25 #include "ir/tensor.h"
26 #include "backend/optimizer/common/helper.h"
27 #include "base/base_ref_utils.h"
28 #include "debug/dump_proto.h"
29 #ifdef ENABLE_DEBUGGER
30 #include "debug/debugger/debugger.h"
31 #endif
32 #ifdef ENABLE_DUMP_IR
33 #include "debug/anf_ir_dump.h"
34 #include "debug/rdr/running_data_recorder.h"
35 #endif
36 #ifndef ENABLE_SECURITY
37 #include "debug/data_dump/dump_json_parser.h"
38 #endif
39
40 namespace mindspore {
41 namespace runtime {
42 namespace {
43 // Whether device address of anf node is valid and device address type
44 // is consistent with device type, for example, device address type
45 // DeviceAddressType::kGPU should be used on GPU device
NodeDeviceAddressExist(const DeviceContext * device_context,const AnfNodePtr & kernel,size_t index)46 bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
47 MS_EXCEPTION_IF_NULL(kernel);
48 MS_EXCEPTION_IF_NULL(device_context);
49 if (AnfAlgo::OutputAddrExist(kernel, index)) {
50 const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
51 MS_EXCEPTION_IF_NULL(address);
52 return address->DeviceType() == device_context->GetDeviceAddressType();
53 }
54 return false;
55 }
56
CreateParameterDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)57 void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
58 MS_EXCEPTION_IF_NULL(device_context);
59 MS_EXCEPTION_IF_NULL(graph);
60 std::vector<AnfNodePtr> graph_inputs = graph->inputs();
61 const std::vector<bool> &graph_valid_input = graph->valid_inputs();
62 (void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
63
64 // Anf nodes which need create device address.
65 std::vector<AnfNodePtr> nodes_list;
66 for (size_t i = 0; i < graph_inputs.size(); ++i) {
67 AnfNodePtr item = graph_inputs[i];
68 MS_EXCEPTION_IF_NULL(item);
69 if (i < graph_valid_input.size() && !graph_valid_input[i]) {
70 continue;
71 }
72
73 if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
74 std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item);
75 for (const auto &out : outs) {
76 MS_EXCEPTION_IF_NULL(out);
77 if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
78 continue;
79 }
80 nodes_list.push_back(out);
81 }
82 }
83 if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
84 continue;
85 }
86 nodes_list.push_back(item);
87 }
88
89 // Create device address for anf node in nodes_list
90 for (const auto &item : nodes_list) {
91 auto output_size = AnfAlgo::GetOutputTensorNum(item);
92 for (size_t index = 0; index < output_size; index++) {
93 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
94 if (output_type_id == kTypeUnknown) {
95 output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
96 }
97
98 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
99 auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
100 AnfAlgo::GetOutputFormat(item, index), output_type_id);
101 AnfAlgo::SetOutputAddr(device_address, index, item.get());
102 }
103 }
104 }
105
CreateDeviceAddressForTensorValue(const DeviceContext * device_context,const ValuePtr & node_value,size_t output_idx,const ValueNodePtr & value_node)106 void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
107 size_t output_idx, const ValueNodePtr &value_node) {
108 MS_EXCEPTION_IF_NULL(device_context);
109 MS_EXCEPTION_IF_NULL(node_value);
110 MS_EXCEPTION_IF_NULL(value_node);
111 const auto &ms_context = MsContext::GetInstance();
112 MS_EXCEPTION_IF_NULL(ms_context);
113 std::vector<TensorPtr> tensors;
114 TensorValueToTensor(node_value, &tensors);
115
116 for (const auto &tensor : tensors) {
117 if (tensor == nullptr) {
118 MS_LOG(WARNING) << "Tensor is null";
119 return;
120 }
121 auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
122 if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
123 bool is_pynative_infer = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
124 bool is_graph_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode);
125 if (is_graph_mode || is_pynative_infer) {
126 AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
127 value_node.get());
128 }
129 continue;
130 }
131
132 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
133 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
134 if (output_type_id == kTypeUnknown) {
135 output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
136 }
137 std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
138
139 device::DeviceAddressPtr address =
140 device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
141 MS_EXCEPTION_IF_NULL(address);
142 AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
143 }
144 }
145
CreateValueNodeDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)146 void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
147 MS_EXCEPTION_IF_NULL(device_context);
148 MS_EXCEPTION_IF_NULL(graph);
149 for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
150 MS_EXCEPTION_IF_NULL(value_node);
151 if (NodeDeviceAddressExist(device_context, value_node, 0)) {
152 continue;
153 }
154
155 const auto &node_value = value_node->value();
156 MS_EXCEPTION_IF_NULL(node_value);
157 if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
158 CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
159 } else if (node_value->isa<StringImm>()) {
160 auto value = GetValue<std::string>(node_value);
161 size_t tensor_size = value.size();
162 auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
163 MS_EXCEPTION_IF_NULL(address);
164
165 AnfAlgo::SetOutputAddr(address, 0, value_node.get());
166 }
167 }
168 }
169
CreateKernelOutputDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)170 void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
171 MS_EXCEPTION_IF_NULL(device_context);
172 MS_EXCEPTION_IF_NULL(graph);
173 const std::vector<CNodePtr> &kernels = graph->execution_order();
174 for (const auto &kernel : kernels) {
175 MS_EXCEPTION_IF_NULL(kernel);
176 if (AnfAlgo::IsControlOpExecInBackend(kernel)) {
177 continue;
178 }
179 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
180 MS_EXCEPTION_IF_NULL(kernel_mod);
181 auto output_sizes = kernel_mod->GetOutputSizeList();
182 for (size_t i = 0; i < output_sizes.size(); ++i) {
183 if (AnfAlgo::OutputAddrExist(kernel, i)) {
184 continue;
185 }
186
187 std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
188 auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
189 auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
190 AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
191 }
192 }
193 }
194
CreateKernelWorkspaceDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)195 void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
196 MS_EXCEPTION_IF_NULL(device_context);
197 MS_EXCEPTION_IF_NULL(graph);
198 const std::vector<CNodePtr> &kernels = graph->execution_order();
199 for (const auto &kernel : kernels) {
200 MS_EXCEPTION_IF_NULL(kernel);
201 if (AnfAlgo::IsControlOpExecInBackend(kernel)) {
202 continue;
203 }
204 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
205 MS_EXCEPTION_IF_NULL(kernel_mod);
206 auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
207 for (size_t i = 0; i < workspace_sizes.size(); ++i) {
208 auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
209 AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
210 }
211 }
212 }
213
UpdateDeviceAddressForInplaceNode(const KernelGraphPtr & graph)214 void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
215 MS_EXCEPTION_IF_NULL(graph);
216 // Collect the inplace groups.
217 std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
218 const std::vector<CNodePtr> &kernels = graph->execution_order();
219 for (const auto &kernel : kernels) {
220 if (!AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
221 continue;
222 }
223 auto primitive = AnfAlgo::GetCNodePrimitive(kernel);
224 MS_EXCEPTION_IF_NULL(primitive);
225 auto inplace_group_attr = primitive->GetAttr("inplace_group");
226 MS_EXCEPTION_IF_NULL(inplace_group_attr);
227 auto group_id = GetValue<uint32_t>(inplace_group_attr);
228 (void)inplace_groups[group_id].emplace_back(kernel);
229 }
230
231 const size_t kMinInplaceGroupSize = 2;
232 for (const auto &inplace_group : inplace_groups) {
233 auto &group_nodes = inplace_group.second;
234 if (group_nodes.size() < kMinInplaceGroupSize) {
235 continue;
236 }
237 // Get the device address of the first node in the inplace group.
238 auto node_primitive = AnfAlgo::GetCNodePrimitive(group_nodes[0]);
239 MS_EXCEPTION_IF_NULL(node_primitive);
240 auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
241 auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
242 MS_EXCEPTION_IF_NULL(device_address);
243
244 // Update the device address of other nodes using device address of the first node in the inplace group.
245 for (size_t i = 1; i < group_nodes.size(); ++i) {
246 auto &group_node = group_nodes[i];
247 auto prim = AnfAlgo::GetCNodePrimitive(group_node);
248 MS_EXCEPTION_IF_NULL(prim);
249 auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
250 AnfAlgo::SetOutputAddr(device_address, index, group_node.get());
251 // Update the reference count of device address.
252 device_address->IncreaseOriginalRefCount();
253 device_address->ResetRefCount();
254 }
255 }
256 }
257
SetSummaryNodesRefCount(const KernelGraph * graph)258 void SetSummaryNodesRefCount(const KernelGraph *graph) {
259 MS_EXCEPTION_IF_NULL(graph);
260 if (!graph->summary_node_exist()) {
261 return;
262 }
263
264 const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes = graph->summary_nodes();
265 if (summary_nodes.empty()) {
266 return;
267 }
268
269 for (const auto &item : summary_nodes) {
270 const AnfNodePtr &node = item.second.first;
271 size_t index = IntToSize(item.second.second);
272 auto device_address = AnfAlgo::GetMutableOutputAddr(node, index, false);
273 MS_EXCEPTION_IF_NULL(device_address);
274 device_address->set_original_ref_count(SIZE_MAX);
275 device_address->ResetRefCount();
276 }
277 }
278
UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> & output_with_index)279 void UpdateRefCountForGraphOutput(const std::vector<KernelWithIndex> &output_with_index) {
280 for (const auto &item_with_index : output_with_index) {
281 if (!AnfAlgo::OutputAddrExist(item_with_index.first, item_with_index.second, false)) {
282 continue;
283 }
284 auto device_address = AnfAlgo::GetMutableOutputAddr(item_with_index.first, item_with_index.second, false);
285 MS_EXCEPTION_IF_NULL(device_address);
286 device_address->set_original_ref_count(SIZE_MAX);
287 device_address->ResetRefCount();
288 }
289 }
290 } // namespace
291
~GraphCompilerInfo()292 GraphCompilerInfo::~GraphCompilerInfo() { GraphScheduler::GetInstance().Clear(name_, graphs_); }
293
CompileGraph(const AnfNodePtrList & nodes,const AnfNodePtrList & outputs,const DeviceContext * device_context)294 GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs,
295 const DeviceContext *device_context) {
296 MS_EXCEPTION_IF_NULL(session_);
297 // Generate kernel graph.
298 KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs);
299 MS_EXCEPTION_IF_NULL(graph);
300
301 // Cache the backend graph output nodes to front nodes with output index.
302 for (auto &output : outputs) {
303 auto backend_node = graph->GetBackendAnfByFrontAnf(output);
304 if (backend_node != nullptr) {
305 graph->CacheGraphOutputToFrontNodeWithIndex(backend_node, output);
306 }
307 }
308
309 return CompileGraphImpl(graph, device_context);
310 }
311
CompileGraphImpl(const KernelGraphPtr & graph,const DeviceContext * device_context) const312 GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
313 MS_EXCEPTION_IF_NULL(graph);
314 MS_EXCEPTION_IF_NULL(device_context);
315 const auto &ms_context = MsContext::GetInstance();
316 MS_EXCEPTION_IF_NULL(ms_context);
317 #ifdef ENABLE_DUMP_IR
318 bool save_graphs = ms_context->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
319 // Dump .pb graph before graph optimization.
320 if (save_graphs) {
321 DumpIRProto(graph, "before_opt_" + std::to_string(graph->graph_id()));
322 }
323 #endif
324
325 MS_LOG(INFO) << "Get graph outputs before optimizer, graph id: " << graph->graph_id();
326 auto outputs_before_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
327
328 // Execute optimization pass.
329 device_context->OptimizeGraph(graph);
330
331 // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
332 // 'KernelMod' is real executive object of kernel.
333 device_context->CreateKernel(graph->execution_order());
334
335 // Adjust kernel graph before run graph.
336 device_context->PreprocessBeforeRunGraph(graph);
337
338 MS_LOG(INFO) << "Get graph outputs after optimizer, graph id: " << graph->graph_id();
339 auto outputs_after_optimizer = AnfAlgo::GetAllOutputWithIndex(graph->output());
340 // Update the output map of kernel graph by modified output nodes.
341 graph->UpdateGraphOutputMap(outputs_before_optimizer, outputs_after_optimizer);
342
343 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
344 // Create device address for all anf nodes of graph.
345 CreateDeviceAddress(graph, device_context);
346 }
347
348 graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
349
350 MS_EXCEPTION_IF_NULL(session_);
351 session_->InitAllBucket(graph, device_context);
352 #ifndef ENABLE_SECURITY
353 session_->SetSummaryNodes(graph.get());
354 #endif
355 SetSummaryNodesRefCount(graph.get());
356 #ifdef ENABLE_DUMP_IR
357 // Dump .pb graph after graph optimization.
358 if (save_graphs) {
359 DumpIRProto(graph, "after_opt_" + std::to_string(graph->graph_id()));
360 }
361 #endif
362
363 #ifdef ENABLE_DEBUGGER
364 auto debugger = Debugger::GetInstance();
365 debugger->DumpInGraphCompiler(graph);
366 if (debugger && debugger->DebuggerBackendEnabled()) {
367 debugger->LoadGraphs(graph);
368 }
369 #endif
370
371 #ifdef ENABLE_DUMP_IR
372 std::string name = "graph_build";
373 DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
374 (void)mindspore::RDR::RecordAnfGraph(SubModuleId::SM_SESSION, name, graph, dump_params, ".ir,.pb");
375 auto &kernels = graph->execution_order();
376 std::string exec_order_name = "graph_exec_order." + std::to_string(graph->graph_id());
377 (void)mindspore::RDR::RecordGraphExecOrder(SubModuleId::SM_SESSION, exec_order_name, kernels);
378 #endif
379
380 session_->DumpGraph(graph);
381 return graph->graph_id();
382 }
383
CompileGraph(const session::OpRunInfo & op_run_info,const GraphInfo & graph_info,const std::vector<int64_t> * tensors_mask,std::vector<TensorPtr> * const input_tensors,bool * single_op_cache_hit,const DeviceContext * device_context)384 GraphId GraphCompiler::CompileGraph(const session::OpRunInfo &op_run_info, const GraphInfo &graph_info,
385 const std::vector<int64_t> *tensors_mask,
386 std::vector<TensorPtr> *const input_tensors, bool *single_op_cache_hit,
387 const DeviceContext *device_context) {
388 // Check if the graph cache exists.
389 auto iter = run_op_graphs_.find(graph_info);
390 if (iter != run_op_graphs_.end()) {
391 const auto &graph = iter->second;
392 MS_EXCEPTION_IF_NULL(graph);
393 *single_op_cache_hit = true;
394 return graph->graph_id();
395 }
396 *single_op_cache_hit = false;
397 // Generate kernel graph.
398 MS_EXCEPTION_IF_NULL(session_);
399 KernelGraphPtr graph = session_->ConstructSingleOpGraph(op_run_info, *input_tensors, *tensors_mask);
400 MS_EXCEPTION_IF_NULL(graph);
401
402 MS_EXCEPTION_IF_NULL(device_context);
403 device_context->OptimizeSingleOpGraph(graph);
404
405 // Generate 'KernelMod' for kernel in graph.
406 device_context->CreateKernel(graph->execution_order());
407
408 device_context->PreprocessBeforeRunSingleOpGraph(graph);
409
410 // Create device address for all anf nodes of graph.
411 CreateDeviceAddress(graph, device_context);
412
413 graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
414 run_op_graphs_[graph_info] = graph;
415
416 auto output_nodes = graph->outputs();
417 auto &outputs_with_index = run_op_graph_output_nodes_[graph->graph_id()];
418 for (auto &node : output_nodes) {
419 MS_EXCEPTION_IF_NULL(node);
420 (void)outputs_with_index.emplace_back(AnfAlgo::VisitKernelWithReturnType(node, 0, false));
421 }
422
423 UpdateRefCountForGraphOutput(outputs_with_index);
424
425 return graph->graph_id();
426 }
427
Fetch(GraphId graph_id) const428 KernelGraphPtr GraphCompiler::Fetch(GraphId graph_id) const {
429 MS_EXCEPTION_IF_NULL(session_);
430 return session_->GetGraph(graph_id);
431 }
432
Fetch(const GraphInfo & graph_info) const433 KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
434 auto iter = run_op_graphs_.find(graph_info);
435 if (iter == run_op_graphs_.end()) {
436 MS_LOG(ERROR) << "Can't find graph for: " << graph_info;
437 return nullptr;
438 }
439 return iter->second;
440 }
441
CreateDeviceAddress(const KernelGraphPtr & graph,const DeviceContext * device_context) const442 void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) const {
443 CreateParameterDeviceAddress(device_context, graph);
444 CreateValueNodeDeviceAddress(device_context, graph);
445 CreateKernelOutputDeviceAddress(device_context, graph);
446 CreateKernelWorkspaceDeviceAddress(device_context, graph);
447 UpdateDeviceAddressForInplaceNode(graph);
448 }
449
GetParamAndOutputIndex(const KernelGraphPtr & graph,const std::vector<TensorPtr> & inputs,VectorRef * const outputs,std::map<AnfNodePtr,size_t> * parameter_index,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)450 void GraphCompiler::GetParamAndOutputIndex(
451 const KernelGraphPtr &graph, const std::vector<TensorPtr> &inputs, VectorRef *const outputs,
452 std::map<AnfNodePtr, size_t> *parameter_index,
453 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
454 MS_EXCEPTION_IF_NULL(session_);
455 session_->GetParameterIndex(graph.get(), inputs, parameter_index);
456 session_->CreateOutputPlaceholder(graph, inputs, outputs, output_indexes);
457 }
458
GetSingleOpInputTensors(const CNodePtr & kernel,const std::map<KernelWithIndex,TensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,InputTensorInfo * const input_tensor_info)459 void GraphCompiler::GetSingleOpInputTensors(const CNodePtr &kernel,
460 const std::map<KernelWithIndex, TensorPtr> &op_output,
461 const std::map<AnfNodePtr, size_t> ¶meter_index,
462 const std::vector<TensorPtr> &graph_inputs,
463 InputTensorInfo *const input_tensor_info) {
464 MS_EXCEPTION_IF_NULL(session_);
465 session_->GetOpInputTensors(kernel, op_output, parameter_index, graph_inputs, input_tensor_info);
466 }
467
GetSingleOpInputTensorByIndex(const CNodePtr & kernel,const std::map<KernelWithIndex,TensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,InputTensorInfo * const input_tensor_info,size_t input_index)468 TensorPtr GraphCompiler::GetSingleOpInputTensorByIndex(const CNodePtr &kernel,
469 const std::map<KernelWithIndex, TensorPtr> &op_output,
470 const std::map<AnfNodePtr, size_t> ¶meter_index,
471 const std::vector<TensorPtr> &graph_inputs,
472 InputTensorInfo *const input_tensor_info, size_t input_index) {
473 MS_EXCEPTION_IF_NULL(session_);
474 return session_->GetOpInputTensorByIndex(kernel, op_output, parameter_index, graph_inputs, input_tensor_info,
475 input_index);
476 }
477
GetSingleOpRunInfoAndGraphInfo(const CNodePtr & kernel,const std::vector<TensorPtr> & input_tensors,OpRunInfo * const run_info,GraphInfo * const graph_info)478 void GraphCompiler::GetSingleOpRunInfoAndGraphInfo(const CNodePtr &kernel, const std::vector<TensorPtr> &input_tensors,
479 OpRunInfo *const run_info, GraphInfo *const graph_info) {
480 MS_EXCEPTION_IF_NULL(session_);
481 MS_EXCEPTION_IF_NULL(graph_info);
482 session_->GetSingleOpRunInfo(kernel, run_info);
483 *graph_info = session_->GetSingleOpGraphInfo(kernel, input_tensors);
484 }
485
CalculateRefCount(const KernelGraphPtr & graph,std::map<KernelWithIndex,size_t> * ref_count) const486 void GraphCompiler::CalculateRefCount(const KernelGraphPtr &graph, std::map<KernelWithIndex, size_t> *ref_count) const {
487 MS_EXCEPTION_IF_NULL(session_);
488 session_->GetRefCount(graph.get(), ref_count);
489 }
490
UpdateRefCount(const std::set<KernelWithIndex> & input_kernels_with_index,std::map<KernelWithIndex,size_t> * ref_count,std::map<KernelWithIndex,tensor::TensorPtr> * op_output_map) const491 void GraphCompiler::UpdateRefCount(const std::set<KernelWithIndex> &input_kernels_with_index,
492 std::map<KernelWithIndex, size_t> *ref_count,
493 std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) const {
494 MS_EXCEPTION_IF_NULL(session_);
495 session_->HandleOpInputs(input_kernels_with_index, ref_count, op_output_map);
496 }
497
RecoverGraphOutput(const AnfNodePtr & kernel,const VectorRef & op_outputs,const std::map<KernelWithIndex,size_t> & ref_count,std::map<KernelWithIndex,TensorPtr> * op_output_map,GraphOutputInfo * const graph_output_info) const498 void GraphCompiler::RecoverGraphOutput(const AnfNodePtr &kernel, const VectorRef &op_outputs,
499 const std::map<KernelWithIndex, size_t> &ref_count,
500 std::map<KernelWithIndex, TensorPtr> *op_output_map,
501 GraphOutputInfo *const graph_output_info) const {
502 MS_EXCEPTION_IF_NULL(session_);
503 session_->HandleOpOutputs(kernel, op_outputs, ref_count, op_output_map, graph_output_info);
504 }
505
AddGradAddrToBucket(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & grad_tensor)506 void GraphCompiler::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
507 MS_EXCEPTION_IF_NULL(session_);
508 session_->AddGradAddrToBucket(graph_id, grad_tensor);
509 }
510
ClearAllBucket(const GraphId & graph_id)511 void GraphCompiler::ClearAllBucket(const GraphId &graph_id) {
512 MS_EXCEPTION_IF_NULL(session_);
513 session_->ClearAllBucket(graph_id);
514 }
515
GetGraphOutputNodes(GraphId graph_id) const516 const std::vector<KernelWithIndex> &GraphCompiler::GetGraphOutputNodes(GraphId graph_id) const {
517 const auto &iter = run_op_graph_output_nodes_.find(graph_id);
518 if (iter == run_op_graph_output_nodes_.end()) {
519 MS_LOG(EXCEPTION) << "Can not find output nodes for graph id: " << graph_id;
520 }
521 return iter->second;
522 }
523
RegisterSummaryCallBackFunc(const CallBackFunc & callback) const524 void GraphCompiler::RegisterSummaryCallBackFunc(const CallBackFunc &callback) const {
525 MS_EXCEPTION_IF_NULL(session_);
526 #ifndef ENABLE_SECURITY
527 session_->RegisterSummaryCallBackFunc(callback);
528 #endif
529 }
530
Summary(const std::vector<KernelGraphPtr> & graphs) const531 void GraphCompiler::Summary(const std::vector<KernelGraphPtr> &graphs) const {
532 MS_EXCEPTION_IF_NULL(session_);
533 for (const auto &graph : graphs) {
534 #ifndef ENABLE_SECURITY
535 session_->Summary(graph.get());
536 #endif
537 }
538 }
539
EraseSingleOpCache(const GraphInfo & graph_info,const GraphId & graph_id)540 void GraphCompiler::EraseSingleOpCache(const GraphInfo &graph_info, const GraphId &graph_id) {
541 (void)run_op_graphs_.erase(graph_info);
542 (void)run_op_graph_output_nodes_.erase(graph_id);
543 }
544 } // namespace runtime
545 } // namespace mindspore
546