• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/graph_compiler/backend.h"
17 
18 #include <algorithm>
19 #include <vector>
20 #include <map>
21 #include <stack>
22 #include <unordered_map>
23 #include "ops/sequence_ops.h"
24 #include "ops/nn_op_name.h"
25 #include "ops/structure_op_name.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "backend/graph_compiler/transform.h"
28 #include "backend/common/session/session_factory.h"
29 #include "runtime/pynative/op_executor.h"
30 #include "runtime/pynative/op_compiler.h"
31 #include "include/backend/optimizer/helper.h"
32 #include "pipeline/jit/ps/action.h"
33 #include "pipeline/jit/ps/parse/data_converter.h"
34 #include "pipeline/pynative/grad/jit/jit_call_graph.h"
35 #include "ir/anf.h"
36 #include "pybind_api/ir/base_ref_py.h"
37 #include "pybind_api/pybind_patch.h"
38 #include "include/common/utils/callbacks.h"
39 #include "include/common/utils/convert_utils.h"
40 #include "include/common/utils/convert_utils_py.h"
41 #include "utils/log_adapter.h"
42 #include "utils/ms_utils.h"
43 #include "runtime/hardware/device_context_manager.h"
44 #include "runtime/graph_scheduler/graph_compiler.h"
45 #include "runtime/pynative/op_runner.h"
46 #include "runtime/pynative/graph_adapter.h"
47 #include "kernel/pyboost/pyboost_utils.h"
48 #include "runtime/pynative/op_function/pyboost_grad_functions.h"
49 #include "include/backend/distributed/recovery/recovery_context.h"
50 #include "pybind_api/gil_scoped_long_running.h"
51 #ifdef ENABLE_DEBUGGER
52 #include "include/backend/debug/debugger/debugger.h"
53 #endif
54 #ifndef ENABLE_SECURITY
55 #include "include/backend/debug/data_dump/dump_json_parser.h"
56 #endif
57 #if defined(__linux__) && defined(WITH_BACKEND)
58 #include "include/backend/distributed/ps/ps_context.h"
59 #endif
60 
61 #include "runtime/device/device_address_utils.h"
62 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
63 
64 namespace mindspore {
65 namespace compile {
MsConvert(const GraphSegmentPtr & segment,const std::string & target)66 LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
67   MS_LOG(DEBUG) << "MsConvert";
68   MS_EXCEPTION_IF_NULL(segment);
69   MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
70   LinConvertResult result;
71   FuncGraphPtr fg;
72   AnfNodePtrList inputs;
73   AnfNodePtrList outputs;
74   std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
75   result.inputs = inputs;
76   result.outputs = outputs;
77   result.graph_id = kInvalidGraphId;
78   auto current_session = target_sess_;
79   if (target != target_device_ && !target.empty()) {
80     CreateOtherSession(target);
81     current_session = other_sess_;
82   }
83   MS_EXCEPTION_IF_NULL(current_session);
84   GraphId graph_id = current_session->CompileGraph(segment, outputs);
85   segment->graph_id_ = graph_id;
86   auto graph = current_session->GetGraph(graph_id);
87   MS_EXCEPTION_IF_NULL(graph);
88   for (const auto &pre_segment : segment->pre_segments_) {
89     MS_EXCEPTION_IF_NULL(pre_segment);
90     MS_EXCEPTION_IF_NULL(target_sess_);
91     auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
92     if (pre_graph == nullptr) {
93       MS_EXCEPTION_IF_NULL(other_sess_);
94       pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
95     }
96     MS_EXCEPTION_IF_NULL(pre_graph);
97     pre_graph->AddPostGraph(graph);
98     graph->AddPreGraph(pre_graph);
99     MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id;
100   }
101 
102   if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
103     MS_LOG(INFO) << "PrecompileOnly, stop run graph";
104     return result;
105   }
106   auto ms_context = MsContext::GetInstance();
107   const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
108   if (!pynative_mode || target != "Ascend") {
109     if (target != target_device_ && !target.empty()) {
110       MS_EXCEPTION_IF_NULL(other_sess_);
111       other_sess_->BuildGraph(graph_id);
112     } else if (!is_multi_graph_sink_) {
113       MS_EXCEPTION_IF_NULL(target_sess_);
114       target_sess_->BuildGraph(graph_id);
115     }
116   }
117   result.run = std::make_shared<RunFunc>(
118     [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
119   MS_EXCEPTION_IF_NULL(result.run);
120 
121   result.simu_run = std::make_shared<RunFunc>(
122     [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id); });
123   MS_EXCEPTION_IF_NULL(result.simu_run);
124   result.graph_id = graph_id;
125 
126   graph_id_map_[graph_id] = result;
127   return result;
128 }
129 
130 // compile set input output
MsSimuRunGraph(const GraphId & g)131 VectorRef MsBackend::MsSimuRunGraph(const GraphId &g) {
132   MS_LOG(DEBUG) << "Set graph input:" << g;
133   std::vector<BaseRef> outputs;
134   (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
135                        [](const AnfNodePtr &v) { return v; });
136   return VectorRef(outputs);
137 }
138 
139 namespace {
ClearGraphDeviceAddress(const KernelGraphPtr & graph,const DeviceContext * device_context,bool is_gradient_out)140 void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context, bool is_gradient_out) {
141   MS_EXCEPTION_IF_NULL(graph);
142   for (const auto &node : graph->execution_order()) {
143     auto output_address_num = AnfAlgo::GetOutputAddressNum(node);
144     // Clear old output device address of kernel
145     for (size_t i = 0; i < output_address_num; ++i) {
146       if (!AnfAlgo::OutputAddrExist(node, i, false)) {
147         continue;
148       }
149       const auto &device_address = AnfAlgo::GetMutableOutputAddr(node, i, false);
150       if (device_address == nullptr) {
151         continue;
152       }
153       MS_EXCEPTION_IF_NULL(device_context);
154       auto new_device_address = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(device_address, device_context);
155       if (is_gradient_out) {
156         new_device_address->set_from_persistent_mem(true);
157       }
158       AnfAlgo::SetOutputAddr(new_device_address, i, node.get());
159     }
160 
161     // Clear old workspace device address of kernel
162     auto kernel_mod = AnfAlgo::GetKernelMod(node);
163     MS_EXCEPTION_IF_NULL(kernel_mod);
164     auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
165     for (size_t i = 0; i < workspace_lists.size(); ++i) {
166       if (!AnfAlgo::WorkspaceAddrExist(node, i)) {
167         continue;
168       }
169       const auto &device_address = AnfAlgo::GetMutableWorkspaceAddr(node, i);
170       auto new_device_address = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(device_address, device_context);
171       AnfAlgo::SetWorkspaceAddr(new_device_address, i, node.get());
172     }
173   }
174 }
175 
ClearInputDeviceAddress(const KernelGraphPtr & graph,const DeviceContext * device_context)176 void ClearInputDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) {
177   MS_EXCEPTION_IF_NULL(graph);
178   MS_EXCEPTION_IF_NULL(device_context);
179   for (const auto &node : graph->input_nodes()) {
180     MS_EXCEPTION_IF_NULL(node);
181     if (node->isa<Parameter>()) {
182       auto device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
183       if (device_address == nullptr) {
184         continue;
185       }
186       auto new_device_address = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(device_address, device_context);
187       AnfAlgo::SetOutputAddr(new_device_address, 0, node.get());
188     }
189   }
190 }
191 
AllocateMemForTensor(const tensor::BaseTensorPtr & tensor,DeviceContext * device_context,bool is_cpu_address_exist)192 void AllocateMemForTensor(const tensor::BaseTensorPtr &tensor, DeviceContext *device_context,
193                           bool is_cpu_address_exist) {
194   MS_EXCEPTION_IF_NULL(tensor);
195   MS_EXCEPTION_IF_NULL(device_context);
196 
197   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
198   MS_EXCEPTION_IF_NULL(device_address);
199   device_address->set_is_view(true);
200   if (is_cpu_address_exist) {
201     if (device_address->from_mem_pool()) {
202       // If CPU address is exit, and address from pool, no need to copy.
203       return;
204     } else {
205       // If not from the pool, the lifetime of the device ptr is guaranteed elsewhere.
206       // Before applying for a new address, clear the address. Otherwise a warnging is generated.
207       device_address->set_ptr(nullptr);
208       if (device_context->GetDeviceType() != device_address->GetDeviceType()) {
209         device_context = runtime::OpRunner::GetDeviceContext(kCPUDevice);
210         MS_EXCEPTION_IF_NULL(device_context);
211       }
212     }
213   }
214 
215   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "PyNative", "ContiguousAllocMem", "");
216   auto mem_type = tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
217   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
218                                                  device_address.get());
219   if ((device_address->GetPtr() == nullptr) &&
220       (!device_context->device_res_manager_->AllocateMemory(device_address.get()))) {
221     MS_LOG(EXCEPTION) << "Allocate memory failed";
222   }
223 
224   auto tensor_size = LongToSize(tensor->data().nbytes());
225   auto tensor_type = tensor->data_type();
226   if (!device_address->SyncHostToDevice(tensor->shape(), tensor_size, tensor_type, "DefaultFormat",
227                                         tensor->data_ptr())) {
228     MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
229   }
230 }
231 
GetOutputDeviceAddress(const OpCompilerInfoPtr & op_compiler_info)232 device::DeviceAddressPtrList GetOutputDeviceAddress(const OpCompilerInfoPtr &op_compiler_info) {
233   const auto &output_edges = op_compiler_info->simple_graph_->outputs_;
234   device::DeviceAddressPtrList output_address;
235   output_address.reserve(output_edges.size());
236   std::transform(output_edges.begin(), output_edges.end(), std::back_inserter(output_address),
237                  [](const pynative::EdgePtr &edge) { return edge->address_; });
238   return output_address;
239 }
240 
ClearOpInputOutput(const OpCompilerInfoPtr & op_compiler_info)241 void ClearOpInputOutput(const OpCompilerInfoPtr &op_compiler_info) {
242   const auto &all_edges = op_compiler_info->simple_graph_->all_edges_;
243   for (const auto &edge : all_edges) {
244     if (edge->type_ != pynative::EdgeType::kValueNodeEdge) {
245       // Just set edge address to null rather than clone empty address.
246       // Clone empty address in next RunOp if needed.
247       edge->address_ = nullptr;
248     }
249   }
250 }
251 }  // namespace
252 
MsRunGraph(const GraphId & g,const VectorRef & args,const std::string & target)253 VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
254   MS_LOG(DEBUG) << "Start ms graph run:" << args.size() << ", g:" << g;
255   // Run graph
256   std::vector<tensor::TensorPtr> inputs;
257   for (const auto &arg : args) {
258     std::vector<tensor::TensorPtr> flatten_values;
259     AnfAlgo::FlattenInputArg(arg, nullptr, &flatten_values);
260     (void)std::copy(flatten_values.begin(), flatten_values.end(), std::back_inserter(inputs));
261   }
262 
263   VectorRef outputs;
264   // Call ms RunGraphAsync
265   const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_);
266   MS_EXCEPTION_IF_NULL(exe_session);
267 
268 #if defined(__linux__) && defined(WITH_BACKEND)
269   // If in PS mode, must use sync mode to run graph in case that the weights on server are not updated in the last step.
270   if (ps::PSContext::instance()->is_ps_mode()) {
271     exe_session->RunGraph(g, inputs, &outputs);
272     return outputs;
273   }
274 #endif
275 
276   auto ms_context = MsContext::GetInstance();
277   const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
278   if (pynative_mode) {
279     MS_LOG(EXCEPTION) << "Pynative can't call this function anymore!";
280   }
281   exe_session->RunGraphAsync(g, inputs, &outputs);
282 
283   MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
284   return outputs;
285 }
286 
MsBackend(const std::string & name,const std::string & target,uint32_t device_id)287 MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
288   convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
289   target_sess_ = session::SessionFactory::Get().Create(target);
290   if (target_sess_ == nullptr) {
291     MS_LOG(EXCEPTION) << "Session create failed! Please make sure target device:" << target << " is available.";
292   }
293   target_sess_->Init(device_id);
294 #ifndef ENABLE_SECURITY
295   target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
296 #endif
297   target_device_ = target;
298 }
299 
CreateOtherSession(const std::string & target)300 void MsBackend::CreateOtherSession(const std::string &target) {
301   if (other_sess_ != nullptr && other_device_ == target) {
302     return;
303   }
304   other_sess_ = session::SessionFactory::Get().Create(target);
305   if (other_sess_ == nullptr) {
306     MS_LOG(EXCEPTION) << "Session create failed! Please make sure target device:" << target << " is available.";
307   }
308   auto context_ptr = MsContext::GetInstance();
309   MS_EXCEPTION_IF_NULL(context_ptr);
310   uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
311   other_sess_->Init(device_id);
312 #ifndef ENABLE_SECURITY
313   other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
314 #endif
315   other_device_ = target;
316 }
317 
CompileGraph(const NotNull<FuncGraphPtr> & fg)318 GraphId MsBackend::CompileGraph(const NotNull<FuncGraphPtr> &fg) {
319   MS_EXCEPTION_IF_NULL(target_sess_);
320   return target_sess_->CompileGraph(fg);
321 }
322 
RunGraph(GraphId graph_id,const VectorRef & args)323 VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
324 
ClearSessionGraphs()325 void MsBackend::ClearSessionGraphs() {
326   if (target_sess_ != nullptr) {
327     target_sess_->ClearGraph();
328   }
329 }
330 
331 #ifdef ENABLE_DEBUGGER
SetDebugger()332 void MsBackend::SetDebugger() {
333   MS_EXCEPTION_IF_NULL(target_sess_);
334   target_sess_->SetDebugger();
335 }
336 #endif
337 
338 namespace {
GetInputofBpropCut(const std::shared_ptr<GraphCompiler> & graph_compiler,const CNodePtr & parent_node,const AnfNodePtr & input_node,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,InputInfo * input_info,size_t input_index)339 ValuePtr GetInputofBpropCut(const std::shared_ptr<GraphCompiler> &graph_compiler, const CNodePtr &parent_node,
340                             const AnfNodePtr &input_node,
341                             const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
342                             const std::map<AnfNodePtr, size_t> &parameter_index,
343                             const std::vector<TensorPtr> &graph_inputs, InputInfo *input_info, size_t input_index) {
344   if (!IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
345     auto real_input = common::AnfAlgo::VisitKernel(input_node, 0).first;
346     MS_EXCEPTION_IF_NULL(real_input);
347     ValuePtr value = nullptr;
348     if (!real_input->isa<ValueNode>()) {
349       if (real_input->abstract() != nullptr && real_input->abstract()->isa<abstract::AbstractSparseTensor>()) {
350         value = TensorListToSparseTensor(real_input->abstract(), graph_inputs);
351       } else {
352         value = graph_compiler->GetSingleOpInputTensorByIndex(parent_node, op_output, parameter_index, graph_inputs,
353                                                               input_info, input_index);
354       }
355       MS_EXCEPTION_IF_NULL(value);
356     } else {
357       const auto &value_node = real_input->cast<ValueNodePtr>();
358       MS_EXCEPTION_IF_NULL(value_node);
359       value = value_node->value();
360       MS_EXCEPTION_IF_NULL(value);
361     }
362     return value;
363   }
364   auto cnode = input_node->cast<CNodePtr>();
365   MS_EXCEPTION_IF_NULL(cnode);
366 
367   std::vector<ValuePtr> args_tuple;
368   for (size_t i = 1; i < cnode->size(); ++i) {
369     auto input = cnode->inputs()[i];
370     auto value =
371       GetInputofBpropCut(graph_compiler, cnode, input, op_output, parameter_index, graph_inputs, input_info, i - 1);
372     MS_EXCEPTION_IF_NULL(value);
373     (void)args_tuple.emplace_back(value);
374   }
375   auto arg = std::make_shared<ValueTuple>(args_tuple);
376   return arg;
377 }
378 
GetFrontArgByParameter(const std::vector<AnfNodePtr> & origin_paramters,const VectorRef & front_args,const AnfNodePtr & front_node)379 ValuePtr GetFrontArgByParameter(const std::vector<AnfNodePtr> &origin_paramters, const VectorRef &front_args,
380                                 const AnfNodePtr &front_node) {
381   const auto &iter = std::find(origin_paramters.begin(), origin_paramters.end(), front_node);
382   const size_t index = static_cast<size_t>(iter - origin_paramters.begin());
383   // If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
384   // and there is no need to input a tensor.
385   if (index >= front_args.size()) {
386     MS_LOG(EXCEPTION) << "Position out of front args range, position value is " << index << " and args size is "
387                       << front_args.size() << ".";
388   }
389   auto value = utils::cast<ValuePtr>(front_args[index]);
390   MS_EXCEPTION_IF_NULL(value);
391   return value;
392 }
393 
GetControlOpInput(const std::shared_ptr<GraphCompiler> & graph_compiler,const std::vector<AnfNodePtr> & origin_paramters,const VectorRef & front_args,const CNodePtr & front_cnode,const CNodePtr & backend_cnode,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output_map,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info,VectorRef * args)394 void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler,
395                        const std::vector<AnfNodePtr> &origin_paramters, const VectorRef &front_args,
396                        const CNodePtr &front_cnode, const CNodePtr &backend_cnode,
397                        const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output_map,
398                        const std::map<AnfNodePtr, size_t> &parameter_index,
399                        const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info, VectorRef *args) {
400   MS_EXCEPTION_IF_NULL(front_cnode);
401   MS_EXCEPTION_IF_NULL(backend_cnode);
402   MS_EXCEPTION_IF_NULL(graph_compiler);
403   MS_EXCEPTION_IF_NULL(args);
404   auto front_size = front_cnode->size();
405   auto back_size = backend_cnode->size();
406   if (front_size != back_size) {
407     MS_LOG(EXCEPTION) << "Bpropcut op front cnode size: " << front_size << ", back cnode size:" << back_size
408                       << ", bpropcut op should not flatten";
409   }
410   for (size_t index = 1; index < back_size; ++index) {
411     auto input_node = backend_cnode->input(index);
412     ValuePtr value = nullptr;
413     if (input_node->isa<Parameter>() && input_node->abstract() != nullptr &&
414         input_node->abstract()->isa<abstract::AbstractSequence>()) {
415       auto front_input_node = front_cnode->input(index);
416       value = GetFrontArgByParameter(origin_paramters, front_args, front_input_node);
417     } else {
418       value = GetInputofBpropCut(graph_compiler, backend_cnode, input_node, op_output_map, parameter_index,
419                                  graph_inputs, input_info, index - 1);
420     }
421     MS_EXCEPTION_IF_NULL(value);
422     (void)args->emplace_back(value);
423   }
424 }
425 
RunControlOperator(const std::shared_ptr<GraphCompiler> & graph_compiler,const std::vector<AnfNodePtr> & origin_paramters,const VectorRef & front_args,const KernelGraphPtr & graph,const CNodePtr & kernel,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output_map,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info,VectorRef * op_outputs)426 void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler,
427                         const std::vector<AnfNodePtr> &origin_paramters, const VectorRef &front_args,
428                         const KernelGraphPtr &graph, const CNodePtr &kernel,
429                         const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output_map,
430                         const std::map<AnfNodePtr, size_t> &parameter_index,
431                         const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info,
432                         VectorRef *op_outputs) {
433   MS_EXCEPTION_IF_NULL(graph);
434   MS_EXCEPTION_IF_NULL(kernel);
435   MS_EXCEPTION_IF_NULL(op_outputs);
436   AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
437   if (front_node == nullptr && graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
438     front_node = kernel;
439   }
440   MS_EXCEPTION_IF_NULL(front_node);
441   if (!front_node->isa<CNode>()) {
442     MS_LOG(EXCEPTION) << "The front node of bprop_cut is not CNode";
443   }
444   CNodePtr cnode = front_node->cast<CNodePtr>();
445   MS_EXCEPTION_IF_NULL(cnode);
446   const std::vector<AnfNodePtr> &node_inputs = cnode->inputs();
447   if (node_inputs.empty()) {
448     MS_LOG(EXCEPTION) << "The inputs of node[" << cnode->fullname_with_scope() << "] is empty";
449   }
450 
451   const AnfNodePtr &fn = node_inputs.at(0);
452   if (!IsValueNode<Primitive>(fn)) {
453     MS_LOG(EXCEPTION) << "The input[0] of kernel[" << kernel->fullname_with_scope()
454                       << "] is not a ValueNode of Primitive";
455   }
456 
457   PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
458   MS_EXCEPTION_IF_NULL(prim);
459   if (prim->name() == kBpropCutOpName) {
460     VectorRef args;
461     GetControlOpInput(graph_compiler, origin_paramters, front_args, cnode, kernel, op_output_map, parameter_index,
462                       graph_inputs, input_info, &args);
463     py::gil_scoped_acquire acquire;
464     BaseRef out = python_adapter::PyAdapterCallback::RunPrimitivePyHookFunction(prim, args);
465     // Convert pyobject output to tensor.
466     if (utils::isa<PyObjectRef>(out)) {
467       PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
468       auto out_py_tuple = py_ref.object_;
469       std::vector<ValuePtr> output_tensors;
470       ConvertPyObjectToTensor(out_py_tuple, &output_tensors);
471       // If bprop change grad, kernel abstract need update for its users
472       std::vector<abstract::AbstractBasePtr> output_tensor_abs;
473       for (auto &tensor : output_tensors) {
474         (void)output_tensor_abs.emplace_back(tensor->ToAbstract()->Broaden());
475         (void)op_outputs->elements_.emplace_back(std::move(tensor));
476       }
477       kernel->set_abstract(std::make_shared<abstract::AbstractTuple>(output_tensor_abs));
478     }
479   }
480 }
481 
UpdateOutputAbstract(const VectorRef & outputs,const session::BackendOpRunInfoPtr & op_run_info)482 void UpdateOutputAbstract(const VectorRef &outputs, const session::BackendOpRunInfoPtr &op_run_info) {
483   auto output_size = outputs.size();
484   if (output_size == 1 && op_run_info->base_op_run_info.op_name != kGetNextOpName) {
485     auto output_tensor = utils::cast<tensor::BaseTensorPtr>(outputs[0]);
486     MS_EXCEPTION_IF_NULL(output_tensor);
487     op_run_info->base_op_run_info.abstract = output_tensor->ToAbstract();
488     MS_LOG(DEBUG) << "Update output abstract of " << op_run_info->base_op_run_info.op_name << " to "
489                   << op_run_info->base_op_run_info.abstract->ToString();
490     return;
491   }
492   AbstractBasePtrList elements;
493   for (size_t i = 0; i < output_size; ++i) {
494     auto output_tensor = utils::cast<tensor::BaseTensorPtr>(outputs[i]);
495     MS_EXCEPTION_IF_NULL(output_tensor);
496     (void)elements.emplace_back(output_tensor->ToAbstract());
497   }
498   op_run_info->base_op_run_info.abstract = std::make_shared<abstract::AbstractTuple>(elements);
499   MS_LOG(DEBUG) << "Update output abstract of " << op_run_info->base_op_run_info.op_name << " to "
500                 << op_run_info->base_op_run_info.abstract->ToString();
501 }
502 
CreateOutputTensor(const AnfNodePtr & output_node,size_t output_index)503 tensor::BaseTensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index) {
504   MS_EXCEPTION_IF_NULL(output_node);
505   const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
506   MS_EXCEPTION_IF_NULL(device_tensor);
507 
508   const auto &user_data = device_tensor->user_data();
509   bool is_map_tensor_output = user_data && user_data->get<UserDataType>(kUserDataType) &&
510                               *(user_data->get<UserDataType>(kUserDataType)) == UserDataType::kUserTypeHashTable;
511   if (is_map_tensor_output) {
512     return AnfAlgo::CreateMapTensor(output_node, output_index);
513   }
514 
515   device_tensor->SetNodeIndex(output_node, output_index);
516   device_tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
517   runtime::DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(device_tensor, output_node, output_index);
518 
519   const auto &kernel_tensor = device_tensor->kernel_tensor();
520   MS_EXCEPTION_IF_NULL(kernel_tensor);
521 
522   // Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
523   // when infer type is not equal to device type.
524   auto tensor = std::make_shared<tensor::BaseTensor>(kernel_tensor->dtype_id(), kernel_tensor->GetShapeVector());
525 
526   // Put device tensor into host tensor.
527   tensor->set_device_address(device_tensor);
528   tensor->set_sync_status(kNeedSyncDeviceToHost);
529 
530   // MindRT is disabled in the multi graphs scenario
531   // Delete tensor->data_sync() when MindRT is enabled in all scenes.
532   auto ms_context = MsContext::GetInstance();
533   MS_EXCEPTION_IF_NULL(ms_context);
534   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
535       !runtime::OpExecutor::GetInstance().async_for_graph()) {
536     // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
537     // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
538     tensor->data_sync(false);
539   }
540 
541   return tensor;
542 }
CreateOutputTensorDynamicImpl(const OpCompilerInfoPtr & op_compiler_info,const AnfNodePtr & output_node,size_t output_index,const std::shared_ptr<device::DeviceAddress> & address,size_t idx_in_graph_outputs)543 tensor::BaseTensorPtr CreateOutputTensorDynamicImpl(const OpCompilerInfoPtr &op_compiler_info,
544                                                     const AnfNodePtr &output_node, size_t output_index,
545                                                     const std::shared_ptr<device::DeviceAddress> &address,
546                                                     size_t idx_in_graph_outputs) {
547   MS_EXCEPTION_IF_NULL(output_node);
548   MS_EXCEPTION_IF_NULL(address);
549   MS_EXCEPTION_IF_NULL(op_compiler_info);
550 
551   const auto &user_data = address->user_data();
552   bool is_map_tensor_output = user_data && user_data->get<UserDataType>(kUserDataType) &&
553                               *(user_data->get<UserDataType>(kUserDataType)) == UserDataType::kUserTypeHashTable;
554   if (is_map_tensor_output) {
555     return AnfAlgo::CreateMapTensor(address);
556   }
557 
558   // Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
559   // when infer type is not equal to device type.
560   auto tensor = std::make_shared<tensor::BaseTensor>(address->type_id(), address->host_shape());
561 
562   // Put device tensor into host tensor.
563   address->SetNodeIndex(output_node, output_index);
564   address->set_padding_type(op_compiler_info->graph_outputs_padding_type_[idx_in_graph_outputs]);
565   tensor->set_device_address(address);
566 
567   // MindRT is disabled in the multi graphs scenario
568   // Delete tensor->data_sync() when MindRT is enabled in all scenes.
569   auto ms_context = MsContext::GetInstance();
570   MS_EXCEPTION_IF_NULL(ms_context);
571   if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
572       !runtime::OpExecutor::GetInstance().async_for_graph()) {
573     // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
574     // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
575     tensor->data_sync(false);
576   }
577   return tensor;
578 }
579 
580 #if !defined(__APPLE__)
EnablePyNativeSyncRunning()581 bool EnablePyNativeSyncRunning() {
582   auto ms_context = MsContext::GetInstance();
583   MS_EXCEPTION_IF_NULL(ms_context);
584   return ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
585 }
586 #endif
587 
DisableRunOpAsync(const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info)588 bool DisableRunOpAsync(const OpCompilerInfoPtr &op_compiler_info, const session::BackendOpRunInfoPtr &op_run_info) {
589 #if defined(__APPLE__)
590   return true;
591 #else
592   return op_run_info->base_op_run_info.has_dynamic_output ||  // Infer output is dynamic.
593          op_compiler_info->need_refresh_abstract_ ||          // Graph output is dynamic after IR Pass. (e.g. Dropout)
594          op_compiler_info->need_erase_ ||                     // Random op cache need to be erased.
595          runtime::OpExecutor::NeedSync() ||                   // Cannot find a wait point before compile graph.
596          EnablePyNativeSyncRunning();                         // context.set_context(pynative_synchronize=True)
597 #endif
598 }
599 }  // namespace
600 
CreateKernelTensor(const std::vector<std::vector<tensor::TensorPtr>> & input_tensors,std::vector<DeviceContext * > device_contexts)601 void CreateKernelTensor(const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
602                         std::vector<DeviceContext *> device_contexts) {
603   if (input_tensors.size() < device_contexts.size()) {
604     MS_LOG(EXCEPTION) << "Invalid input_tensors size " << input_tensors.size() << " device_contexts size "
605                       << device_contexts.size();
606   }
607   for (size_t i = 0; i < device_contexts.size(); ++i) {
608     const auto &tensors = input_tensors[i];
609     const auto &device_context = device_contexts[i];
610     MS_EXCEPTION_IF_NULL(device_context);
611     for (const auto &tensor : tensors) {
612       if (tensor != nullptr && tensor->device_address() != nullptr) {
613         auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
614         MS_EXCEPTION_IF_NULL(device_address);
615         if (device_address->kernel_tensor() == nullptr) {
616           runtime::DeviceAddressUtils::CreateKernelTensor(device_address, tensor);
617         }
618       }
619     }
620   }
621 }
622 
CreateKernelTensor(const BaseRef & arg)623 void CreateKernelTensor(const BaseRef &arg) {
624   if (utils::isa<tensor::BaseTensor>(arg)) {
625     auto tensor = utils::cast<tensor::BaseTensorPtr>(arg);
626     auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
627     if (device_address != nullptr) {
628       runtime::DeviceAddressUtils::CreateKernelTensor(device_address, tensor);
629     }
630   } else if (utils::isa<ValueSequencePtr>(arg)) {
631     auto value_sequence = utils::cast<ValueSequencePtr>(arg);
632     MS_EXCEPTION_IF_NULL(value_sequence);
633     const auto &sequence_value = value_sequence->value();
634     for (const auto &value : sequence_value) {
635       CreateKernelTensor(value);
636     }
637   } else {
638     MS_LOG(DEBUG) << "Only tensor need create KernelTensor";
639   }
640 }
641 
CreateKernelTensor(const VectorRef & args)642 void CreateKernelTensor(const VectorRef &args) {
643   for (const auto &arg : args) {
644     CreateKernelTensor(arg);
645   }
646 }
647 
RealCompileGraphBeforeRunActor(const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,bool no_multi_graph)648 runtime::ActorSet *MindRTBackend::RealCompileGraphBeforeRunActor(const GraphCompilerInfo &graph_compiler_info,
649                                                                  const VectorRef &args, bool no_multi_graph) {
650   auto graphs = graph_compiler_info.graphs_;
651   auto device_contexts = graph_compiler_info.device_contexts_;
652   CreateKernelTensor(args);
653 
654   for (size_t i = 0; i < graphs.size(); ++i) {
655     const auto &graph = graphs[i];
656     MS_EXCEPTION_IF_NULL(graph);
657     graph->set_flag(kFlagPyNativeRunInGraph, true);
658     graph->set_flag(kFlagIsPynativeBpropGraph, root_graph_->has_flag(kFlagIsPynativeBpropGraph));
659     if (graph->is_any_type_input()) {
660       continue;
661     }
662     if (no_multi_graph) {
663       MS_LOG(INFO) << "Replace parameter format";
664       // The input tensors of heterogeneous graphs or control flow graphs are null.
665       // Need to get tensor after ParseControlNodes.
666       auto input_tensors = GetRunGraphInputs(graph_compiler_info, args);
667       pynative::GraphAdapter::ReplaceGraphParameterProperties(graph, input_tensors.at(i), device_contexts[i]);
668     }
669     (void)graph_compiler_->CompileGraphImpl(graph, device_contexts[i]);
670     pynative::GraphAdapter::RemoveUnusedValueNodes(graph);
671     // PyNative use kernel graph will result in front node and back node is the same; But in pynative task sink, backend
672     // still create new kernel graph
673     if (root_graph_->has_flag(kFlagIsPyNativeBpropKernelGraph) &&
674         !pynative::GraphAdapter::PyNativeEnableTaskSink(root_graph_)) {
675       graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, {graph->output()});
676     } else {
677       graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs());
678     }
679     // Clear front outputs after the outputs is cached.
680     graph->set_front_outputs({});
681     AnfAlgo::UpdateGraphValidRefPair(graph);
682     pynative::GraphAdapter::SensTensorToDevice(graph, device_contexts[i]);
683   }
684 
685   ParseControlNodes(graph_compiler_info);
686   UpdateGraphCompilerInfo(graph_compiler_info.name_);
687   auto actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info);
688   MS_EXCEPTION_IF_NULL(actor_set);
689   constexpr auto kKernelActorThreshold = 5000;
690   // Turning off multithreading may cause stack overflow in control flow scenarios.
691   if (no_multi_graph && actor_set->kernel_actors_.size() < kKernelActorThreshold &&
692       root_graph_->has_flag(kFlagIsPynativeBpropGraph)) {
693     // Multithreading can cause spikes in memory usage and performance fluctuations.
694     actor_set->is_multi_thread_execution_ = false;
695     MS_LOG(INFO) << "Actor Multithreading is turned off!";
696   }
697   runtime::GraphScheduler::GetInstance().Schedule(actor_set);
698 
699   for (size_t i = 0; i < graphs.size(); ++i) {
700     pynative::GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(graphs[i], device_contexts[i]);
701     pynative::GraphAdapter::GenerateRefCountForBpropValueNode(graphs[i]);
702     graph_adapter_.GenerateBackoffValueNodeOwners(graphs[i]);
703   }
704   return actor_set;
705 }
706 
RunGraphByActors(const ActorInfo & actor_info,const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,VectorRef * outputs)707 void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
708                                      const VectorRef &args, VectorRef *outputs) {
709   MS_LOG(INFO) << "Status record: begin run actor: " << actor_info;
710   WaitTaskFinish();
711   MS_EXCEPTION_IF_NULL(graph_compiler_);
712   auto graphs = graph_compiler_info.graphs_;
713   auto device_contexts = graph_compiler_info.device_contexts_;
714   if (device_contexts.size() != graphs.size()) {
715     MS_LOG(EXCEPTION) << "Graphs size " << graphs.size() << " is not equal to device_contexts size "
716                       << device_contexts.size();
717   }
718 
719   // KernelByKernel: The size of control_nodes is at least 1 since there is return node in the graph.
720   // GraphMode: No control nodes.
721   bool no_multi_graph = control_nodes_.size() <= 1 && graphs.size() == 1;
722   auto actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
723   if (actor_set == nullptr) {
724     actor_set = RealCompileGraphBeforeRunActor(graph_compiler_info, args, no_multi_graph);
725   }
726 
727   if (root_graph_->has_flag(kFlagIsPynativeBpropGraph)) {
728     for (size_t i = 0; i < graphs.size(); ++i) {
729       graph_adapter_.UpdateForwardOutputInBpropGraph(graphs[i], device_contexts[i], no_multi_graph);
730       pynative::GraphAdapter::UpdateDynamicValueNodeAbstract(graphs[i]);
731     }
732   }
733 
734   auto input_tensors = GetRunGraphInputs(graph_compiler_info, args);
735   if (graphs.size() > input_tensors.size()) {
736     MS_LOG(EXCEPTION) << "The actor_set " << actor_info << " graphs size " << graphs.size()
737                       << " should less than or equal to inputs size " << input_tensors.size();
738   }
739   pynative::GraphAdapter::HandleHeterogeneousTensors(input_tensors, device_contexts);
740   CreateKernelTensor(input_tensors, device_contexts);
741 
742   // Release GIL and run actor DAG.
743   GilReleaseWithCheck release_gil;
744   VectorRef empty_args;
745   runtime::GraphScheduler::GetInstance().Run(actor_set, input_tensors, empty_args);
746 
747   MS_EXCEPTION_IF_NULL(graph_compiler_);
748   graph_compiler_->Summary(graph_compiler_info.graphs_);
749 
750   auto output = root_graph_->output();
751   MS_LOG(DEBUG) << "Current out " << output->DebugString();
752   if (root_graph_->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
753     MS_EXCEPTION_IF_NULL(output_node_);
754     root_graph_->set_output(output_node_);
755   }
756   ConstructOutputs(actor_set, outputs, root_graph_);
757   actor_set->output_actor_->FreeSummaryNodeMem();
758   runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
759   // Close abstract_lock for dynamic_shape
760   AnfUtils::CloseAbstractLock();
761   MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
762 }
763 
RunMsGradGraph(const CNodePtr & kernel,const VectorRef & args,VectorRef * outputs) const764 void MindRTBackend::RunMsGradGraph(const CNodePtr &kernel, const VectorRef &args, VectorRef *outputs) const {
765   MS_EXCEPTION_IF_NULL(kernel);
766   auto jit_call_graph = kernel->user_data<pynative::JitCallGraph>();
767   MS_EXCEPTION_IF_NULL(jit_call_graph);
768   *outputs = jit_call_graph->Run(args);
769 }
770 
RunGraphBySingleOp(const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,VectorRef * outputs)771 void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args,
772                                        VectorRef *outputs) {
773   WaitTaskFinish();
774 
775   MS_LOG(INFO) << "Status record: begin run graph by single op";
776   MS_EXCEPTION_IF_NULL(graph_compiler_);
777   const auto &graphs = graph_compiler_info.graphs_;
778   auto inputs = GetRunGraphInputs(graph_compiler_info, args);
779   for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
780     const auto &graph = graphs[graph_index];
781     MS_EXCEPTION_IF_NULL(graph);
782     std::map<KernelWithIndex, tensor::BaseTensorPtr> op_output_map;
783     std::map<AnfNodePtr, size_t> parameter_index;
784     GraphOutputInfo graph_output_info;
785     graph_output_info.graph_outputs = outputs;
786     graph_compiler_->GetParamAndOutputIndex(graph, inputs[graph_index], outputs, &parameter_index,
787                                             &graph_output_info.output_indexes);
788 
789     std::map<KernelWithIndex, size_t> cnode_ref_count;
790     auto iter = cnode_ref_counts_.find(graph->graph_id());
791     if (iter == cnode_ref_counts_.end()) {
792       graph_compiler_->CalculateRefCount(graph, &cnode_ref_count);
793       (void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count);
794     } else {
795       cnode_ref_count = iter->second;
796     }
797 
798     MS_EXCEPTION_IF_NULL(root_graph_);
799     if (root_graph_->has_flag(kFlagIsPynativeBpropGraph)) {
800       graph_compiler_->CalculateForwardOpOutputCount(graph, inputs[graph_index], &forward_op_output_tensor_id_,
801                                                      parameter_index);
802     }
803 
804     GilReleaseWithCheck gil_release;
805     auto is_dynamic = root_graph_->has_flag(kFlagPyNativeBpropGraphIsDynamic);
806     bool has_bprop_cut = root_graph_->has_flag(kFlagPyNativeBpropGraphWithBpropCut);
807     auto ms_context = MsContext::GetInstance();
808     MS_EXCEPTION_IF_NULL(ms_context);
809     const std::string &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
810     for (const auto &kernel : graph->execution_order()) {
811       MS_LOG(DEBUG) << "Split and run op " << kernel->fullname_with_scope();
812       InputInfo input_info;
813       VectorRef op_outputs;
814       if (has_bprop_cut && common::AnfAlgo::IsBpropCutOpExecInBackend(kernel)) {
815         const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
816         RunControlOperator(graph_compiler_, origin_parameters, args, graph, kernel, op_output_map, parameter_index,
817                            inputs[graph_index], &input_info, &op_outputs);
818         // Execute remaining lazy tasks before PyNative hook exit.
819         WaitTaskFinish();
820       } else if (common::AnfAlgo::HasNodeAttr(kAttrJitCallNode, kernel)) {
821         graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], false,
822                                                  &input_info);
823         VectorRef input_args;
824         (void)std::transform(input_info.input_values.begin(), input_info.input_values.end(),
825                              std::back_inserter(input_args.elements_),
826                              [](ValuePtr &value) { return std::move(value); });
827 
828         RunMsGradGraph(kernel, input_args, &op_outputs);
829         WaitTaskFinish();
830       } else {
831         const auto &primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
832         MS_EXCEPTION_IF_NULL(primitive);
833         if (runtime::PyBoostOpExecute::GetInstance().IsPyBoostOpRegistered(primitive->name()) &&
834             (kernel::pyboost::PyBoostUtils::IsKernelModRegistered(device_target, primitive->name()) ||
835              kernel::pyboost::PyBoostUtils::IsPyBoostCustomRegistered(device_target, primitive->name()))) {
836           MS_LOG(DEBUG) << "Run " << primitive->name() << " by pyboost";
837           graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], true,
838                                                    &input_info);
839           runtime::OpRunnerInfo op_runner_info{
840             primitive, device_target, input_info.input_values, input_info.input_abs, {}, kernel->abstract()};
841           runtime::PyBoostOpExecute::GetInstance().RunPyBoostCall(&op_runner_info, &op_outputs);
842         } else {
843           MS_LOG(DEBUG) << "Run " << primitive->name() << " by single op graph";
844           session::BackendOpRunInfoPtr op_run_info;
845           graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], false,
846                                                    &input_info);
847           graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_info, is_dynamic, &op_run_info,
848                                                           &graph_output_info);
849           if (is_dynamic) {
850             op_run_info->op_prim = std::make_shared<Primitive>(*op_run_info->op_prim);
851             AnfAlgo::SetDynamicAttrToPrim(op_run_info->op_prim);
852             RunOpDynamic(op_run_info, &op_outputs);
853           } else {
854             RunOp(op_run_info, &op_outputs);
855           }
856         }
857       }
858 
859       graph_compiler_->UpdateRefCount(input_info.input_kernel, &cnode_ref_count, &op_output_map);
860 
861       graph_output_info.graph_output_tensors.clear();
862       graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
863     }
864     WaitTaskFinish();
865   }
866   python_adapter::PyAdapterCallback::ProcessUnPairedCellHook(true);
867   MS_LOG(INFO) << "Status record: end run graph by single op";
868 }
869 
RunGraphByCondition(const ActorInfo & actor_info,const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,VectorRef * outputs)870 void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
871                                         const VectorRef &args, VectorRef *outputs) {
872   bool enable_run_graph_by_single_op =
873     std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
874                 [](const KernelGraphPtr &graph) { return graph->has_flag(kFlagEnableRunGraphBySingleOp); });
875   if (enable_run_graph_by_single_op) {
876     RunGraphBySingleOp(graph_compiler_info, args, outputs);
877   } else {
878     RunGraphByActors(actor_info, graph_compiler_info, args, outputs);
879   }
880 }
881 
WaitTaskFinish() const882 void MindRTBackend::WaitTaskFinish() const {
883   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kWaitTaskFinish,
884                                      runtime::kDefaultOpName);
885   runtime::OpExecutor::GetInstance().WaitAll();
886 }
887 
ClearOpExecutorResource() const888 void MindRTBackend::ClearOpExecutorResource() const { runtime::OpExecutor::GetInstance().Reset(); }
889 
SyncStream()890 void MindRTBackend::SyncStream() {
891   const auto &device_context =
892     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
893   MS_EXCEPTION_IF_NULL(device_context);
894   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
895 
896   auto ret = device_context->device_res_manager_->SyncAllStreams();
897   if (!ret) {
898     MS_LOG(EXCEPTION) << "Sync Stream failed";
899   }
900 }
901 
EraseSingleOpCache(const GraphInfo & graph_info) const902 void MindRTBackend::EraseSingleOpCache(const GraphInfo &graph_info) const {
903   pynative::OpCompiler::GetInstance().ClearOpCache(graph_info);
904 }
905 
ReleaseForwardOutput(const std::vector<ValuePtr> & input_values)906 void MindRTBackend::ReleaseForwardOutput(const std::vector<ValuePtr> &input_values) {
907   graph_compiler_->UpdateForwardOpOutputRefCount(input_values, &forward_op_output_tensor_id_);
908 }
909 
OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> & context)910 void MindRTBackend::OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> &context) {
911   MS_LOG(DEBUG) << "OpRunCallback start";
912   auto ms_context = MsContext::GetInstance();
913   MS_EXCEPTION_IF_NULL(ms_context);
914   auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
915   ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, context->is_pynative_infer());
916   MS_EXCEPTION_IF_NULL(context);
917   runtime::OpRunner::RunSingleOpGraph(context->op_run_info(), context->op_compiler_info(),
918                                       runtime::OpRunner::GetTensorWithoutValueMask(context->op_run_info()));
919 
920   MS_EXCEPTION_IF_NULL(context->op_run_info());
921   if (!context->op_run_info()->is_infer) {
922     ReleaseForwardOutput(context->op_run_info()->base_op_run_info.expanded_input_values);
923   }
924 
925   ClearGraphDeviceAddress(context->graph(), context->device_context(), context->op_run_info()->is_gradient_out);
926   ClearInputDeviceAddress(context->graph(), context->device_context());
927   ClearOpInputOutput(context->op_compiler_info());
928 
929   // Reset PyNative infer flag.
930   ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, infer_flag);
931   MS_LOG(DEBUG) << "OpRunCallback end";
932 }
933 
OpRunCallbackDynamic(const std::shared_ptr<runtime::OpTaskContext> & context)934 void MindRTBackend::OpRunCallbackDynamic(const std::shared_ptr<runtime::OpTaskContext> &context) {
935   MS_LOG(DEBUG) << "OpRunCallback start";
936   auto ms_context = MsContext::GetInstance();
937   MS_EXCEPTION_IF_NULL(ms_context);
938   auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
939   ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, context->is_pynative_infer());
940 
941   MS_EXCEPTION_IF_NULL(context);
942   runtime::DynamicOpRunner::RunSingleOpGraph(context->op_run_info(), context->op_compiler_info(),
943                                              runtime::OpRunner::GetTensorWithoutValueMask(context->op_run_info()));
944 
945   MS_EXCEPTION_IF_NULL(context->op_run_info());
946   if (!context->op_run_info()->is_infer) {
947     ReleaseForwardOutput(context->op_run_info()->base_op_run_info.expanded_input_values);
948   }
949 
950   ClearOpInputOutput(context->op_compiler_info());
951   // Reset PyNative infer flag.
952   ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, infer_flag);
953   MS_LOG(DEBUG) << "OpRunCallback end";
954 }
955 
DispatchOpTask(bool single_op_cache_hit,VectorRef * outputs,const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info)956 void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs,
957                                    const OpCompilerInfoPtr &op_compiler_info,
958                                    const session::BackendOpRunInfoPtr &op_run_info) {
959   MS_EXCEPTION_IF_NULL(op_compiler_info);
960   const auto &graph = op_compiler_info->graph_;
961   MS_EXCEPTION_IF_NULL(graph);
962 
963   runtime::OpRunner::UpdateDeviceAddress(graph, runtime::OpRunner::GetTensorWithoutValueMask(op_run_info),
964                                          op_compiler_info->device_context_, false);
965   // Create output tensor
966   UpdateOutput(op_compiler_info->graph_output_nodes_, outputs);
967 
968   auto ms_context = MsContext::GetInstance();
969   MS_EXCEPTION_IF_NULL(ms_context);
970   auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
971   auto run_op_context =
972     std::make_shared<runtime::OpTaskContext>(graph->graph_id(), graph, op_run_info, op_compiler_info, infer_flag);
973 
974   auto &op_executor = runtime::OpExecutor::GetInstance();
975   if (!single_op_cache_hit) {
976     CompileSingleOpGraph(op_compiler_info, op_compiler_info->device_context_);
977   }
978 
979   auto run_task = std::make_shared<runtime::DeviceOpRunTask>(
980     run_op_context, [this](const std::shared_ptr<runtime::OpTaskContext> &ctx) { OpRunCallback(ctx); });
981   runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(run_task->task_id());
982   op_executor.PushOpRunTask(run_task);
983 }
984 
DispatchOpTaskDynamic(VectorRef * outputs,const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info,const vector<device::DeviceAddressPtr> & device_address_list)985 void MindRTBackend::DispatchOpTaskDynamic(VectorRef *outputs, const OpCompilerInfoPtr &op_compiler_info,
986                                           const session::BackendOpRunInfoPtr &op_run_info,
987                                           const vector<device::DeviceAddressPtr> &device_address_list) {
988   MS_EXCEPTION_IF_NULL(op_compiler_info);
989   const auto &graph = op_compiler_info->graph_;
990   MS_EXCEPTION_IF_NULL(graph);
991 
992   auto ms_context = MsContext::GetInstance();
993   MS_EXCEPTION_IF_NULL(ms_context);
994   auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
995   auto run_op_context =
996     std::make_shared<runtime::OpTaskContext>(graph->graph_id(), graph, op_run_info, op_compiler_info, infer_flag);
997 
998   auto &op_executor = runtime::OpExecutor::GetInstance();
999   auto task = std::make_shared<runtime::DeviceOpRunTask>(
1000     run_op_context, [this](const std::shared_ptr<runtime::OpTaskContext> &ctx) { OpRunCallbackDynamic(ctx); });
1001   runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(task->task_id());
1002   op_executor.PushOpRunTask(task);
1003 }
1004 
RunOpImpl(bool single_op_cache_hit,const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info,VectorRef * outputs)1005 void MindRTBackend::RunOpImpl(bool single_op_cache_hit, const OpCompilerInfoPtr &op_compiler_info,
1006                               const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
1007   MS_EXCEPTION_IF_NULL(op_run_info);
1008   MS_EXCEPTION_IF_NULL(op_compiler_info);
1009   // Fetch outputs.
1010   const auto &graph = op_compiler_info->graph_;
1011   MS_EXCEPTION_IF_NULL(graph);
1012   MS_EXCEPTION_IF_NULL(graph_compiler_);
1013   const auto &output_nodes = op_compiler_info->graph_output_nodes_;
1014   MS_EXCEPTION_IF_NULL(outputs);
1015 
1016   auto device_context = op_compiler_info->device_context_;
1017   auto &op_executor = runtime::OpExecutor::GetInstance();
1018   if (!DisableRunOpAsync(op_compiler_info, op_run_info)) {
1019     MS_LOG(DEBUG) << "Async exec enabled, op: " << op_run_info->base_op_run_info.op_name;
1020     DispatchOpTask(single_op_cache_hit, outputs, op_compiler_info, op_run_info);
1021     return;
1022   }
1023 
1024   MS_LOG(DEBUG) << "Async exec disabled, op: " << op_run_info->base_op_run_info.op_name;
1025   if (!op_executor.RunQueueEmpty()) {
1026     WaitTaskFinish();
1027   }
1028   if (!single_op_cache_hit) {
1029     CompileSingleOpGraph(op_compiler_info, device_context);
1030   }
1031   const auto &tensors_without_value_mask = runtime::OpRunner::GetTensorWithoutValueMask(op_run_info);
1032   runtime::OpRunner::UpdateDeviceAddress(graph, tensors_without_value_mask, device_context, true);
1033 
1034   runtime::OpRunner::RunSingleOpGraph(op_run_info, op_compiler_info, tensors_without_value_mask);
1035 
1036   if (!op_run_info->is_infer) {
1037     ReleaseForwardOutput(op_run_info->base_op_run_info.expanded_input_values);
1038   }
1039   UpdateOutput(output_nodes, outputs);
1040 
1041   ClearGraphDeviceAddress(graph, device_context, op_run_info->is_gradient_out);
1042   ClearInputDeviceAddress(graph, device_context);
1043   ClearOpInputOutput(op_compiler_info);
1044 
1045   if (op_run_info->base_op_run_info.has_dynamic_output || op_compiler_info->need_refresh_abstract_) {
1046     UpdateOutputAbstract(*outputs, op_run_info);
1047   }
1048   if (op_compiler_info->need_erase_) {
1049     EraseSingleOpCache(op_compiler_info->graph_info_);
1050   }
1051 }
1052 
RunOpImplDynamic(bool single_op_cache_hit,const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info,VectorRef * outputs)1053 void MindRTBackend::RunOpImplDynamic(bool single_op_cache_hit, const OpCompilerInfoPtr &op_compiler_info,
1054                                      const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
1055   MS_EXCEPTION_IF_NULL(op_run_info);
1056   MS_EXCEPTION_IF_NULL(op_compiler_info);
1057   MS_LOG(DEBUG) << "RunOpImplDynamic " << op_run_info->base_op_run_info.op_name;
1058   // Fetch outputs.
1059   const auto &graph = op_compiler_info->graph_;
1060   MS_EXCEPTION_IF_NULL(graph);
1061   MS_EXCEPTION_IF_NULL(graph_compiler_);
1062   MS_EXCEPTION_IF_NULL(outputs);
1063 
1064   auto device_context = op_compiler_info->device_context_;
1065   if (!single_op_cache_hit) {
1066     CompileSingleOpGraph(op_compiler_info, device_context, true);
1067   }
1068   if (!DisableRunOpAsync(op_compiler_info, op_run_info)) {
1069     MS_LOG(DEBUG) << "Async exec enabled, op: " << op_run_info->base_op_run_info.op_name;
1070     auto input_tensors = runtime::OpRunner::GetTensorWithoutValueMask(op_run_info);
1071     runtime::DynamicOpRunner::UpdateInputDeviceAddress(op_compiler_info, input_tensors, false);
1072     auto device_address_list = runtime::DeviceAddressUtils::CreateGraphOutputDeviceAddress(
1073       op_compiler_info, op_run_info->base_op_run_info.abstract, op_run_info->base_op_run_info.stream_id);
1074     // Create output tensor
1075     UpdateOutputDynamic(op_run_info, op_compiler_info, device_address_list, outputs);
1076     DispatchOpTaskDynamic(outputs, op_compiler_info, op_run_info, device_address_list);
1077     return;
1078   }
1079   MS_LOG(DEBUG) << "Async exec disabled, op: " << op_run_info->base_op_run_info.op_name;
1080   auto &op_executor = runtime::OpExecutor::GetInstance();
1081   if (!op_executor.RunQueueEmpty()) {
1082     WaitTaskFinish();
1083   }
1084   auto input_tensors = runtime::OpRunner::GetTensorWithoutValueMask(op_run_info);
1085   runtime::DynamicOpRunner::UpdateInputDeviceAddress(op_compiler_info, input_tensors, true);
1086   runtime::DynamicOpRunner::RunSingleOpGraph(op_run_info, op_compiler_info, input_tensors);
1087 
1088   if (!op_run_info->is_infer) {
1089     ReleaseForwardOutput(op_run_info->base_op_run_info.expanded_input_values);
1090   }
1091 
1092   const auto &device_address_list = GetOutputDeviceAddress(op_compiler_info);
1093   // Create output tensor
1094   UpdateOutputDynamic(op_run_info, op_compiler_info, device_address_list, outputs);
1095   UpdateOutputAbstract(*outputs, op_run_info);
1096   ClearOpInputOutput(op_compiler_info);
1097   if (op_compiler_info->need_erase_) {
1098     EraseSingleOpCache(op_compiler_info->graph_info_);
1099   }
1100 }
1101 
RunOp(const session::BackendOpRunInfoPtr & op_run_info,VectorRef * outputs)1102 void MindRTBackend::RunOp(const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
1103   MS_EXCEPTION_IF_NULL(op_run_info);
1104   MS_EXCEPTION_IF_NULL(graph_compiler_);
1105   MS_LOG(DEBUG) << "Run Op " << op_run_info->base_op_run_info.op_name;
1106 
1107   bool single_op_cache_hit = true;
1108   auto op_compiler_info =
1109     pynative::OpCompiler::GetInstance().Compile(op_run_info, &single_op_cache_hit, device_name_, device_id_);
1110   MS_EXCEPTION_IF_NULL(op_compiler_info);
1111   op_compiler_info->WaitReady();
1112   RunOpImpl(single_op_cache_hit, op_compiler_info, op_run_info, outputs);
1113 }
1114 
RunOpDynamic(const session::BackendOpRunInfoPtr & op_run_info,VectorRef * outputs)1115 void MindRTBackend::RunOpDynamic(const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
1116   MS_EXCEPTION_IF_NULL(op_run_info);
1117   MS_EXCEPTION_IF_NULL(graph_compiler_);
1118   MS_LOG(DEBUG) << "Run Op " << op_run_info->base_op_run_info.op_name;
1119 
1120   // Single op graph compile
1121   bool single_op_cache_hit = true;
1122   auto op_compiler_info =
1123     pynative::OpCompiler::GetInstance().Compile(op_run_info, &single_op_cache_hit, device_name_, device_id_);
1124   MS_EXCEPTION_IF_NULL(op_compiler_info);
1125   op_compiler_info->WaitReady();
1126   RunOpImplDynamic(single_op_cache_hit, op_compiler_info, op_run_info, outputs);
1127 }
1128 
RunViewKernelTaskAsyncImpl(const runtime::KernelTaskType & task_type,DeviceContext * device_context,const device::DeviceAddressPtrList & input_addr_list,const device::DeviceAddressPtrList & output_addr_list,const size_t & stream_id)1129 void MindRTBackend::RunViewKernelTaskAsyncImpl(const runtime::KernelTaskType &task_type, DeviceContext *device_context,
1130                                                const device::DeviceAddressPtrList &input_addr_list,
1131                                                const device::DeviceAddressPtrList &output_addr_list,
1132                                                const size_t &stream_id) {
1133   static auto kernel_task_func = [stream_id, task_type, &input_addr_list, &output_addr_list, device_context]() {
1134     runtime::OpRunner::LaunchKernelTask(task_type, device_context, input_addr_list, output_addr_list, stream_id);
1135   };
1136 
1137   runtime::OpExecutor::GetInstance().PushSimpleOpRunTask(
1138     std::make_shared<runtime::PassthroughDeviceTask>(kernel_task_func));
1139 }
1140 
RunViewKernelTask(const pynative::BaseOpRunInfo & base_op_run_info,const runtime::KernelTaskType & task_type,bool enable_async)1141 void MindRTBackend::RunViewKernelTask(const pynative::BaseOpRunInfo &base_op_run_info,
1142                                       const runtime::KernelTaskType &task_type, bool enable_async) {
1143   device::DeviceAddressPtrList input_addr_list;
1144   device::DeviceAddressPtrList output_addr_list;
1145 
1146   const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1147     {base_op_run_info.device_target, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
1148   MS_EXCEPTION_IF_NULL(device_context);
1149 
1150   for (size_t idx = 0; idx < base_op_run_info.expanded_input_values.size(); idx++) {
1151     auto input_tensor = base_op_run_info.expanded_input_values[idx]->cast<tensor::BaseTensorPtr>();
1152     MS_EXCEPTION_IF_NULL(input_tensor);
1153     if (input_tensor->device_address() == nullptr) {
1154       if (idx == 0) {
1155         MS_LOG(EXCEPTION) << "First tensor can not be nullptr, op name:" << base_op_run_info.op_name;
1156       }
1157       auto address_size = GetTypeByte(TypeIdToType(input_tensor->data_type())) * SizeOf(input_tensor->shape());
1158 
1159       auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1160         nullptr, address_size, Format::DEFAULT_FORMAT, input_tensor->data_type(), input_tensor->shape(),
1161         device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1162       kernel_tensor->SetType(std::make_shared<TensorType>(input_tensor->Dtype()));
1163       kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(input_tensor->shape()));
1164       kernel_tensor->set_stream_id(base_op_run_info.stream_id);
1165       auto input_addr = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1166 
1167       input_tensor->set_device_address(input_addr);
1168       RunAllocMemTask(device_context, input_tensor, enable_async, false);
1169       (void)input_addr_list.emplace_back(input_addr);
1170     } else {
1171       auto input_addr = std::static_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
1172       MS_EXCEPTION_IF_NULL(input_addr);
1173       if (input_addr->GetDeviceType() == device::DeviceType::kCPU) {
1174         RunAllocMemTask(device_context, input_tensor, enable_async, true);
1175       }
1176 
1177       (void)input_addr_list.emplace_back(input_addr);
1178     }
1179   }
1180 
1181   std::transform(base_op_run_info.output_tensors.begin(), base_op_run_info.output_tensors.end(),
1182                  std::back_inserter(output_addr_list), [](const auto &tensor) {
1183                    return std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1184                  });
1185 
1186   if (enable_async) {
1187     RunViewKernelTaskAsyncImpl(task_type, device_context, input_addr_list, output_addr_list,
1188                                base_op_run_info.stream_id);
1189   } else {
1190     WaitTaskFinish();
1191     runtime::OpRunner::LaunchKernelTask(task_type, device_context, input_addr_list, output_addr_list,
1192                                         base_op_run_info.stream_id);
1193   }
1194 }
1195 
RunAllocMemTask(DeviceContext * device_context,const tensor::BaseTensorPtr & tensor,bool enable_async,bool is_cpu_address_exist)1196 void MindRTBackend::RunAllocMemTask(DeviceContext *device_context, const tensor::BaseTensorPtr &tensor,
1197                                     bool enable_async, bool is_cpu_address_exist) {
1198   if (!enable_async) {
1199     WaitTaskFinish();
1200     return AllocateMemForTensor(tensor, device_context, is_cpu_address_exist);
1201   }
1202   auto alloc_mem_func = [device_context, tensor, is_cpu_address_exist]() {
1203     AllocateMemForTensor(tensor, device_context, is_cpu_address_exist);
1204   };
1205   runtime::OpExecutor::GetInstance().PushSimpleOpRunTask(
1206     std::make_shared<runtime::PassthroughDeviceTask>(alloc_mem_func));
1207 }
1208 
CompileSingleOpGraph(const OpCompilerInfoPtr & op_compiler_info,const DeviceContext * device_context,bool is_dynamic_shape) const1209 void MindRTBackend::CompileSingleOpGraph(const OpCompilerInfoPtr &op_compiler_info, const DeviceContext *device_context,
1210                                          bool is_dynamic_shape) const {
1211   MS_EXCEPTION_IF_NULL(op_compiler_info);
1212   MS_EXCEPTION_IF_NULL(device_context);
1213   pynative::OpCompiler::GetInstance().KernelBuild(op_compiler_info, device_context, is_dynamic_shape);
1214 }
1215 
UpdateOutput(const std::vector<session::KernelWithIndex> & output_nodes,VectorRef * const outputs) const1216 void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes,
1217                                  VectorRef *const outputs) const {
1218   MS_EXCEPTION_IF_NULL(outputs);
1219 
1220   for (auto &item_with_index : output_nodes) {
1221     MS_EXCEPTION_IF_NULL(item_with_index.first);
1222     if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
1223       continue;
1224     }
1225     auto output_tensor = CreateOutputTensor(item_with_index.first, item_with_index.second);
1226     MS_EXCEPTION_IF_NULL(output_tensor);
1227     output_tensor->set_need_pipeline_sync(true);
1228     outputs->emplace_back(output_tensor);
1229   }
1230 }
1231 
UpdateOutputDynamic(const session::BackendOpRunInfoPtr & op_run_info,const OpCompilerInfoPtr & op_compiler_info,const vector<device::DeviceAddressPtr> & device_address_list,VectorRef * const outputs) const1232 void MindRTBackend::UpdateOutputDynamic(const session::BackendOpRunInfoPtr &op_run_info,
1233                                         const OpCompilerInfoPtr &op_compiler_info,
1234                                         const vector<device::DeviceAddressPtr> &device_address_list,
1235                                         VectorRef *const outputs) const {
1236   MS_EXCEPTION_IF_NULL(op_run_info);
1237   MS_LOG(DEBUG) << "No promise, just create tensor and address, op " << op_run_info->base_op_run_info.op_name;
1238   MS_EXCEPTION_IF_NULL(op_compiler_info);
1239   auto output_nodes = op_compiler_info->graph_output_nodes_;
1240   auto outputs_size = output_nodes.size();
1241   if (op_compiler_info->graph_outputs_tensor_num_.size() != outputs_size) {
1242     MS_LOG(EXCEPTION) << "The size of graph_outputs_tensor_num_:" << op_compiler_info->graph_outputs_tensor_num_.size()
1243                       << " is not equal to outputs_size:" << outputs_size;
1244   }
1245 
1246   if (device_address_list.size() != outputs_size) {
1247     MS_LOG(EXCEPTION) << "The size of device_address_list:" << device_address_list.size()
1248                       << " is not equal to outputs_size:" << outputs_size;
1249   }
1250 
1251   for (size_t i = 0; i < outputs_size; ++i) {
1252     auto item_with_index = output_nodes[i];
1253     MS_EXCEPTION_IF_NULL(item_with_index.first);
1254     if (op_compiler_info->graph_outputs_tensor_num_[i] == 0) {
1255       continue;
1256     }
1257     auto output_address = device_address_list[i];
1258     MS_EXCEPTION_IF_NULL(output_address);
1259     auto output_tensor =
1260       CreateOutputTensorDynamicImpl(op_compiler_info, item_with_index.first, item_with_index.second, output_address, i);
1261     MS_EXCEPTION_IF_NULL(output_tensor);
1262     output_tensor->set_need_pipeline_sync(true);
1263     outputs->emplace_back(output_tensor);
1264   }
1265 }
1266 
ClearResource()1267 void MindRTBackend::ClearResource() {
1268   graph_compiler_ = std::make_shared<GraphCompiler>();
1269   graph_id_to_device_context_.clear();
1270   func_graph_to_kernel_graph_ids_.clear();
1271   graph_info_to_device_context_.clear();
1272   control_nodes_.clear();
1273   actor_to_graph_compiler_info_.clear();
1274   cnode_ref_counts_.clear();
1275 }
1276 
GetGraphById(GraphId graph_id)1277 KernelGraphPtr MindRTBackend::GetGraphById(GraphId graph_id) {
1278   MS_EXCEPTION_IF_NULL(graph_compiler_);
1279   return graph_compiler_->Fetch(graph_id);
1280 }
1281 }  // namespace compile
1282 }  // namespace mindspore
1283