1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/graph_compiler/backend.h"
17
18 #include <algorithm>
19 #include <vector>
20 #include <map>
21 #include <stack>
22 #include <unordered_map>
23 #include "ops/sequence_ops.h"
24 #include "ops/nn_op_name.h"
25 #include "ops/structure_op_name.h"
26 #include "include/common/utils/parallel_context.h"
27 #include "backend/graph_compiler/transform.h"
28 #include "backend/common/session/session_factory.h"
29 #include "runtime/pynative/op_executor.h"
30 #include "runtime/pynative/op_compiler.h"
31 #include "include/backend/optimizer/helper.h"
32 #include "pipeline/jit/ps/action.h"
33 #include "pipeline/jit/ps/parse/data_converter.h"
34 #include "pipeline/pynative/grad/jit/jit_call_graph.h"
35 #include "ir/anf.h"
36 #include "pybind_api/ir/base_ref_py.h"
37 #include "pybind_api/pybind_patch.h"
38 #include "include/common/utils/callbacks.h"
39 #include "include/common/utils/convert_utils.h"
40 #include "include/common/utils/convert_utils_py.h"
41 #include "utils/log_adapter.h"
42 #include "utils/ms_utils.h"
43 #include "runtime/hardware/device_context_manager.h"
44 #include "runtime/graph_scheduler/graph_compiler.h"
45 #include "runtime/pynative/op_runner.h"
46 #include "runtime/pynative/graph_adapter.h"
47 #include "kernel/pyboost/pyboost_utils.h"
48 #include "runtime/pynative/op_function/pyboost_grad_functions.h"
49 #include "include/backend/distributed/recovery/recovery_context.h"
50 #include "pybind_api/gil_scoped_long_running.h"
51 #ifdef ENABLE_DEBUGGER
52 #include "include/backend/debug/debugger/debugger.h"
53 #endif
54 #ifndef ENABLE_SECURITY
55 #include "include/backend/debug/data_dump/dump_json_parser.h"
56 #endif
57 #if defined(__linux__) && defined(WITH_BACKEND)
58 #include "include/backend/distributed/ps/ps_context.h"
59 #endif
60
61 #include "runtime/device/device_address_utils.h"
62 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
63
64 namespace mindspore {
65 namespace compile {
MsConvert(const GraphSegmentPtr & segment,const std::string & target)66 LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
67 MS_LOG(DEBUG) << "MsConvert";
68 MS_EXCEPTION_IF_NULL(segment);
69 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
70 LinConvertResult result;
71 FuncGraphPtr fg;
72 AnfNodePtrList inputs;
73 AnfNodePtrList outputs;
74 std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
75 result.inputs = inputs;
76 result.outputs = outputs;
77 result.graph_id = kInvalidGraphId;
78 auto current_session = target_sess_;
79 if (target != target_device_ && !target.empty()) {
80 CreateOtherSession(target);
81 current_session = other_sess_;
82 }
83 MS_EXCEPTION_IF_NULL(current_session);
84 GraphId graph_id = current_session->CompileGraph(segment, outputs);
85 segment->graph_id_ = graph_id;
86 auto graph = current_session->GetGraph(graph_id);
87 MS_EXCEPTION_IF_NULL(graph);
88 for (const auto &pre_segment : segment->pre_segments_) {
89 MS_EXCEPTION_IF_NULL(pre_segment);
90 MS_EXCEPTION_IF_NULL(target_sess_);
91 auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
92 if (pre_graph == nullptr) {
93 MS_EXCEPTION_IF_NULL(other_sess_);
94 pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
95 }
96 MS_EXCEPTION_IF_NULL(pre_graph);
97 pre_graph->AddPostGraph(graph);
98 graph->AddPreGraph(pre_graph);
99 MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id;
100 }
101
102 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
103 MS_LOG(INFO) << "PrecompileOnly, stop run graph";
104 return result;
105 }
106 auto ms_context = MsContext::GetInstance();
107 const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
108 if (!pynative_mode || target != "Ascend") {
109 if (target != target_device_ && !target.empty()) {
110 MS_EXCEPTION_IF_NULL(other_sess_);
111 other_sess_->BuildGraph(graph_id);
112 } else if (!is_multi_graph_sink_) {
113 MS_EXCEPTION_IF_NULL(target_sess_);
114 target_sess_->BuildGraph(graph_id);
115 }
116 }
117 result.run = std::make_shared<RunFunc>(
118 [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
119 MS_EXCEPTION_IF_NULL(result.run);
120
121 result.simu_run = std::make_shared<RunFunc>(
122 [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id); });
123 MS_EXCEPTION_IF_NULL(result.simu_run);
124 result.graph_id = graph_id;
125
126 graph_id_map_[graph_id] = result;
127 return result;
128 }
129
130 // compile set input output
MsSimuRunGraph(const GraphId & g)131 VectorRef MsBackend::MsSimuRunGraph(const GraphId &g) {
132 MS_LOG(DEBUG) << "Set graph input:" << g;
133 std::vector<BaseRef> outputs;
134 (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
135 [](const AnfNodePtr &v) { return v; });
136 return VectorRef(outputs);
137 }
138
139 namespace {
ClearGraphDeviceAddress(const KernelGraphPtr & graph,const DeviceContext * device_context,bool is_gradient_out)140 void ClearGraphDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context, bool is_gradient_out) {
141 MS_EXCEPTION_IF_NULL(graph);
142 for (const auto &node : graph->execution_order()) {
143 auto output_address_num = AnfAlgo::GetOutputAddressNum(node);
144 // Clear old output device address of kernel
145 for (size_t i = 0; i < output_address_num; ++i) {
146 if (!AnfAlgo::OutputAddrExist(node, i, false)) {
147 continue;
148 }
149 const auto &device_address = AnfAlgo::GetMutableOutputAddr(node, i, false);
150 if (device_address == nullptr) {
151 continue;
152 }
153 MS_EXCEPTION_IF_NULL(device_context);
154 auto new_device_address = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(device_address, device_context);
155 if (is_gradient_out) {
156 new_device_address->set_from_persistent_mem(true);
157 }
158 AnfAlgo::SetOutputAddr(new_device_address, i, node.get());
159 }
160
161 // Clear old workspace device address of kernel
162 auto kernel_mod = AnfAlgo::GetKernelMod(node);
163 MS_EXCEPTION_IF_NULL(kernel_mod);
164 auto workspace_lists = kernel_mod->GetWorkspaceSizeList();
165 for (size_t i = 0; i < workspace_lists.size(); ++i) {
166 if (!AnfAlgo::WorkspaceAddrExist(node, i)) {
167 continue;
168 }
169 const auto &device_address = AnfAlgo::GetMutableWorkspaceAddr(node, i);
170 auto new_device_address = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(device_address, device_context);
171 AnfAlgo::SetWorkspaceAddr(new_device_address, i, node.get());
172 }
173 }
174 }
175
ClearInputDeviceAddress(const KernelGraphPtr & graph,const DeviceContext * device_context)176 void ClearInputDeviceAddress(const KernelGraphPtr &graph, const DeviceContext *device_context) {
177 MS_EXCEPTION_IF_NULL(graph);
178 MS_EXCEPTION_IF_NULL(device_context);
179 for (const auto &node : graph->input_nodes()) {
180 MS_EXCEPTION_IF_NULL(node);
181 if (node->isa<Parameter>()) {
182 auto device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
183 if (device_address == nullptr) {
184 continue;
185 }
186 auto new_device_address = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(device_address, device_context);
187 AnfAlgo::SetOutputAddr(new_device_address, 0, node.get());
188 }
189 }
190 }
191
AllocateMemForTensor(const tensor::BaseTensorPtr & tensor,DeviceContext * device_context,bool is_cpu_address_exist)192 void AllocateMemForTensor(const tensor::BaseTensorPtr &tensor, DeviceContext *device_context,
193 bool is_cpu_address_exist) {
194 MS_EXCEPTION_IF_NULL(tensor);
195 MS_EXCEPTION_IF_NULL(device_context);
196
197 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
198 MS_EXCEPTION_IF_NULL(device_address);
199 device_address->set_is_view(true);
200 if (is_cpu_address_exist) {
201 if (device_address->from_mem_pool()) {
202 // If CPU address is exit, and address from pool, no need to copy.
203 return;
204 } else {
205 // If not from the pool, the lifetime of the device ptr is guaranteed elsewhere.
206 // Before applying for a new address, clear the address. Otherwise a warnging is generated.
207 device_address->set_ptr(nullptr);
208 if (device_context->GetDeviceType() != device_address->GetDeviceType()) {
209 device_context = runtime::OpRunner::GetDeviceContext(kCPUDevice);
210 MS_EXCEPTION_IF_NULL(device_context);
211 }
212 }
213 }
214
215 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "PyNative", "ContiguousAllocMem", "");
216 auto mem_type = tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
217 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
218 device_address.get());
219 if ((device_address->GetPtr() == nullptr) &&
220 (!device_context->device_res_manager_->AllocateMemory(device_address.get()))) {
221 MS_LOG(EXCEPTION) << "Allocate memory failed";
222 }
223
224 auto tensor_size = LongToSize(tensor->data().nbytes());
225 auto tensor_type = tensor->data_type();
226 if (!device_address->SyncHostToDevice(tensor->shape(), tensor_size, tensor_type, "DefaultFormat",
227 tensor->data_ptr())) {
228 MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
229 }
230 }
231
GetOutputDeviceAddress(const OpCompilerInfoPtr & op_compiler_info)232 device::DeviceAddressPtrList GetOutputDeviceAddress(const OpCompilerInfoPtr &op_compiler_info) {
233 const auto &output_edges = op_compiler_info->simple_graph_->outputs_;
234 device::DeviceAddressPtrList output_address;
235 output_address.reserve(output_edges.size());
236 std::transform(output_edges.begin(), output_edges.end(), std::back_inserter(output_address),
237 [](const pynative::EdgePtr &edge) { return edge->address_; });
238 return output_address;
239 }
240
ClearOpInputOutput(const OpCompilerInfoPtr & op_compiler_info)241 void ClearOpInputOutput(const OpCompilerInfoPtr &op_compiler_info) {
242 const auto &all_edges = op_compiler_info->simple_graph_->all_edges_;
243 for (const auto &edge : all_edges) {
244 if (edge->type_ != pynative::EdgeType::kValueNodeEdge) {
245 // Just set edge address to null rather than clone empty address.
246 // Clone empty address in next RunOp if needed.
247 edge->address_ = nullptr;
248 }
249 }
250 }
251 } // namespace
252
MsRunGraph(const GraphId & g,const VectorRef & args,const std::string & target)253 VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
254 MS_LOG(DEBUG) << "Start ms graph run:" << args.size() << ", g:" << g;
255 // Run graph
256 std::vector<tensor::TensorPtr> inputs;
257 for (const auto &arg : args) {
258 std::vector<tensor::TensorPtr> flatten_values;
259 AnfAlgo::FlattenInputArg(arg, nullptr, &flatten_values);
260 (void)std::copy(flatten_values.begin(), flatten_values.end(), std::back_inserter(inputs));
261 }
262
263 VectorRef outputs;
264 // Call ms RunGraphAsync
265 const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_);
266 MS_EXCEPTION_IF_NULL(exe_session);
267
268 #if defined(__linux__) && defined(WITH_BACKEND)
269 // If in PS mode, must use sync mode to run graph in case that the weights on server are not updated in the last step.
270 if (ps::PSContext::instance()->is_ps_mode()) {
271 exe_session->RunGraph(g, inputs, &outputs);
272 return outputs;
273 }
274 #endif
275
276 auto ms_context = MsContext::GetInstance();
277 const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
278 if (pynative_mode) {
279 MS_LOG(EXCEPTION) << "Pynative can't call this function anymore!";
280 }
281 exe_session->RunGraphAsync(g, inputs, &outputs);
282
283 MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
284 return outputs;
285 }
286
MsBackend(const std::string & name,const std::string & target,uint32_t device_id)287 MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
288 convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
289 target_sess_ = session::SessionFactory::Get().Create(target);
290 if (target_sess_ == nullptr) {
291 MS_LOG(EXCEPTION) << "Session create failed! Please make sure target device:" << target << " is available.";
292 }
293 target_sess_->Init(device_id);
294 #ifndef ENABLE_SECURITY
295 target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
296 #endif
297 target_device_ = target;
298 }
299
CreateOtherSession(const std::string & target)300 void MsBackend::CreateOtherSession(const std::string &target) {
301 if (other_sess_ != nullptr && other_device_ == target) {
302 return;
303 }
304 other_sess_ = session::SessionFactory::Get().Create(target);
305 if (other_sess_ == nullptr) {
306 MS_LOG(EXCEPTION) << "Session create failed! Please make sure target device:" << target << " is available.";
307 }
308 auto context_ptr = MsContext::GetInstance();
309 MS_EXCEPTION_IF_NULL(context_ptr);
310 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
311 other_sess_->Init(device_id);
312 #ifndef ENABLE_SECURITY
313 other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
314 #endif
315 other_device_ = target;
316 }
317
CompileGraph(const NotNull<FuncGraphPtr> & fg)318 GraphId MsBackend::CompileGraph(const NotNull<FuncGraphPtr> &fg) {
319 MS_EXCEPTION_IF_NULL(target_sess_);
320 return target_sess_->CompileGraph(fg);
321 }
322
RunGraph(GraphId graph_id,const VectorRef & args)323 VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
324
ClearSessionGraphs()325 void MsBackend::ClearSessionGraphs() {
326 if (target_sess_ != nullptr) {
327 target_sess_->ClearGraph();
328 }
329 }
330
331 #ifdef ENABLE_DEBUGGER
SetDebugger()332 void MsBackend::SetDebugger() {
333 MS_EXCEPTION_IF_NULL(target_sess_);
334 target_sess_->SetDebugger();
335 }
336 #endif
337
338 namespace {
GetInputofBpropCut(const std::shared_ptr<GraphCompiler> & graph_compiler,const CNodePtr & parent_node,const AnfNodePtr & input_node,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<TensorPtr> & graph_inputs,InputInfo * input_info,size_t input_index)339 ValuePtr GetInputofBpropCut(const std::shared_ptr<GraphCompiler> &graph_compiler, const CNodePtr &parent_node,
340 const AnfNodePtr &input_node,
341 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
342 const std::map<AnfNodePtr, size_t> ¶meter_index,
343 const std::vector<TensorPtr> &graph_inputs, InputInfo *input_info, size_t input_index) {
344 if (!IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
345 auto real_input = common::AnfAlgo::VisitKernel(input_node, 0).first;
346 MS_EXCEPTION_IF_NULL(real_input);
347 ValuePtr value = nullptr;
348 if (!real_input->isa<ValueNode>()) {
349 if (real_input->abstract() != nullptr && real_input->abstract()->isa<abstract::AbstractSparseTensor>()) {
350 value = TensorListToSparseTensor(real_input->abstract(), graph_inputs);
351 } else {
352 value = graph_compiler->GetSingleOpInputTensorByIndex(parent_node, op_output, parameter_index, graph_inputs,
353 input_info, input_index);
354 }
355 MS_EXCEPTION_IF_NULL(value);
356 } else {
357 const auto &value_node = real_input->cast<ValueNodePtr>();
358 MS_EXCEPTION_IF_NULL(value_node);
359 value = value_node->value();
360 MS_EXCEPTION_IF_NULL(value);
361 }
362 return value;
363 }
364 auto cnode = input_node->cast<CNodePtr>();
365 MS_EXCEPTION_IF_NULL(cnode);
366
367 std::vector<ValuePtr> args_tuple;
368 for (size_t i = 1; i < cnode->size(); ++i) {
369 auto input = cnode->inputs()[i];
370 auto value =
371 GetInputofBpropCut(graph_compiler, cnode, input, op_output, parameter_index, graph_inputs, input_info, i - 1);
372 MS_EXCEPTION_IF_NULL(value);
373 (void)args_tuple.emplace_back(value);
374 }
375 auto arg = std::make_shared<ValueTuple>(args_tuple);
376 return arg;
377 }
378
GetFrontArgByParameter(const std::vector<AnfNodePtr> & origin_paramters,const VectorRef & front_args,const AnfNodePtr & front_node)379 ValuePtr GetFrontArgByParameter(const std::vector<AnfNodePtr> &origin_paramters, const VectorRef &front_args,
380 const AnfNodePtr &front_node) {
381 const auto &iter = std::find(origin_paramters.begin(), origin_paramters.end(), front_node);
382 const size_t index = static_cast<size_t>(iter - origin_paramters.begin());
383 // If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
384 // and there is no need to input a tensor.
385 if (index >= front_args.size()) {
386 MS_LOG(EXCEPTION) << "Position out of front args range, position value is " << index << " and args size is "
387 << front_args.size() << ".";
388 }
389 auto value = utils::cast<ValuePtr>(front_args[index]);
390 MS_EXCEPTION_IF_NULL(value);
391 return value;
392 }
393
GetControlOpInput(const std::shared_ptr<GraphCompiler> & graph_compiler,const std::vector<AnfNodePtr> & origin_paramters,const VectorRef & front_args,const CNodePtr & front_cnode,const CNodePtr & backend_cnode,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output_map,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info,VectorRef * args)394 void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler,
395 const std::vector<AnfNodePtr> &origin_paramters, const VectorRef &front_args,
396 const CNodePtr &front_cnode, const CNodePtr &backend_cnode,
397 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output_map,
398 const std::map<AnfNodePtr, size_t> ¶meter_index,
399 const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info, VectorRef *args) {
400 MS_EXCEPTION_IF_NULL(front_cnode);
401 MS_EXCEPTION_IF_NULL(backend_cnode);
402 MS_EXCEPTION_IF_NULL(graph_compiler);
403 MS_EXCEPTION_IF_NULL(args);
404 auto front_size = front_cnode->size();
405 auto back_size = backend_cnode->size();
406 if (front_size != back_size) {
407 MS_LOG(EXCEPTION) << "Bpropcut op front cnode size: " << front_size << ", back cnode size:" << back_size
408 << ", bpropcut op should not flatten";
409 }
410 for (size_t index = 1; index < back_size; ++index) {
411 auto input_node = backend_cnode->input(index);
412 ValuePtr value = nullptr;
413 if (input_node->isa<Parameter>() && input_node->abstract() != nullptr &&
414 input_node->abstract()->isa<abstract::AbstractSequence>()) {
415 auto front_input_node = front_cnode->input(index);
416 value = GetFrontArgByParameter(origin_paramters, front_args, front_input_node);
417 } else {
418 value = GetInputofBpropCut(graph_compiler, backend_cnode, input_node, op_output_map, parameter_index,
419 graph_inputs, input_info, index - 1);
420 }
421 MS_EXCEPTION_IF_NULL(value);
422 (void)args->emplace_back(value);
423 }
424 }
425
RunControlOperator(const std::shared_ptr<GraphCompiler> & graph_compiler,const std::vector<AnfNodePtr> & origin_paramters,const VectorRef & front_args,const KernelGraphPtr & graph,const CNodePtr & kernel,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output_map,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info,VectorRef * op_outputs)426 void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler,
427 const std::vector<AnfNodePtr> &origin_paramters, const VectorRef &front_args,
428 const KernelGraphPtr &graph, const CNodePtr &kernel,
429 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output_map,
430 const std::map<AnfNodePtr, size_t> ¶meter_index,
431 const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info,
432 VectorRef *op_outputs) {
433 MS_EXCEPTION_IF_NULL(graph);
434 MS_EXCEPTION_IF_NULL(kernel);
435 MS_EXCEPTION_IF_NULL(op_outputs);
436 AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
437 if (front_node == nullptr && graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
438 front_node = kernel;
439 }
440 MS_EXCEPTION_IF_NULL(front_node);
441 if (!front_node->isa<CNode>()) {
442 MS_LOG(EXCEPTION) << "The front node of bprop_cut is not CNode";
443 }
444 CNodePtr cnode = front_node->cast<CNodePtr>();
445 MS_EXCEPTION_IF_NULL(cnode);
446 const std::vector<AnfNodePtr> &node_inputs = cnode->inputs();
447 if (node_inputs.empty()) {
448 MS_LOG(EXCEPTION) << "The inputs of node[" << cnode->fullname_with_scope() << "] is empty";
449 }
450
451 const AnfNodePtr &fn = node_inputs.at(0);
452 if (!IsValueNode<Primitive>(fn)) {
453 MS_LOG(EXCEPTION) << "The input[0] of kernel[" << kernel->fullname_with_scope()
454 << "] is not a ValueNode of Primitive";
455 }
456
457 PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
458 MS_EXCEPTION_IF_NULL(prim);
459 if (prim->name() == kBpropCutOpName) {
460 VectorRef args;
461 GetControlOpInput(graph_compiler, origin_paramters, front_args, cnode, kernel, op_output_map, parameter_index,
462 graph_inputs, input_info, &args);
463 py::gil_scoped_acquire acquire;
464 BaseRef out = python_adapter::PyAdapterCallback::RunPrimitivePyHookFunction(prim, args);
465 // Convert pyobject output to tensor.
466 if (utils::isa<PyObjectRef>(out)) {
467 PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
468 auto out_py_tuple = py_ref.object_;
469 std::vector<ValuePtr> output_tensors;
470 ConvertPyObjectToTensor(out_py_tuple, &output_tensors);
471 // If bprop change grad, kernel abstract need update for its users
472 std::vector<abstract::AbstractBasePtr> output_tensor_abs;
473 for (auto &tensor : output_tensors) {
474 (void)output_tensor_abs.emplace_back(tensor->ToAbstract()->Broaden());
475 (void)op_outputs->elements_.emplace_back(std::move(tensor));
476 }
477 kernel->set_abstract(std::make_shared<abstract::AbstractTuple>(output_tensor_abs));
478 }
479 }
480 }
481
UpdateOutputAbstract(const VectorRef & outputs,const session::BackendOpRunInfoPtr & op_run_info)482 void UpdateOutputAbstract(const VectorRef &outputs, const session::BackendOpRunInfoPtr &op_run_info) {
483 auto output_size = outputs.size();
484 if (output_size == 1 && op_run_info->base_op_run_info.op_name != kGetNextOpName) {
485 auto output_tensor = utils::cast<tensor::BaseTensorPtr>(outputs[0]);
486 MS_EXCEPTION_IF_NULL(output_tensor);
487 op_run_info->base_op_run_info.abstract = output_tensor->ToAbstract();
488 MS_LOG(DEBUG) << "Update output abstract of " << op_run_info->base_op_run_info.op_name << " to "
489 << op_run_info->base_op_run_info.abstract->ToString();
490 return;
491 }
492 AbstractBasePtrList elements;
493 for (size_t i = 0; i < output_size; ++i) {
494 auto output_tensor = utils::cast<tensor::BaseTensorPtr>(outputs[i]);
495 MS_EXCEPTION_IF_NULL(output_tensor);
496 (void)elements.emplace_back(output_tensor->ToAbstract());
497 }
498 op_run_info->base_op_run_info.abstract = std::make_shared<abstract::AbstractTuple>(elements);
499 MS_LOG(DEBUG) << "Update output abstract of " << op_run_info->base_op_run_info.op_name << " to "
500 << op_run_info->base_op_run_info.abstract->ToString();
501 }
502
CreateOutputTensor(const AnfNodePtr & output_node,size_t output_index)503 tensor::BaseTensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index) {
504 MS_EXCEPTION_IF_NULL(output_node);
505 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
506 MS_EXCEPTION_IF_NULL(device_tensor);
507
508 const auto &user_data = device_tensor->user_data();
509 bool is_map_tensor_output = user_data && user_data->get<UserDataType>(kUserDataType) &&
510 *(user_data->get<UserDataType>(kUserDataType)) == UserDataType::kUserTypeHashTable;
511 if (is_map_tensor_output) {
512 return AnfAlgo::CreateMapTensor(output_node, output_index);
513 }
514
515 device_tensor->SetNodeIndex(output_node, output_index);
516 device_tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
517 runtime::DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(device_tensor, output_node, output_index);
518
519 const auto &kernel_tensor = device_tensor->kernel_tensor();
520 MS_EXCEPTION_IF_NULL(kernel_tensor);
521
522 // Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
523 // when infer type is not equal to device type.
524 auto tensor = std::make_shared<tensor::BaseTensor>(kernel_tensor->dtype_id(), kernel_tensor->GetShapeVector());
525
526 // Put device tensor into host tensor.
527 tensor->set_device_address(device_tensor);
528 tensor->set_sync_status(kNeedSyncDeviceToHost);
529
530 // MindRT is disabled in the multi graphs scenario
531 // Delete tensor->data_sync() when MindRT is enabled in all scenes.
532 auto ms_context = MsContext::GetInstance();
533 MS_EXCEPTION_IF_NULL(ms_context);
534 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
535 !runtime::OpExecutor::GetInstance().async_for_graph()) {
536 // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
537 // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
538 tensor->data_sync(false);
539 }
540
541 return tensor;
542 }
CreateOutputTensorDynamicImpl(const OpCompilerInfoPtr & op_compiler_info,const AnfNodePtr & output_node,size_t output_index,const std::shared_ptr<device::DeviceAddress> & address,size_t idx_in_graph_outputs)543 tensor::BaseTensorPtr CreateOutputTensorDynamicImpl(const OpCompilerInfoPtr &op_compiler_info,
544 const AnfNodePtr &output_node, size_t output_index,
545 const std::shared_ptr<device::DeviceAddress> &address,
546 size_t idx_in_graph_outputs) {
547 MS_EXCEPTION_IF_NULL(output_node);
548 MS_EXCEPTION_IF_NULL(address);
549 MS_EXCEPTION_IF_NULL(op_compiler_info);
550
551 const auto &user_data = address->user_data();
552 bool is_map_tensor_output = user_data && user_data->get<UserDataType>(kUserDataType) &&
553 *(user_data->get<UserDataType>(kUserDataType)) == UserDataType::kUserTypeHashTable;
554 if (is_map_tensor_output) {
555 return AnfAlgo::CreateMapTensor(address);
556 }
557
558 // Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
559 // when infer type is not equal to device type.
560 auto tensor = std::make_shared<tensor::BaseTensor>(address->type_id(), address->host_shape());
561
562 // Put device tensor into host tensor.
563 address->SetNodeIndex(output_node, output_index);
564 address->set_padding_type(op_compiler_info->graph_outputs_padding_type_[idx_in_graph_outputs]);
565 tensor->set_device_address(address);
566
567 // MindRT is disabled in the multi graphs scenario
568 // Delete tensor->data_sync() when MindRT is enabled in all scenes.
569 auto ms_context = MsContext::GetInstance();
570 MS_EXCEPTION_IF_NULL(ms_context);
571 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
572 !runtime::OpExecutor::GetInstance().async_for_graph()) {
573 // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
574 // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
575 tensor->data_sync(false);
576 }
577 return tensor;
578 }
579
580 #if !defined(__APPLE__)
EnablePyNativeSyncRunning()581 bool EnablePyNativeSyncRunning() {
582 auto ms_context = MsContext::GetInstance();
583 MS_EXCEPTION_IF_NULL(ms_context);
584 return ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE);
585 }
586 #endif
587
DisableRunOpAsync(const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info)588 bool DisableRunOpAsync(const OpCompilerInfoPtr &op_compiler_info, const session::BackendOpRunInfoPtr &op_run_info) {
589 #if defined(__APPLE__)
590 return true;
591 #else
592 return op_run_info->base_op_run_info.has_dynamic_output || // Infer output is dynamic.
593 op_compiler_info->need_refresh_abstract_ || // Graph output is dynamic after IR Pass. (e.g. Dropout)
594 op_compiler_info->need_erase_ || // Random op cache need to be erased.
595 runtime::OpExecutor::NeedSync() || // Cannot find a wait point before compile graph.
596 EnablePyNativeSyncRunning(); // context.set_context(pynative_synchronize=True)
597 #endif
598 }
599 } // namespace
600
CreateKernelTensor(const std::vector<std::vector<tensor::TensorPtr>> & input_tensors,std::vector<DeviceContext * > device_contexts)601 void CreateKernelTensor(const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
602 std::vector<DeviceContext *> device_contexts) {
603 if (input_tensors.size() < device_contexts.size()) {
604 MS_LOG(EXCEPTION) << "Invalid input_tensors size " << input_tensors.size() << " device_contexts size "
605 << device_contexts.size();
606 }
607 for (size_t i = 0; i < device_contexts.size(); ++i) {
608 const auto &tensors = input_tensors[i];
609 const auto &device_context = device_contexts[i];
610 MS_EXCEPTION_IF_NULL(device_context);
611 for (const auto &tensor : tensors) {
612 if (tensor != nullptr && tensor->device_address() != nullptr) {
613 auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
614 MS_EXCEPTION_IF_NULL(device_address);
615 if (device_address->kernel_tensor() == nullptr) {
616 runtime::DeviceAddressUtils::CreateKernelTensor(device_address, tensor);
617 }
618 }
619 }
620 }
621 }
622
CreateKernelTensor(const BaseRef & arg)623 void CreateKernelTensor(const BaseRef &arg) {
624 if (utils::isa<tensor::BaseTensor>(arg)) {
625 auto tensor = utils::cast<tensor::BaseTensorPtr>(arg);
626 auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
627 if (device_address != nullptr) {
628 runtime::DeviceAddressUtils::CreateKernelTensor(device_address, tensor);
629 }
630 } else if (utils::isa<ValueSequencePtr>(arg)) {
631 auto value_sequence = utils::cast<ValueSequencePtr>(arg);
632 MS_EXCEPTION_IF_NULL(value_sequence);
633 const auto &sequence_value = value_sequence->value();
634 for (const auto &value : sequence_value) {
635 CreateKernelTensor(value);
636 }
637 } else {
638 MS_LOG(DEBUG) << "Only tensor need create KernelTensor";
639 }
640 }
641
CreateKernelTensor(const VectorRef & args)642 void CreateKernelTensor(const VectorRef &args) {
643 for (const auto &arg : args) {
644 CreateKernelTensor(arg);
645 }
646 }
647
RealCompileGraphBeforeRunActor(const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,bool no_multi_graph)648 runtime::ActorSet *MindRTBackend::RealCompileGraphBeforeRunActor(const GraphCompilerInfo &graph_compiler_info,
649 const VectorRef &args, bool no_multi_graph) {
650 auto graphs = graph_compiler_info.graphs_;
651 auto device_contexts = graph_compiler_info.device_contexts_;
652 CreateKernelTensor(args);
653
654 for (size_t i = 0; i < graphs.size(); ++i) {
655 const auto &graph = graphs[i];
656 MS_EXCEPTION_IF_NULL(graph);
657 graph->set_flag(kFlagPyNativeRunInGraph, true);
658 graph->set_flag(kFlagIsPynativeBpropGraph, root_graph_->has_flag(kFlagIsPynativeBpropGraph));
659 if (graph->is_any_type_input()) {
660 continue;
661 }
662 if (no_multi_graph) {
663 MS_LOG(INFO) << "Replace parameter format";
664 // The input tensors of heterogeneous graphs or control flow graphs are null.
665 // Need to get tensor after ParseControlNodes.
666 auto input_tensors = GetRunGraphInputs(graph_compiler_info, args);
667 pynative::GraphAdapter::ReplaceGraphParameterProperties(graph, input_tensors.at(i), device_contexts[i]);
668 }
669 (void)graph_compiler_->CompileGraphImpl(graph, device_contexts[i]);
670 pynative::GraphAdapter::RemoveUnusedValueNodes(graph);
671 // PyNative use kernel graph will result in front node and back node is the same; But in pynative task sink, backend
672 // still create new kernel graph
673 if (root_graph_->has_flag(kFlagIsPyNativeBpropKernelGraph) &&
674 !pynative::GraphAdapter::PyNativeEnableTaskSink(root_graph_)) {
675 graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, {graph->output()});
676 } else {
677 graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs());
678 }
679 // Clear front outputs after the outputs is cached.
680 graph->set_front_outputs({});
681 AnfAlgo::UpdateGraphValidRefPair(graph);
682 pynative::GraphAdapter::SensTensorToDevice(graph, device_contexts[i]);
683 }
684
685 ParseControlNodes(graph_compiler_info);
686 UpdateGraphCompilerInfo(graph_compiler_info.name_);
687 auto actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info);
688 MS_EXCEPTION_IF_NULL(actor_set);
689 constexpr auto kKernelActorThreshold = 5000;
690 // Turning off multithreading may cause stack overflow in control flow scenarios.
691 if (no_multi_graph && actor_set->kernel_actors_.size() < kKernelActorThreshold &&
692 root_graph_->has_flag(kFlagIsPynativeBpropGraph)) {
693 // Multithreading can cause spikes in memory usage and performance fluctuations.
694 actor_set->is_multi_thread_execution_ = false;
695 MS_LOG(INFO) << "Actor Multithreading is turned off!";
696 }
697 runtime::GraphScheduler::GetInstance().Schedule(actor_set);
698
699 for (size_t i = 0; i < graphs.size(); ++i) {
700 pynative::GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(graphs[i], device_contexts[i]);
701 pynative::GraphAdapter::GenerateRefCountForBpropValueNode(graphs[i]);
702 graph_adapter_.GenerateBackoffValueNodeOwners(graphs[i]);
703 }
704 return actor_set;
705 }
706
RunGraphByActors(const ActorInfo & actor_info,const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,VectorRef * outputs)707 void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
708 const VectorRef &args, VectorRef *outputs) {
709 MS_LOG(INFO) << "Status record: begin run actor: " << actor_info;
710 WaitTaskFinish();
711 MS_EXCEPTION_IF_NULL(graph_compiler_);
712 auto graphs = graph_compiler_info.graphs_;
713 auto device_contexts = graph_compiler_info.device_contexts_;
714 if (device_contexts.size() != graphs.size()) {
715 MS_LOG(EXCEPTION) << "Graphs size " << graphs.size() << " is not equal to device_contexts size "
716 << device_contexts.size();
717 }
718
719 // KernelByKernel: The size of control_nodes is at least 1 since there is return node in the graph.
720 // GraphMode: No control nodes.
721 bool no_multi_graph = control_nodes_.size() <= 1 && graphs.size() == 1;
722 auto actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
723 if (actor_set == nullptr) {
724 actor_set = RealCompileGraphBeforeRunActor(graph_compiler_info, args, no_multi_graph);
725 }
726
727 if (root_graph_->has_flag(kFlagIsPynativeBpropGraph)) {
728 for (size_t i = 0; i < graphs.size(); ++i) {
729 graph_adapter_.UpdateForwardOutputInBpropGraph(graphs[i], device_contexts[i], no_multi_graph);
730 pynative::GraphAdapter::UpdateDynamicValueNodeAbstract(graphs[i]);
731 }
732 }
733
734 auto input_tensors = GetRunGraphInputs(graph_compiler_info, args);
735 if (graphs.size() > input_tensors.size()) {
736 MS_LOG(EXCEPTION) << "The actor_set " << actor_info << " graphs size " << graphs.size()
737 << " should less than or equal to inputs size " << input_tensors.size();
738 }
739 pynative::GraphAdapter::HandleHeterogeneousTensors(input_tensors, device_contexts);
740 CreateKernelTensor(input_tensors, device_contexts);
741
742 // Release GIL and run actor DAG.
743 GilReleaseWithCheck release_gil;
744 VectorRef empty_args;
745 runtime::GraphScheduler::GetInstance().Run(actor_set, input_tensors, empty_args);
746
747 MS_EXCEPTION_IF_NULL(graph_compiler_);
748 graph_compiler_->Summary(graph_compiler_info.graphs_);
749
750 auto output = root_graph_->output();
751 MS_LOG(DEBUG) << "Current out " << output->DebugString();
752 if (root_graph_->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
753 MS_EXCEPTION_IF_NULL(output_node_);
754 root_graph_->set_output(output_node_);
755 }
756 ConstructOutputs(actor_set, outputs, root_graph_);
757 actor_set->output_actor_->FreeSummaryNodeMem();
758 runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
759 // Close abstract_lock for dynamic_shape
760 AnfUtils::CloseAbstractLock();
761 MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
762 }
763
RunMsGradGraph(const CNodePtr & kernel,const VectorRef & args,VectorRef * outputs) const764 void MindRTBackend::RunMsGradGraph(const CNodePtr &kernel, const VectorRef &args, VectorRef *outputs) const {
765 MS_EXCEPTION_IF_NULL(kernel);
766 auto jit_call_graph = kernel->user_data<pynative::JitCallGraph>();
767 MS_EXCEPTION_IF_NULL(jit_call_graph);
768 *outputs = jit_call_graph->Run(args);
769 }
770
RunGraphBySingleOp(const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,VectorRef * outputs)771 void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args,
772 VectorRef *outputs) {
773 WaitTaskFinish();
774
775 MS_LOG(INFO) << "Status record: begin run graph by single op";
776 MS_EXCEPTION_IF_NULL(graph_compiler_);
777 const auto &graphs = graph_compiler_info.graphs_;
778 auto inputs = GetRunGraphInputs(graph_compiler_info, args);
779 for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
780 const auto &graph = graphs[graph_index];
781 MS_EXCEPTION_IF_NULL(graph);
782 std::map<KernelWithIndex, tensor::BaseTensorPtr> op_output_map;
783 std::map<AnfNodePtr, size_t> parameter_index;
784 GraphOutputInfo graph_output_info;
785 graph_output_info.graph_outputs = outputs;
786 graph_compiler_->GetParamAndOutputIndex(graph, inputs[graph_index], outputs, ¶meter_index,
787 &graph_output_info.output_indexes);
788
789 std::map<KernelWithIndex, size_t> cnode_ref_count;
790 auto iter = cnode_ref_counts_.find(graph->graph_id());
791 if (iter == cnode_ref_counts_.end()) {
792 graph_compiler_->CalculateRefCount(graph, &cnode_ref_count);
793 (void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count);
794 } else {
795 cnode_ref_count = iter->second;
796 }
797
798 MS_EXCEPTION_IF_NULL(root_graph_);
799 if (root_graph_->has_flag(kFlagIsPynativeBpropGraph)) {
800 graph_compiler_->CalculateForwardOpOutputCount(graph, inputs[graph_index], &forward_op_output_tensor_id_,
801 parameter_index);
802 }
803
804 GilReleaseWithCheck gil_release;
805 auto is_dynamic = root_graph_->has_flag(kFlagPyNativeBpropGraphIsDynamic);
806 bool has_bprop_cut = root_graph_->has_flag(kFlagPyNativeBpropGraphWithBpropCut);
807 auto ms_context = MsContext::GetInstance();
808 MS_EXCEPTION_IF_NULL(ms_context);
809 const std::string &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
810 for (const auto &kernel : graph->execution_order()) {
811 MS_LOG(DEBUG) << "Split and run op " << kernel->fullname_with_scope();
812 InputInfo input_info;
813 VectorRef op_outputs;
814 if (has_bprop_cut && common::AnfAlgo::IsBpropCutOpExecInBackend(kernel)) {
815 const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
816 RunControlOperator(graph_compiler_, origin_parameters, args, graph, kernel, op_output_map, parameter_index,
817 inputs[graph_index], &input_info, &op_outputs);
818 // Execute remaining lazy tasks before PyNative hook exit.
819 WaitTaskFinish();
820 } else if (common::AnfAlgo::HasNodeAttr(kAttrJitCallNode, kernel)) {
821 graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], false,
822 &input_info);
823 VectorRef input_args;
824 (void)std::transform(input_info.input_values.begin(), input_info.input_values.end(),
825 std::back_inserter(input_args.elements_),
826 [](ValuePtr &value) { return std::move(value); });
827
828 RunMsGradGraph(kernel, input_args, &op_outputs);
829 WaitTaskFinish();
830 } else {
831 const auto &primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
832 MS_EXCEPTION_IF_NULL(primitive);
833 if (runtime::PyBoostOpExecute::GetInstance().IsPyBoostOpRegistered(primitive->name()) &&
834 (kernel::pyboost::PyBoostUtils::IsKernelModRegistered(device_target, primitive->name()) ||
835 kernel::pyboost::PyBoostUtils::IsPyBoostCustomRegistered(device_target, primitive->name()))) {
836 MS_LOG(DEBUG) << "Run " << primitive->name() << " by pyboost";
837 graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], true,
838 &input_info);
839 runtime::OpRunnerInfo op_runner_info{
840 primitive, device_target, input_info.input_values, input_info.input_abs, {}, kernel->abstract()};
841 runtime::PyBoostOpExecute::GetInstance().RunPyBoostCall(&op_runner_info, &op_outputs);
842 } else {
843 MS_LOG(DEBUG) << "Run " << primitive->name() << " by single op graph";
844 session::BackendOpRunInfoPtr op_run_info;
845 graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index], false,
846 &input_info);
847 graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_info, is_dynamic, &op_run_info,
848 &graph_output_info);
849 if (is_dynamic) {
850 op_run_info->op_prim = std::make_shared<Primitive>(*op_run_info->op_prim);
851 AnfAlgo::SetDynamicAttrToPrim(op_run_info->op_prim);
852 RunOpDynamic(op_run_info, &op_outputs);
853 } else {
854 RunOp(op_run_info, &op_outputs);
855 }
856 }
857 }
858
859 graph_compiler_->UpdateRefCount(input_info.input_kernel, &cnode_ref_count, &op_output_map);
860
861 graph_output_info.graph_output_tensors.clear();
862 graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
863 }
864 WaitTaskFinish();
865 }
866 python_adapter::PyAdapterCallback::ProcessUnPairedCellHook(true);
867 MS_LOG(INFO) << "Status record: end run graph by single op";
868 }
869
RunGraphByCondition(const ActorInfo & actor_info,const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,VectorRef * outputs)870 void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info,
871 const VectorRef &args, VectorRef *outputs) {
872 bool enable_run_graph_by_single_op =
873 std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
874 [](const KernelGraphPtr &graph) { return graph->has_flag(kFlagEnableRunGraphBySingleOp); });
875 if (enable_run_graph_by_single_op) {
876 RunGraphBySingleOp(graph_compiler_info, args, outputs);
877 } else {
878 RunGraphByActors(actor_info, graph_compiler_info, args, outputs);
879 }
880 }
881
WaitTaskFinish() const882 void MindRTBackend::WaitTaskFinish() const {
883 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kWaitTaskFinish,
884 runtime::kDefaultOpName);
885 runtime::OpExecutor::GetInstance().WaitAll();
886 }
887
ClearOpExecutorResource() const888 void MindRTBackend::ClearOpExecutorResource() const { runtime::OpExecutor::GetInstance().Reset(); }
889
SyncStream()890 void MindRTBackend::SyncStream() {
891 const auto &device_context =
892 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
893 MS_EXCEPTION_IF_NULL(device_context);
894 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
895
896 auto ret = device_context->device_res_manager_->SyncAllStreams();
897 if (!ret) {
898 MS_LOG(EXCEPTION) << "Sync Stream failed";
899 }
900 }
901
EraseSingleOpCache(const GraphInfo & graph_info) const902 void MindRTBackend::EraseSingleOpCache(const GraphInfo &graph_info) const {
903 pynative::OpCompiler::GetInstance().ClearOpCache(graph_info);
904 }
905
ReleaseForwardOutput(const std::vector<ValuePtr> & input_values)906 void MindRTBackend::ReleaseForwardOutput(const std::vector<ValuePtr> &input_values) {
907 graph_compiler_->UpdateForwardOpOutputRefCount(input_values, &forward_op_output_tensor_id_);
908 }
909
OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> & context)910 void MindRTBackend::OpRunCallback(const std::shared_ptr<runtime::OpTaskContext> &context) {
911 MS_LOG(DEBUG) << "OpRunCallback start";
912 auto ms_context = MsContext::GetInstance();
913 MS_EXCEPTION_IF_NULL(ms_context);
914 auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
915 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, context->is_pynative_infer());
916 MS_EXCEPTION_IF_NULL(context);
917 runtime::OpRunner::RunSingleOpGraph(context->op_run_info(), context->op_compiler_info(),
918 runtime::OpRunner::GetTensorWithoutValueMask(context->op_run_info()));
919
920 MS_EXCEPTION_IF_NULL(context->op_run_info());
921 if (!context->op_run_info()->is_infer) {
922 ReleaseForwardOutput(context->op_run_info()->base_op_run_info.expanded_input_values);
923 }
924
925 ClearGraphDeviceAddress(context->graph(), context->device_context(), context->op_run_info()->is_gradient_out);
926 ClearInputDeviceAddress(context->graph(), context->device_context());
927 ClearOpInputOutput(context->op_compiler_info());
928
929 // Reset PyNative infer flag.
930 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, infer_flag);
931 MS_LOG(DEBUG) << "OpRunCallback end";
932 }
933
OpRunCallbackDynamic(const std::shared_ptr<runtime::OpTaskContext> & context)934 void MindRTBackend::OpRunCallbackDynamic(const std::shared_ptr<runtime::OpTaskContext> &context) {
935 MS_LOG(DEBUG) << "OpRunCallback start";
936 auto ms_context = MsContext::GetInstance();
937 MS_EXCEPTION_IF_NULL(ms_context);
938 auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
939 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, context->is_pynative_infer());
940
941 MS_EXCEPTION_IF_NULL(context);
942 runtime::DynamicOpRunner::RunSingleOpGraph(context->op_run_info(), context->op_compiler_info(),
943 runtime::OpRunner::GetTensorWithoutValueMask(context->op_run_info()));
944
945 MS_EXCEPTION_IF_NULL(context->op_run_info());
946 if (!context->op_run_info()->is_infer) {
947 ReleaseForwardOutput(context->op_run_info()->base_op_run_info.expanded_input_values);
948 }
949
950 ClearOpInputOutput(context->op_compiler_info());
951 // Reset PyNative infer flag.
952 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, infer_flag);
953 MS_LOG(DEBUG) << "OpRunCallback end";
954 }
955
DispatchOpTask(bool single_op_cache_hit,VectorRef * outputs,const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info)956 void MindRTBackend::DispatchOpTask(bool single_op_cache_hit, VectorRef *outputs,
957 const OpCompilerInfoPtr &op_compiler_info,
958 const session::BackendOpRunInfoPtr &op_run_info) {
959 MS_EXCEPTION_IF_NULL(op_compiler_info);
960 const auto &graph = op_compiler_info->graph_;
961 MS_EXCEPTION_IF_NULL(graph);
962
963 runtime::OpRunner::UpdateDeviceAddress(graph, runtime::OpRunner::GetTensorWithoutValueMask(op_run_info),
964 op_compiler_info->device_context_, false);
965 // Create output tensor
966 UpdateOutput(op_compiler_info->graph_output_nodes_, outputs);
967
968 auto ms_context = MsContext::GetInstance();
969 MS_EXCEPTION_IF_NULL(ms_context);
970 auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
971 auto run_op_context =
972 std::make_shared<runtime::OpTaskContext>(graph->graph_id(), graph, op_run_info, op_compiler_info, infer_flag);
973
974 auto &op_executor = runtime::OpExecutor::GetInstance();
975 if (!single_op_cache_hit) {
976 CompileSingleOpGraph(op_compiler_info, op_compiler_info->device_context_);
977 }
978
979 auto run_task = std::make_shared<runtime::DeviceOpRunTask>(
980 run_op_context, [this](const std::shared_ptr<runtime::OpTaskContext> &ctx) { OpRunCallback(ctx); });
981 runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(run_task->task_id());
982 op_executor.PushOpRunTask(run_task);
983 }
984
DispatchOpTaskDynamic(VectorRef * outputs,const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info,const vector<device::DeviceAddressPtr> & device_address_list)985 void MindRTBackend::DispatchOpTaskDynamic(VectorRef *outputs, const OpCompilerInfoPtr &op_compiler_info,
986 const session::BackendOpRunInfoPtr &op_run_info,
987 const vector<device::DeviceAddressPtr> &device_address_list) {
988 MS_EXCEPTION_IF_NULL(op_compiler_info);
989 const auto &graph = op_compiler_info->graph_;
990 MS_EXCEPTION_IF_NULL(graph);
991
992 auto ms_context = MsContext::GetInstance();
993 MS_EXCEPTION_IF_NULL(ms_context);
994 auto infer_flag = ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER);
995 auto run_op_context =
996 std::make_shared<runtime::OpTaskContext>(graph->graph_id(), graph, op_run_info, op_compiler_info, infer_flag);
997
998 auto &op_executor = runtime::OpExecutor::GetInstance();
999 auto task = std::make_shared<runtime::DeviceOpRunTask>(
1000 run_op_context, [this](const std::shared_ptr<runtime::OpTaskContext> &ctx) { OpRunCallbackDynamic(ctx); });
1001 runtime::ProfilerAnalyzer::GetInstance().RecordFlowData(task->task_id());
1002 op_executor.PushOpRunTask(task);
1003 }
1004
RunOpImpl(bool single_op_cache_hit,const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info,VectorRef * outputs)1005 void MindRTBackend::RunOpImpl(bool single_op_cache_hit, const OpCompilerInfoPtr &op_compiler_info,
1006 const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
1007 MS_EXCEPTION_IF_NULL(op_run_info);
1008 MS_EXCEPTION_IF_NULL(op_compiler_info);
1009 // Fetch outputs.
1010 const auto &graph = op_compiler_info->graph_;
1011 MS_EXCEPTION_IF_NULL(graph);
1012 MS_EXCEPTION_IF_NULL(graph_compiler_);
1013 const auto &output_nodes = op_compiler_info->graph_output_nodes_;
1014 MS_EXCEPTION_IF_NULL(outputs);
1015
1016 auto device_context = op_compiler_info->device_context_;
1017 auto &op_executor = runtime::OpExecutor::GetInstance();
1018 if (!DisableRunOpAsync(op_compiler_info, op_run_info)) {
1019 MS_LOG(DEBUG) << "Async exec enabled, op: " << op_run_info->base_op_run_info.op_name;
1020 DispatchOpTask(single_op_cache_hit, outputs, op_compiler_info, op_run_info);
1021 return;
1022 }
1023
1024 MS_LOG(DEBUG) << "Async exec disabled, op: " << op_run_info->base_op_run_info.op_name;
1025 if (!op_executor.RunQueueEmpty()) {
1026 WaitTaskFinish();
1027 }
1028 if (!single_op_cache_hit) {
1029 CompileSingleOpGraph(op_compiler_info, device_context);
1030 }
1031 const auto &tensors_without_value_mask = runtime::OpRunner::GetTensorWithoutValueMask(op_run_info);
1032 runtime::OpRunner::UpdateDeviceAddress(graph, tensors_without_value_mask, device_context, true);
1033
1034 runtime::OpRunner::RunSingleOpGraph(op_run_info, op_compiler_info, tensors_without_value_mask);
1035
1036 if (!op_run_info->is_infer) {
1037 ReleaseForwardOutput(op_run_info->base_op_run_info.expanded_input_values);
1038 }
1039 UpdateOutput(output_nodes, outputs);
1040
1041 ClearGraphDeviceAddress(graph, device_context, op_run_info->is_gradient_out);
1042 ClearInputDeviceAddress(graph, device_context);
1043 ClearOpInputOutput(op_compiler_info);
1044
1045 if (op_run_info->base_op_run_info.has_dynamic_output || op_compiler_info->need_refresh_abstract_) {
1046 UpdateOutputAbstract(*outputs, op_run_info);
1047 }
1048 if (op_compiler_info->need_erase_) {
1049 EraseSingleOpCache(op_compiler_info->graph_info_);
1050 }
1051 }
1052
RunOpImplDynamic(bool single_op_cache_hit,const OpCompilerInfoPtr & op_compiler_info,const session::BackendOpRunInfoPtr & op_run_info,VectorRef * outputs)1053 void MindRTBackend::RunOpImplDynamic(bool single_op_cache_hit, const OpCompilerInfoPtr &op_compiler_info,
1054 const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
1055 MS_EXCEPTION_IF_NULL(op_run_info);
1056 MS_EXCEPTION_IF_NULL(op_compiler_info);
1057 MS_LOG(DEBUG) << "RunOpImplDynamic " << op_run_info->base_op_run_info.op_name;
1058 // Fetch outputs.
1059 const auto &graph = op_compiler_info->graph_;
1060 MS_EXCEPTION_IF_NULL(graph);
1061 MS_EXCEPTION_IF_NULL(graph_compiler_);
1062 MS_EXCEPTION_IF_NULL(outputs);
1063
1064 auto device_context = op_compiler_info->device_context_;
1065 if (!single_op_cache_hit) {
1066 CompileSingleOpGraph(op_compiler_info, device_context, true);
1067 }
1068 if (!DisableRunOpAsync(op_compiler_info, op_run_info)) {
1069 MS_LOG(DEBUG) << "Async exec enabled, op: " << op_run_info->base_op_run_info.op_name;
1070 auto input_tensors = runtime::OpRunner::GetTensorWithoutValueMask(op_run_info);
1071 runtime::DynamicOpRunner::UpdateInputDeviceAddress(op_compiler_info, input_tensors, false);
1072 auto device_address_list = runtime::DeviceAddressUtils::CreateGraphOutputDeviceAddress(
1073 op_compiler_info, op_run_info->base_op_run_info.abstract, op_run_info->base_op_run_info.stream_id);
1074 // Create output tensor
1075 UpdateOutputDynamic(op_run_info, op_compiler_info, device_address_list, outputs);
1076 DispatchOpTaskDynamic(outputs, op_compiler_info, op_run_info, device_address_list);
1077 return;
1078 }
1079 MS_LOG(DEBUG) << "Async exec disabled, op: " << op_run_info->base_op_run_info.op_name;
1080 auto &op_executor = runtime::OpExecutor::GetInstance();
1081 if (!op_executor.RunQueueEmpty()) {
1082 WaitTaskFinish();
1083 }
1084 auto input_tensors = runtime::OpRunner::GetTensorWithoutValueMask(op_run_info);
1085 runtime::DynamicOpRunner::UpdateInputDeviceAddress(op_compiler_info, input_tensors, true);
1086 runtime::DynamicOpRunner::RunSingleOpGraph(op_run_info, op_compiler_info, input_tensors);
1087
1088 if (!op_run_info->is_infer) {
1089 ReleaseForwardOutput(op_run_info->base_op_run_info.expanded_input_values);
1090 }
1091
1092 const auto &device_address_list = GetOutputDeviceAddress(op_compiler_info);
1093 // Create output tensor
1094 UpdateOutputDynamic(op_run_info, op_compiler_info, device_address_list, outputs);
1095 UpdateOutputAbstract(*outputs, op_run_info);
1096 ClearOpInputOutput(op_compiler_info);
1097 if (op_compiler_info->need_erase_) {
1098 EraseSingleOpCache(op_compiler_info->graph_info_);
1099 }
1100 }
1101
RunOp(const session::BackendOpRunInfoPtr & op_run_info,VectorRef * outputs)1102 void MindRTBackend::RunOp(const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
1103 MS_EXCEPTION_IF_NULL(op_run_info);
1104 MS_EXCEPTION_IF_NULL(graph_compiler_);
1105 MS_LOG(DEBUG) << "Run Op " << op_run_info->base_op_run_info.op_name;
1106
1107 bool single_op_cache_hit = true;
1108 auto op_compiler_info =
1109 pynative::OpCompiler::GetInstance().Compile(op_run_info, &single_op_cache_hit, device_name_, device_id_);
1110 MS_EXCEPTION_IF_NULL(op_compiler_info);
1111 op_compiler_info->WaitReady();
1112 RunOpImpl(single_op_cache_hit, op_compiler_info, op_run_info, outputs);
1113 }
1114
RunOpDynamic(const session::BackendOpRunInfoPtr & op_run_info,VectorRef * outputs)1115 void MindRTBackend::RunOpDynamic(const session::BackendOpRunInfoPtr &op_run_info, VectorRef *outputs) {
1116 MS_EXCEPTION_IF_NULL(op_run_info);
1117 MS_EXCEPTION_IF_NULL(graph_compiler_);
1118 MS_LOG(DEBUG) << "Run Op " << op_run_info->base_op_run_info.op_name;
1119
1120 // Single op graph compile
1121 bool single_op_cache_hit = true;
1122 auto op_compiler_info =
1123 pynative::OpCompiler::GetInstance().Compile(op_run_info, &single_op_cache_hit, device_name_, device_id_);
1124 MS_EXCEPTION_IF_NULL(op_compiler_info);
1125 op_compiler_info->WaitReady();
1126 RunOpImplDynamic(single_op_cache_hit, op_compiler_info, op_run_info, outputs);
1127 }
1128
RunViewKernelTaskAsyncImpl(const runtime::KernelTaskType & task_type,DeviceContext * device_context,const device::DeviceAddressPtrList & input_addr_list,const device::DeviceAddressPtrList & output_addr_list,const size_t & stream_id)1129 void MindRTBackend::RunViewKernelTaskAsyncImpl(const runtime::KernelTaskType &task_type, DeviceContext *device_context,
1130 const device::DeviceAddressPtrList &input_addr_list,
1131 const device::DeviceAddressPtrList &output_addr_list,
1132 const size_t &stream_id) {
1133 static auto kernel_task_func = [stream_id, task_type, &input_addr_list, &output_addr_list, device_context]() {
1134 runtime::OpRunner::LaunchKernelTask(task_type, device_context, input_addr_list, output_addr_list, stream_id);
1135 };
1136
1137 runtime::OpExecutor::GetInstance().PushSimpleOpRunTask(
1138 std::make_shared<runtime::PassthroughDeviceTask>(kernel_task_func));
1139 }
1140
RunViewKernelTask(const pynative::BaseOpRunInfo & base_op_run_info,const runtime::KernelTaskType & task_type,bool enable_async)1141 void MindRTBackend::RunViewKernelTask(const pynative::BaseOpRunInfo &base_op_run_info,
1142 const runtime::KernelTaskType &task_type, bool enable_async) {
1143 device::DeviceAddressPtrList input_addr_list;
1144 device::DeviceAddressPtrList output_addr_list;
1145
1146 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1147 {base_op_run_info.device_target, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
1148 MS_EXCEPTION_IF_NULL(device_context);
1149
1150 for (size_t idx = 0; idx < base_op_run_info.expanded_input_values.size(); idx++) {
1151 auto input_tensor = base_op_run_info.expanded_input_values[idx]->cast<tensor::BaseTensorPtr>();
1152 MS_EXCEPTION_IF_NULL(input_tensor);
1153 if (input_tensor->device_address() == nullptr) {
1154 if (idx == 0) {
1155 MS_LOG(EXCEPTION) << "First tensor can not be nullptr, op name:" << base_op_run_info.op_name;
1156 }
1157 auto address_size = GetTypeByte(TypeIdToType(input_tensor->data_type())) * SizeOf(input_tensor->shape());
1158
1159 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1160 nullptr, address_size, Format::DEFAULT_FORMAT, input_tensor->data_type(), input_tensor->shape(),
1161 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1162 kernel_tensor->SetType(std::make_shared<TensorType>(input_tensor->Dtype()));
1163 kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(input_tensor->shape()));
1164 kernel_tensor->set_stream_id(base_op_run_info.stream_id);
1165 auto input_addr = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1166
1167 input_tensor->set_device_address(input_addr);
1168 RunAllocMemTask(device_context, input_tensor, enable_async, false);
1169 (void)input_addr_list.emplace_back(input_addr);
1170 } else {
1171 auto input_addr = std::static_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
1172 MS_EXCEPTION_IF_NULL(input_addr);
1173 if (input_addr->GetDeviceType() == device::DeviceType::kCPU) {
1174 RunAllocMemTask(device_context, input_tensor, enable_async, true);
1175 }
1176
1177 (void)input_addr_list.emplace_back(input_addr);
1178 }
1179 }
1180
1181 std::transform(base_op_run_info.output_tensors.begin(), base_op_run_info.output_tensors.end(),
1182 std::back_inserter(output_addr_list), [](const auto &tensor) {
1183 return std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1184 });
1185
1186 if (enable_async) {
1187 RunViewKernelTaskAsyncImpl(task_type, device_context, input_addr_list, output_addr_list,
1188 base_op_run_info.stream_id);
1189 } else {
1190 WaitTaskFinish();
1191 runtime::OpRunner::LaunchKernelTask(task_type, device_context, input_addr_list, output_addr_list,
1192 base_op_run_info.stream_id);
1193 }
1194 }
1195
RunAllocMemTask(DeviceContext * device_context,const tensor::BaseTensorPtr & tensor,bool enable_async,bool is_cpu_address_exist)1196 void MindRTBackend::RunAllocMemTask(DeviceContext *device_context, const tensor::BaseTensorPtr &tensor,
1197 bool enable_async, bool is_cpu_address_exist) {
1198 if (!enable_async) {
1199 WaitTaskFinish();
1200 return AllocateMemForTensor(tensor, device_context, is_cpu_address_exist);
1201 }
1202 auto alloc_mem_func = [device_context, tensor, is_cpu_address_exist]() {
1203 AllocateMemForTensor(tensor, device_context, is_cpu_address_exist);
1204 };
1205 runtime::OpExecutor::GetInstance().PushSimpleOpRunTask(
1206 std::make_shared<runtime::PassthroughDeviceTask>(alloc_mem_func));
1207 }
1208
CompileSingleOpGraph(const OpCompilerInfoPtr & op_compiler_info,const DeviceContext * device_context,bool is_dynamic_shape) const1209 void MindRTBackend::CompileSingleOpGraph(const OpCompilerInfoPtr &op_compiler_info, const DeviceContext *device_context,
1210 bool is_dynamic_shape) const {
1211 MS_EXCEPTION_IF_NULL(op_compiler_info);
1212 MS_EXCEPTION_IF_NULL(device_context);
1213 pynative::OpCompiler::GetInstance().KernelBuild(op_compiler_info, device_context, is_dynamic_shape);
1214 }
1215
UpdateOutput(const std::vector<session::KernelWithIndex> & output_nodes,VectorRef * const outputs) const1216 void MindRTBackend::UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes,
1217 VectorRef *const outputs) const {
1218 MS_EXCEPTION_IF_NULL(outputs);
1219
1220 for (auto &item_with_index : output_nodes) {
1221 MS_EXCEPTION_IF_NULL(item_with_index.first);
1222 if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
1223 continue;
1224 }
1225 auto output_tensor = CreateOutputTensor(item_with_index.first, item_with_index.second);
1226 MS_EXCEPTION_IF_NULL(output_tensor);
1227 output_tensor->set_need_pipeline_sync(true);
1228 outputs->emplace_back(output_tensor);
1229 }
1230 }
1231
UpdateOutputDynamic(const session::BackendOpRunInfoPtr & op_run_info,const OpCompilerInfoPtr & op_compiler_info,const vector<device::DeviceAddressPtr> & device_address_list,VectorRef * const outputs) const1232 void MindRTBackend::UpdateOutputDynamic(const session::BackendOpRunInfoPtr &op_run_info,
1233 const OpCompilerInfoPtr &op_compiler_info,
1234 const vector<device::DeviceAddressPtr> &device_address_list,
1235 VectorRef *const outputs) const {
1236 MS_EXCEPTION_IF_NULL(op_run_info);
1237 MS_LOG(DEBUG) << "No promise, just create tensor and address, op " << op_run_info->base_op_run_info.op_name;
1238 MS_EXCEPTION_IF_NULL(op_compiler_info);
1239 auto output_nodes = op_compiler_info->graph_output_nodes_;
1240 auto outputs_size = output_nodes.size();
1241 if (op_compiler_info->graph_outputs_tensor_num_.size() != outputs_size) {
1242 MS_LOG(EXCEPTION) << "The size of graph_outputs_tensor_num_:" << op_compiler_info->graph_outputs_tensor_num_.size()
1243 << " is not equal to outputs_size:" << outputs_size;
1244 }
1245
1246 if (device_address_list.size() != outputs_size) {
1247 MS_LOG(EXCEPTION) << "The size of device_address_list:" << device_address_list.size()
1248 << " is not equal to outputs_size:" << outputs_size;
1249 }
1250
1251 for (size_t i = 0; i < outputs_size; ++i) {
1252 auto item_with_index = output_nodes[i];
1253 MS_EXCEPTION_IF_NULL(item_with_index.first);
1254 if (op_compiler_info->graph_outputs_tensor_num_[i] == 0) {
1255 continue;
1256 }
1257 auto output_address = device_address_list[i];
1258 MS_EXCEPTION_IF_NULL(output_address);
1259 auto output_tensor =
1260 CreateOutputTensorDynamicImpl(op_compiler_info, item_with_index.first, item_with_index.second, output_address, i);
1261 MS_EXCEPTION_IF_NULL(output_tensor);
1262 output_tensor->set_need_pipeline_sync(true);
1263 outputs->emplace_back(output_tensor);
1264 }
1265 }
1266
ClearResource()1267 void MindRTBackend::ClearResource() {
1268 graph_compiler_ = std::make_shared<GraphCompiler>();
1269 graph_id_to_device_context_.clear();
1270 func_graph_to_kernel_graph_ids_.clear();
1271 graph_info_to_device_context_.clear();
1272 control_nodes_.clear();
1273 actor_to_graph_compiler_info_.clear();
1274 cnode_ref_counts_.clear();
1275 }
1276
GetGraphById(GraphId graph_id)1277 KernelGraphPtr MindRTBackend::GetGraphById(GraphId graph_id) {
1278 MS_EXCEPTION_IF_NULL(graph_compiler_);
1279 return graph_compiler_->Fetch(graph_id);
1280 }
1281 } // namespace compile
1282 } // namespace mindspore
1283