1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "vm/backend.h"
17
18 #include <algorithm>
19 #include <vector>
20 #include <map>
21
22 #include "vm/transform.h"
23 #include "backend/session/session_factory.h"
24 #include "backend/optimizer/common/helper.h"
25 #include "pipeline/pynative/pynative_execute.h"
26 #include "pipeline/jit/parse/data_converter.h"
27 #include "ir/anf.h"
28 #include "pybind_api/ir/base_ref_py.h"
29 #include "utils/callbacks.h"
30 #include "utils/convert_utils.h"
31 #include "utils/log_adapter.h"
32 #include "utils/ms_utils.h"
33 #include "runtime/hardware/device_context_manager.h"
34 #include "runtime/framework/graph_compiler.h"
35 #include "utils/scoped_long_running.h"
36 #ifdef ENABLE_GE
37 #include "utils/callbacks_ge.h"
38 #endif
39 #ifdef ENABLE_DEBUGGER
40 #include "debug/debugger/debugger.h"
41 #endif
42 #ifndef ENABLE_SECURITY
43 #include "debug/data_dump/dump_json_parser.h"
44 #endif
45 #ifdef ENABLE_DUMP_IR
46 #include "debug/rdr/running_data_recorder.h"
47 #endif
48
49 namespace mindspore {
50 namespace compile {
GetCond(const BaseRef & c,bool * const value)51 bool Backend::GetCond(const BaseRef &c, bool *const value) {
52 mindspore::ScopedLongRunning long_running;
53 return BaseRefToBool(c, value);
54 }
GetIndex(const BaseRef & c,int64_t * const value)55 bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
56
Backend(const std::string & name)57 Backend::Backend(const std::string &name) : name_(name) {
58 MS_LOG(DEBUG) << "Select backend:" << name;
59 convert_fn_ = MsVmConvert;
60 is_multi_graph_sink_ = false;
61 }
62
MsConvert(const GraphSegmentPtr & segment,const std::string & target)63 LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) {
64 MS_LOG(DEBUG) << "MsConvert";
65 MS_EXCEPTION_IF_NULL(segment);
66 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
67 LinConvertResult result;
68 FuncGraphPtr fg;
69 AnfNodePtrList inputs;
70 AnfNodePtrList outputs;
71 std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
72 result.inputs = inputs;
73 result.outputs = outputs;
74 result.graph_id = kInvalidGraphId;
75 auto current_session = target_sess_;
76 if (target != target_device_ && !target.empty()) {
77 CreateOtherSession(target);
78 current_session = other_sess_;
79 }
80 MS_EXCEPTION_IF_NULL(current_session);
81 GraphId graph_id = current_session->CompileGraph(segment, outputs);
82 segment->graph_id_ = graph_id;
83 auto graph = current_session->GetGraph(graph_id);
84 MS_EXCEPTION_IF_NULL(graph);
85 for (auto &pre_segment : segment->pre_segments_) {
86 MS_EXCEPTION_IF_NULL(pre_segment);
87 MS_EXCEPTION_IF_NULL(target_sess_);
88 auto pre_graph = target_sess_->GetGraph(pre_segment->graph_id_);
89 if (pre_graph == nullptr) {
90 MS_EXCEPTION_IF_NULL(other_sess_);
91 pre_graph = other_sess_->GetGraph(pre_segment->graph_id_);
92 }
93 MS_EXCEPTION_IF_NULL(pre_graph);
94 pre_graph->AddPostGraph(graph);
95 graph->AddPreGraph(pre_graph);
96 MS_LOG(INFO) << "Link graph " << pre_segment->graph_id_ << " to " << graph_id;
97 }
98
99 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
100 MS_LOG(INFO) << "PrecompileOnly, stop run graph";
101 return result;
102 }
103 auto ms_context = MsContext::GetInstance();
104 const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
105 if (!pynative_mode || target != "Ascend") {
106 if (target != target_device_ && !target.empty()) {
107 MS_EXCEPTION_IF_NULL(other_sess_);
108 other_sess_->BuildGraph(graph_id);
109 } else if (!is_multi_graph_sink_) {
110 MS_EXCEPTION_IF_NULL(target_sess_);
111 target_sess_->BuildGraph(graph_id);
112 }
113 }
114 result.run = std::make_shared<RunFunc>(
115 [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); });
116 MS_EXCEPTION_IF_NULL(result.run);
117
118 result.simu_run = std::make_shared<RunFunc>(
119 [graph_id, this](const VectorRef &args) -> VectorRef { return MsSimuRunGraph(graph_id); });
120 MS_EXCEPTION_IF_NULL(result.simu_run);
121 result.graph_id = graph_id;
122
123 graph_id_map_[graph_id] = result;
124 return result;
125 }
126
127 // compile set input output
MsSimuRunGraph(const GraphId & g)128 VectorRef MsBackend::MsSimuRunGraph(const GraphId &g) {
129 MS_LOG(DEBUG) << "Set graph input:" << g;
130 std::vector<BaseRef> outputs;
131 (void)std::transform(graph_id_map_[g].outputs.begin(), graph_id_map_[g].outputs.end(), std::back_inserter(outputs),
132 [](const AnfNodePtr &v) { return v; });
133 return VectorRef(outputs);
134 }
135
136 namespace {
PushInputTensor(const BaseRef & arg,std::vector<tensor::TensorPtr> * inputs)137 void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
138 MS_EXCEPTION_IF_NULL(inputs);
139 if (utils::isa<tensor::TensorPtr>(arg)) {
140 auto value = utils::cast<tensor::TensorPtr>(arg);
141 inputs->push_back(value);
142 } else if (utils::isa<ValuePtr>(arg)) {
143 auto value = utils::cast<ValuePtr>(arg);
144 MS_EXCEPTION_IF_NULL(value);
145 if (value->isa<ValueTuple>()) {
146 auto value_tuple = value->cast<ValueTuplePtr>();
147 MS_EXCEPTION_IF_NULL(value_tuple);
148 auto tuple_value = value_tuple->value();
149 (void)std::transform(tuple_value.begin(), tuple_value.end(), std::back_inserter(*inputs),
150 [](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
151 } else if (value->isa<Scalar>()) {
152 tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
153 inputs->push_back(scalar_tensor);
154 } else if (value->isa<Monad>()) {
155 // If value is a monad, replace it with an unused tensor.
156 inputs->push_back(std::make_shared<tensor::Tensor>(int64_t(0), kBool));
157 } else {
158 inputs->push_back(value->cast<tensor::TensorPtr>());
159 }
160 } else if (utils::isa<PyObjectRef>(arg)) {
161 auto value = utils::cast<PyObjectRef>(arg).object_;
162 inputs->push_back(py::cast<tensor::TensorPtr>(value));
163 } else if (utils::isa<VectorRefPtr>(arg)) {
164 const auto &args_new = utils::cast<VectorRef>(arg);
165 for (const auto &v : args_new) {
166 PushInputTensor(v, inputs);
167 }
168 } else {
169 MS_LOG(WARNING) << "Invalid input type.";
170 }
171 }
172
173 // Insert the front_node related tensor in the input_tensor.
PushTensor(const VectorRef & args,const std::vector<AnfNodePtr> & parameters,const AnfNodePtr & front_node,std::vector<tensor::TensorPtr> * input_tensor)174 void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
175 std::vector<tensor::TensorPtr> *input_tensor) {
176 const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
177 if (iter == parameters.end()) {
178 (void)((*input_tensor).emplace_back(nullptr));
179 return;
180 }
181 auto position = iter - parameters.begin();
182 PushInputTensor(args[position], input_tensor);
183 }
184
UpdateOutputAbstract(const KernelGraphPtr & kernel_graph,OpRunInfo * op_run_info)185 void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, OpRunInfo *op_run_info) {
186 MS_EXCEPTION_IF_NULL(kernel_graph);
187 MS_EXCEPTION_IF_NULL(op_run_info);
188 const auto &kernels = kernel_graph->execution_order();
189 for (const auto &kernel : kernels) {
190 MS_EXCEPTION_IF_NULL(kernel);
191 if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
192 op_run_info->abstract = kernel->abstract();
193 }
194 }
195 }
196
CreateOutputTensor(const AnfNodePtr & output_node,size_t output_index)197 TensorPtr CreateOutputTensor(const AnfNodePtr &output_node, size_t output_index) {
198 MS_EXCEPTION_IF_NULL(output_node);
199 // Create host tensor, the output tensor should use the infer type, it will be handed correctly by tensor data sync
200 // when infer type is not equal to device type.
201 auto type_id = AnfAlgo::GetOutputInferDataType(output_node, output_index);
202 std::vector<int64_t> temp_shape;
203 const auto &shape = AnfAlgo::GetOutputInferShape(output_node, output_index);
204 (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
205 auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
206 tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(output_node, output_index));
207
208 // Put device tensor into host tensor.
209 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
210 MS_EXCEPTION_IF_NULL(device_tensor);
211 tensor->set_device_address(device_tensor);
212
213 auto ms_context = MsContext::GetInstance();
214 MS_EXCEPTION_IF_NULL(ms_context);
215 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
216 // If execution mode is Graph Mode in MsContext, the tensor will be the input of graph which will execute in Graph
217 // Mode, if the graph contain no CNode after optimization, the tensor need sync to host.
218 tensor->data_sync(false);
219 }
220
221 return tensor;
222 }
223
UpdateOutput(const std::vector<session::KernelWithIndex> & output_nodes,VectorRef * const outputs)224 void UpdateOutput(const std::vector<session::KernelWithIndex> &output_nodes, VectorRef *const outputs) {
225 MS_EXCEPTION_IF_NULL(outputs);
226 for (auto &item_with_index : output_nodes) {
227 MS_EXCEPTION_IF_NULL(item_with_index.first);
228 // if is graph return nothing ,the function should return a null anylist
229 if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
230 continue;
231 }
232 outputs->emplace_back(CreateOutputTensor(item_with_index.first, item_with_index.second));
233 }
234 }
235
UpdateOutputDeviceAddress(const std::vector<session::KernelWithIndex> & output_nodes,const DeviceContext * device_context)236 void UpdateOutputDeviceAddress(const std::vector<session::KernelWithIndex> &output_nodes,
237 const DeviceContext *device_context) {
238 for (auto &item_with_index : output_nodes) {
239 auto &output_node = item_with_index.first;
240 auto output_index = item_with_index.second;
241 if (output_node != nullptr) {
242 if (!AnfAlgo::OutputAddrExist(output_node, output_index, false)) {
243 continue;
244 }
245 const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(output_node, output_index, false);
246
247 if ((device_tensor == nullptr) || (device_tensor->GetPtr() == nullptr)) {
248 continue;
249 }
250
251 MS_EXCEPTION_IF_NULL(device_context);
252 auto new_device_tensor = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
253 device_tensor->format(), device_tensor->type_id());
254 MS_EXCEPTION_IF_NULL(new_device_tensor);
255 new_device_tensor->set_original_ref_count(device_tensor->original_ref_count());
256 new_device_tensor->ResetRefCount();
257 AnfAlgo::SetOutputAddr(new_device_tensor, output_index, output_node.get());
258 }
259 }
260 }
261
UpdateInputDeviceAddress(const KernelGraphPtr & graph)262 void UpdateInputDeviceAddress(const KernelGraphPtr &graph) {
263 MS_EXCEPTION_IF_NULL(graph);
264 for (const auto &node : graph->input_nodes()) {
265 MS_EXCEPTION_IF_NULL(node);
266 if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
267 AnfAlgo::SetOutputAddr(nullptr, 0, node.get());
268 }
269 }
270 }
271 } // namespace
272
MsRunGraph(const GraphId & g,const VectorRef & args,const std::string & target)273 VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
274 MS_LOG(DEBUG) << "Start ms graph run:" << args.size() << ", g:" << g;
275 // Run graph
276 std::vector<tensor::TensorPtr> inputs;
277 for (const auto &arg : args) {
278 PushInputTensor(arg, &inputs);
279 }
280
281 VectorRef outputs;
282 // Call ms RunGraphAsync or RunOpsInGraph (graphId, input ,output)
283 const session::SessionPtr &exe_session = ((target != target_device_ && !target.empty()) ? other_sess_ : target_sess_);
284 MS_EXCEPTION_IF_NULL(exe_session);
285 auto ms_context = MsContext::GetInstance();
286 const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
287 if (pynative_mode) {
288 exe_session->RunOpsInGraph(g, inputs, &outputs);
289 } else {
290 exe_session->RunGraphAsync(g, inputs, &outputs);
291 }
292
293 MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size();
294 return outputs;
295 }
296
MsBackend(const std::string & name,const std::string & target,uint32_t device_id)297 MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) {
298 convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2);
299 target_sess_ = session::SessionFactory::Get().Create(target);
300 if (target_sess_ == nullptr) {
301 MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
302 }
303 target_sess_->Init(device_id);
304 #ifndef ENABLE_SECURITY
305 target_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
306 #endif
307 target_device_ = target;
308 }
309
CreateOtherSession(const std::string & target)310 void MsBackend::CreateOtherSession(const std::string &target) {
311 if (other_sess_ != nullptr && other_device_ == target) {
312 return;
313 }
314 other_sess_ = session::SessionFactory::Get().Create(target);
315 if (other_sess_ == nullptr) {
316 MS_LOG(EXCEPTION) << "Session create failed!, please make sure target device:" << target << " is available.";
317 }
318 auto context_ptr = MsContext::GetInstance();
319 MS_EXCEPTION_IF_NULL(context_ptr);
320 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
321 other_sess_->Init(device_id);
322 #ifndef ENABLE_SECURITY
323 other_sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
324 #endif
325 other_device_ = target;
326 }
327
CompileGraph(NotNull<FuncGraphPtr> fg)328 GraphId MsBackend::CompileGraph(NotNull<FuncGraphPtr> fg) {
329 MS_EXCEPTION_IF_NULL(target_sess_);
330 return target_sess_->CompileGraph(fg);
331 }
332
RunGraph(GraphId graph_id,const VectorRef & args)333 VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); }
334
ClearSessionGraphs()335 void MsBackend::ClearSessionGraphs() {
336 if (target_sess_ != nullptr) {
337 target_sess_->ClearGraph();
338 }
339 }
340
341 #ifdef ENABLE_DEBUGGER
SetDebugger()342 void MsBackend::SetDebugger() {
343 MS_EXCEPTION_IF_NULL(target_sess_);
344 target_sess_->SetDebugger();
345 }
346 #endif
347
MindRTBackend(const std::string & backend_name,const std::string & device_name,uint32_t device_id)348 MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
349 : Backend(backend_name), device_name_(device_name) {
350 root_graph_ = nullptr;
351 auto ms_context = MsContext::GetInstance();
352 const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
353 auto &cut_list = pynative_mode ? compile::control_ops : GetMsNonlinearOps();
354
355 graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
356 graph_compiler_ = std::make_shared<GraphCompiler>();
357
358 const auto &device_context =
359 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
360 device_context->Initialize();
361 device_id_ = device_context->device_context_key().device_id_;
362 #ifdef ENABLE_DEBUGGER
363 SetDebuggerInit();
364 #endif
365 runtime::GraphScheduler::GetInstance().Initialize();
366 }
367
CompileGraphs(const FuncGraphPtr & func_graph)368 const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) {
369 MS_EXCEPTION_IF_NULL(graph_compiler_);
370 MS_EXCEPTION_IF_NULL(func_graph);
371 auto root_graph = WrapPrimitives(func_graph);
372 MS_EXCEPTION_IF_NULL(root_graph);
373 root_graph_ = root_graph.get();
374 // Register a summary callback function, which is called in the final stages of summary.
375 graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
376
377 auto context_ptr = MsContext::GetInstance();
378 MS_EXCEPTION_IF_NULL(context_ptr);
379 ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
380 real_execution_mode_ = ms_execution_mode_;
381
382 // Compile root graph.
383 graph_id_to_device_context_.clear();
384 control_nodes_.clear();
385 CompileGraph(root_graph);
386
387 // Compile sub graphs.
388 MS_EXCEPTION_IF_NULL(root_graph->manager());
389 FuncGraphSet sub_graphs = root_graph->manager()->func_graphs();
390 for (auto sub_graph : sub_graphs) {
391 if (sub_graph != func_graph && sub_graph != nullptr) {
392 CompileGraph(sub_graph);
393 }
394 }
395
396 // Construct the graph compiler info.
397 auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
398
399 if (real_execution_mode_ == kGraphMode) {
400 // Transform graph to actor DAG, and schedule the actor DAG.
401 const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
402 runtime::GraphScheduler::GetInstance().Schedule(actor_set);
403 }
404 MS_EXCEPTION_IF_NULL(graph_compiler_info);
405 const ActorInfo &actor_info = graph_compiler_info->name_;
406 (void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
407 return actor_info;
408 }
409
CompileGraph(const FuncGraphPtr & func_graph)410 void MindRTBackend::CompileGraph(const FuncGraphPtr &func_graph) {
411 MS_EXCEPTION_IF_NULL(func_graph);
412 MS_EXCEPTION_IF_NULL(graph_partition_);
413 MS_EXCEPTION_IF_NULL(graph_compiler_);
414
415 bool contain_multi_target = false;
416 // Split graph to segments.
417 const auto &segments = graph_partition_->Partition(func_graph, &contain_multi_target);
418 MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size:" << segments.size();
419 auto context_ptr = MsContext::GetInstance();
420 MS_EXCEPTION_IF_NULL(context_ptr);
421
422 // Foreach the segments to compile graph.
423 for (const auto &segment : segments) {
424 MS_EXCEPTION_IF_NULL(segment);
425 // Compile the normal nodes, which doesn't contain the cut node.
426 if (segment->nodes_.size() == 0) {
427 MS_LOG(EXCEPTION) << "The segments size is 0.";
428 }
429 if (!segment->is_cut_) {
430 MS_EXCEPTION_IF_NULL(segment->nodes_[0]);
431 MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->fullname_with_scope();
432
433 // Get the device context.
434 const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]);
435 const auto &device_context =
436 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_});
437 MS_EXCEPTION_IF_NULL(device_context);
438 device_context->Initialize();
439
440 // Transform nodes to inputs and outputs.
441 FuncGraphPtr fg;
442 AnfNodePtrList inputs;
443 AnfNodePtrList outputs;
444 std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
445
446 // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
447 if (contain_multi_target && ms_execution_mode_ == kPynativeMode) {
448 real_execution_mode_ = kGraphMode;
449 context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
450 }
451
452 // Compile graph.
453 auto graph_id = graph_compiler_->CompileGraph(segment->nodes_, outputs, device_context);
454
455 if (ms_execution_mode_ != real_execution_mode_) {
456 context_ptr->set_param<int>(MS_CTX_EXECUTION_MODE, ms_execution_mode_);
457 }
458
459 graph_id_to_device_context_[graph_id] = device_context;
460 } else {
461 // Compile the cut node.
462 auto cut_node = segment->nodes_[0];
463 MS_EXCEPTION_IF_NULL(cut_node);
464 MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->fullname_with_scope();
465 control_nodes_.push_back(cut_node);
466 }
467 }
468 }
469
CompileGraph(const OpRunInfo & op_run_info,const GraphInfo & graph_info,const std::vector<int64_t> * tensors_mask,std::vector<tensor::TensorPtr> * input_tensors)470 const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
471 const std::vector<int64_t> *tensors_mask,
472 std::vector<tensor::TensorPtr> *input_tensors) {
473 MS_EXCEPTION_IF_NULL(graph_compiler_);
474 // Get the device context.
475 const auto &device_context =
476 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
477 MS_EXCEPTION_IF_NULL(device_context);
478 device_context->Initialize();
479
480 bool single_op_cache_hit = true;
481 auto graph_id = graph_compiler_->CompileGraph(op_run_info, graph_info, tensors_mask, input_tensors,
482 &single_op_cache_hit, device_context);
483 // The actor set name: graph_id + single operator name.
484 std::string actor_info = std::to_string(graph_id) + "_" + op_run_info.op_name;
485 if (single_op_cache_hit) {
486 auto iter = actor_to_graph_compiler_info_.find(actor_info);
487 if (iter == actor_to_graph_compiler_info_.end()) {
488 MS_LOG(EXCEPTION) << "Can not find graph compiler info for actor set: " << actor_info;
489 }
490 return iter->first;
491 }
492
493 graph_info_to_device_context_.clear();
494 graph_info_to_device_context_[graph_info] = device_context;
495
496 auto context_ptr = MsContext::GetInstance();
497 MS_EXCEPTION_IF_NULL(context_ptr);
498 bool enable_cache = context_ptr->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_OP_GRAPH_CACHE);
499 auto graph_compiler_info = ConstructGraphCompilerInfo(actor_info, tensors_mask, input_tensors, !enable_cache);
500 const auto actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
501 runtime::GraphScheduler::GetInstance().Schedule(actor_set);
502 MS_EXCEPTION_IF_NULL(graph_compiler_info);
503 graph_compiler_info->input_tensors_.clear();
504
505 auto ret = actor_to_graph_compiler_info_.emplace(actor_info, std::move(graph_compiler_info));
506 return ret.first->first;
507 }
508
509 namespace {
GetControlOpInput(const std::shared_ptr<GraphCompiler> & graph_compiler,const CNodePtr & front_cnode,const CNodePtr & backend_cnode,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output_map,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputTensorInfo * input_tensor_info,VectorRef * args)510 void GetControlOpInput(const std::shared_ptr<GraphCompiler> &graph_compiler, const CNodePtr &front_cnode,
511 const CNodePtr &backend_cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
512 const std::map<AnfNodePtr, size_t> ¶meter_index,
513 const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
514 VectorRef *args) {
515 MS_EXCEPTION_IF_NULL(front_cnode);
516 MS_EXCEPTION_IF_NULL(backend_cnode);
517 MS_EXCEPTION_IF_NULL(graph_compiler);
518 MS_EXCEPTION_IF_NULL(args);
519 size_t input_index = 0;
520 auto inputs = front_cnode->inputs();
521 for (size_t i = 1; i < inputs.size(); i++) {
522 const auto &input_node = inputs[i];
523 MS_EXCEPTION_IF_NULL(input_node);
524 auto kernel_with_index = AnfAlgo::VisitKernel(input_node, 0);
525 auto real_input = kernel_with_index.first;
526 MS_EXCEPTION_IF_NULL(real_input);
527
528 if (!real_input->isa<ValueNode>()) {
529 TensorPtr tensor = graph_compiler->GetSingleOpInputTensorByIndex(backend_cnode, op_output_map, parameter_index,
530 graph_inputs, input_tensor_info, input_index);
531 MS_EXCEPTION_IF_NULL(tensor);
532 args->emplace_back(tensor);
533 input_index++;
534 continue;
535 }
536
537 // Get value from value node.
538 const auto &value_node = real_input->cast<ValueNodePtr>();
539 MS_EXCEPTION_IF_NULL(value_node);
540 const auto &value = value_node->value();
541 MS_EXCEPTION_IF_NULL(value);
542
543 if (value->isa<ValueSequeue>()) {
544 const auto &value_sequeue = value->cast<ValueSequeuePtr>();
545 MS_EXCEPTION_IF_NULL(value_sequeue);
546 input_index += value_sequeue->size();
547 } else {
548 input_index++;
549 }
550
551 args->emplace_back(value);
552 }
553 }
554
PlantTensorTupleToVector(const py::tuple & tuple_inputs,std::vector<tensor::TensorPtr> * tensors)555 void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
556 MS_EXCEPTION_IF_NULL(tensors);
557 for (const auto &input_object : tuple_inputs) {
558 if (!py::isinstance<tensor::Tensor>(input_object)) {
559 MS_LOG(EXCEPTION) << "The input object is not a tensor!";
560 }
561 auto tensor = py::cast<tensor::TensorPtr>(input_object);
562 MS_EXCEPTION_IF_NULL(tensor);
563 (void)tensors->emplace_back(tensor);
564 }
565 }
566
ConvertValueTupleToTensor(const py::object & input_object,std::vector<tensor::TensorPtr> * tensors)567 void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
568 MS_EXCEPTION_IF_NULL(tensors);
569 ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
570 MS_EXCEPTION_IF_NULL(input_value);
571 if (!input_value->isa<ValueTuple>()) {
572 MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
573 }
574
575 auto value_tuple = input_value->cast<ValueTuplePtr>();
576 MS_EXCEPTION_IF_NULL(value_tuple);
577 tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
578 MS_EXCEPTION_IF_NULL(tensor_ptr);
579 (void)tensors->emplace_back(tensor_ptr);
580 }
581
ConvertMultiPyObjectToTensor(const py::object & input_object,std::vector<tensor::TensorPtr> * tensors)582 void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
583 MS_EXCEPTION_IF_NULL(tensors);
584 if (!py::isinstance<py::tuple>(input_object)) {
585 MS_LOG(EXCEPTION) << "The input should be a tuple!";
586 }
587
588 auto inputs = py::cast<py::tuple>(input_object);
589 if (inputs.empty()) {
590 MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
591 }
592
593 if (py::isinstance<tensor::Tensor>(inputs[0])) {
594 PlantTensorTupleToVector(inputs, tensors);
595 } else {
596 ConvertValueTupleToTensor(input_object, tensors);
597 }
598 }
599
RunControlOperator(const std::shared_ptr<GraphCompiler> & graph_compiler,const KernelGraphPtr & graph,const CNodePtr & kernel,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output_map,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputTensorInfo * input_tensor_info,VectorRef * op_outputs)600 void RunControlOperator(const std::shared_ptr<GraphCompiler> &graph_compiler, const KernelGraphPtr &graph,
601 const CNodePtr &kernel, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output_map,
602 const std::map<AnfNodePtr, size_t> ¶meter_index,
603 const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info,
604 VectorRef *op_outputs) {
605 MS_EXCEPTION_IF_NULL(graph);
606 MS_EXCEPTION_IF_NULL(kernel);
607 MS_EXCEPTION_IF_NULL(op_outputs);
608 AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
609 MS_EXCEPTION_IF_NULL(front_node);
610 if (!front_node->isa<CNode>()) {
611 MS_LOG(EXCEPTION) << "The front node of bprop_cut is not CNode";
612 }
613 CNodePtr cnode = front_node->cast<CNodePtr>();
614 MS_EXCEPTION_IF_NULL(cnode);
615 const std::vector<AnfNodePtr> &node_inputs = cnode->inputs();
616 if (node_inputs.empty()) {
617 MS_LOG(EXCEPTION) << "The inputs of node[" << cnode->fullname_with_scope() << "] is empty";
618 }
619
620 const AnfNodePtr &fn = node_inputs.at(0);
621 if (!IsValueNode<Primitive>(fn)) {
622 MS_LOG(EXCEPTION) << "The input[0] of kernel[" << kernel->fullname_with_scope()
623 << "] is not a ValueNode of Primitive";
624 }
625
626 PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
627 MS_EXCEPTION_IF_NULL(prim);
628 if (prim->name() == kBpropCutOpName) {
629 VectorRef args;
630 GetControlOpInput(graph_compiler, cnode, kernel, op_output_map, parameter_index, graph_inputs, input_tensor_info,
631 &args);
632 BaseRef out = prim->RunHookFunction(args);
633 // Convert pyobject output to tensor.
634 if (utils::isa<PyObjectRef>(out)) {
635 PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
636 auto out_py_tuple = py_ref.object_;
637 std::vector<tensor::TensorPtr> output_tensors;
638 ConvertMultiPyObjectToTensor(out_py_tuple, &output_tensors);
639 (void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
640 [](tensor::TensorPtr &tensor) { return std::move(tensor); });
641 }
642 }
643 }
644
TensorValueToVector(const ValuePtr & value,VectorRef * outputs)645 void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
646 MS_EXCEPTION_IF_NULL(value);
647 MS_EXCEPTION_IF_NULL(outputs);
648 if (value->isa<ValueTuple>()) {
649 auto value_tuple = value->cast<ValueTuplePtr>();
650 MS_EXCEPTION_IF_NULL(value_tuple);
651 for (size_t i = 0; i < value_tuple->size(); ++i) {
652 ValuePtr element = value_tuple->value()[i];
653 MS_EXCEPTION_IF_NULL(element);
654 if (element->isa<tensor::Tensor>()) {
655 auto tensor = element->cast<tensor::TensorPtr>();
656 MS_EXCEPTION_IF_NULL(tensor);
657 outputs->emplace_back(tensor);
658 } else if (element->isa<ValueTuple>()) {
659 TensorValueToVector(element, outputs);
660 }
661 }
662 } else if (value->isa<tensor::Tensor>()) {
663 auto tensor = value->cast<tensor::TensorPtr>();
664 MS_EXCEPTION_IF_NULL(tensor);
665 outputs->emplace_back(tensor);
666 }
667 }
668
IsGraphOutputValueNodeOrParameter(const AnfNodePtr & graph_output,const VectorRef & args,VectorRef * outputs)669 bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const VectorRef &args, VectorRef *outputs) {
670 MS_EXCEPTION_IF_NULL(graph_output);
671 MS_EXCEPTION_IF_NULL(outputs);
672 if (graph_output->isa<ValueNode>()) {
673 MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
674 VectorRef output_tmp;
675 ValuePtr value = GetValueNode(graph_output);
676 TensorValueToVector(value, &output_tmp);
677 if (output_tmp.size() == 1) {
678 *outputs = std::move(output_tmp);
679 } else if (output_tmp.size() > 1) {
680 outputs->emplace_back(output_tmp);
681 } else {
682 MS_LOG(EXCEPTION) << "Output is empty!";
683 }
684 return true;
685 }
686
687 if (graph_output->isa<Parameter>()) {
688 MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
689 // Find the right parameter as ret_val.
690 auto func_graph = graph_output->func_graph();
691 MS_EXCEPTION_IF_NULL(func_graph);
692 auto params = func_graph->parameters();
693 if (args.size() != params.size()) {
694 MS_LOG(EXCEPTION) << "Input size " << args.size() << " not equal to graph input size " << params.size();
695 }
696
697 auto it = std::find(params.begin(), params.end(), graph_output);
698 if (it == params.end()) {
699 MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
700 }
701 size_t index = it - params.cbegin();
702 if (index >= args.size()) {
703 MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size();
704 }
705
706 outputs->emplace_back(args[index]);
707 return true;
708 }
709 return false;
710 }
711 } // namespace
712
RunGraphBySingleOp(const std::vector<KernelGraphPtr> & graphs,const std::vector<std::vector<tensor::TensorPtr>> & inputs,VectorRef * outputs)713 void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
714 const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
715 MS_EXCEPTION_IF_NULL(graph_compiler_);
716 for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) {
717 const auto &graph = graphs[graph_index];
718 MS_EXCEPTION_IF_NULL(graph);
719 std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
720 std::map<AnfNodePtr, size_t> parameter_index;
721 GraphOutputInfo graph_output_info;
722 graph_output_info.graph_outputs = outputs;
723 graph_compiler_->GetParamAndOutputIndex(graph, inputs[graph_index], outputs, ¶meter_index,
724 &graph_output_info.output_indexes);
725
726 std::map<KernelWithIndex, size_t> cnode_ref_count;
727 auto iter = cnode_ref_counts_.find(graph->graph_id());
728 if (iter == cnode_ref_counts_.end()) {
729 graph_compiler_->CalculateRefCount(graph, &cnode_ref_count);
730 (void)cnode_ref_counts_.emplace(graph->graph_id(), cnode_ref_count);
731 } else {
732 cnode_ref_count = iter->second;
733 }
734
735 // Clear bucket resources every step
736 if (graph->is_bprop()) {
737 graph_compiler_->ClearAllBucket(graph->graph_id());
738 }
739
740 for (const auto &kernel : graph->execution_order()) {
741 InputTensorInfo input_tensor_info;
742 VectorRef op_outputs;
743
744 if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
745 OpRunInfo op_run_info;
746 GraphInfo graph_info;
747 graph_compiler_->GetSingleOpInputTensors(kernel, op_output_map, parameter_index, inputs[graph_index],
748 &input_tensor_info);
749 graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info,
750 &graph_info);
751
752 const ActorInfo &actor_info = CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask,
753 &input_tensor_info.input_tensors);
754 RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
755 &op_outputs);
756 } else {
757 RunControlOperator(graph_compiler_, graph, kernel, op_output_map, parameter_index, inputs[graph_index],
758 &input_tensor_info, &op_outputs);
759 }
760
761 graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);
762
763 graph_output_info.graph_output_tensors.clear();
764 graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
765
766 // Save grad node to Bucket
767 if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel))) {
768 graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
769 }
770 }
771 }
772 }
773
RunGraph(const ActorInfo & actor_info,const VectorRef & args,VectorRef * outputs)774 void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
775 MS_LOG(INFO) << "Run actor begin, actor name: " << actor_info;
776 MS_EXCEPTION_IF_NULL(root_graph_);
777 if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
778 return;
779 }
780
781 const auto &context_ptr = MsContext::GetInstance();
782 MS_EXCEPTION_IF_NULL(context_ptr);
783 if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
784 MS_LOG(INFO) << "PrecompileOnly, stop run graph";
785 return;
786 }
787
788 // Fetch the graph compiler info.
789 const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
790 if (graph_iter == actor_to_graph_compiler_info_.end()) {
791 MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
792 }
793 MS_EXCEPTION_IF_NULL(graph_iter->second);
794 const auto &graph_compiler_info = *(graph_iter->second);
795 const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
796
797 // Transform args to input tensors.
798 // Input tensors of the graph.
799 std::vector<std::vector<tensor::TensorPtr>> input_tensors;
800 for (const auto &kernel_graph : graph_compiler_info.graphs_) {
801 std::vector<tensor::TensorPtr> input_tensor;
802 MS_EXCEPTION_IF_NULL(kernel_graph);
803 for (const auto &input_node : kernel_graph->input_nodes()) {
804 const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
805 PushTensor(args, origin_parameters, front_node, &input_tensor);
806 }
807 (void)input_tensors.emplace_back(input_tensor);
808 }
809
810 // Input tensors of the control node.
811 std::vector<tensor::TensorPtr> input_tensor;
812 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
813 // Get inputs of control node which come from the host actor.
814 const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
815 for (const auto ¶meter : control_node_parameters) {
816 PushTensor(args, origin_parameters, parameter, &input_tensor);
817 }
818 (void)input_tensors.emplace_back(input_tensor);
819
820 // Run in the pynative mode.
821 MS_EXCEPTION_IF_NULL(outputs);
822 // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode.
823 if (real_execution_mode_ == kPynativeMode) {
824 RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs);
825 return;
826 }
827 // Run actor DAG.
828 mindspore::ScopedLongRunning long_running;
829 const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
830 MS_EXCEPTION_IF_NULL(actor_set);
831 if (!runtime::GraphScheduler::GetInstance().Run(actor_set, input_tensors)) {
832 #ifdef ENABLE_DUMP_IR
833 mindspore::RDR::TriggerAll();
834 #endif
835 MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
836 }
837
838 if (graph_compiler_info.device_contexts_.empty()) {
839 MS_LOG(EXCEPTION) << "The device contexts is empty.";
840 }
841 // Sync device stream.
842 const auto &first_device_context = graph_compiler_info.device_contexts_[0];
843 MS_EXCEPTION_IF_NULL(first_device_context);
844 if (!first_device_context->SyncStream()) {
845 MS_LOG(EXCEPTION) << "Sync stream failed:" << first_device_context->device_context_key().ToString();
846 }
847 for (size_t i = 0; i < graph_compiler_info.device_contexts_.size(); ++i) {
848 const auto &device_context = graph_compiler_info.device_contexts_[i];
849 MS_EXCEPTION_IF_NULL(device_context);
850 if ((device_context != first_device_context) && (!device_context->SyncStream())) {
851 MS_LOG(EXCEPTION) << "Sync stream failed:" << device_context->device_context_key().ToString();
852 }
853 }
854
855 // Fetch outputs.
856 MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
857 auto &output_tensors = actor_set->output_actor_->outputs();
858 if (output_tensors.size() > 0) {
859 size_t output_position = 0;
860 ConstructOutputs(root_graph_->output(), output_tensors, &output_position, outputs);
861 }
862
863 MS_EXCEPTION_IF_NULL(graph_compiler_);
864 graph_compiler_->Summary(graph_compiler_info.graphs_);
865
866 // Update device address for output node of graph.
867 actor_set->output_actor_->UpdateOutputDeviceAddress();
868 MS_LOG(INFO) << "Run actor end, actor name: " << actor_info;
869 }
870
ConstructOutputs(const AnfNodePtr & output_node,const std::vector<tensor::TensorPtr> & output_tensors,size_t * output_position,VectorRef * outputs)871 void MindRTBackend::ConstructOutputs(const AnfNodePtr &output_node,
872 const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position,
873 VectorRef *outputs) {
874 MS_EXCEPTION_IF_NULL(output_node);
875 MS_EXCEPTION_IF_NULL(outputs);
876 MS_EXCEPTION_IF_NULL(output_position);
877 // The makeTuple node need expand and recurse.
878 if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimMakeTuple)) {
879 auto make_tuple = output_node->cast<CNodePtr>();
880 MS_EXCEPTION_IF_NULL(make_tuple);
881 VectorRef make_tuple_output;
882 for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
883 ConstructOutputs(make_tuple->input(i), output_tensors, output_position, &make_tuple_output);
884 }
885 outputs->emplace_back(std::move(make_tuple_output));
886 return;
887 }
888
889 // The depend node need get the real node.
890 if (AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
891 auto depend_node = output_node->cast<CNodePtr>();
892 MS_EXCEPTION_IF_NULL(depend_node);
893 ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs);
894 return;
895 }
896
897 auto outputs_num = AnfAlgo::GetOutputTensorNum(output_node);
898 // The value node uses the value to be output, to avoid the host memory of value free due to value node destruction.
899 if (output_node->isa<ValueNode>()) {
900 auto value = output_node->cast<ValueNodePtr>()->value();
901 MS_EXCEPTION_IF_NULL(value);
902 if (value->isa<ValueTuple>()) {
903 outputs->emplace_back(value);
904 (*output_position) += CountValueNum(value->cast<ValueTuplePtr>());
905 } else if (outputs_num != 0) {
906 outputs->emplace_back(value);
907 (*output_position) += outputs_num;
908 }
909 // The empty value node return the empty VectorRef.
910 return;
911 }
912
913 auto &output_abstract = output_node->abstract();
914 MS_EXCEPTION_IF_NULL(output_abstract);
915 // Wrap output to VectorRef if the output is tuple.
916 if (output_abstract->isa<abstract::AbstractTuple>()) {
917 VectorRef output_tuple;
918 for (size_t i = 0; i < outputs_num; ++i) {
919 if (*output_position >= output_tensors.size()) {
920 MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
921 }
922 output_tuple.emplace_back(std::move(output_tensors[*output_position]));
923 ++(*output_position);
924 }
925 outputs->emplace_back(std::move(output_tuple));
926 } else {
927 for (size_t i = 0; i < outputs_num; ++i) {
928 if (*output_position >= output_tensors.size()) {
929 MS_LOG(EXCEPTION) << "The output position is out of range: " << *output_position;
930 }
931 outputs->emplace_back(std::move(output_tensors[*output_position]));
932 ++(*output_position);
933 }
934 }
935 }
936
937 #ifdef ENABLE_DEBUGGER
SetDebuggerInit()938 void MindRTBackend::SetDebuggerInit() {
939 auto debugger_ = Debugger::GetInstance();
940 auto ms_context = MsContext::GetInstance();
941 MS_EXCEPTION_IF_NULL(ms_context);
942 debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
943 }
944 #endif
945
ConstructGraphCompilerInfo(const FuncGraphPtr & root_graph)946 std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
947 MS_EXCEPTION_IF_NULL(root_graph);
948 MS_EXCEPTION_IF_NULL(graph_compiler_);
949
950 std::vector<KernelGraphPtr> graphs;
951 std::vector<DeviceContext *> device_contexts;
952 std::string name = "kernel_graph";
953 for (const auto &graph_id_to_context : graph_id_to_device_context_) {
954 (void)graphs.emplace_back(graph_compiler_->Fetch(graph_id_to_context.first));
955 (void)device_contexts.emplace_back(graph_id_to_context.second);
956 (void)name.append("_").append(std::to_string(graph_id_to_context.first));
957 }
958
959 auto parser = std::make_shared<ControlNodeParser>();
960 parser->Parse(control_nodes_, graphs, device_contexts, root_graph);
961
962 runtime::KernelMapPosition outputs_order;
963 size_t outputs_num = 0;
964 const auto &root_output =
965 AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
966 size_t position = 0;
967 auto outputs = AnfAlgo::GetAllOutputWithIndex(root_output);
968 if (runtime::IsCallNode(root_output)) {
969 std::vector<AnfNodePtr> call_nodes;
970 size_t call_output_num = runtime::FetchOutputSizebyCallNode(root_output, &call_nodes);
971 for (size_t i = 0; i < call_output_num; ++i) {
972 (void)outputs.emplace_back(root_output, i);
973 }
974 }
975 outputs_num = outputs.size();
976 for (const auto &output : outputs) {
977 if (outputs_order.count(output) == 0) {
978 outputs_order[output] = {position++};
979 } else {
980 (void)outputs_order[output].emplace_back(position++);
981 }
982 }
983
984 std::vector<std::vector<int64_t> *> tensors_mask;
985 std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
986 return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
987 root_graph->parameters(), parser, outputs_order, outputs_num, name, false,
988 runtime::GraphExecutionStrategy::kPipeline);
989 }
990
ConstructGraphCompilerInfo(const ActorInfo & actor_info,const std::vector<int64_t> * tensors_mask,const std::vector<tensor::TensorPtr> * input_tensors,bool need_erase)991 std::unique_ptr<GraphCompilerInfo> MindRTBackend::ConstructGraphCompilerInfo(
992 const ActorInfo &actor_info, const std::vector<int64_t> *tensors_mask,
993 const std::vector<tensor::TensorPtr> *input_tensors, bool need_erase) {
994 std::vector<KernelGraphPtr> graphs;
995 std::vector<DeviceContext *> device_contexts;
996 runtime::KernelMapPosition outputs_order;
997 size_t position = 0;
998 MS_EXCEPTION_IF_NULL(graph_compiler_);
999 for (const auto &graph_info_to_context : graph_info_to_device_context_) {
1000 const auto &graph = graph_compiler_->Fetch(graph_info_to_context.first);
1001 MS_EXCEPTION_IF_NULL(graph);
1002 (void)graphs.emplace_back(graph);
1003 (void)device_contexts.emplace_back(graph_info_to_context.second);
1004
1005 auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
1006 for (const auto &output : outputs) {
1007 if (outputs_order.count(output) == 0) {
1008 outputs_order[output] = {position++};
1009 } else {
1010 (void)outputs_order[output].emplace_back(position++);
1011 }
1012 }
1013 }
1014
1015 std::vector<std::vector<int64_t> *> tensors_mask_list(1, const_cast<std::vector<int64_t> *>(tensors_mask));
1016 std::vector<std::vector<TensorPtr> *> input_tensors_list(1,
1017 const_cast<std::vector<tensor::TensorPtr> *>(input_tensors));
1018 auto parser = std::make_shared<ControlNodeParser>();
1019 return std::make_unique<GraphCompilerInfo>(graphs, device_contexts, tensors_mask_list, input_tensors_list,
1020 std::vector<AnfNodePtr>(), std::vector<AnfNodePtr>(), parser,
1021 outputs_order, outputs_order.size(), actor_info, need_erase,
1022 runtime::GraphExecutionStrategy::kStep);
1023 }
1024
EraseSingleOpCache(const ActorInfo & actor_info,const KernelGraphPtr & graph)1025 void MindRTBackend::EraseSingleOpCache(const ActorInfo &actor_info, const KernelGraphPtr &graph) {
1026 MS_EXCEPTION_IF_NULL(graph);
1027 if (graph_info_to_device_context_.empty()) {
1028 MS_LOG(EXCEPTION) << "The map graph_info_to_device_context_ is empty.";
1029 }
1030 const auto &graph_info = graph_info_to_device_context_.begin()->first;
1031 MS_EXCEPTION_IF_NULL(graph_compiler_);
1032 graph_compiler_->EraseSingleOpCache(graph_info, graph->graph_id());
1033 actor_to_graph_compiler_info_.erase(actor_info);
1034 }
1035
RunGraph(const ActorInfo & actor_info,OpRunInfo * op_run_info,const std::vector<int64_t> * tensors_mask,const std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs)1036 void MindRTBackend::RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info,
1037 const std::vector<int64_t> *tensors_mask,
1038 const std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs) {
1039 MS_EXCEPTION_IF_NULL(input_tensors);
1040 MS_EXCEPTION_IF_NULL(op_run_info);
1041 MS_EXCEPTION_IF_NULL(tensors_mask);
1042 const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
1043 if (graph_iter == actor_to_graph_compiler_info_.end()) {
1044 MS_LOG(EXCEPTION) << "Can't find the graph compiler info.";
1045 }
1046 MS_EXCEPTION_IF_NULL(graph_iter->second);
1047 const auto &graph_compiler_info = *(graph_iter->second);
1048
1049 const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
1050 MS_EXCEPTION_IF_NULL(actor_set);
1051
1052 // Erase value node tensor.
1053 std::vector<tensor::TensorPtr> tensors_without_value_node;
1054 if (input_tensors->size() != tensors_mask->size()) {
1055 MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
1056 << tensors_mask->size();
1057 }
1058 for (size_t index = 0; index < tensors_mask->size(); ++index) {
1059 if (tensors_mask->at(index) != kValueNodeTensorMask) {
1060 (void)tensors_without_value_node.emplace_back(input_tensors->at(index));
1061 }
1062 }
1063
1064 for (auto &tensor : tensors_without_value_node) {
1065 MS_EXCEPTION_IF_NULL(tensor);
1066 if (tensor->NeedWaitDevice()) {
1067 tensor->WaitDevice();
1068 }
1069 }
1070
1071 if (!runtime::GraphScheduler::GetInstance().Run(actor_set, {tensors_without_value_node}, *input_tensors,
1072 runtime::GraphExecutionStrategy::kStep)) {
1073 MS_LOG(EXCEPTION) << "The actor runs failed, actor name: " << actor_set->name_;
1074 }
1075
1076 // Fetch outputs.
1077 const auto &graph = graph_compiler_info.graphs_.front();
1078 MS_EXCEPTION_IF_NULL(graph);
1079 MS_EXCEPTION_IF_NULL(graph_compiler_);
1080 const auto &output_nodes = graph_compiler_->GetGraphOutputNodes(graph->graph_id());
1081 MS_EXCEPTION_IF_NULL(outputs);
1082 UpdateOutput(output_nodes, outputs);
1083
1084 // Update output abstract of dynamic op to op_run_info
1085 if (op_run_info->is_dynamic_shape) {
1086 UpdateOutputAbstract(graph, op_run_info);
1087 }
1088
1089 // Release the kernel resource.
1090 const auto &kernels = graph->execution_order();
1091 for (const auto &kernel : kernels) {
1092 MS_EXCEPTION_IF_NULL(kernel);
1093 if (kOpCacheBlackList.find(AnfAlgo::GetCNodeName(kernel)) != kOpCacheBlackList.end()) {
1094 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
1095 if (kernel_mod) {
1096 kernel_mod->ReleaseResource();
1097 }
1098 }
1099 }
1100
1101 // Update device address for input and output of graph.
1102 UpdateOutputDeviceAddress(output_nodes, graph_compiler_info.device_contexts_.front());
1103 UpdateInputDeviceAddress(graph);
1104
1105 if (graph_compiler_info.need_erase_) {
1106 EraseSingleOpCache(actor_info, graph);
1107 }
1108 }
1109 } // namespace compile
1110 } // namespace mindspore
1111