• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/graph_compiler/backend_base.h"
17 
18 #include <algorithm>
19 #include <map>
20 #include <vector>
21 #include <queue>
22 #if defined(_WIN32) || defined(_WIN64)
23 #include <windows.h>
24 #endif
25 
26 #include "pipeline/jit/ps/parse/data_converter.h"
27 #include "backend/graph_compiler/transform.h"
28 #include "backend/common/pass/erase_invalid_micro_depend.h"
29 #include "backend/common/pass/erase_not_cut_attr.h"
30 #include "backend/common/pass/switch_not_cut.h"
31 #include "include/backend/distributed/recovery/recovery_context.h"
32 #include "include/common/utils/callbacks.h"
33 #include "include/common/utils/scoped_long_running.h"
34 #include "include/common/debug/anf_ir_dump.h"
35 #include "include/backend/mem_reuse/mem_tracker.h"
36 #include "ir/anf.h"
37 #include "ops/framework_ops.h"
38 #include "ops/sequence_ops.h"
39 #include "ops/sparse_tensor_ops.h"
40 #include "ops/nn_ops.h"
41 #include "runtime/device/device_address_utils.h"
42 #include "runtime/device/multi_stream_controller.h"
43 #include "runtime/graph_scheduler/graph_compiler.h"
44 #include "runtime/pynative/graph_adapter.h"
45 #include "pybind_api/gil_scoped_long_running.h"
46 #include "utils/log_adapter.h"
47 #ifdef ENABLE_DEBUGGER
48 #include "include/backend/debug/debugger/debugger.h"
49 #endif
50 #include "include/backend/debug/profiler/profiling.h"
51 #if defined(__linux__) && defined(WITH_BACKEND)
52 #include "include/backend/distributed/ps/ps_context.h"
53 #endif
54 #include "backend/common/graph_kernel/graph_kernel_flags.h"
55 #include "include/common/symbol_engine/symbol_engine_impl.h"
56 
57 namespace mindspore {
58 namespace compile {
GetCond(const BaseRef & c,bool * value)59 bool Backend::GetCond(const BaseRef &c, bool *value) {
60   mindspore::ScopedLongRunning long_running;
61   return BaseRefToBool(c, value);
62 }
GetIndex(const BaseRef & c,int64_t * value)63 bool Backend::GetIndex(const BaseRef &c, int64_t *value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
64 
Backend(const std::string & name)65 Backend::Backend(const std::string &name) : name_(name), is_multi_graph_sink_(false) {
66   MS_LOG(DEBUG) << "Select backend:" << name;
67   convert_fn_ = MsVmConvert;
68 }
69 
set_pydata_converter(const pyexecute::PyDataConverter & pydata_converter)70 void set_pydata_converter(const pyexecute::PyDataConverter &pydata_converter) {
71   pyexecute::set_pydata_converter(pydata_converter);
72 }
73 
74 namespace {
75 // Insert the front_node related tensor in the input_tensor.
PushTensor(const VectorRef & args,const std::vector<AnfNodePtr> & parameters,const AnfNodePtr & front_node,std::vector<tensor::TensorPtr> * input_tensors)76 void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
77                 std::vector<tensor::TensorPtr> *input_tensors) {
78   MS_EXCEPTION_IF_NULL(input_tensors);
79   const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
80   if (iter == parameters.end()) {
81     (void)((*input_tensors).emplace_back(nullptr));
82     return;
83   }
84   auto position = iter - parameters.begin();
85 
86   std::vector<tensor::TensorPtr> flatten_values;
87   AnfAlgo::FlattenInputArg(args[position], front_node, &flatten_values);
88   (void)std::copy(flatten_values.begin(), flatten_values.end(), std::back_inserter(*input_tensors));
89 }
90 
PushTupleTensor(const VectorRef & args,const std::vector<AnfNodePtr> & parameters,const AnfNodePtr & front_node,size_t index,std::map<size_t,std::vector<tensor::TensorPtr>> * flatten_values,std::vector<tensor::TensorPtr> * input_tensors)91 void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
92                      size_t index, std::map<size_t, std::vector<tensor::TensorPtr>> *flatten_values,
93                      std::vector<tensor::TensorPtr> *input_tensors) {
94   MS_EXCEPTION_IF_NULL(input_tensors);
95   MS_EXCEPTION_IF_NULL(flatten_values);
96 
97   const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
98   const size_t position = iter - parameters.begin();
99   // If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
100   // and there is no need to input a tensor.
101   if (position >= args.size()) {
102     MS_LOG(DEBUG) << "Position out of args range, position value is " << position << " and args size is " << args.size()
103                   << ".";
104     (void)input_tensors->emplace_back(nullptr);
105     return;
106   }
107 
108   // Avoid repeating flatten tuple for each args position.
109   auto &flatten_value = (*flatten_values)[position];
110   if (flatten_value.empty()) {
111     AnfAlgo::FlattenInputArg(args[position], front_node, &flatten_value);
112   }
113 
114   if (index >= flatten_value.size()) {
115     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Index out of flatten_value range, index value is "
116                                << index << " and flatten_value size is " << flatten_value.size() << ".";
117   }
118   auto tensor_input = flatten_value[index];
119   MS_EXCEPTION_IF_NULL(tensor_input);
120   input_tensors->push_back(tensor_input);
121 }
122 }  // namespace
123 
GetTensorFromForwardOutputParameter(const AnfNodePtr & input_node,std::vector<tensor::TensorPtr> * input_tensors)124 bool GetTensorFromForwardOutputParameter(const AnfNodePtr &input_node, std::vector<tensor::TensorPtr> *input_tensors) {
125   MS_EXCEPTION_IF_NULL(input_node);
126   // if input_node if from ValueNode,
127   // push Tensor of ValueNode to input_tensors.
128   if (input_node->isa<Parameter>()) {
129     auto parameter = input_node->cast<ParameterPtr>();
130     MS_EXCEPTION_IF_NULL(parameter);
131     if (parameter->has_user_data(kForwardOutput)) {
132       auto value = parameter->user_data<Value>(kForwardOutput);
133       auto tensor = value->cast<tensor::TensorPtr>();
134       MS_EXCEPTION_IF_NULL(tensor);
135       (void)input_tensors->emplace_back(tensor);
136       MS_LOG(DEBUG) << "Get forward output tensor " << tensor->ToString()
137                     << " for graph input, address:" << tensor->device_address().get();
138       return true;
139     }
140   }
141   return false;
142 }
143 
GetRunGraphInputs(const GraphCompilerInfo & graph_compiler_info,const VectorRef & args)144 std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info,
145                                                               const VectorRef &args) {
146   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kInputProcess,
147                                      graph_compiler_info.name_);
148   const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
149   std::vector<std::vector<tensor::TensorPtr>> input_tensor_lists;
150   std::map<size_t, std::vector<tensor::TensorPtr>> flatten_values;
151 
152   for (const auto &kernel_graph : graph_compiler_info.graphs_) {
153     std::vector<tensor::TensorPtr> input_tensors;
154     MS_EXCEPTION_IF_NULL(kernel_graph);
155     bool is_pynative_bprop_kernel_graph = kernel_graph->has_flag(kFlagIsPyNativeBpropKernelGraph);
156     for (const auto &input_node : kernel_graph->input_nodes()) {
157       if (is_pynative_bprop_kernel_graph && GetTensorFromForwardOutputParameter(input_node, &input_tensors)) {
158         continue;
159       }
160 
161       auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
162       if (element_pair.first) {
163         PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &flatten_values,
164                         &input_tensors);
165       } else {
166         const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
167         // Use kernel graph in compile
168         if (front_node == nullptr && is_pynative_bprop_kernel_graph) {
169           PushTensor(args, origin_parameters, input_node, &input_tensors);
170           continue;
171         }
172         PushTensor(args, origin_parameters, front_node, &input_tensors);
173       }
174     }
175     (void)input_tensor_lists.emplace_back(input_tensors);
176   }
177 
178   // Input tensors of the control node.
179   std::vector<tensor::TensorPtr> input_tensors;
180   MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
181   // Get inputs of control node which come from the host actor.
182   const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
183   for (const auto &parameter_with_index : control_node_parameters) {
184     const auto &parameter = parameter_with_index.first;
185     MS_EXCEPTION_IF_NULL(parameter);
186     const auto &abs = parameter->abstract();
187     MS_EXCEPTION_IF_NULL(abs);
188     if (abs->isa<abstract::AbstractSequence>() && (!common::AnfAlgo::IsDynamicSequence(parameter))) {
189       MS_LOG(DEBUG) << "Fetch input tensor for tuple parameter:" << parameter->DebugString() << " in control flow.";
190       PushTupleTensor(args, origin_parameters, parameter, parameter_with_index.second, &flatten_values, &input_tensors);
191     } else {
192       PushTensor(args, origin_parameters, parameter, &input_tensors);
193     }
194   }
195   (void)input_tensor_lists.emplace_back(input_tensors);
196 
197   return input_tensor_lists;
198 }
199 
FetchOriginOutputOrder(const AnfNodePtr & node)200 runtime::KernelMapPosition FetchOriginOutputOrder(const AnfNodePtr &node) {
201   MS_EXCEPTION_IF_NULL(node);
202   runtime::KernelMapPosition outputs_order;
203   const auto &root_output = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
204   size_t position = 0;
205   auto outputs = common::AnfAlgo::GetAllOutputWithIndex(root_output);
206   for (const auto &output : outputs) {
207     if (outputs_order.count(output) == 0) {
208       outputs_order[output] = {position++};
209     } else {
210       (void)outputs_order[output].emplace_back(position++);
211     }
212   }
213   return outputs_order;
214 }
215 
MindRTBackendBase(const std::string & backend_name,const std::string & device_name,uint32_t device_id)216 MindRTBackendBase::MindRTBackendBase(const std::string &backend_name, const std::string &device_name,
217                                      uint32_t device_id)
218     : Backend(backend_name), device_name_(device_name), device_id_(device_id) {
219   root_graph_ = nullptr;
220   auto ms_context = MsContext::GetInstance();
221   const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
222   auto &cut_list = pynative_mode ? GetControlOps() : GetMsNonlinearOps();
223 
224   graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
225   graph_compiler_ = std::make_shared<GraphCompiler>();
226 
227   const auto &device_context =
228     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
229   MS_EXCEPTION_IF_NULL(device_context);
230   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventDeviceInit, kStageDeviceInit, 1, 0, 0);
231   device_context->Initialize();
232   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventDeviceInit, kStageDeviceInit, 1, 0, 1);
233   device_id_ = device_context->device_context_key().device_id_;
234 #ifdef ENABLE_DEBUGGER
235   SetDebuggerInit();
236 #endif
237   runtime::GraphScheduler::GetInstance().Initialize();
238 }
239 
ProcessNotSupportCnode(const FuncGraphPtr & func_graph,const mindspore::device::DeviceType & old_target,const mindspore::device::DeviceType & new_target) const240 void MindRTBackendBase::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
241                                                const mindspore::device::DeviceType &old_target,
242                                                const mindspore::device::DeviceType &new_target) const {
243   const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
244   for (const auto &node : all_nodes) {
245     MS_EXCEPTION_IF_NULL(node);
246     if (!node->isa<CNode>()) {
247       continue;
248     }
249 
250     auto cnode = node->cast<CNodePtr>();
251     if (!common::AnfAlgo::HasNodeAttr(mindspore::kAttrNotSupportOpForDevice, cnode)) {
252       continue;
253     }
254 
255     auto not_support_device = common::AnfAlgo::GetNodeAttr<std::string>(node, mindspore::kAttrNotSupportOpForDevice);
256     if (device::GetDeviceTypeByName(not_support_device) != old_target) {
257       continue;
258     }
259 
260     common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(device::GetDeviceNameByType(new_target)), node);
261   }
262 }
263 
264 namespace {
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)265 int64_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
266   MS_EXCEPTION_IF_NULL(tuple_get_item);
267   if (tuple_get_item->size() != kTupleGetItemInputSize) {
268     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The node tuple_get_item must have 2 inputs!";
269   }
270   auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
271   MS_EXCEPTION_IF_NULL(output_index_value_node);
272   auto value_node = output_index_value_node->cast<ValueNodePtr>();
273   MS_EXCEPTION_IF_NULL(value_node);
274   auto value = value_node->value();
275   MS_EXCEPTION_IF_NULL(value);
276   auto idx = value->isa<Int64Imm>() ? GetValue<int64_t>(value) : GetValue<int>(value);
277   return idx;
278 }
279 
VisitRealNodeWithNestLevel(const AnfNodePtr & anf_node,size_t index,size_t * nest_level)280 KernelWithIndex VisitRealNodeWithNestLevel(const AnfNodePtr &anf_node, size_t index, size_t *nest_level) {
281   MS_EXCEPTION_IF_NULL(anf_node);
282   if (!anf_node->isa<CNode>()) {
283     return {anf_node, index};
284   }
285   auto cnode = anf_node->cast<CNodePtr>();
286   if (common::AnfAlgo::GetCNodeName(cnode) == mindspore::kTupleGetItemOpName) {
287     (*nest_level)++;
288     auto real_node_with_index = VisitRealNodeWithNestLevel(common::AnfAlgo::GetTupleGetItemRealInput(cnode),
289                                                            common::AnfAlgo::GetTupleGetItemOutIndex(cnode), nest_level);
290     auto real_node = real_node_with_index.first;
291     auto real_index = real_node_with_index.second;
292     MS_EXCEPTION_IF_NULL(real_node);
293     if (real_node->isa<CNode>() && common::AnfAlgo::GetCNodeName(real_node) == mindspore::kMakeTupleOpName) {
294       (*nest_level)--;
295       auto make_tuple = real_node->cast<CNodePtr>();
296       return VisitRealNodeWithNestLevel(make_tuple->input(real_index + 1), index, nest_level);
297     }
298     return real_node_with_index;
299   }
300   return common::AnfAlgo::VisitKernelWithReturnType(anf_node, index, false,
301                                                     {prim::kPrimMakeTuple, prim::kPrimTupleGetItem});
302 }
303 
NeedConvertToRealTupleGetItem(const CNodePtr & cnode)304 bool NeedConvertToRealTupleGetItem(const CNodePtr &cnode) {
305   if (cnode->size() != kTupleGetItemInputSize) {
306     return false;
307   }
308   if (!cnode->input(kInputNodeOutputIndexInTupleGetItem)->isa<ValueNode>() || GetTupleGetItemOutIndex(cnode) < 0) {
309     return true;
310   }
311   size_t nest_level = 0;
312   const size_t nest_limit = 1;
313   auto real_node = VisitRealNodeWithNestLevel(cnode, 0, &nest_level);
314   if (!common::AnfAlgo::IsCallNode(real_node.first) && AnfUtils::IsRealCNodeKernel(real_node.first) &&
315       nest_level > nest_limit) {
316     return true;
317   }
318   return false;
319 }
320 
321 // If it is windows OS, create a child thread with 8M stack space to call `common::AnfAlgo::GetRealPrevNodesOutput`.
322 #if defined(_WIN32) || defined(_WIN64)
323 typedef struct {
324   const AnfNodePtr *anf_node_;
325   size_t input_idx_;
326   std::vector<KernelWithIndex> *nodes_ptr_;
327 } WinThreadParam;
328 
WinThreadFunction(PVOID para)329 DWORD WINAPI WinThreadFunction(PVOID para) {
330   auto p = static_cast<WinThreadParam *>(para);
331   MS_EXCEPTION_IF_NULL(p->anf_node_);
332   MS_EXCEPTION_IF_NULL(p->nodes_ptr_);
333   const AnfNodePtr &anf_node = *(p->anf_node_);
334   std::vector<KernelWithIndex> *nodes_ptr = p->nodes_ptr_;
335   auto inputs = common::AnfAlgo::GetRealPrevNodesOutput(anf_node, p->input_idx_);
336   nodes_ptr->insert(nodes_ptr->end(), inputs.begin(), inputs.end());
337   return 0;
338 }
339 #endif
340 
CheckNodeValid(const AnfNodePtr & node)341 void CheckNodeValid(const AnfNodePtr &node) {
342   MS_EXCEPTION_IF_NULL(node);
343   // Check the joined any abstract.
344   const auto &node_abs = node->abstract();
345   if (node_abs != nullptr && node_abs->isa<abstract::AbstractJoinedAny>()) {
346     auto abs_joined_any = node_abs->cast<abstract::AbstractJoinedAnyPtr>();
347     if (abs_joined_any != nullptr) {
348       abs_joined_any->ThrowException();
349     }
350   }
351 }
352 
AddKernelGraphCompileInfo(const KernelGraphPtr & kernel_graph,const session::SessionPtr & session_ptr)353 bool AddKernelGraphCompileInfo(const KernelGraphPtr &kernel_graph, const session::SessionPtr &session_ptr) {
354   const auto &parameters = kernel_graph->parameters();
355   // Just have a return node or empty graph
356   if ((kernel_graph->nodes().size() - parameters.size()) < kIndex2) {
357     return false;
358   }
359   // Update parameters info
360   const auto &manager = kernel_graph->manager();
361   MS_EXCEPTION_IF_NULL(manager);
362   const auto &users = manager->node_users();
363   for (const auto &p : parameters) {
364     // Exclude parameter not used in graph, such as constant input
365     if (users.find(p) != users.end()) {
366       (void)session_ptr->CreateNewParameterFromParameter(p, kernel_graph.get());
367       kernel_graph->SetKernelInfoForNode(p);
368     }
369   }
370 
371   // Run by single op will create kernel info in single op graph, so no need do this here;
372   // But, run by Actor need kernel info, so do this here
373   bool run_by_single_op = kernel_graph->has_flag(kFlagEnableRunGraphBySingleOp);
374   if (!run_by_single_op) {
375     const auto &nodes = TopoSort(kernel_graph->get_return());
376     for (const auto &node : nodes) {
377       if (node->isa<CNode>()) {
378         const auto &cnode = node->cast<CNodePtr>();
379         // Bprop cut use prim_py, no need change
380         if (auto prim = GetValueNode<PrimitivePtr>(cnode->input(kIndex0));
381             !IsPrimitiveEquals(prim, prim::kPrimBpropCut)) {
382           auto new_prim = std::make_shared<Primitive>(*prim);
383           cnode->set_input(kIndex0, NewValueNode(new_prim));
384         }
385         kernel_graph->PostNewCNode(cnode);
386       } else {
387         if (node->isa<ValueNode>()) {
388           session_ptr->CreateNewValueNode(node, kernel_graph.get());
389         }
390         // Kernel graph new value node will create kernel info
391         if (node->kernel_info() == nullptr) {
392           kernel_graph->SetKernelInfoForNode(node);
393         }
394       }
395     }
396   }
397   auto output_node = kernel_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), kernel_graph->output()});
398   AbstractBasePtrList output_abs_list{kernel_graph->output()->abstract()};
399   auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(output_abs_list);
400   output_node->set_abstract(abstract_tuple);
401   kernel_graph->set_output(output_node);
402   MS_LOG(INFO) << "Insert make tuple for output";
403   return true;
404 }
405 
NeedCheckMultiTarget(const FuncGraphPtr & func_graph,int ms_execution_mode)406 bool NeedCheckMultiTarget(const FuncGraphPtr &func_graph, int ms_execution_mode) {
407   if (ms_execution_mode == kGraphMode) {
408     return true;
409   }
410   bool run_in_dynamic = ms_execution_mode == kPynativeMode && func_graph->has_flag(kFlagEnableRunGraphBySingleOp);
411   bool is_call_graph = func_graph->has_flag(kFlagJitCallGraph);
412   bool is_control_flow = !func_graph->func_graphs_used_total().empty();
413   return (run_in_dynamic && is_call_graph) || is_control_flow;
414 }
415 
UnifyIR(const CNodePtr & cnode,bool enable_run_graph_by_single_op)416 void UnifyIR(const CNodePtr &cnode, bool enable_run_graph_by_single_op) {
417   MS_EXCEPTION_IF_NULL(cnode);
418   static const std::map<std::string, std::string> kOpListToTupleNames = {
419     {mindspore::kMakeListNewOpName, mindspore::kMakeTupleOpName},
420     {mindspore::kListGetItemOpName, mindspore::kTupleGetItemOpName},
421     {mindspore::kListSetItemOpName, mindspore::kTupleSetItemOpName}};
422   // List name --> tuple name.
423   auto &&op_name = common::AnfAlgo::GetCNodeName(cnode);
424   auto iter = kOpListToTupleNames.find(op_name);
425   if (iter != kOpListToTupleNames.end()) {
426     common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
427     cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(iter->second)));
428     // Reset full scope name.
429     cnode->set_fullname_with_scope("");
430     MS_LOG(INFO) << "Rename op from " << iter->first << " to " << iter->second << " for op "
431                  << cnode->fullname_with_scope() << ", debug name:" << cnode->DebugString();
432     op_name = iter->second;
433   }
434 
435   // TupleGetItem --> RealTupleGetItem.
436   if (!enable_run_graph_by_single_op && op_name == mindspore::kTupleGetItemOpName &&
437       NeedConvertToRealTupleGetItem(cnode)) {
438     common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
439     cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(mindspore::kRealTupleGetItemOpName)));
440     // Reset full scope name.
441     cnode->set_fullname_with_scope("");
442     MS_LOG(INFO) << "Rename op from TupleGetItem to RealTupleGetItem for op " << cnode->fullname_with_scope()
443                  << ", debug name:" << cnode->DebugString();
444   }
445 
446   // MakeTuple --> RealMakeTuple
447   if (op_name == mindspore::kMakeTupleOpName && common::AnfAlgo::IsDynamicSequence(cnode)) {
448     common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
449     cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(mindspore::kRealMakeTupleOpName)));
450     // Reset full scope name.
451     cnode->set_fullname_with_scope("");
452     MS_LOG(INFO) << "Rename op from MakeTuple to RealMakeTuple for op " << cnode->fullname_with_scope()
453                  << ", debug name:" << cnode->DebugString();
454   }
455 }
456 
EnableSymbolEngine(const FuncGraphPtr & func_graph,device::RunMode run_mode)457 bool EnableSymbolEngine(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
458   // Currently, only Graph Kernel Fusion dynamic shape case need build symbol engine
459   if (run_mode != device::RunMode::kKernelMode) {
460     return false;
461   }
462   if (common::GetEnv("MS_SYMBOL_ENGINE_OPTIMIZE") == "off") {
463     return false;
464   }
465   if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
466     return false;
467   }
468   return common::AnfAlgo::IsDynamicGraph(func_graph);
469 }
470 
BuildSymbolEngine(const FuncGraphPtr & func_graph,device::RunMode run_mode)471 void BuildSymbolEngine(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
472   if (func_graph == nullptr) {
473     return;
474   }
475   MS_LOG(INFO) << "Status record: start build symbol engine for function graph: " << func_graph->ToString();
476   if (!EnableSymbolEngine(func_graph, run_mode)) {
477     MS_LOG(INFO) << "Status record: skip build symbol engine for function graph: " << func_graph->ToString();
478     return;
479   }
480   try {
481     MS_LOG_TRY_CATCH_SCOPE;
482     symshape::SymbolEngineImpl::Build(func_graph);
483   } catch (std::exception &e) {
484     MS_LOG(WARNING) << "A problem occurs when build symbol engine for function graph[" << func_graph->ToString()
485                     << "]: " << e.what();
486   }
487   MS_LOG(INFO) << "Status record: end build symbol engine for function graph: " << func_graph->ToString();
488 }
489 }  // namespace
490 
CompileGraphs(const FuncGraphPtr & func_graph)491 const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph) {
492   WaitTaskFinish();
493   MS_EXCEPTION_IF_NULL(graph_compiler_);
494   MS_EXCEPTION_IF_NULL(func_graph);
495   MS_LOG(INFO) << "Status record: start compile function graph: " << func_graph->ToString();
496   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCompileGraphs, 1, 0, 0);
497   PROF_START(compile_backend_graph);
498 
499   auto root_graph = WrapPrimitives(func_graph);
500   MS_EXCEPTION_IF_NULL(root_graph);
501   bool pynative_with_jit_call_graph = func_graph->has_flag(kFlagPyNativeWithJitCallGraph);
502   if (!pynative_with_jit_call_graph) {
503     UnifyMindIR(root_graph);
504   }
505   root_graph_ = root_graph;
506   // Use kernel graph, which output maybe change by backed pass, so backup output
507   if (root_graph_->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
508     output_node_ = root_graph_->output();
509   }
510 
511   // Register a summary callback function, which is called in the final stages of summary.
512   graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
513 
514   auto context_ptr = MsContext::GetInstance();
515   MS_EXCEPTION_IF_NULL(context_ptr);
516   ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
517   func_graph->set_flag(kFlagPyNativeRunInGraph, ms_execution_mode_ == kPynativeMode);
518 
519   // Compile root graph.
520   graph_id_to_device_context_.clear();
521   func_graph_to_kernel_graph_ids_.clear();
522   control_nodes_.clear();
523 
524   const auto &device_context =
525     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
526   MS_EXCEPTION_IF_NULL(device_context);
527   device_context->Initialize();
528   device_context->device_res_manager_->BindDeviceToCurrentThread(false);
529 
530   // Current only ascend do need do checkout in PartitionGraph
531   bool all_support = device_context->PartitionGraph(func_graph);
532   PROF_START(CompileSubGraph);
533   if (all_support) {
534     auto run_mode = device_context->GetRunMode(func_graph);
535     if (run_mode == device::RunMode::kGraphMode && pynative::GraphAdapter::PyNativeEnableTaskSink(func_graph)) {
536       auto graph_id = graph_compiler_->CompileWholeGraphForGraphRunMode(func_graph, device_context);
537       graph_id_to_device_context_[graph_id] = device_context;
538     } else {
539       // Build symbol engine for root graph before partition graph
540       BuildSymbolEngine(func_graph, device::RunMode::kKernelMode);
541       CompileSubGraph(func_graph, device::RunMode::kKernelMode);
542     }
543   } else {
544     if (NeedCheckMultiTarget(func_graph, ms_execution_mode_)) {
545       ProcessNotSupportCnode(func_graph, device_context->GetDeviceType(), mindspore::device::DeviceType::kCPU);
546     }
547     // Build symbol engine for root graph before partition graph
548     BuildSymbolEngine(func_graph, device_context->GetRunMode(func_graph));
549     CompileSubGraph(func_graph);
550   }
551   PROF_END(CompileSubGraph);
552 
553   // Construct the graph compiler info.
554   auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
555   MS_EXCEPTION_IF_NULL(graph_compiler_info);
556   if ((ms_execution_mode_ == kGraphMode ||
557        (ms_execution_mode_ == kPynativeMode && pynative::GraphAdapter::IsPynativeGeGraphSink(root_graph_))) &&
558       ((!graph_compiler_info->graphs_.empty()) || graph_compiler_info->control_nodes_.size() > 1)) {
559     MS_LOG(DEBUG) << "Start transform";
560     PROF_START(GraphScheduler);
561     // Transform graph to actor DAG, and schedule the actor DAG.
562     ParseControlNodes(*graph_compiler_info);
563     const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
564     runtime::GraphScheduler::GetInstance().Schedule(actor_set);
565     PROF_END(GraphScheduler);
566   }
567   const ActorInfo &actor_info = graph_compiler_info->name_;
568   (void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
569   PROF_END(compile_backend_graph);
570 
571   for (const auto &graph_id_to_context : graph_id_to_device_context_) {
572     auto context = graph_id_to_context.second;
573     device::MultiStreamController::GetInstance()->Refresh(context);
574   }
575 
576   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCompileGraphs, 1, 0, 1);
577   MS_LOG(INFO) << "Status record: end compile function graph: " << func_graph->ToString()
578                << ", produce actor: " << actor_info;
579   return actor_info;
580 }
581 
582 namespace {
DoUnifyMindIRPass(const FuncGraphPtr & graph,const std::shared_ptr<opt::GraphOptimizer> & optimizer)583 void DoUnifyMindIRPass(const FuncGraphPtr &graph, const std::shared_ptr<opt::GraphOptimizer> &optimizer) {
584   MS_EXCEPTION_IF_NULL(graph);
585   MS_EXCEPTION_IF_NULL(optimizer);
586   auto context_ptr = MsContext::GetInstance();
587   MS_EXCEPTION_IF_NULL(context_ptr);
588   MS_LOG(INFO) << "Do unify mindir pass for graph " << graph->ToString();
589 #ifdef ENABLE_DUMP_IR
590   if (context_ptr->CanDump(kIntroductory)) {
591     std::string file_name = "hwopt_before_mindrt_unify_mindir_graph_" + graph->ToString() + ".ir";
592     DumpIR(file_name, graph, true, kWholeStack);
593   }
594 #endif
595   (void)optimizer->Optimize(graph);
596 #ifdef ENABLE_DUMP_IR
597   if (context_ptr->CanDump(kIntroductory)) {
598     std::string file_name = "hwopt_end_mindrt_unify_mindir_graph_" + graph->ToString() + ".ir";
599     DumpIR(file_name, graph, true, kWholeStack);
600   }
601 #endif
602 }
603 
HasSwitchNode(const FuncGraphPtr & func_graph)604 bool HasSwitchNode(const FuncGraphPtr &func_graph) {
605   if (func_graph == nullptr) {
606     return false;
607   }
608   const auto &nodes = TopoSort(func_graph->get_return());
609   return std::any_of(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) {
610     return node != nullptr && node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch);
611   });
612 }
613 
IsNodeValid(const AnfNodePtr & node)614 bool IsNodeValid(const AnfNodePtr &node) {
615   if (node != nullptr && common::AnfAlgo::IsNodeOutputDynamicShape(node)) {
616     MS_LOG(INFO) << "Disable switch inline for dynamic shape node:" << node->DebugString();
617     return false;
618   } else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
619     const auto &cnode = node->cast<CNodePtr>();
620     MS_EXCEPTION_IF_NULL(cnode);
621     if (cnode->size() <= 1 || cnode->input(1) == nullptr || !(IsValueNode<FuncGraph>(cnode->input(1)))) {
622       return true;
623     }
624     const auto &func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
625     MS_EXCEPTION_IF_NULL(func_graph);
626     if (std::any_of(func_graph->parameters().begin(), func_graph->parameters().end(), [](const AnfNodePtr &para) {
627           return para != nullptr && para->abstract() != nullptr &&
628                  para->abstract()->isa<abstract::AbstractSequence>() &&
629                  (para->abstract()->cast<abstract::AbstractSequencePtr>()->dynamic_len() ||
630                   para->abstract()->cast<abstract::AbstractSequencePtr>()->size() > 1);
631         })) {
632       MS_LOG(INFO) << "Disable switch inline for tuple input in graph:" << func_graph->ToString()
633                    << " for partial node:" << node->DebugString();
634       return false;
635     }
636   }
637   return true;
638 }
639 
IsEnableControlFlowInline(const FuncGraphPtr & graph)640 bool IsEnableControlFlowInline(const FuncGraphPtr &graph) {
641   auto context = MsContext::GetInstance();
642   MS_EXCEPTION_IF_NULL(context);
643   if (std::any_of(
644         graph->func_graphs_used_total().cbegin(), graph->func_graphs_used_total().cend(), [](const auto &sub_graph) {
645           return sub_graph != nullptr && sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && HasSwitchNode(sub_graph);
646         })) {
647     MS_LOG(INFO) << "Set reuse level from:" << context->CellReuseLevel() << " to:" << CellReuseLevel::kNoInline;
648     context->SetCellReuseLevel(CellReuseLevel::kNoInline);
649   }
650 
651   static const auto is_disable_switch_inline = common::IsDisableRuntimeConfig(common::kRuntimeSwitchInline);
652   if (is_disable_switch_inline) {
653     MS_LOG(INFO) << "Disable switch inline by runtime config.";
654     return false;
655   }
656 
657   // Only support ge backend, kernel by kernel mode and multi-funcgraph.
658   static const bool is_enable_ge = (context->backend_policy() == "ge");
659   if (!is_enable_ge || !context->IsKByKExecutorMode() || graph->func_graphs_used_total().empty()) {
660     MS_LOG(INFO) << "Disable switch inline, executor mode:" << context->IsKByKExecutorMode();
661     return false;
662   }
663 
664   MS_EXCEPTION_IF_NULL(graph);
665   // Not support recursive.
666   if (std::any_of(graph->func_graphs_used_total().cbegin(), graph->func_graphs_used_total().cend(),
667                   [](const auto &sub_graph) { return sub_graph->recursive(); })) {
668     MS_LOG(INFO) << "Disable switch inline for recursive.";
669     return false;
670   }
671 
672   if (context->CellReuseLevel() != CellReuseLevel::kLazyInline) {
673     auto is_include_no_switch_call = [](const FuncGraphPtr &graph) {
674       MS_EXCEPTION_IF_NULL(graph);
675       const auto &nodes = TopoSort(graph->get_return());
676       for (const auto &node : nodes) {
677         MS_EXCEPTION_IF_NULL(node);
678         if (common::AnfAlgo::IsCallNode(node)) {
679           const auto &cnode = node->cast<CNodePtr>();
680           if (!common::AnfAlgo::CheckPrimitiveType(cnode->input(0), prim::kPrimSwitch)) {
681             return true;
682           }
683         }
684       }
685       return false;
686     };
687     if (is_include_no_switch_call(graph)) {
688       MS_LOG(INFO) << "Disable switch inline for unsupported call node.";
689       return false;
690     }
691     if (std::any_of(graph->func_graphs_used_total().begin(), graph->func_graphs_used_total().end(),
692                     is_include_no_switch_call)) {
693       MS_LOG(INFO) << "Disable switch inline for unsupported call node.";
694       return false;
695     }
696   }
697   const auto &mng = graph->manager();
698   if (mng != nullptr && std::any_of(mng->all_nodes().begin(), mng->all_nodes().end(),
699                                     [](const AnfNodePtr &node) { return !IsNodeValid(node); })) {
700     return false;
701   }
702   MS_LOG(INFO) << "Enable switch inline.";
703   return true;
704 }
705 
AddGraphDynamicShapeAttr(const KernelGraphPtr & kernel_graph)706 void AddGraphDynamicShapeAttr(const KernelGraphPtr &kernel_graph) {
707   MS_EXCEPTION_IF_NULL(kernel_graph);
708   if (kernel_graph->is_dynamic_shape()) {
709     return;
710   }
711 
712   const auto &nodes = TopoSort(kernel_graph->output());
713   for (const auto &node : nodes) {
714     MS_EXCEPTION_IF_NULL(node);
715     if (node->isa<CNode>() && common::AnfAlgo::IsDynamicShape(node)) {
716       kernel_graph->SetGraphDynamicAttr(true);
717       break;
718     }
719   }
720 }
721 }  // namespace
722 
UnifyMindIR(const FuncGraphPtr & root_graph) const723 void MindRTBackendBase::UnifyMindIR(const FuncGraphPtr &root_graph) const {
724   MS_EXCEPTION_IF_NULL(root_graph);
725   MS_EXCEPTION_IF_NULL(root_graph->manager());
726   // When the input is an empty sequence, the number of inputs will be recorded as 0, and the tensor cannot be
727   // expressed, so the empty sequence is set to dynamic len.
728   for (const auto &parameter : root_graph->parameters()) {
729     MS_EXCEPTION_IF_NULL(parameter);
730     const auto &abs = parameter->abstract();
731     if (abs != nullptr && abs->isa<abstract::AbstractSequence>()) {
732       const auto &sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
733       MS_EXCEPTION_IF_NULL(sequence_abs);
734       if ((!sequence_abs->dynamic_len()) && sequence_abs->empty()) {
735         MS_LOG(INFO) << "Set dynamic len flag for empty sequence input:" << parameter->DebugString();
736         sequence_abs->set_dynamic_len(true);
737       }
738     }
739   }
740   bool enable_run_graph_by_single_op = root_graph->has_flag(kFlagEnableRunGraphBySingleOp);
741   const auto &graphs = root_graph->manager()->func_graphs();
742   for (const auto &graph : graphs) {
743     MS_EXCEPTION_IF_NULL(graph);
744     auto output = graph->get_return();
745     if (!output->isa<CNode>()) {
746       continue;
747     }
748     auto seen = NewSeenGeneration();
749     std::queue<AnfNodePtr> to_visit;
750     to_visit.emplace(output);
751     while (!to_visit.empty()) {
752       auto node = to_visit.front();
753       to_visit.pop();
754       MS_EXCEPTION_IF_NULL(node);
755       CheckNodeValid(node);
756 
757       const auto &cnode = node->cast<CNodePtr>();
758       MS_EXCEPTION_IF_NULL(cnode);
759       UnifyIR(cnode, enable_run_graph_by_single_op);
760       for (auto &input : cnode->inputs()) {
761         MS_EXCEPTION_IF_NULL(input);
762         if (input->seen_ == seen || !input->isa<CNode>()) {
763           continue;
764         }
765         to_visit.emplace(input);
766         input->seen_ = seen;
767       }
768     }
769   }
770 
771   auto optimizer = std::make_shared<opt::GraphOptimizer>();
772   auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
773   unify_mindir_pm->AddPass(std::make_shared<opt::EraseInvalidMicroDepend>());
774   if (common::AnfAlgo::IsDynamicGraph(root_graph)) {
775     unify_mindir_pm->AddPass(std::make_shared<opt::EraseNotCutAttr>());
776   }
777   if (IsEnableControlFlowInline(root_graph)) {
778     unify_mindir_pm->AddPass(std::make_shared<opt::SwitchNotCut>());
779   }
780   optimizer->AddPassManager(unify_mindir_pm);
781 
782   DoUnifyMindIRPass(root_graph, optimizer);
783   const auto &sub_graphs = root_graph->manager()->func_graphs_used_total(root_graph);
784   for (const auto &sub_graph : sub_graphs) {
785     MS_EXCEPTION_IF_NULL(sub_graph);
786     DoUnifyMindIRPass(sub_graph, optimizer);
787   }
788 }
789 
CompileSubGraph(const FuncGraphPtr & func_graph,device::RunMode run_mode)790 void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
791   auto root_graph = func_graph;
792   if (!func_graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
793     root_graph = WrapPrimitives(func_graph);
794   }
795   MS_EXCEPTION_IF_NULL(root_graph);
796   auto manager = root_graph->manager();
797   CompileGraph(root_graph, run_mode);
798   auto context = MsContext::GetInstance();
799   MS_EXCEPTION_IF_NULL(context);
800   MS_EXCEPTION_IF_NULL(manager);
801   const auto &sub_graphs = manager->func_graphs_used_total(root_graph);
802   std::vector<FuncGraphPtr> cand_graph(sub_graphs.begin(), sub_graphs.end());
803   std::sort(cand_graph.begin(), cand_graph.end(),
804             [](const FuncGraphPtr &a, const FuncGraphPtr &b) { return a->ToString() < b->ToString(); });
805   for (const auto &sub_graph : cand_graph) {
806     MS_EXCEPTION_IF_NULL(sub_graph);
807     bool skip_inline_graph =
808       (sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && context->CellReuseLevel() == CellReuseLevel::kLazyInline) ||
809       sub_graph->has_flag(kFlagSwitchInline);
810     if (sub_graph != func_graph && sub_graph != nullptr && !sub_graph->has_flag(kFlagJitCallGraph) &&
811         !skip_inline_graph) {
812       MS_LOG(INFO) << "Compile sub graph " << sub_graph->ToString();
813       CompileGraph(sub_graph, run_mode);
814     }
815   }
816 }
817 
CompileGraph(const FuncGraphPtr & func_graph,device::RunMode run_mode)818 void MindRTBackendBase::CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
819   MS_EXCEPTION_IF_NULL(func_graph);
820   if (!func_graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
821     (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageGraphPartition, 1, 0, 0);
822     // Split graph to segments.
823     MS_EXCEPTION_IF_NULL(graph_partition_);
824     const auto &segments = graph_partition_->Partition(func_graph);
825     (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageGraphPartition, 1, 0, 1);
826     MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size: " << segments.size();
827 
828     // Foreach the segments to compile graph.
829     for (const auto &segment : segments) {
830       CompileGraphFromSegment(segment, run_mode);
831     }
832   } else {
833     auto kernel_graph = func_graph->cast<KernelGraphPtr>();
834     AddGraphDynamicShapeAttr(kernel_graph);
835     MS_EXCEPTION_IF_NULL(kernel_graph);
836     const auto &session = graph_compiler_->session_ptr();
837     MS_EXCEPTION_IF_NULL(session);
838     session->SetKernelGraphId(kernel_graph);
839     MS_LOG(INFO) << "Compile graph: " << kernel_graph->ToString() << ", kernel graph";
840     if (AddKernelGraphCompileInfo(kernel_graph, session)) {
841       kernel_graph->SetExecOrderByDefault();
842       auto context_ptr = MsContext::GetInstance();
843       MS_EXCEPTION_IF_NULL(context_ptr);
844       auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
845         {context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET), device_id_});
846       MS_EXCEPTION_IF_NULL(device_context);
847       device_context->Initialize();
848       CompileKernelGraph(kernel_graph, std::make_pair(kernel_graph->inputs(), kernel_graph->outputs()), device_context,
849                          run_mode);
850     }
851   }
852 }
853 
CompileGraphFromSegment(const GraphSegmentPtr & segment,device::RunMode run_mode)854 void MindRTBackendBase::CompileGraphFromSegment(const GraphSegmentPtr &segment, device::RunMode run_mode) {
855   MS_EXCEPTION_IF_NULL(segment);
856   // Compile the normal nodes, which doesn't contain the cut node.
857   if (segment->nodes_.empty()) {
858     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The segments size is 0.";
859   }
860   if (!segment->is_cut_) {
861     MS_EXCEPTION_IF_NULL(segment->nodes_[0]);
862     MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->DebugString();
863 
864     // Get the device context.
865     const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]);
866     auto device_context =
867       device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_});
868     MS_EXCEPTION_IF_NULL(device_context);
869     device_context->Initialize();
870 
871     // Transform nodes to inputs and outputs.
872     FuncGraphPtr fg;
873     AnfNodePtrList inputs;
874     AnfNodePtrList outputs;
875     std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
876 
877     // Get segment run mode.
878     auto seg_run_mode = run_mode;
879     for (auto &node : outputs) {
880       if (node->isa<CNode>()) {
881         if (common::AnfAlgo::GetGraphSplitGroup(node) == kKernelGroup) {
882           seg_run_mode = device::RunMode::kKernelMode;
883           break;
884         }
885       }
886     }
887 
888     GraphId graph_id;
889     if (root_graph_->has_flag(kFlagEnableRunGraphBySingleOp)) {
890       graph_id = graph_compiler_->CompileDynamicGraph(segment, outputs, device_context);
891     } else {
892       graph_id = graph_compiler_->CompileGraph(segment, std::make_pair(inputs, outputs), device_context, seg_run_mode,
893                                                ms_execution_mode_ == kPynativeMode);
894       if (graph_compiler_->Fetch(graph_id)->has_flag(kFlagEnableRunGraphBySingleOp)) {
895         MS_LOG(INFO)
896           << "Set kFlagEnableRunGraphBySingleOp: require the root_graph and subgraph to have the same markings ";
897         root_graph_->set_flag(kFlagEnableRunGraphBySingleOp, true);
898       }
899     }
900     CacheFuncGraphWithKernelGraphId(segment->nodes_[0]->func_graph(), graph_id, device_context);
901   } else {
902     // Compile the cut node.
903     auto cut_node = segment->nodes_[0];
904     MS_EXCEPTION_IF_NULL(cut_node);
905     MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
906     control_nodes_.push_back(cut_node);
907     if (common::AnfAlgo::IsCallNode(cut_node) || common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) ||
908         common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) {
909       const auto &func_graph = cut_node->func_graph();
910       MS_EXCEPTION_IF_NULL(func_graph);
911       (void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());
912     }
913   }
914 }
915 
CompileKernelGraph(const KernelGraphPtr & kernel_graph,const std::pair<AnfNodePtrList,AnfNodePtrList> & io_nodes,DeviceContext * device_context,device::RunMode run_mode)916 void MindRTBackendBase::CompileKernelGraph(const KernelGraphPtr &kernel_graph,
917                                            const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes,
918                                            DeviceContext *device_context, device::RunMode run_mode) {
919   GraphId graph_id;
920   if (root_graph_->has_flag(kFlagEnableRunGraphBySingleOp)) {
921     graph_id = graph_compiler_->CompileDynamicGraph(kernel_graph, device_context);
922   } else {
923     graph_id = graph_compiler_->CompileGraph(kernel_graph, io_nodes, device_context, run_mode,
924                                              ms_execution_mode_ == kPynativeMode);
925     if (graph_compiler_->Fetch(graph_id)->has_flag(kFlagEnableRunGraphBySingleOp)) {
926       MS_LOG(INFO)
927         << "Set kFlagEnableRunGraphBySingleOp: require the root_graph and subgraph to have the same markings ";
928       root_graph_->set_flag(kFlagEnableRunGraphBySingleOp, true);
929     }
930   }
931   CacheFuncGraphWithKernelGraphId(kernel_graph, graph_id, device_context);
932 }
933 
CacheFuncGraphWithKernelGraphId(const FuncGraphPtr & func_graph,const GraphId & graph_id,DeviceContext * device_context)934 void MindRTBackendBase::CacheFuncGraphWithKernelGraphId(const FuncGraphPtr &func_graph, const GraphId &graph_id,
935                                                         DeviceContext *device_context) {
936   graph_id_to_device_context_[graph_id] = device_context;
937   if (func_graph_to_kernel_graph_ids_.find(func_graph) == func_graph_to_kernel_graph_ids_.end()) {
938     (void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>{graph_id});
939   } else {
940     (void)func_graph_to_kernel_graph_ids_[func_graph].back().emplace_back(graph_id);
941   }
942 }
943 
944 namespace {
TensorValueToVector(const ValuePtr & value,VectorRef * outputs)945 void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
946   MS_EXCEPTION_IF_NULL(value);
947   MS_EXCEPTION_IF_NULL(outputs);
948   if (value->isa<ValueSequence>()) {
949     auto value_tuple = value->cast<ValueSequencePtr>();
950     MS_EXCEPTION_IF_NULL(value_tuple);
951     for (size_t i = 0; i < value_tuple->size(); ++i) {
952       ValuePtr element = value_tuple->value()[i];
953       MS_EXCEPTION_IF_NULL(element);
954       if (element->isa<tensor::Tensor>()) {
955         auto tensor = element->cast<tensor::TensorPtr>();
956         MS_EXCEPTION_IF_NULL(tensor);
957         outputs->emplace_back(tensor);
958       } else if (element->isa<Scalar>()) {
959         auto scalar = element->cast<ScalarPtr>();
960         MS_EXCEPTION_IF_NULL(scalar);
961         outputs->emplace_back(ScalarToTensor(scalar));
962       } else if (element->isa<ValueSequence>()) {
963         VectorRef tuple;
964         TensorValueToVector(element, &tuple);
965         outputs->emplace_back(tuple);
966       }
967     }
968   } else if (value->isa<tensor::Tensor>()) {
969     auto tensor = value->cast<tensor::TensorPtr>();
970     MS_EXCEPTION_IF_NULL(tensor);
971     outputs->emplace_back(tensor);
972   } else if (value->isa<Scalar>()) {
973     auto scalar = value->cast<ScalarPtr>();
974     MS_EXCEPTION_IF_NULL(scalar);
975     outputs->emplace_back(ScalarToTensor(scalar));
976   }
977 }
978 
IsGraphOutputValueNodeOrParameter(const AnfNodePtr & graph_output,const VectorRef & args,VectorRef * outputs)979 bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const VectorRef &args, VectorRef *outputs) {
980   MS_EXCEPTION_IF_NULL(graph_output);
981   MS_EXCEPTION_IF_NULL(outputs);
982   if (graph_output->isa<ValueNode>()) {
983     MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
984     VectorRef output_tmp;
985     ValuePtr value = GetValueNode(graph_output);
986     TensorValueToVector(value, &output_tmp);
987     MS_EXCEPTION_IF_NULL(value);
988     if (value->isa<ValueSequence>()) {
989       outputs->emplace_back(output_tmp);
990     } else if (value->isa<tensor::Tensor>() || value->isa<Scalar>()) {
991       *outputs = output_tmp;
992     } else {
993       MS_LOG(INFO) << "Graph output is empty!";
994     }
995     return true;
996   }
997 
998   if (graph_output->isa<Parameter>()) {
999     MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
1000     // Find the right parameter as ret_val.
1001     auto func_graph = graph_output->func_graph();
1002     MS_EXCEPTION_IF_NULL(func_graph);
1003     auto params = func_graph->parameters();
1004     if (args.size() != params.size()) {
1005       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Input size " << args.size()
1006                                  << " is not equal to graph input size " << params.size();
1007     }
1008 
1009     auto it = std::find(params.begin(), params.end(), graph_output);
1010     if (it == params.end()) {
1011       MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
1012     }
1013     size_t index = it - params.cbegin();
1014     if (index >= args.size()) {
1015       MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size();
1016     }
1017 
1018     outputs->emplace_back(args[index]);
1019     return true;
1020   }
1021   return false;
1022 }
1023 }  // namespace
1024 
ConstructOutputs(runtime::ActorSet * actor_set,VectorRef * outputs,const FuncGraphPtr & root_graph)1025 void MindRTBackendBase::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs,
1026                                          const FuncGraphPtr &root_graph) {
1027   MS_EXCEPTION_IF_NULL(actor_set);
1028   MS_EXCEPTION_IF_NULL(outputs);
1029   MS_EXCEPTION_IF_NULL(root_graph);
1030   bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
1031                                 distributed::recovery::RecoveryContext::GetInstance()->need_reset());
1032   bool is_embedding_cache_server = false;
1033 #if defined(__linux__) && defined(WITH_BACKEND)
1034   is_embedding_cache_server = ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
1035 #endif
1036   if (need_contruct_output) {
1037     MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
1038     // Update device address for output node of graph.
1039     // Summary processing will use the output device address, so must be after the summary processing.
1040     if (!is_embedding_cache_server) {
1041       actor_set->output_actor_->UpdateOutputDeviceAddress();
1042     }
1043 
1044     // Fetch outputs.
1045     auto &output_tensors = actor_set->output_actor_->outputs();
1046     if (!output_tensors.empty()) {
1047       size_t output_position = 0;
1048       std::vector<tensor::TensorPtr> tuple_tensors;
1049       ConstructOutputs(root_graph->output(), output_tensors, &output_position, outputs, &tuple_tensors);
1050 
1051       // The tensor may be repeated, so it needs to be set null last.
1052       for (auto &tuple_tensor : tuple_tensors) {
1053         MS_EXCEPTION_IF_NULL(tuple_tensor);
1054         tuple_tensor->set_device_address(nullptr);
1055       }
1056     }
1057   }
1058 }
1059 
ContiguousArgs(const VectorRef & args,const GraphCompilerInfo & graph_compiler_info)1060 void MindRTBackendBase::ContiguousArgs(const VectorRef &args, const GraphCompilerInfo &graph_compiler_info) {
1061   for (const auto &arg : args) {
1062     if (utils::isa<tensor::BaseTensorPtr>(arg)) {
1063       auto value = utils::cast<tensor::BaseTensorPtr>(arg);
1064       runtime::DeviceAddressUtils::ConvertContiguousTensorSync(value);
1065     } else if (utils::isa<ValuePtr>(arg)) {
1066       auto value = utils::cast<ValuePtr>(arg);
1067       MS_EXCEPTION_IF_NULL(value);
1068       if (!value->isa<ValueSequence>()) {
1069         return;
1070       }
1071       auto value_tuple = value->cast<ValueSequencePtr>();
1072       MS_EXCEPTION_IF_NULL(value_tuple);
1073       auto tuple_value = value_tuple->value();
1074       for (const auto &v : tuple_value) {
1075         if (!v->isa<tensor::BaseTensor>()) {
1076           continue;
1077         }
1078         auto t = v->cast<tensor::BaseTensorPtr>();
1079         runtime::DeviceAddressUtils::ConvertContiguousTensorSync(t);
1080       }
1081     }
1082   }
1083 }
1084 
WaitMultiStream(const GraphCompilerInfo & graph_compiler_info)1085 void MindRTBackendBase::WaitMultiStream(const GraphCompilerInfo &graph_compiler_info) {
1086   for (auto device_context : graph_compiler_info.device_contexts_) {
1087     MS_EXCEPTION_IF_NULL(device_context);
1088     if (device_context->device_res_manager_->single_op_multi_stream_enable()) {
1089       device_context->device_res_manager_->SyncNotDefaultStreams();
1090     }
1091   }
1092 }
1093 
RunGraph(const ActorInfo & actor_info,const VectorRef & args,VectorRef * outputs)1094 void MindRTBackendBase::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
1095   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kBackendGraphRunInner,
1096                                      actor_info, true);
1097   MS_EXCEPTION_IF_NULL(root_graph_);
1098   if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
1099     return;
1100   }
1101 
1102   const auto &context_ptr = MsContext::GetInstance();
1103   MS_EXCEPTION_IF_NULL(context_ptr);
1104   if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
1105     MS_LOG(INFO) << "PrecompileOnly, stop run graph";
1106     return;
1107   }
1108 
1109   // Open abstract_lock for dynamic_shape
1110   AnfUtils::OpenAbstractLock();
1111 
1112   // Fetch the graph compiler info.
1113   const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
1114   if (graph_iter == actor_to_graph_compiler_info_.end()) {
1115     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Can't find the graph compiler info.";
1116   }
1117   MS_EXCEPTION_IF_NULL(graph_iter->second);
1118   const auto &graph_compiler_info = *(graph_iter->second);
1119   // For pynative and graph mix execution.
1120   WaitTaskFinish();
1121   WaitMultiStream(graph_compiler_info);
1122 
1123   // Run in the pynative mode.
1124   MS_EXCEPTION_IF_NULL(outputs);
1125   // There will be more than one kernel graph in heterogeneous scenario in a jit of PyNative Mode.
1126   if (ms_execution_mode_ == kPynativeMode && !pynative::GraphAdapter::IsPynativeGeGraphSink(root_graph_)) {
1127     // The tensor needs to be converted to contiguous before being given to the actors.
1128     // After the view feature is supported in the graph mode, the following code will be deleted.
1129     ContiguousArgs(args, graph_compiler_info);
1130     RunGraphByCondition(actor_info, graph_compiler_info, args, outputs);
1131     return;
1132   }
1133 
1134   MS_LOG(INFO) << "Status record: start run actor: " << actor_info;
1135   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventRunGraph, kStageRunGraph, 1, 0, 0);
1136   std::vector<std::vector<tensor::TensorPtr>> input_tensors;
1137   if (graph_compiler_info.exist_flatten_concat_) {
1138     input_tensors = GetRunGraphInputs(graph_compiler_info, args);
1139     // The tensor needs to be converted to contiguous before being given to the actors.
1140     // After the view feature is supported in the graph mode, the following code will be deleted.
1141     // Single ops(run in pynative mode) output to net(context is graph mode) input.
1142     (void)std::for_each(input_tensors.begin(), input_tensors.end(), [this](const auto &tensor_vec) {
1143       (void)std::for_each(tensor_vec.begin(), tensor_vec.end(), [](const tensor::TensorPtr &t) {
1144         runtime::DeviceAddressUtils::ConvertContiguousTensorSync(t);
1145         runtime::DeviceAddressUtils::CreateKernelTensor(t);
1146       });
1147     });
1148   }
1149   // Release python gil.
1150   mindspore::ScopedLongRunning long_running;
1151   // Run actor DAG.
1152   const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
1153   MS_EXCEPTION_IF_NULL(actor_set);
1154   runtime::GraphScheduler::GetInstance().Run(actor_set, input_tensors, args);
1155 
1156   {
1157     uint64_t start_time = 0;
1158     PROFILER_START(start_time);
1159     MS_EXCEPTION_IF_NULL(graph_compiler_);
1160     graph_compiler_->Summary(graph_compiler_info.graphs_);
1161     ConstructOutputs(actor_set, outputs, root_graph_);
1162     actor_set->output_actor_->FreeSummaryNodeMem();
1163     runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
1164     PROFILER_END(start_time, runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kOutputProcess, actor_set->name_,
1165                  false);
1166   }
1167   // Close abstract_lock for dynamic_shape
1168   AnfUtils::CloseAbstractLock();
1169   (void)profiler::CollectHostInfo(kModelNameRuntime, kEventRunGraph, kStageRunGraph, 1, 0, 1);
1170   MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
1171 }
1172 
GetRandomStatus(const ActorInfo & actor_info)1173 std::string MindRTBackendBase::GetRandomStatus(const ActorInfo &actor_info) {
1174   auto iter = actor_to_graph_compiler_info_.find(actor_info);
1175   if (iter == actor_to_graph_compiler_info_.end()) {
1176     MS_LOG(EXCEPTION) << "Cannot find actor info " << actor_info;
1177   }
1178   MS_EXCEPTION_IF_NULL(iter->second);
1179 
1180   auto device_context =
1181     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
1182   MS_EXCEPTION_IF_NULL(device_context);
1183   if (device_context->graph_executor_ == nullptr) {
1184     return "";
1185   }
1186   std::vector<FuncGraphPtr> graphs;
1187   std::transform(iter->second->graphs_.begin(), iter->second->graphs_.end(), std::back_inserter(graphs),
1188                  [](const auto &g) -> FuncGraphPtr { return g; });
1189   return device_context->graph_executor_->GetRandomStatus(graphs);
1190 }
1191 
1192 namespace {
IsTupleOutputOfAnyType(const abstract::AbstractBasePtr & abstract,const tensor::TensorPtr & tensor)1193 bool IsTupleOutputOfAnyType(const abstract::AbstractBasePtr &abstract, const tensor::TensorPtr &tensor) {
1194   if (abstract == nullptr || !abstract->isa<abstract::AbstractAny>() || tensor == nullptr) {
1195     return false;
1196   }
1197   auto device_tensor = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1198   return device_tensor != nullptr && device_tensor->user_data() == nullptr &&
1199          device_tensor->kernel_tensor() != nullptr && device_tensor->kernel_tensor()->GetShape() != nullptr &&
1200          device_tensor->kernel_tensor()->GetShape()->isa<abstract::SequenceShape>();
1201 }
1202 }  // namespace
1203 
ConstructOutputByAbstract(const abstract::AbstractBasePtr & abstract,const std::vector<tensor::TensorPtr> & output_tensors,size_t * output_position,std::vector<tensor::TensorPtr> * tuple_tensors)1204 BaseRef MindRTBackendBase::ConstructOutputByAbstract(const abstract::AbstractBasePtr &abstract,
1205                                                      const std::vector<tensor::TensorPtr> &output_tensors,
1206                                                      size_t *output_position,
1207                                                      std::vector<tensor::TensorPtr> *tuple_tensors) {
1208   MS_EXCEPTION_IF_NULL(abstract);
1209   MS_EXCEPTION_IF_NULL(output_position);
1210   MS_EXCEPTION_IF_NULL(tuple_tensors);
1211 
1212   size_t outputs_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
1213   if (*output_position + outputs_num > output_tensors.size()) {
1214     MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The output position is out of range: "
1215                                << *output_position << " need:" << outputs_num << " total:" << output_tensors.size();
1216   }
1217 
1218   if (!abstract->isa<abstract::AbstractSequence>()) {
1219     if (IsTupleOutputOfAnyType(abstract, output_tensors[*output_position])) {
1220       MS_LOG(DEBUG) << "Any output for position:" << *output_position;
1221       VectorRef outputs;
1222       auto device_tensor =
1223         std::dynamic_pointer_cast<device::DeviceAddress>(output_tensors[*output_position]->device_address());
1224       ConstructOutputByTupleTensor(output_tensors[*output_position],
1225                                    device_tensor->kernel_tensor()->GetShape()->cast<abstract::SequenceShapePtr>(),
1226                                    &outputs, tuple_tensors);
1227       (*output_position)++;
1228       std::vector<ValuePtr> values;
1229 
1230       (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(values),
1231                            [](const auto &output) { return utils::cast<ValuePtr>(output); });
1232       return std::make_shared<ValueList>(values);
1233     }
1234 
1235     (*output_position)++;
1236     return output_tensors[(*output_position) - 1];
1237   }
1238 
1239   VectorRef outputs;
1240   const auto &tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
1241   MS_EXCEPTION_IF_NULL(tuple_abstract);
1242   // Dynamic len tuple.
1243   if (tuple_abstract->dynamic_len()) {
1244     auto &output_tensor = output_tensors[*output_position];
1245     MS_EXCEPTION_IF_NULL(output_tensor);
1246     auto &tensor_shape = output_tensor->base_shape_ptr();
1247     // Restore the tuple output by the tensor of tuple.
1248     if ((tensor_shape != nullptr) && tensor_shape->isa<abstract::SequenceShape>()) {
1249       ConstructOutputByTupleTensor(output_tensor, tensor_shape->cast<abstract::SequenceShapePtr>(), &outputs,
1250                                    tuple_tensors);
1251       (*output_position)++;
1252       return outputs;
1253     }
1254   }
1255 
1256   const auto &sub_abstracts = tuple_abstract->elements();
1257   for (const auto &sub_abstract : sub_abstracts) {
1258     MS_EXCEPTION_IF_NULL(sub_abstract);
1259     outputs.emplace_back(ConstructOutputByAbstract(sub_abstract, output_tensors, output_position, tuple_tensors));
1260   }
1261   return outputs;
1262 }
1263 
ConstructOutputByTupleTensor(tensor::TensorPtr output_tensor,const abstract::SequenceShapePtr & tensor_shape,VectorRef * outputs,std::vector<tensor::TensorPtr> * tuple_tensors) const1264 void MindRTBackendBase::ConstructOutputByTupleTensor(tensor::TensorPtr output_tensor,
1265                                                      const abstract::SequenceShapePtr &tensor_shape, VectorRef *outputs,
1266                                                      std::vector<tensor::TensorPtr> *tuple_tensors) const {
1267   MS_EXCEPTION_IF_NULL(output_tensor);
1268   MS_EXCEPTION_IF_NULL(tensor_shape);
1269   MS_EXCEPTION_IF_NULL(outputs);
1270   MS_EXCEPTION_IF_NULL(tuple_tensors);
1271   MS_LOG(DEBUG) << "Tensor shape:" << tensor_shape->ToString();
1272   // If outputs an empty sequence return an empty sequence value.
1273   if (tensor_shape->size() == 0) {
1274     if (tensor_shape->isa<abstract::TupleShape>()) {
1275       outputs->emplace_back(std::make_shared<ValueTuple>(std::vector<ValuePtr>()));
1276     } else {
1277       outputs->emplace_back(std::make_shared<ValueList>(std::vector<ValuePtr>()));
1278     }
1279     return;
1280   }
1281   // No need split multi tensors when the tuple size is not greater than 1.
1282   if (tensor_shape->size() <= 1) {
1283     outputs->emplace_back(output_tensor);
1284     return;
1285   }
1286 
1287   auto tensor_type_id = output_tensor->data_type();
1288   auto device_tensor = std::dynamic_pointer_cast<device::DeviceAddress>(output_tensor->device_address());
1289   MS_EXCEPTION_IF_NULL(device_tensor);
1290   auto tensor_device_ptr = device_tensor->GetMutablePtr();
1291   auto tensor_device_size = device_tensor->GetSize();
1292   MS_EXCEPTION_IF_NULL(tensor_device_ptr);
1293   auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1294     {device_tensor->device_name(), device_tensor->device_id()});
1295   MS_EXCEPTION_IF_NULL(device_context);
1296   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
1297 
1298   const auto &output_kernel_tensor = device_tensor->kernel_tensor();
1299   MS_EXCEPTION_IF_NULL(output_kernel_tensor);
1300   TypePtr output_type = output_kernel_tensor->GetType();
1301   MS_EXCEPTION_IF_NULL(output_type);
1302   TuplePtr output_tuple_type = output_type->cast<TuplePtr>();
1303   MS_EXCEPTION_IF_NULL(output_tuple_type);
1304   const auto &element_types = output_tuple_type->elements();
1305   if (tensor_shape->size() != element_types.size()) {
1306     MS_LOG(EXCEPTION) << "The tensor shape size[" << tensor_shape->size() << "] is not equal to output element size["
1307                       << element_types.size() << "].";
1308   }
1309 
1310   // Split the tensor of tuple to tensors.
1311   (void)tuple_tensors->emplace_back(output_tensor);
1312   size_t copy_offset_size = 0;
1313   for (size_t i = 0; i < tensor_shape->size(); ++i) {
1314     // Create split tensor.
1315     auto split_tensor_shape = BaseShapeToShape((*tensor_shape)[i]);
1316     auto split_tensor_size = SizeOf(split_tensor_shape) * GetTypeByte(TypeIdToType(tensor_type_id));
1317     auto split_tensor = std::make_shared<tensor::Tensor>(tensor_type_id, split_tensor_shape);
1318 
1319     auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1320       nullptr, split_tensor_size, kernel::GetFormatFromStrToEnum(device_tensor->format()), device_tensor->type_id(),
1321       split_tensor_shape, device_context->device_context_key().device_name_,
1322       device_context->device_context_key().device_id_);
1323     kernel_tensor->SetType(element_types[i]);
1324     kernel_tensor->SetShape((*tensor_shape)[i]);
1325     kernel_tensor->set_stream_id(device_tensor->stream_id());
1326     auto split_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1327     MS_LOG(DEBUG) << "Create device tensor:" << split_device_tensor << " type:" << device_tensor->type_id();
1328     // Copy data from origin tensor to the split tensor.
1329     device::DynamicMemAllocatorDebugInfo::SetDebugInfo("Split tuple outputs", device::AllocatorType::kOther);
1330     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "ConstructOutputByTupleTensor",
1331                                                    "ConstructOutputByTupleTensor", "");
1332     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "ConstructOutputByTupleTensor",
1333                                                    device::tracker::MemType::kOther, split_device_tensor->GetSize(),
1334                                                    split_device_tensor.get());
1335     if (!device_context->device_res_manager_->AllocateMemory(split_device_tensor.get())) {
1336       MS_LOG(EXCEPTION) << "#umsg#Memory not enough:#umsg#Device(id:" << device_context->device_context_key().device_id_
1337                         << ") memory isn't enough and alloc failed, kernel name: Split tuple outputs, alloc size: "
1338                         << split_device_tensor->GetSize() << "B.";
1339     }
1340     if (copy_offset_size + split_tensor_size > tensor_device_size) {
1341       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The copy size is out of range, copy size:"
1342                                  << split_tensor_size << ", copy offset size:" << copy_offset_size
1343                                  << ", total size:" << tensor_device_size;
1344     }
1345     if (!split_device_tensor->SyncDeviceToDevice(split_tensor_shape, split_tensor_size, device_tensor->type_id(),
1346                                                  AddressOffset(tensor_device_ptr, copy_offset_size),
1347                                                  device_tensor->format())) {
1348       MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Sync device to device failed, device type:"
1349                                  << split_device_tensor->GetDeviceType() << ", copy size:" << split_tensor_size
1350                                  << ", output node: Split tuple outputs.";
1351     }
1352     copy_offset_size += split_tensor_size;
1353 
1354     // Fill the outputs.
1355     split_tensor->set_device_address(split_device_tensor);
1356     outputs->emplace_back(split_tensor);
1357   }
1358 }
1359 
1360 namespace {
IsEmptySequence(const AnfNodePtr & output_node,const std::vector<tensor::TensorPtr> & output_tensors,const size_t * const output_position)1361 bool IsEmptySequence(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
1362                      const size_t *const output_position) {
1363   MS_EXCEPTION_IF_NULL(output_node);
1364   MS_EXCEPTION_IF_NULL(output_position);
1365   // When the output node is a valuenode, the position may out of range.
1366   if (*output_position >= output_tensors.size()) {
1367     return false;
1368   }
1369 
1370   if (output_node->abstract() == nullptr || (!output_node->abstract()->isa<abstract::AbstractSequence>())) {
1371     return false;
1372   }
1373   const auto &tuple_abs = output_node->abstract()->cast<abstract::AbstractSequencePtr>();
1374   MS_EXCEPTION_IF_NULL(tuple_abs);
1375   if ((!tuple_abs->dynamic_len()) && tuple_abs->dynamic_len_element_abs() == nullptr) {
1376     return false;
1377   }
1378   const auto &tensor = output_tensors[*output_position];
1379   MS_EXCEPTION_IF_NULL(tensor);
1380   if (tensor->base_shape_ptr() == nullptr || (!tensor->base_shape_ptr()->isa<abstract::SequenceShape>())) {
1381     return false;
1382   }
1383   const auto &sequence_shape = tensor->base_shape_ptr()->cast<abstract::SequenceShapePtr>();
1384   MS_EXCEPTION_IF_NULL(sequence_shape);
1385   return sequence_shape->size() == 0;
1386 }
1387 }  // namespace
1388 
ConstructOutputs(const AnfNodePtr & output_node,const std::vector<tensor::TensorPtr> & output_tensors,size_t * output_position,VectorRef * outputs,std::vector<tensor::TensorPtr> * tuple_tensors)1389 void MindRTBackendBase::ConstructOutputs(const AnfNodePtr &output_node,
1390                                          const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position,
1391                                          VectorRef *outputs, std::vector<tensor::TensorPtr> *tuple_tensors) {
1392   MS_EXCEPTION_IF_NULL(output_node);
1393   MS_EXCEPTION_IF_NULL(outputs);
1394   MS_EXCEPTION_IF_NULL(output_position);
1395   MS_EXCEPTION_IF_NULL(tuple_tensors);
1396   static const PrimitiveSet expand_prims{
1397     prim::kPrimMakeTuple,
1398     prim::kPrimMakeCSRTensor,
1399     prim::kPrimMakeCOOTensor,
1400     prim::kPrimMakeRowTensor,
1401   };
1402   MS_LOG(DEBUG) << "output node:" << output_node->DebugString();
1403   // If outputs an empty sequence return an empty sequence value.
1404   if (IsEmptySequence(output_node, output_tensors, output_position)) {
1405     if (output_node->abstract()->isa<abstract::AbstractTuple>()) {
1406       outputs->emplace_back(std::make_shared<ValueTuple>(std::vector<ValuePtr>()));
1407     } else {
1408       outputs->emplace_back(std::make_shared<ValueList>(std::vector<ValuePtr>()));
1409     }
1410     ++(*output_position);
1411     return;
1412   }
1413 
1414   // The MakeTuple/MakeSaprse node need expand and recurse.
1415   if (IsOneOfPrimitiveCNode(output_node, expand_prims)) {
1416     auto make_tuple = output_node->cast<CNodePtr>();
1417     MS_EXCEPTION_IF_NULL(make_tuple);
1418     VectorRef make_tuple_output;
1419     for (size_t i = 1; i < make_tuple->size(); i++) {
1420       ConstructOutputs(make_tuple->input(i), output_tensors, output_position, &make_tuple_output, tuple_tensors);
1421     }
1422     outputs->emplace_back(std::move(make_tuple_output));
1423     return;
1424   }
1425 
1426   // The depend node need get the real node.
1427   if (common::AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
1428     auto depend_node = output_node->cast<CNodePtr>();
1429     MS_EXCEPTION_IF_NULL(depend_node);
1430     ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs,
1431                      tuple_tensors);
1432     return;
1433   }
1434 
1435   auto outputs_num = AnfAlgo::GetOutputElementNum(output_node);
1436   // The value node uses the value to be output, to avoid the host memory of value free due to value node destruction.
1437   if (output_node->isa<ValueNode>()) {
1438     auto value = output_node->cast<ValueNodePtr>()->value();
1439     MS_EXCEPTION_IF_NULL(value);
1440     if (value->isa<ValueSequence>()) {
1441       outputs->emplace_back(value);
1442       (*output_position) += CountValueNum(value->cast<ValueSequencePtr>());
1443     } else if (outputs_num != 0) {
1444       outputs->emplace_back(value);
1445       (*output_position) += outputs_num;
1446     }
1447     // The empty value node return the empty VectorRef.
1448     return;
1449   }
1450 
1451   if (common::AnfAlgo::IsCallNode(output_node)) {
1452     auto abstract = output_node->abstract();
1453     MS_EXCEPTION_IF_NULL(abstract);
1454     outputs->emplace_back(ConstructOutputByAbstract(abstract, output_tensors, output_position, tuple_tensors));
1455     return;
1456   }
1457 
1458   auto &output_abstract = output_node->abstract();
1459   MS_EXCEPTION_IF_NULL(output_abstract);
1460   // Wrap output to VectorRef if the output is tuple.
1461   MS_LOG(DEBUG) << "output abstract:" << output_abstract->ToString();
1462   if (output_abstract->isa<abstract::AbstractSequence>()) {
1463     VectorRef output_tuple;
1464     for (size_t i = 0; i < outputs_num; ++i) {
1465       MS_LOG(DEBUG) << "output index:" << i;
1466       if (*output_position >= output_tensors.size()) {
1467         MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The output position is out of range: "
1468                                    << *output_position;
1469       }
1470       auto &output_tensor = output_tensors[*output_position];
1471       MS_EXCEPTION_IF_NULL(output_tensor);
1472       auto &tensor_shape = output_tensor->base_shape_ptr();
1473       // Restore the tuple output by the tensor of tuple.
1474       if ((tensor_shape != nullptr) && tensor_shape->isa<abstract::SequenceShape>()) {
1475         ConstructOutputByTupleTensor(output_tensor, tensor_shape->cast<abstract::SequenceShapePtr>(), &output_tuple,
1476                                      tuple_tensors);
1477       } else {
1478         output_tuple.emplace_back(output_tensor);
1479       }
1480       ++(*output_position);
1481     }
1482     outputs->emplace_back(std::move(output_tuple));
1483   } else {
1484     for (size_t i = 0; i < outputs_num; ++i) {
1485       if (*output_position >= output_tensors.size()) {
1486         MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The output position is out of range: "
1487                                    << *output_position;
1488       }
1489       outputs->emplace_back(output_tensors[*output_position]);
1490       ++(*output_position);
1491     }
1492   }
1493 }
1494 
1495 #ifdef ENABLE_DEBUGGER
SetDebuggerInit() const1496 void MindRTBackendBase::SetDebuggerInit() const {
1497   auto debugger_ = Debugger::GetInstance();
1498   auto ms_context = MsContext::GetInstance();
1499   MS_EXCEPTION_IF_NULL(ms_context);
1500   MS_EXCEPTION_IF_NULL(debugger_);
1501   debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
1502 }
1503 #endif
1504 
ConstructGraphCompilerInfo(const FuncGraphPtr & root_graph)1505 std::shared_ptr<GraphCompilerInfo> MindRTBackendBase::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
1506   MS_EXCEPTION_IF_NULL(root_graph);
1507   MS_EXCEPTION_IF_NULL(graph_compiler_);
1508 
1509   std::vector<KernelGraphPtr> graphs;
1510   std::vector<DeviceContext *> device_contexts;
1511   std::string name = "kernel_graph";
1512   size_t graph_index = 0;
1513   for (const auto &graph_id_to_context : graph_id_to_device_context_) {
1514     (void)graphs.emplace_back(graph_compiler_->Fetch(graph_id_to_context.first));
1515     (void)device_contexts.emplace_back(graph_id_to_context.second);
1516     if (graph_index == 0) {
1517       (void)name.append("_").append(std::to_string(graph_id_to_context.first));
1518     } else if (graph_index == graph_id_to_device_context_.size() - 1) {
1519       (void)name.append("-").append(std::to_string(graph_id_to_context.first));
1520     }
1521     ++graph_index;
1522   }
1523 
1524   auto parser = std::make_shared<ControlNodeParser>();
1525   const auto &root_output =
1526     common::AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
1527   auto outputs_num = common::AnfAlgo::GetAllOutputWithIndex(root_output).size();
1528   runtime::KernelMapPosition outputs_order = FetchOriginOutputOrder(root_graph->output());
1529 
1530   std::vector<std::vector<int64_t> *> tensors_mask;
1531   std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
1532   auto strategy = runtime::GraphExecutionStrategy::kPipeline;
1533   auto context_ptr = MsContext::GetInstance();
1534   MS_EXCEPTION_IF_NULL(context_ptr);
1535   if (context_ptr->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) != kOptimizeO0) {
1536     strategy = runtime::GraphExecutionStrategy::kPipelineWithExecutionOrder;
1537   }
1538   auto compile_func = [graph_compiler = this->graph_compiler_](
1539                         const GraphSegmentPtr &segment, const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes,
1540                         const DeviceContext *device_context, device::RunMode run_mode) -> KernelGraphPtr {
1541     auto graph_id = graph_compiler->CompileGraph(segment, io_nodes, device_context, run_mode, false);
1542     return graph_compiler->Fetch(graph_id);
1543   };
1544 
1545   return std::make_shared<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
1546                                              root_graph->parameters(), parser, outputs_order, outputs_num,
1547                                              root_graph->GetPositionalArgsCount(), name, false, strategy, compile_func);
1548 }
1549 
ParseControlNodes(const GraphCompilerInfo & graph_compile_info)1550 void MindRTBackendBase::ParseControlNodes(const GraphCompilerInfo &graph_compile_info) {
1551   MS_EXCEPTION_IF_NULL(graph_compiler_);
1552   MS_EXCEPTION_IF_NULL(graph_compile_info.control_node_parser_);
1553 
1554   FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
1555   for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
1556     const auto &func_graph = func_graph_to_kernel_graph_ids.first;
1557     for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) {
1558       std::vector<KernelGraphPtr> kernel_graphs;
1559       for (const auto &graph_id : sub_kernel_graphs_ids) {
1560         const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
1561         MS_EXCEPTION_IF_NULL(kernel_graph);
1562         (void)kernel_graphs.emplace_back(kernel_graph);
1563       }
1564       (void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs);
1565     }
1566   }
1567 
1568   graph_compile_info.control_node_parser_->Parse(control_nodes_, graph_compile_info.graphs_,
1569                                                  graph_compile_info.device_contexts_, root_graph_,
1570                                                  func_graph_to_kernel_graphs);
1571 }
1572 
UpdateGraphCompilerInfo(const ActorInfo & actor_info)1573 void MindRTBackendBase::UpdateGraphCompilerInfo(const ActorInfo &actor_info) {
1574   const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
1575   if (graph_iter == actor_to_graph_compiler_info_.end()) {
1576     return;
1577   }
1578   MS_EXCEPTION_IF_NULL(graph_iter->second);
1579   MS_EXCEPTION_IF_NULL(root_graph_);
1580   graph_iter->second->origin_outputs_order_ = FetchOriginOutputOrder(root_graph_->output());
1581 }
1582 }  // namespace compile
1583 }  // namespace mindspore
1584