• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "vm/transform.h"
20 
21 #include <algorithm>
22 #include <map>
23 #include <queue>
24 #include <string>
25 #include <vector>
26 
27 #include "abstract/abstract_value.h"
28 #ifdef ENABLE_GE
29 #include "transform/graph_ir/convert.h"
30 #endif
31 #include "ir/graph_utils.h"
32 #include "utils/ms_context.h"
33 #include "debug/trace.h"
34 #include "debug/anf_ir_dump.h"
35 #if ((defined ENABLE_CPU) && (!defined _WIN32))
36 #include "ps/ps_context.h"
37 #endif
38 
39 namespace mindspore {
40 namespace compile {
41 using mindspore::abstract::AbstractFunction;
42 using mindspore::abstract::AbstractFunctionPtr;
43 using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
44 using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
45 using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
46 
47 std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
48                                            prim::kPrimMakeTuple, prim::kPrimBpropCut};
49 
50 std::vector<PrimitivePtr> control_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple,
51                                          prim::kPrimSwitchLayer};
52 
GetMsNonlinearOps()53 const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
54   static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn,   prim::kPrimPartial,
55                                                              prim::kPrimSwitch,   prim::kPrimMakeTuple,
56                                                              prim::kPrimBpropCut, prim::kPrimSwitchLayer};
57   return ms_nonlinear_ops;
58 }
59 
CompileGraph(const BackendPtr & backend,const std::vector<PrimitivePtr> & cut_list)60 CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
61   MS_EXCEPTION_IF_NULL(backend_);
62   lin_convert_ = backend_->convert_fn();
63   if (lin_convert_ == nullptr) {
64     MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
65   }
66   graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name());
67 }
68 
69 // Push the value node on the stack.
Push(const AnfNodePtr & node)70 void CompileGraph::Push(const AnfNodePtr &node) {
71   MS_EXCEPTION_IF_NULL(node);
72   if (slots_.count(node) > 0) {
73     MS_LOG(WARNING) << "Push failed node in slots:" << node->DebugString()
74                     << " NodeInfo: " << trace::GetDebugInfo(node->debug_info());
75     return;
76   }
77   MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
78                 << " is parameter: " << node->isa<Parameter>();
79   slots_[node] = height_;
80   set_height(height_ + 1);
81 }
82 
AddInst(const Instruction & inst,const int64_t & arg)83 void CompileGraph::AddInst(const Instruction &inst, const int64_t &arg) {
84   VectorRef args;
85   args.push_back(arg);
86   AddInst(inst, args);
87 }
88 
AddInst(const Instruction & inst,const ValuePtr & arg)89 void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
90   VectorRef args;
91   args.push_back(arg);
92   AddInst(inst, args);
93 }
94 
AddInst(const Instruction & inst,const VectorRef & args)95 void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
96   inst_.push_back(std::make_pair(inst, args));
97 }
98 
99 // Gets the stack reference for the node value. If the node is a constant,
100 // it may actually cause the push in to not be mentioned before.
Ref(const AnfNodePtr & node)101 int64_t CompileGraph::Ref(const AnfNodePtr &node) {
102   MS_EXCEPTION_IF_NULL(node);
103   MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
104   if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
105     if (IsValueNode<FuncGraph>(node)) {
106       MS_LOG(DEBUG) << "Push graph.";
107       AddInst(Instruction::kGraph, GetValueNode(node));
108     } else {
109       MS_LOG(DEBUG) << "Push.";
110       if (IsValueNode<Primitive>(node)) {
111         MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfo(node->debug_info());
112       } else {
113         AddInst(Instruction::kPush, GetValueNode(node));
114       }
115     }
116     Push(node);
117   }
118   MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
119                 << ", return: " << slots_[node] - height_;
120   return slots_[node] - height_;
121 }
122 
123 // Make sure the value of node is at the top of the stack.
AddInput(const AnfNodePtr & node)124 void CompileGraph::AddInput(const AnfNodePtr &node) {
125   MS_EXCEPTION_IF_NULL(node);
126   if (slots_.count(node) == 0) {
127     MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
128     (void)Ref(node);
129     return;
130   }
131   AddInst(Instruction::kInput, Ref(node));
132   set_height(height_ + 1);
133 }
134 
135 // Call back effect in stack
Ret(int64_t nargs)136 void CompileGraph::Ret(int64_t nargs) { set_height(height_ - nargs); }
137 
PushParameters(const FuncGraphPtr & graph)138 void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
139   MS_EXCEPTION_IF_NULL(graph);
140   std::vector<AnfNodePtr> parameters = graph->parameters();
141   for (size_t i = parameters.size(); i != 0; i--) {
142     MS_EXCEPTION_IF_NULL(parameters[i - 1]);
143     Push(parameters[i - 1]);
144     MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << parameters[i - 1]->DebugString(true);
145   }
146 }
147 
LinConvert(const FuncGraphPtr & graph,const GraphSegmentPtr & segment,const std::string & target)148 int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
149   MS_EXCEPTION_IF_NULL(segment);
150   MS_LOG(DEBUG) << "LinConvert start";
151   LinConvertResult result;
152 
153   result = lin_convert_(segment, target);
154 
155   if (result.run == nullptr) {
156     MS_LOG(ERROR) << "LinConvert failed";
157     return RET_FAILED;
158   }
159 
160   if (!(*result.run)) {
161     if (result.inputs.size() != result.outputs.size()) {
162       MS_EXCEPTION_IF_NULL(graph);
163       MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
164     } else {
165       size_t size = result.inputs.size();
166       for (size_t i = 0; i < size; i++) {
167         Tie(result.inputs[i], result.outputs[i]);
168       }
169       return RET_CONTINUE;
170     }
171   }
172   AddExternal(result);
173 
174   return RET_SUCCESS;
175 }
176 
InterpretNode(const FuncGraphPtr & graph,const CNodePtr & node)177 int64_t CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
178   MS_EXCEPTION_IF_NULL(node);
179   MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
180   std::vector<AnfNodePtr> node_inputs = node->inputs();
181   if (node_inputs.empty()) {
182     MS_LOG(EXCEPTION) << "The node->inputs() is empty";
183   }
184   AnfNodePtr fn = node_inputs[0];
185   if (IsValueNode<Primitive>(fn)) {
186     PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
187     MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
188     for (size_t i = node_inputs.size() - 1; i > 0; i--) {
189       AddInput(node->input(i));
190     }
191     if (IsPrimitive(fn, prim::kPrimReturn)) {
192       AddReturn(node);
193       return RET_BREAK;
194     }
195     if (IsPrimitive(fn, prim::kPrimPartial)) {
196       AddPartial(node);
197     } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
198       AddSwitch(node);
199     } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
200       AddSwitchLayer(node);
201     } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
202       AddMakeTuple(node);
203     } else {
204       AddPrimitive(node, value);
205     }
206   } else {
207     int64_t ret = AddCall(graph, node);
208     if (ret == RET_BREAK) {
209       return ret;
210     }
211   }
212   Push(node);
213   return RET_SUCCESS;
214 }
215 
Compile(const FuncGraphPtr & graph)216 bool CompileGraph::Compile(const FuncGraphPtr &graph) {
217   MS_LOG(DEBUG) << "Start split graph";
218   MS_EXCEPTION_IF_NULL(graph);
219   MS_EXCEPTION_IF_NULL(graph_partition_);
220   auto segments = graph_partition_->Partition(graph);
221 
222   MS_LOG(DEBUG) << "Split nodes size:" << segments.size();
223   for (auto &segment : segments) {
224     MS_EXCEPTION_IF_NULL(segment);
225     int64_t ret = RET_SUCCESS;
226     if (!segment->is_cut_) {
227       MS_LOG(DEBUG) << "Start a extern LinConvert";
228       if (!segment->nodes_.empty()) {
229         std::string cur_target = GetCNodeTarget(segment->nodes_[0]);
230         ret = LinConvert(graph, segment, cur_target);
231       } else {
232         ret = LinConvert(graph, segment);
233       }
234       MS_LOG(DEBUG) << "End a extern LinConvert";
235       if (ret == RET_FAILED) {
236         return false;
237       }
238       if (ret == RET_CONTINUE) {
239         continue;
240       }
241     } else if (!segment->nodes_.empty()) {
242       MS_LOG(DEBUG) << "Start a cut node";
243       auto &cut_node = segment->nodes_[0];
244       MS_EXCEPTION_IF_NULL(cut_node);
245       if (!cut_node->isa<CNode>()) {
246         MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info());
247       }
248       auto node = cut_node->cast<CNodePtr>();
249       ret = InterpretNode(graph, node);
250       MS_LOG(DEBUG) << "End a cut node";
251       if (ret == RET_BREAK) {
252         break;
253       }
254     }
255   }
256   MS_LOG(DEBUG) << "End split graph";
257   return true;
258 }
259 
Run(const FuncGraphPtr & graph)260 InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
261   MS_EXCEPTION_IF_NULL(graph);
262 
263   Reset();
264   PushParameters(graph);
265 
266   int64_t param_height = height_;
267   MS_EXCEPTION_IF_NULL(graph->get_return());
268   MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
269 
270   if (!Compile(graph)) {
271     return inst_;
272   }
273 
274   AddPadStack(param_height);
275   auto ret = inst_;
276   Reset();
277   return ret;
278 }
279 
AddPadStack(int64_t param_height)280 void CompileGraph::AddPadStack(int64_t param_height) {
281   int64_t stack_sizes = max_height_ - param_height;
282   MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
283                 << " need_stack:" << stack_sizes;
284   if (stack_sizes > 0) {
285     VectorRef need_stacks({stack_sizes});
286     (void)inst_.insert(inst_.begin(), std::make_pair(Instruction::kPadStack, need_stacks));
287   }
288 }
289 
AddTailCall(const AnfNodePtr & fn,size_t size)290 void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
291   VectorRef args;
292   args.emplace_back(Ref(fn));
293   args.emplace_back(height_);
294   args.emplace_back(static_cast<int64_t>(size - 1));
295   MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
296   AddInst(Instruction::kTailCall, args);
297 }
298 
AddPartial(const CNodePtr & node)299 void CompileGraph::AddPartial(const CNodePtr &node) {
300   MS_EXCEPTION_IF_NULL(node);
301   auto inputs = node->inputs();
302   VectorRef args;
303   if (inputs.size() <= 1) {
304     MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
305   }
306   auto fn = inputs[1];
307   if (!IsValueNode<FuncGraph>(fn)) {
308     MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph";
309   }
310   for (size_t i = 1; i < inputs.size(); i++) {
311     args.emplace_back(Ref(inputs[i]));
312   }
313   AddInst(Instruction::kPartial, args);
314 }
315 
AddMakeTuple(const CNodePtr & node)316 void CompileGraph::AddMakeTuple(const CNodePtr &node) {
317   MS_EXCEPTION_IF_NULL(node);
318   auto inputs = node->inputs();
319   VectorRef args;
320   for (size_t i = 1; i < inputs.size(); i++) {
321     args.emplace_back(Ref(inputs[i]));
322   }
323   AddInst(Instruction::kTuple, args);
324 }
325 
AddSwitch(const CNodePtr & node)326 void CompileGraph::AddSwitch(const CNodePtr &node) {
327   MS_EXCEPTION_IF_NULL(node);
328   auto inputs = node->inputs();
329   if (inputs.size() < kSwitchInputSize) {
330     MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
331   }
332   VectorRef args;
333   args.emplace_back(Ref(inputs[kCallKernelGraphIndex]));
334   args.emplace_back(Ref(inputs[kSwitchTrueKernelGraphIndex]));
335   args.emplace_back(Ref(inputs[kSwitchFalseKernelGraphIndex]));
336   AddInst(Instruction::kSwitch, args);
337 }
338 
AddSwitchLayer(const CNodePtr & node)339 void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
340   MS_EXCEPTION_IF_NULL(node);
341   auto inputs = node->inputs();
342   if (inputs.size() != kSwitchLayerInputSize) {
343     MS_LOG(EXCEPTION) << "Switch layer must have index and branches.";
344   }
345   VectorRef args;
346   const size_t cond_index = 1;
347   const size_t tuple_index = 2;
348   args.emplace_back(Ref(inputs[cond_index]));
349   args.emplace_back(Ref(inputs[tuple_index]));
350   AddInst(Instruction::kSwitchLayer, args);
351 }
352 
AddReturn(const CNodePtr & node)353 void CompileGraph::AddReturn(const CNodePtr &node) {
354   MS_EXCEPTION_IF_NULL(node);
355   VectorRef args;
356   if (node->inputs().size() <= 1) {
357     MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
358   }
359   args.emplace_back(Ref(node->input(1)));
360   args.emplace_back(height_);
361   AddInst(Instruction::kReturn, args);
362 }
363 
AddPrimitive(const CNodePtr & node,const PrimitivePtr & prim)364 void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
365   MS_EXCEPTION_IF_NULL(node);
366   auto inputs = node->inputs();
367   VectorRef args;
368   args.push_back(prim);
369   for (size_t i = 1; i < inputs.size(); i++) {
370     args.emplace_back(Ref(inputs[i]));
371   }
372   AddInst(Instruction::kPrim, args);
373 }
374 
AddCall(const FuncGraphPtr & graph,const CNodePtr & node)375 int64_t CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
376   MS_EXCEPTION_IF_NULL(graph);
377   MS_EXCEPTION_IF_NULL(node);
378   auto inputs = node->inputs();
379   if (inputs.empty()) {
380     MS_LOG(EXCEPTION) << "The node->inputs() is empty.";
381   }
382   AnfNodePtr fn = inputs[0];
383   (void)Ref(fn);
384   size_t size = inputs.size();
385   for (size_t i = size - 1; i > 0; i--) {
386     AddInput(inputs[i]);
387   }
388   if (node == graph->output()) {
389     AddTailCall(fn, size);
390     return RET_BREAK;
391   }
392   MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
393   AddInst(Instruction::kCall, Ref(fn));
394   Ret(static_cast<int64_t>(size - 1));
395 
396   for (size_t i = size - 1; i > 0; i--) {
397     const auto iter = slots_.find(inputs[i]);
398     if (iter != slots_.end() && iter->second >= height_) {
399       (void)slots_.erase(inputs[i]);
400     }
401   }
402   return RET_SUCCESS;
403 }
404 
AddExternal(const LinConvertResult & result)405 void CompileGraph::AddExternal(const LinConvertResult &result) {
406   VectorRef args;
407   args.push_back(result.run);
408   args.push_back(result.simu_run);
409   size_t size = result.inputs.size();
410   for (size_t i = 0; i < size; i++) {
411     args.emplace_back(Ref(result.inputs[i]));
412   }
413   AddInst(Instruction::kExternal, args);
414   for (auto &out : result.outputs) {
415     Push(out);
416   }
417 }
418 
TraverseGraphMap(const FuncGraphManagerPtr & manager_ptr,FuncGraphTransaction * const tr,const FuncGraphSet & fgs,const std::function<std::shared_ptr<FuncGraph> (const PrimitivePtr,const AbstractFunctionPtr)> & get_prim_graph)419 void TraverseGraphMap(
420   const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphSet &fgs,
421   const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
422   MS_EXCEPTION_IF_NULL(manager_ptr);
423   MS_EXCEPTION_IF_NULL(tr);
424   for (const auto &fg : fgs) {
425     MS_EXCEPTION_IF_NULL(fg);
426     for (const auto &ct_any : fg->value_nodes()) {
427       AnfNodePtr const_primitive_node = ct_any.first;
428       if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
429         auto users = manager_ptr->node_users()[const_primitive_node];
430         for (auto &use : users) {
431           CNodePtr node = use.first->cast<CNodePtr>();
432           MS_EXCEPTION_IF_NULL(node);
433           if (node->func_graph() != fg) {
434             continue;
435           }
436           int64_t key = use.second;
437           if (key != 0) {
438             MS_EXCEPTION_IF_NULL(node->input(0));
439             bool key_is_const = node->input(0)->isa<ValueNode>();
440             PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
441             if (value != nullptr) {
442               bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
443               bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
444               if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
445                 continue;
446               }
447             }
448             FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
449                                             dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
450             tr->SetEdge(node, LongToInt(key), NewValueNode(g));
451           }
452         }
453       }
454     }
455   }
456 }
457 
WrapPrimitives(const FuncGraphPtr & graph)458 FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
459   MS_EXCEPTION_IF_NULL(graph);
460   FuncGraphManagerPtr manager_ptr = graph->manager();
461   MS_EXCEPTION_IF_NULL(manager_ptr);
462   MapPrimTypeFuncGraph prim_graphs;
463   auto get_prim_graph = [&prim_graphs](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
464     PrimTypePair prim_type = std::make_pair(prim, type);
465     if (prim_graphs.end() == prim_graphs.find(prim_type)) {
466       FuncGraphPtr g = std::make_shared<FuncGraph>();
467       std::vector<AnfNodePtr> args;
468       ValueNodePtr prim_ct = NewValueNode(prim);
469       MS_EXCEPTION_IF_NULL(prim_ct);
470       prim_ct->set_abstract(type);
471       args.push_back(prim_ct);
472       MS_EXCEPTION_IF_NULL(type);
473       TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
474       MS_EXCEPTION_IF_NULL(tp);
475       MS_EXCEPTION_IF_NULL(g);
476       for (auto t : tp->args_spec_list()) {
477         ParameterPtr p = g->add_parameter();
478         p->set_abstract(t);
479         args.push_back(p);
480       }
481       AnfNodePtr out = g->NewCNode(args);
482       out->set_abstract(tp->output());
483       g->set_output(out);
484       prim_graphs[prim_type] = g;
485     }
486 
487     return prim_graphs[prim_type];
488   };
489 
490   FuncGraphTransaction tr = manager_ptr->Transact();
491   auto &fgs = manager_ptr->func_graphs();
492   TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
493   tr.Commit();
494 
495   return graph;
496 }
497 
CompileGraphs(const BackendPtr & backend,const std::vector<PrimitivePtr> & cut_list)498 CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
499   MS_EXCEPTION_IF_NULL(backend);
500   MS_LOG(DEBUG) << "Start vm: " << backend->name();
501   transform_ = std::make_shared<CompileGraph>(backend, cut_list);
502   Reset();
503 }
504 
505 // Convert graphs to unlinked instructions.
Compile(const FuncGraphPtr & graph)506 void CompileGraphs::Compile(const FuncGraphPtr &graph) {
507   MS_LOG(DEBUG) << "Start";
508   mapping_[graph] = static_cast<int64_t>(insts_.size());
509   if (transform_ != nullptr) {
510     InstSet insts = transform_->Run(graph);
511     if (!insts.empty()) {
512       (void)insts_.insert(insts_.end(), insts.begin(), insts.end());
513     }
514   }
515   MS_LOG(DEBUG) << "End";
516 }
517 
518 // Link instructions from multiple function graphs together.
Link()519 FinalVMPtr CompileGraphs::Link() {
520   MS_LOG(DEBUG) << "Start";
521   for (std::size_t i = 0; i < insts_.size(); i++) {
522     InstType inst = insts_[i];
523     MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
524     if (Instruction::kGraph == inst.first) {
525       if (inst.second.empty()) {
526         MS_LOG(EXCEPTION) << "The second element of inst is empty";
527       }
528       FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
529       MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
530       insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
531     }
532   }
533 
534   FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
535   MS_LOG(DEBUG) << "End";
536   return rt;
537 }
538 
539 // Convert all graphs to unlinked instructions and link them.
CompileAndLink(const FuncGraphPtr & graph)540 FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
541   MS_EXCEPTION_IF_NULL(graph);
542   MS_LOG(DEBUG) << "Start";
543   Reset();
544   MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
545 
546   FuncGraphPtr prim_graph = WrapPrimitives(graph);
547   Compile(prim_graph);
548   MS_EXCEPTION_IF_NULL(prim_graph);
549   MS_EXCEPTION_IF_NULL(prim_graph->manager());
550   FuncGraphSet graphs = prim_graph->manager()->func_graphs();
551   for (auto g : graphs) {
552     if (g != graph && g != nullptr) {
553       Compile(g);
554     }
555   }
556 
557   FinalVMPtr rt = Link();
558   Reset();
559   MS_LOG(DEBUG) << "End";
560   return rt;
561 }
562 
CreateBackend()563 BackendPtr CreateBackend() {
564   auto context_ptr = MsContext::GetInstance();
565   MS_EXCEPTION_IF_NULL(context_ptr);
566   std::string name = context_ptr->backend_policy();
567   MS_LOG(INFO) << "CreateBackend is: " << name;
568   if (backend_list.count(name) == 0) {
569     MS_LOG(EXCEPTION) << "Backend is error: " << name;
570   }
571 
572   if (name == kMsConvert) {
573     std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
574     uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
575     BackendPtr backend = nullptr;
576     // Create MindRTBackend or MsBackend according to whether mindrt is used.
577     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
578       backend = std::make_shared<MindRTBackend>(name, target, device_id);
579     } else {
580       backend = std::make_shared<MsBackend>(name, target, device_id);
581     }
582 
583     if (target == kAscendDevice) {
584       if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
585         backend->set_is_multi_graph_sink(false);
586         context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
587       } else {
588         auto single_op = common::GetEnv(kGraphOpRun);
589         if (single_op == "1") {
590           context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
591         }
592         auto enable_mem_scheduler = common::GetEnv(kEnableMemScheduler);
593         if (enable_mem_scheduler == "1") {
594           context_ptr->set_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER, true);
595           context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, false);
596         }
597       }
598     }
599     return backend;
600   }
601 
602   return std::make_shared<Backend>(name);
603 }
604 
SetMindRTEnable()605 void SetMindRTEnable() {
606   auto context_ptr = MsContext::GetInstance();
607   MS_EXCEPTION_IF_NULL(context_ptr);
608   if (context_ptr->get_param<bool>(MS_CTX_ALREADY_SET_ENABLE_MINDRT)) {
609     return;
610   }
611 
612   std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
613   if ((target != kGPUDevice) && (target != kCPUDevice)) {
614     return;
615   }
616 
617 #if defined(_WIN32) || defined(_WIN64)
618   return;
619 #endif
620 
621 #if ((defined ENABLE_CPU) && (!defined _WIN32))
622   if (ps::PSContext::instance()->is_ps_mode()) {
623     return;
624   }
625 #endif
626 
627   MS_LOG(DEBUG) << "Enable mindRT.";
628   context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, true);
629 }
630 }  // namespace compile
631 }  // namespace mindspore
632