• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "vm/backend.h"
17 
18 #include <algorithm>
19 #include <vector>
20 #include <map>
21 
22 #include "vm/transform.h"
23 #include "backend/session/session_factory.h"
24 #include "backend/optimizer/common/helper.h"
25 #include "pipeline/pynative/pynative_execute.h"
26 #include "pipeline/jit/parse/data_converter.h"
27 #include "ir/anf.h"
28 #include "pybind_api/ir/base_ref_py.h"
29 #include "utils/callbacks.h"
30 #include "utils/convert_utils.h"
31 #include "utils/log_adapter.h"
32 #include "utils/ms_utils.h"
33 #include "runtime/hardware/device_context_manager.h"
34 #include "runtime/framework/graph_compiler.h"
35 #include "utils/scoped_long_running.h"
36 #ifdef ENABLE_GE
37 #include "utils/callbacks_ge.h"
38 #endif
39 #ifdef ENABLE_DEBUGGER
40 #include "debug/debugger/debugger.h"
41 #endif
42 #ifndef ENABLE_SECURITY
43 #include "debug/data_dump/dump_json_parser.h"
44 #endif
45 #ifdef ENABLE_DUMP_IR
46 #include "debug/rdr/running_data_recorder.h"
47 #endif
48 
49 namespace mindspore {
50 namespace compile {
GetCond(const BaseRef & c,bool * const value)51 bool Backend::GetCond(const BaseRef &c, bool *const value) {
52   mindspore::ScopedLongRunning long_running;
53   return BaseRefToBool(c, value);
54 }
GetIndex(const BaseRef & c,int64_t * const value)55 bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
56 
Backend(const std::string & name)57 Backend::Backend(const std::string &name) : name_(name) {
58   MS_LOG(DEBUG) << "Select backend:" << name;
59   convert_fn_ = MsVmConvert;
60   is_multi_graph_sink_ = false;
61 }
62 
MsConvert(const GraphSegmentPtr & segment,const std::string & target)63 LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
64   MS_LOG(DEBUG) << "MsConvert";
65   MS_EXCEPTION_IF_NULL(segment);
66   MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
67   LinConvertResult result;
68   FuncGraphPtr fg;
69   AnfNodePtrList inputs;
70   AnfNodePtrList outputs;
71   std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
72   result.inputs = inputs;
73   result.outputs = outputs;
74   result.graph_id = kInvalidGraphId;
75   auto current_session = target_sess_;
76   if (target != target_device_ && !target.empty()) {
77     CreateOtherSession(target);
78     current_session = other_sess_;
79   }
80   MS_EXCEPTION_IF_NULL(current_session);
81   GraphId graph_id = current_session->CompileGraph(segment, outputs);
82   segment->graph_id_ = graph_id;
83   auto graph = current_session->GetGraph(graph_id);
84   MS_EXCEPTION_IF_NULL(graph);
85   for (auto &pre_segment : segment->pre_segments_) {
86     MS_EXCEPTION_IF_NULL(pre_segment);
87     MS_EXCEPTION_IF_NULL(target_sess_);
88     auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
89     if (pre_graph == nullptr) {
90       MS_EXCEPTION_IF_NULL(other_sess_);
91       pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
92     }
93     MS_EXCEPTION_IF_NULL(pre_graph);
94     pre_graph->AddPostGraph(graph);
95     graph->AddPreGraph(pre_graph);
96     MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id;
97   }
98 
99   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
100     MS_LOG(INFO) << "PrecompileOnly, stop run graph";
101     return result;
102   }
103   auto ms_context = MsContext::GetInstance();
104   const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
105   if (!pynative_mode || target != "Ascend") {
106     if (target != target_device_ && !target.empty()) {
107       MS_EXCEPTION_IF_NULL(other_sess_);
108       other_sess_->BuildGraph(graph_id);
109     } else if (!is_multi_graph_sink_) {
110       MS_EXCEPTION_IF_NULL(target_sess_);
111       target_sess_->BuildGraph(graph_id);
112     }
113   }
114   result.run = std::make_shared<RunFunc>(
115     [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
116   MS_EXCEPTION_IF_NULL(result.run);
117 
118   result.simu_run = std::make_shared<RunFunc>(
119     [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id); });
120   MS_EXCEPTION_IF_NULL(result.simu_run);
121   result.graph_id = graph_id;
122 
123   graph_id_map_[graph_id] = result;
124   return result;
125 }
126 
127 // compile set input output
MsSimuRunGraph(const GraphId & g)128 VectorRef MsBackend::MsSimuRunGraph(const GraphId &g) {
129   MS_LOG(DEBUG) << "Set graph input:" << g;
130   std::vector<BaseRef> outputs;
131   (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
132                        [](const AnfNodePtr &v) { return v; });
133   return VectorRef(outputs);
134 }
135 
136 namespace {
PushInputTensor(const BaseRef & arg,std::vector<tensor::TensorPtr> * inputs)137 void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
138   MS_EXCEPTION_IF_NULL(inputs);
139   if (utils::isa<tensor::TensorPtr>(arg)) {
140     auto value = utils::cast<tensor::TensorPtr>(arg);
141     inputs->push_back(value);
142   } else if (utils::isa<ValuePtr>(arg)) {
143     auto value = utils::cast<ValuePtr>(arg);
144     MS_EXCEPTION_IF_NULL(value);
145     if (value->isa<ValueTuple>()) {
146       auto value_tuple = value->cast<ValueTuplePtr>();
147       MS_EXCEPTION_IF_NULL(value_tuple);
148       auto tuple_value = value_tuple->value();
149       (void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
150                            [](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
151     } else if (value->isa<Scalar>()) {
152       tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
153       inputs->push_back(scalar_tensor);
154     } else if (value->isa<Monad>()) {
155       // If value is a monad, replace it with an unused tensor.
156       inputs->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
157     } else {
158       inputs->push_back(value->cast<tensor::TensorPtr>());
159     }
160   } else if (utils::isa<PyObjectRef>(arg)) {
161     auto value = utils::cast<PyObjectRef>(arg).object_;
162     inputs->push_back(py::cast<tensor::TensorPtr>(value));
163   } else if (utils::isa<VectorRefPtr>(arg)) {
164     const auto &args_new = utils::cast<VectorRef>(arg);
165     for (const auto &v : args_new) {
166       PushInputTensor(v, inputs);
167     }
168   } else {
169     MS_LOG(WARNING) << "Invalid input type.";
170   }
171 }
172 
173 // Insert the front_node related tensor in the input_tensor.
PushTensor(const VectorRef & args,const std::vector<AnfNodePtr> & parameters,const AnfNodePtr & front_node,std::vector<tensor::TensorPtr> * input_tensor)174 void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> &parameters, const AnfNodePtr &front_node,
175                 std::vector<tensor::TensorPtr> *input_tensor) {
176   const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
177   if (iter == parameters.end()) {
178     (void)((*input_tensor).emplace_back(nullptr));
179     return;
180   }
181   auto position = iter - parameters.begin();
182   PushInputTensor(args[position], input_tensor);
183 }
184 
UpdateOutputAbstract(const KernelGraphPtr & kernel_graph,OpRunInfo * op_run_info)185 void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, OpRunInfo *op_run_info) {
186   MS_EXCEPTION_IF_NULL(kernel_graph);
187   MS_EXCEPTION_IF_NULL(op_run_info);
188   const auto &kernels = kernel_graph->execution_order();
189   for (const auto &kernel : kernels) {
190     MS_EXCEPTION_IF_NULL(kernel);
191     if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
192       op_run_info->abstract = kernel->abstract();
193     }
194   }
195 }
196 
CreateOutputTensor(const AnfNodePtr & output_node,size_t output_index)197 TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index) {
198   MS_EXCEPTION_IF_NULL(output_node);
199   // Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
200   // when infer type is not equal to device type.
201   auto type_id = AnfAlgo::GetOutputInferDataType(output_node, output_index);
202   std::vector<int64_t> temp_shape;
203   const auto &shape = AnfAlgo::GetOutputInferShape(output_node, output_index);
204   (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
205   auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
206   tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
207 
208   // Put device tensor into host tensor.
209   const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
210   MS_EXCEPTION_IF_NULL(device_tensor);
211   tensor->set_device_address(device_tensor);
212 
213   auto ms_context = MsContext::GetInstance();
214   MS_EXCEPTION_IF_NULL(ms_context);
215   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
216     // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
217     // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
218     tensor->data_sync(false);
219   }
220 
221   return tensor;
222 }
223 
UpdateOutput(const std::vector<session::KernelWithIndex> & output_nodes,VectorRef * const outputs)224 void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) {
225   MS_EXCEPTION_IF_NULL(outputs);
226   for (auto &item_with_index : output_nodes) {
227     MS_EXCEPTION_IF_NULL(item_with_index.first);
228     // if is graph return nothing ,the function should return a null anylist
229     if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
230       continue;
231     }
232     outputs->emplace_back(CreateOutputTensor(item_with_index.first, item_with_index.second));
233   }
234 }
235 
UpdateOutputDeviceAddress(const std::vector<session::KernelWithIndex> & output_nodes,const DeviceContext * device_context)236 void UpdateOutputDeviceAddress(const std::vector<session::KernelWithIndex> &output_nodes,
237                                const DeviceContext *device_context) {
238   for (auto &item_with_index : output_nodes) {
239     auto &output_node = item_with_index.first;
240     auto output_index = item_with_index.second;
241     if (output_node != nullptr) {
242       if (!AnfAlgo::OutputAddrExist(output_node, output_index, false)) {
243         continue;
244       }
245       const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
246 
247       if ((device_tensor == nullptr) || (device_tensor->GetPtr() == nullptr)) {
248         continue;
249       }
250 
251       MS_EXCEPTION_IF_NULL(device_context);
252       auto new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
253                                                                    device_tensor->format(), device_tensor->type_id());
254       MS_EXCEPTION_IF_NULL(new_device_tensor);
255       new_device_tensor->set_original_ref_count(device_tensor->original_ref_count());
256       new_device_tensor->ResetRefCount();
257       AnfAlgo::SetOutputAddr(new_device_tensor, output_index, output_node.get());
258     }
259   }
260 }
261 
UpdateInputDeviceAddress(const KernelGraphPtr & graph)262 void UpdateInputDeviceAddress(const KernelGraphPtr &graph) {
263   MS_EXCEPTION_IF_NULL(graph);
264   for (const auto &node : graph->input_nodes()) {
265     MS_EXCEPTION_IF_NULL(node);
266     if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
267       AnfAlgo::SetOutputAddr(nullptr, 0, node.get());
268     }
269   }
270 }
271 }  // namespace
272 
MsRunGraph(const GraphId & g,const VectorRef & args,const std::string & target)273 VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
274   MS_LOG(DEBUG) << "Start ms graph run:" << args.size() << ", g:" << g;
275   // Run graph
276   std::vector<tensor::TensorPtr> inputs;
277   for (const auto &arg : args) {
278     PushInputTensor(arg, &inputs);
279   }
280 
281   VectorRef outputs;
282   // Call ms RunGraphAsync or RunOpsInGraph (graphId, input ,output)
283   const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_);
284   MS_EXCEPTION_IF_NULL(exe_session);
285   auto ms_context = MsContext::GetInstance();
286   const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
287   if (pynative_mode) {
288     exe_session->RunOpsInGraph(g, inputs, &outputs);
289   } else {
290     exe_session->RunGraphAsync(g, inputs, &outputs);
291   }
292 
293   MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
294   return outputs;
295 }
296 
MsBackend(const std::string & name,const std::string & target,uint32_t device_id)297 MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
298   convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
299   target_sess_ = session::SessionFactory::Get().Create(target);
300   if (target_sess_ == nullptr) {
301     MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
302   }
303   target_sess_->Init(device_id);
304 #ifndef ENABLE_SECURITY
305   target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
306 #endif
307   target_device_ = target;
308 }
309 
CreateOtherSession(const std::string & target)310 void MsBackend::CreateOtherSession(const std::string &target) {
311   if (other_sess_ != nullptr && other_device_ == target) {
312     return;
313   }
314   other_sess_ = session::SessionFactory::Get().Create(target);
315   if (other_sess_ == nullptr) {
316     MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
317   }
318   auto context_ptr = MsContext::GetInstance();
319   MS_EXCEPTION_IF_NULL(context_ptr);
320   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
321   other_sess_->Init(device_id);
322 #ifndef ENABLE_SECURITY
323   other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
324 #endif
325   other_device_ = target;
326 }
327 
CompileGraph(NotNull<FuncGraphPtr> fg)328 GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) {
329   MS_EXCEPTION_IF_NULL(target_sess_);
330   return target_sess_->CompileGraph(fg);
331 }
332 
RunGraph(GraphId graph_id,const VectorRef & args)333 VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
334 
ClearSessionGraphs()335 void MsBackend::ClearSessionGraphs() {
336   if (target_sess_ != nullptr) {
337     target_sess_->ClearGraph();
338   }
339 }
340 
341 #ifdef ENABLE_DEBUGGER
SetDebugger()342 void MsBackend::SetDebugger() {
343   MS_EXCEPTION_IF_NULL(target_sess_);
344   target_sess_->SetDebugger();
345 }
346 #endif
347 
MindRTBackend(const std::string & backend_name,const std::string & device_name,uint32_t device_id)348 MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
349     : Backend(backend_name), device_name_(device_name) {
350   root_graph_ = nullptr;
351   auto ms_context = MsContext::GetInstance();
352   const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
353   auto &cut_list = pynative_mode ? compile::control_ops : GetMsNonlinearOps();
354 
355   graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
356   graph_compiler_ = std::make_shared<GraphCompiler>();
357 
358   const auto &device_context =
359     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
360   device_context->Initialize();
361   device_id_ = device_context->device_context_key().device_id_;
362 #ifdef ENABLE_DEBUGGER
363   SetDebuggerInit();
364 #endif
365   runtime::GraphScheduler::GetInstance().Initialize();
366 }
367 
CompileGraphs(const FuncGraphPtr & func_graph)368 const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
369   MS_EXCEPTION_IF_NULL(graph_compiler_);
370   MS_EXCEPTION_IF_NULL(func_graph);
371   auto root_graph = WrapPrimitives(func_graph);
372   MS_EXCEPTION_IF_NULL(root_graph);
373   root_graph_ = root_graph.get();
374   // Register a summary callback function, which is called in the final stages of summary.
375   graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
376 
377   auto context_ptr = MsContext::GetInstance();
378   MS_EXCEPTION_IF_NULL(context_ptr);
379   ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
380   real_execution_mode_ = ms_execution_mode_;
381 
382   // Compile root graph.
383   graph_id_to_device_context_.clear();
384   control_nodes_.clear();
385   CompileGraph(root_graph);
386 
387   // Compile sub graphs.
388   MS_EXCEPTION_IF_NULL(root_graph->manager());
389   FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
390   for (auto sub_graph : sub_graphs) {
391     if (sub_graph != func_graph && sub_graph != nullptr) {
392       CompileGraph(sub_graph);
393     }
394   }
395 
396   // Construct the graph compiler info.
397   auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
398 
399   if (real_execution_mode_ == kGraphMode) {
400     // Transform graph to actor DAG, and schedule the actor DAG.
401     const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
402     runtime::GraphScheduler::GetInstance().Schedule(actor_set);
403   }
404   MS_EXCEPTION_IF_NULL(graph_compiler_info);
405   const ActorInfo &actor_info = graph_compiler_info->name_;
406   (void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
407   return actor_info;
408 }
409 
CompileGraph(const FuncGraphPtr & func_graph)410 void MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
411   MS_EXCEPTION_IF_NULL(func_graph);
412   MS_EXCEPTION_IF_NULL(graph_partition_);
413   MS_EXCEPTION_IF_NULL(graph_compiler_);
414 
415   bool contain_multi_target = false;
416   // Split graph to segments.
417   const auto &segments = graph_partition_->Partition(func_graph, &contain_multi_target);
418   MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size();
419   auto context_ptr = MsContext::GetInstance();
420   MS_EXCEPTION_IF_NULL(context_ptr);
421 
422   // Foreach the segments to compile graph.
423   for (const auto &segment : segments) {
424     MS_EXCEPTION_IF_NULL(segment);
425     // Compile the normal nodes, which doesn't contain the cut node.
426     if (segment->nodes_.size() == 0) {
427       MS_LOG(EXCEPTION) << "The segments size is 0.";
428     }
429     if (!segment->is_cut_) {
430       MS_EXCEPTION_IF_NULL(segment->nodes_[0]);
431       MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope();
432 
433       // Get the device context.
434       const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]);
435       const auto &device_context =
436         device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_});
437       MS_EXCEPTION_IF_NULL(device_context);
438       device_context->Initialize();
439 
440       // Transform nodes to inputs and outputs.
441       FuncGraphPtr fg;
442       AnfNodePtrList inputs;
443       AnfNodePtrList outputs;
444       std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
445 
446       // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
447       if (contain_multi_target && ms_execution_mode_ == kPynativeMode) {
448         real_execution_mode_ = kGraphMode;
449         context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
450       }
451 
452       // Compile graph.
453       auto graph_id = graph_compiler_->CompileGraph(segment->nodes_, outputs, device_context);
454 
455       if (ms_execution_mode_ != real_execution_mode_) {
456         context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
457       }
458 
459       graph_id_to_device_context_[graph_id] = device_context;
460     } else {
461       // Compile the cut node.
462       auto cut_node = segment->nodes_[0];
463       MS_EXCEPTION_IF_NULL(cut_node);
464       MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope();
465       control_nodes_.push_back(cut_node);
466     }
467   }
468 }
469 
CompileGraph(const OpRunInfo & op_run_info,const GraphInfo & graph_info,const std::vector<int64_t> * tensors_mask,std::vector<tensor::TensorPtr> * input_tensors)470 const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
471                                              const std::vector<int64_t> *tensors_mask,
472                                              std::vector<tensor::TensorPtr> *input_tensors) {
473   MS_EXCEPTION_IF_NULL(graph_compiler_);
474   // Get the device context.
475   const auto &device_context =
476     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
477   MS_EXCEPTION_IF_NULL(device_context);
478   device_context->Initialize();
479 
480   bool single_op_cache_hit = true;
481   auto graph_id = graph_compiler_->CompileGraph(op_run_info, graph_info, tensors_mask, input_tensors,
482                                                 &single_op_cache_hit, device_context);
483   // The actor set name: graph_id + single operator name.
484   std::string actor_info = std::to_string(graph_id) + "_" + op_run_info.op_name;
485   if (single_op_cache_hit) {
486     auto iter = actor_to_graph_compiler_info_.find(actor_info);
487     if (iter == actor_to_graph_compiler_info_.end()) {
488       MS_LOG(EXCEPTION) << "Can not find graph compiler info for actor set: " << actor_info;
489     }
490     return iter->first;
491   }
492 
493   graph_info_to_device_context_.clear();
494   graph_info_to_device_context_[graph_info] = device_context;
495 
496   auto context_ptr = MsContext::GetInstance();
497   MS_EXCEPTION_IF_NULL(context_ptr);
498   bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
499   auto graph_compiler_info = ConstructGraphCompilerInfo(actor_info, tensors_mask, input_tensors, !enable_cache);
500   const auto actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
501   runtime::GraphScheduler::GetInstance().Schedule(actor_set);
502   MS_EXCEPTION_IF_NULL(graph_compiler_info);
503   graph_compiler_info->input_tensors_.clear();
504 
505   auto ret = actor_to_graph_compiler_info_.emplace(actor_info, std::move(graph_compiler_info));
506   return ret.first->first;
507 }
508 
509 namespace {
GetControlOpInput(const std::shared_ptr<GraphCompiler> & graph_compiler,const CNodePtr & front_cnode,const CNodePtr & backend_cnode,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output_map,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputTensorInfo * input_tensor_info,VectorRef * args)510 void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, const CNodePtr &front_cnode,
511                        const CNodePtr &backend_cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
512                        const std::map<AnfNodePtr, size_t> &parameter_index,
513                        const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
514                        VectorRef *args) {
515   MS_EXCEPTION_IF_NULL(front_cnode);
516   MS_EXCEPTION_IF_NULL(backend_cnode);
517   MS_EXCEPTION_IF_NULL(graph_compiler);
518   MS_EXCEPTION_IF_NULL(args);
519   size_t input_index = 0;
520   auto inputs = front_cnode->inputs();
521   for (size_t i = 1; i < inputs.size(); i++) {
522     const auto &input_node = inputs[i];
523     MS_EXCEPTION_IF_NULL(input_node);
524     auto kernel_with_index = AnfAlgo::VisitKernel(input_node, 0);
525     auto real_input = kernel_with_index.first;
526     MS_EXCEPTION_IF_NULL(real_input);
527 
528     if (!real_input->isa<ValueNode>()) {
529       TensorPtr tensor = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
530                                                                        graph_inputs, input_tensor_info, input_index);
531       MS_EXCEPTION_IF_NULL(tensor);
532       args->emplace_back(tensor);
533       input_index++;
534       continue;
535     }
536 
537     // Get value from value node.
538     const auto &value_node = real_input->cast<ValueNodePtr>();
539     MS_EXCEPTION_IF_NULL(value_node);
540     const auto &value = value_node->value();
541     MS_EXCEPTION_IF_NULL(value);
542 
543     if (value->isa<ValueSequeue>()) {
544       const auto &value_sequeue = value->cast<ValueSequeuePtr>();
545       MS_EXCEPTION_IF_NULL(value_sequeue);
546       input_index += value_sequeue->size();
547     } else {
548       input_index++;
549     }
550 
551     args->emplace_back(value);
552   }
553 }
554 
PlantTensorTupleToVector(const py::tuple & tuple_inputs,std::vector<tensor::TensorPtr> * tensors)555 void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
556   MS_EXCEPTION_IF_NULL(tensors);
557   for (const auto &input_object : tuple_inputs) {
558     if (!py::isinstance<tensor::Tensor>(input_object)) {
559       MS_LOG(EXCEPTION) << "The input object is not a tensor!";
560     }
561     auto tensor = py::cast<tensor::TensorPtr>(input_object);
562     MS_EXCEPTION_IF_NULL(tensor);
563     (void)tensors->emplace_back(tensor);
564   }
565 }
566 
ConvertValueTupleToTensor(const py::object & input_object,std::vector<tensor::TensorPtr> * tensors)567 void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
568   MS_EXCEPTION_IF_NULL(tensors);
569   ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
570   MS_EXCEPTION_IF_NULL(input_value);
571   if (!input_value->isa<ValueTuple>()) {
572     MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
573   }
574 
575   auto value_tuple = input_value->cast<ValueTuplePtr>();
576   MS_EXCEPTION_IF_NULL(value_tuple);
577   tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
578   MS_EXCEPTION_IF_NULL(tensor_ptr);
579   (void)tensors->emplace_back(tensor_ptr);
580 }
581 
ConvertMultiPyObjectToTensor(const py::object & input_object,std::vector<tensor::TensorPtr> * tensors)582 void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
583   MS_EXCEPTION_IF_NULL(tensors);
584   if (!py::isinstance<py::tuple>(input_object)) {
585     MS_LOG(EXCEPTION) << "The input should be a tuple!";
586   }
587 
588   auto inputs = py::cast<py::tuple>(input_object);
589   if (inputs.empty()) {
590     MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
591   }
592 
593   if (py::isinstance<tensor::Tensor>(inputs[0])) {
594     PlantTensorTupleToVector(inputs, tensors);
595   } else {
596     ConvertValueTupleToTensor(input_object, tensors);
597   }
598 }
599 
RunControlOperator(const std::shared_ptr<GraphCompiler> & graph_compiler,const KernelGraphPtr & graph,const CNodePtr & kernel,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output_map,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputTensorInfo * input_tensor_info,VectorRef * op_outputs)600 void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, const KernelGraphPtr &graph,
601                         const CNodePtr &kernel, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
602                         const std::map<AnfNodePtr, size_t> &parameter_index,
603                         const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
604                         VectorRef *op_outputs) {
605   MS_EXCEPTION_IF_NULL(graph);
606   MS_EXCEPTION_IF_NULL(kernel);
607   MS_EXCEPTION_IF_NULL(op_outputs);
608   AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
609   MS_EXCEPTION_IF_NULL(front_node);
610   if (!front_node->isa<CNode>()) {
611     MS_LOG(EXCEPTION) << "The front node of bprop_cut is not CNode";
612   }
613   CNodePtr cnode = front_node->cast<CNodePtr>();
614   MS_EXCEPTION_IF_NULL(cnode);
615   const std::vector<AnfNodePtr> &node_inputs = cnode->inputs();
616   if (node_inputs.empty()) {
617     MS_LOG(EXCEPTION) << "The inputs of node[" << cnode->fullname_with_scope() << "] is empty";
618   }
619 
620   const AnfNodePtr &fn = node_inputs.at(0);
621   if (!IsValueNode<Primitive>(fn)) {
622     MS_LOG(EXCEPTION) << "The input[0] of kernel[" << kernel->fullname_with_scope()
623                       << "] is not a ValueNode of Primitive";
624   }
625 
626   PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
627   MS_EXCEPTION_IF_NULL(prim);
628   if (prim->name() == kBpropCutOpName) {
629     VectorRef args;
630     GetControlOpInput(graph_compiler, cnode, kernel, op_output_map, parameter_index, graph_inputs, input_tensor_info,
631                       &args);
632     BaseRef out = prim->RunHookFunction(args);
633     // Convert pyobject output to tensor.
634     if (utils::isa<PyObjectRef>(out)) {
635       PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
636       auto out_py_tuple = py_ref.object_;
637       std::vector<tensor::TensorPtr> output_tensors;
638       ConvertMultiPyObjectToTensor(out_py_tuple, &output_tensors);
639       (void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
640                            [](tensor::TensorPtr &tensor) { return std::move(tensor); });
641     }
642   }
643 }
644 
TensorValueToVector(const ValuePtr & value,VectorRef * outputs)645 void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
646   MS_EXCEPTION_IF_NULL(value);
647   MS_EXCEPTION_IF_NULL(outputs);
648   if (value->isa<ValueTuple>()) {
649     auto value_tuple = value->cast<ValueTuplePtr>();
650     MS_EXCEPTION_IF_NULL(value_tuple);
651     for (size_t i = 0; i < value_tuple->size(); ++i) {
652       ValuePtr element = value_tuple->value()[i];
653       MS_EXCEPTION_IF_NULL(element);
654       if (element->isa<tensor::Tensor>()) {
655         auto tensor = element->cast<tensor::TensorPtr>();
656         MS_EXCEPTION_IF_NULL(tensor);
657         outputs->emplace_back(tensor);
658       } else if (element->isa<ValueTuple>()) {
659         TensorValueToVector(element, outputs);
660       }
661     }
662   } else if (value->isa<tensor::Tensor>()) {
663     auto tensor = value->cast<tensor::TensorPtr>();
664     MS_EXCEPTION_IF_NULL(tensor);
665     outputs->emplace_back(tensor);
666   }
667 }
668 
IsGraphOutputValueNodeOrParameter(const AnfNodePtr & graph_output,const VectorRef & args,VectorRef * outputs)669 bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const VectorRef &args, VectorRef *outputs) {
670   MS_EXCEPTION_IF_NULL(graph_output);
671   MS_EXCEPTION_IF_NULL(outputs);
672   if (graph_output->isa<ValueNode>()) {
673     MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
674     VectorRef output_tmp;
675     ValuePtr value = GetValueNode(graph_output);
676     TensorValueToVector(value, &output_tmp);
677     if (output_tmp.size() == 1) {
678       *outputs = std::move(output_tmp);
679     } else if (output_tmp.size() > 1) {
680       outputs->emplace_back(output_tmp);
681     } else {
682       MS_LOG(EXCEPTION) << "Output is empty!";
683     }
684     return true;
685   }
686 
687   if (graph_output->isa<Parameter>()) {
688     MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
689     // Find the right parameter as ret_val.
690     auto func_graph = graph_output->func_graph();
691     MS_EXCEPTION_IF_NULL(func_graph);
692     auto params = func_graph->parameters();
693     if (args.size() != params.size()) {
694       MS_LOG(EXCEPTION) << "Input size " << args.size() << " not equal to graph input size " << params.size();
695     }
696 
697     auto it = std::find(params.begin(), params.end(), graph_output);
698     if (it == params.end()) {
699       MS_EXCEPTION(UnknownError) << "When graph output is Parameter,  it should be found in graph parameters";
700     }
701     size_t index = it - params.cbegin();
702     if (index >= args.size()) {
703       MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size();
704     }
705 
706     outputs->emplace_back(args[index]);
707     return true;
708   }
709   return false;
710 }
711 }  // namespace
712 
RunGraphBySingleOp(const std::vector<KernelGraphPtr> & graphs,const std::vector<std::vector<tensor::TensorPtr>> & inputs,VectorRef * outputs)713 void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
714                                        const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
715   MS_EXCEPTION_IF_NULL(graph_compiler_);
716   for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
717     const auto &graph = graphs[graph_index];
718     MS_EXCEPTION_IF_NULL(graph);
719     std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
720     std::map<AnfNodePtr, size_t> parameter_index;
721     GraphOutputInfo graph_output_info;
722     graph_output_info.graph_outputs = outputs;
723     graph_compiler_->GetParamAndOutputIndex(graph, inputs[graph_index], outputs, &parameter_index,
724                                             &graph_output_info.output_indexes);
725 
726     std::map<KernelWithIndex, size_t> cnode_ref_count;
727     auto iter = cnode_ref_counts_.find(graph->graph_id());
728     if (iter == cnode_ref_counts_.end()) {
729       graph_compiler_->CalculateRefCount(graph, &cnode_ref_count);
730       (void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count);
731     } else {
732       cnode_ref_count = iter->second;
733     }
734 
735     // Clear bucket resources every step
736     if (graph->is_bprop()) {
737       graph_compiler_->ClearAllBucket(graph->graph_id());
738     }
739 
740     for (const auto &kernel : graph->execution_order()) {
741       InputTensorInfo input_tensor_info;
742       VectorRef op_outputs;
743 
744       if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
745         OpRunInfo op_run_info;
746         GraphInfo graph_info;
747         graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
748                                                  &input_tensor_info);
749         graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info,
750                                                         &graph_info);
751 
752         const ActorInfo &actor_info = CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask,
753                                                    &input_tensor_info.input_tensors);
754         RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
755                  &op_outputs);
756       } else {
757         RunControlOperator(graph_compiler_, graph, kernel, op_output_map, parameter_index, inputs[graph_index],
758                            &input_tensor_info, &op_outputs);
759       }
760 
761       graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);
762 
763       graph_output_info.graph_output_tensors.clear();
764       graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
765 
766       // Save grad node to Bucket
767       if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel))) {
768         graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
769       }
770     }
771   }
772 }
773 
RunGraph(const ActorInfo & actor_info,const VectorRef & args,VectorRef * outputs)774 void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
775   MS_LOG(INFO) << "Run actor begin, actor name: " << actor_info;
776   MS_EXCEPTION_IF_NULL(root_graph_);
777   if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
778     return;
779   }
780 
781   const auto &context_ptr = MsContext::GetInstance();
782   MS_EXCEPTION_IF_NULL(context_ptr);
783   if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
784     MS_LOG(INFO) << "PrecompileOnly, stop run graph";
785     return;
786   }
787 
788   // Fetch the graph compiler info.
789   const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
790   if (graph_iter == actor_to_graph_compiler_info_.end()) {
791     MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
792   }
793   MS_EXCEPTION_IF_NULL(graph_iter->second);
794   const auto &graph_compiler_info = *(graph_iter->second);
795   const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
796 
797   // Transform args to input tensors.
798   // Input tensors of the graph.
799   std::vector<std::vector<tensor::TensorPtr>> input_tensors;
800   for (const auto &kernel_graph : graph_compiler_info.graphs_) {
801     std::vector<tensor::TensorPtr> input_tensor;
802     MS_EXCEPTION_IF_NULL(kernel_graph);
803     for (const auto &input_node : kernel_graph->input_nodes()) {
804       const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
805       PushTensor(args, origin_parameters, front_node, &input_tensor);
806     }
807     (void)input_tensors.emplace_back(input_tensor);
808   }
809 
810   // Input tensors of the control node.
811   std::vector<tensor::TensorPtr> input_tensor;
812   MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
813   // Get inputs of control node which come from the host actor.
814   const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
815   for (const auto &parameter : control_node_parameters) {
816     PushTensor(args, origin_parameters, parameter, &input_tensor);
817   }
818   (void)input_tensors.emplace_back(input_tensor);
819 
820   // Run in the pynative mode.
821   MS_EXCEPTION_IF_NULL(outputs);
822   // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
823   if (real_execution_mode_ == kPynativeMode) {
824     RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
825     return;
826   }
827   // Run actor DAG.
828   mindspore::ScopedLongRunning long_running;
829   const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
830   MS_EXCEPTION_IF_NULL(actor_set);
831   if (!runtime::GraphScheduler::GetInstance().Run(actor_set, input_tensors)) {
832 #ifdef ENABLE_DUMP_IR
833     mindspore::RDR::TriggerAll();
834 #endif
835     MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
836   }
837 
838   if (graph_compiler_info.device_contexts_.empty()) {
839     MS_LOG(EXCEPTION) << "The device contexts is empty.";
840   }
841   // Sync device stream.
842   const auto &first_device_context = graph_compiler_info.device_contexts_[0];
843   MS_EXCEPTION_IF_NULL(first_device_context);
844   if (!first_device_context->SyncStream()) {
845     MS_LOG(EXCEPTION) << "Sync stream failed:" << first_device_context->device_context_key().ToString();
846   }
847   for (size_t i = 0; i < graph_compiler_info.device_contexts_.size(); ++i) {
848     const auto &device_context = graph_compiler_info.device_contexts_[i];
849     MS_EXCEPTION_IF_NULL(device_context);
850     if ((device_context != first_device_context) && (!device_context->SyncStream())) {
851       MS_LOG(EXCEPTION) << "Sync stream failed:" << device_context->device_context_key().ToString();
852     }
853   }
854 
855   // Fetch outputs.
856   MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
857   auto &output_tensors = actor_set->output_actor_->outputs();
858   if (output_tensors.size() > 0) {
859     size_t output_position = 0;
860     ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
861   }
862 
863   MS_EXCEPTION_IF_NULL(graph_compiler_);
864   graph_compiler_->Summary(graph_compiler_info.graphs_);
865 
866   // Update device address for output node of graph.
867   actor_set->output_actor_->UpdateOutputDeviceAddress();
868   MS_LOG(INFO) << "Run actor end, actor name: " << actor_info;
869 }
870 
ConstructOutputs(const AnfNodePtr & output_node,const std::vector<tensor::TensorPtr> & output_tensors,size_t * output_position,VectorRef * outputs)871 void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
872                                      const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position,
873                                      VectorRef *outputs) {
874   MS_EXCEPTION_IF_NULL(output_node);
875   MS_EXCEPTION_IF_NULL(outputs);
876   MS_EXCEPTION_IF_NULL(output_position);
877   // The makeTuple node need expand and recurse.
878   if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
879     auto make_tuple = output_node->cast<CNodePtr>();
880     MS_EXCEPTION_IF_NULL(make_tuple);
881     VectorRef make_tuple_output;
882     for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
883       ConstructOutputs(make_tuple->input(i), output_tensors, output_position, &make_tuple_output);
884     }
885     outputs->emplace_back(std::move(make_tuple_output));
886     return;
887   }
888 
889   // The depend node need get the real node.
890   if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
891     auto depend_node = output_node->cast<CNodePtr>();
892     MS_EXCEPTION_IF_NULL(depend_node);
893     ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs);
894     return;
895   }
896 
897   auto outputs_num = AnfAlgo::GetOutputTensorNum(output_node);
898   // The value node uses the value to be output, to avoid the host memory of value free due to value node destruction.
899   if (output_node->isa<ValueNode>()) {
900     auto value = output_node->cast<ValueNodePtr>()->value();
901     MS_EXCEPTION_IF_NULL(value);
902     if (value->isa<ValueTuple>()) {
903       outputs->emplace_back(value);
904       (*output_position) += CountValueNum(value->cast<ValueTuplePtr>());
905     } else if (outputs_num != 0) {
906       outputs->emplace_back(value);
907       (*output_position) += outputs_num;
908     }
909     // The empty value node return the empty VectorRef.
910     return;
911   }
912 
913   auto &output_abstract = output_node->abstract();
914   MS_EXCEPTION_IF_NULL(output_abstract);
915   // Wrap output to VectorRef if the output is tuple.
916   if (output_abstract->isa<abstract::AbstractTuple>()) {
917     VectorRef output_tuple;
918     for (size_t i = 0; i < outputs_num; ++i) {
919       if (*output_position >= output_tensors.size()) {
920         MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
921       }
922       output_tuple.emplace_back(std::move(output_tensors[*output_position]));
923       ++(*output_position);
924     }
925     outputs->emplace_back(std::move(output_tuple));
926   } else {
927     for (size_t i = 0; i < outputs_num; ++i) {
928       if (*output_position >= output_tensors.size()) {
929         MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
930       }
931       outputs->emplace_back(std::move(output_tensors[*output_position]));
932       ++(*output_position);
933     }
934   }
935 }
936 
937 #ifdef ENABLE_DEBUGGER
SetDebuggerInit()938 void MindRTBackend::SetDebuggerInit() {
939   auto debugger_ = Debugger::GetInstance();
940   auto ms_context = MsContext::GetInstance();
941   MS_EXCEPTION_IF_NULL(ms_context);
942   debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
943 }
944 #endif
945 
ConstructGraphCompilerInfo(const FuncGraphPtr & root_graph)946 std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
947   MS_EXCEPTION_IF_NULL(root_graph);
948   MS_EXCEPTION_IF_NULL(graph_compiler_);
949 
950   std::vector<KernelGraphPtr> graphs;
951   std::vector<DeviceContext *> device_contexts;
952   std::string name = "kernel_graph";
953   for (const auto &graph_id_to_context : graph_id_to_device_context_) {
954     (void)graphs.emplace_back(graph_compiler_->Fetch(graph_id_to_context.first));
955     (void)device_contexts.emplace_back(graph_id_to_context.second);
956     (void)name.append("_").append(std::to_string(graph_id_to_context.first));
957   }
958 
959   auto parser = std::make_shared<ControlNodeParser>();
960   parser->Parse(control_nodes_, graphs, device_contexts, root_graph);
961 
962   runtime::KernelMapPosition outputs_order;
963   size_t outputs_num = 0;
964   const auto &root_output =
965     AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
966   size_t position = 0;
967   auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
968   if (runtime::IsCallNode(root_output)) {
969     std::vector<AnfNodePtr> call_nodes;
970     size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes);
971     for (size_t i = 0; i < call_output_num; ++i) {
972       (void)outputs.emplace_back(root_output, i);
973     }
974   }
975   outputs_num = outputs.size();
976   for (const auto &output : outputs) {
977     if (outputs_order.count(output) == 0) {
978       outputs_order[output] = {position++};
979     } else {
980       (void)outputs_order[output].emplace_back(position++);
981     }
982   }
983 
984   std::vector<std::vector<int64_t> *> tensors_mask;
985   std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
986   return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
987                                              root_graph->parameters(), parser, outputs_order, outputs_num, name, false,
988                                              runtime::GraphExecutionStrategy::kPipeline);
989 }
990 
ConstructGraphCompilerInfo(const ActorInfo & actor_info,const std::vector<int64_t> * tensors_mask,const std::vector<tensor::TensorPtr> * input_tensors,bool need_erase)991 std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
992   const ActorInfo &actor_info, const std::vector<int64_t> *tensors_mask,
993   const std::vector<tensor::TensorPtr> *input_tensors, bool need_erase) {
994   std::vector<KernelGraphPtr> graphs;
995   std::vector<DeviceContext *> device_contexts;
996   runtime::KernelMapPosition outputs_order;
997   size_t position = 0;
998   MS_EXCEPTION_IF_NULL(graph_compiler_);
999   for (const auto &graph_info_to_context : graph_info_to_device_context_) {
1000     const auto &graph = graph_compiler_->Fetch(graph_info_to_context.first);
1001     MS_EXCEPTION_IF_NULL(graph);
1002     (void)graphs.emplace_back(graph);
1003     (void)device_contexts.emplace_back(graph_info_to_context.second);
1004 
1005     auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
1006     for (const auto &output : outputs) {
1007       if (outputs_order.count(output) == 0) {
1008         outputs_order[output] = {position++};
1009       } else {
1010         (void)outputs_order[output].emplace_back(position++);
1011       }
1012     }
1013   }
1014 
1015   std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(tensors_mask));
1016   std::vector<std::vector<TensorPtr> *> input_tensors_list(1,
1017                                                            const_cast<std::vector<tensor::TensorPtr> *>(input_tensors));
1018   auto parser = std::make_shared<ControlNodeParser>();
1019   return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
1020                                              std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), parser,
1021                                              outputs_order, outputs_order.size(), actor_info, need_erase,
1022                                              runtime::GraphExecutionStrategy::kStep);
1023 }
1024 
EraseSingleOpCache(const ActorInfo & actor_info,const KernelGraphPtr & graph)1025 void MindRTBackend::EraseSingleOpCache(const ActorInfo &actor_info, const KernelGraphPtr &graph) {
1026   MS_EXCEPTION_IF_NULL(graph);
1027   if (graph_info_to_device_context_.empty()) {
1028     MS_LOG(EXCEPTION) << "The map graph_info_to_device_context_ is empty.";
1029   }
1030   const auto &graph_info = graph_info_to_device_context_.begin()->first;
1031   MS_EXCEPTION_IF_NULL(graph_compiler_);
1032   graph_compiler_->EraseSingleOpCache(graph_info, graph->graph_id());
1033   actor_to_graph_compiler_info_.erase(actor_info);
1034 }
1035 
RunGraph(const ActorInfo & actor_info,OpRunInfo * op_run_info,const std::vector<int64_t> * tensors_mask,const std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs)1036 void MindRTBackend::RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info,
1037                              const std::vector<int64_t> *tensors_mask,
1038                              const std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs) {
1039   MS_EXCEPTION_IF_NULL(input_tensors);
1040   MS_EXCEPTION_IF_NULL(op_run_info);
1041   MS_EXCEPTION_IF_NULL(tensors_mask);
1042   const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
1043   if (graph_iter == actor_to_graph_compiler_info_.end()) {
1044     MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
1045   }
1046   MS_EXCEPTION_IF_NULL(graph_iter->second);
1047   const auto &graph_compiler_info = *(graph_iter->second);
1048 
1049   const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
1050   MS_EXCEPTION_IF_NULL(actor_set);
1051 
1052   // Erase value node tensor.
1053   std::vector<tensor::TensorPtr> tensors_without_value_node;
1054   if (input_tensors->size() != tensors_mask->size()) {
1055     MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
1056                       << tensors_mask->size();
1057   }
1058   for (size_t index = 0; index < tensors_mask->size(); ++index) {
1059     if (tensors_mask->at(index) != kValueNodeTensorMask) {
1060       (void)tensors_without_value_node.emplace_back(input_tensors->at(index));
1061     }
1062   }
1063 
1064   for (auto &tensor : tensors_without_value_node) {
1065     MS_EXCEPTION_IF_NULL(tensor);
1066     if (tensor->NeedWaitDevice()) {
1067       tensor->WaitDevice();
1068     }
1069   }
1070 
1071   if (!runtime::GraphScheduler::GetInstance().Run(actor_set, {tensors_without_value_node}, *input_tensors,
1072                                                   runtime::GraphExecutionStrategy::kStep)) {
1073     MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
1074   }
1075 
1076   // Fetch outputs.
1077   const auto &graph = graph_compiler_info.graphs_.front();
1078   MS_EXCEPTION_IF_NULL(graph);
1079   MS_EXCEPTION_IF_NULL(graph_compiler_);
1080   const auto &output_nodes = graph_compiler_->GetGraphOutputNodes(graph->graph_id());
1081   MS_EXCEPTION_IF_NULL(outputs);
1082   UpdateOutput(output_nodes, outputs);
1083 
1084   // Update output abstract of dynamic op to op_run_info
1085   if (op_run_info->is_dynamic_shape) {
1086     UpdateOutputAbstract(graph, op_run_info);
1087   }
1088 
1089   // Release the kernel resource.
1090   const auto &kernels = graph->execution_order();
1091   for (const auto &kernel : kernels) {
1092     MS_EXCEPTION_IF_NULL(kernel);
1093     if (kOpCacheBlackList.find(AnfAlgo::GetCNodeName(kernel)) != kOpCacheBlackList.end()) {
1094       auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1095       if (kernel_mod) {
1096         kernel_mod->ReleaseResource();
1097       }
1098     }
1099   }
1100 
1101   // Update device address for input and output of graph.
1102   UpdateOutputDeviceAddress(output_nodes, graph_compiler_info.device_contexts_.front());
1103   UpdateInputDeviceAddress(graph);
1104 
1105   if (graph_compiler_info.need_erase_) {
1106     EraseSingleOpCache(actor_info, graph);
1107   }
1108 }
1109 }  // namespace compile
1110 }  // namespace mindspore
1111