• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "backend/graph_compiler/transform.h"
20 
21 #include <algorithm>
22 #include <map>
23 #include <queue>
24 #include <string>
25 #include <vector>
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/nn_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "abstract/abstract_value.h"
31 #include "abstract/abstract_function.h"
32 #include "ir/graph_utils.h"
33 #include "utils/ms_context.h"
34 #include "utils/trace_base.h"
35 #if defined(__linux__) && defined(WITH_BACKEND)
36 #include "include/backend/distributed/ps/ps_context.h"
37 #endif
38 
39 namespace mindspore {
40 namespace compile {
41 using mindspore::abstract::AbstractFunction;
42 using mindspore::abstract::AbstractFunctionPtr;
43 using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
44 using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
45 using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
46 
GetNonlinearOps()47 const std::vector<PrimitivePtr> &GetNonlinearOps() {
48   static std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
49                                                     prim::kPrimMakeTuple, prim::kPrimBpropCut};
50   return nonlinear_ops;
51 }
52 
GetControlOps()53 const std::vector<PrimitivePtr> &GetControlOps() {
54   static std::vector<PrimitivePtr> control_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
55                                                   prim::kPrimMakeTuple, prim::kPrimSwitchLayer};
56   return control_ops;
57 }
58 
GetMsNonlinearOps()59 const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
60   static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn,   prim::kPrimPartial,
61                                                              prim::kPrimSwitch,   prim::kPrimMakeTuple,
62                                                              prim::kPrimBpropCut, prim::kPrimSwitchLayer};
63   return ms_nonlinear_ops;
64 }
65 
CompileGraph(const BackendPtr & backend,const std::vector<PrimitivePtr> & cut_list)66 CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
67   MS_EXCEPTION_IF_NULL(backend_);
68   lin_convert_ = backend_->convert_fn();
69   if (lin_convert_ == nullptr) {
70     MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
71   }
72   graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name());
73 }
74 
75 // Push the value node on the stack.
Push(const AnfNodePtr & node)76 void CompileGraph::Push(const AnfNodePtr &node) {
77   MS_EXCEPTION_IF_NULL(node);
78   if (slots_.count(node) > 0) {
79     MS_LOG(WARNING) << "Push failed node in slots:" << node->DebugString()
80                     << " NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
81     return;
82   }
83   MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
84                 << " is parameter: " << node->isa<Parameter>();
85   slots_[node] = height_;
86   set_height(height_ + 1);
87 }
88 
AddInst(const Instruction & inst,const int64_t & arg)89 void CompileGraph::AddInst(const Instruction &inst, const int64_t &arg) {
90   VectorRef args;
91   args.push_back(arg);
92   AddInst(inst, args);
93 }
94 
AddInst(const Instruction & inst,const ValuePtr & arg)95 void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
96   VectorRef args;
97   args.push_back(arg);
98   AddInst(inst, args);
99 }
100 
AddInst(const Instruction & inst,const VectorRef & args)101 void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
102   inst_.push_back(std::make_pair(inst, args));
103 }
104 
105 // Gets the stack reference for the node value. If the node is a constant,
106 // it may actually cause the push in to not be mentioned before.
Ref(const AnfNodePtr & node)107 int64_t CompileGraph::Ref(const AnfNodePtr &node) {
108   MS_EXCEPTION_IF_NULL(node);
109   MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
110   if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
111     if (IsValueNode<FuncGraph>(node)) {
112       MS_LOG(DEBUG) << "Push graph.";
113       AddInst(Instruction::kGraph, GetValueNode(node));
114     } else {
115       MS_LOG(DEBUG) << "Push.";
116       if (IsValueNode<Primitive>(node)) {
117         MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
118       } else {
119         AddInst(Instruction::kPush, GetValueNode(node));
120       }
121     }
122     Push(node);
123   }
124   MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
125                 << ", return: " << slots_[node] - height_;
126   return slots_[node] - height_;
127 }
128 
129 // Make sure the value of node is at the top of the stack.
AddInput(const AnfNodePtr & node)130 void CompileGraph::AddInput(const AnfNodePtr &node) {
131   MS_EXCEPTION_IF_NULL(node);
132   if (slots_.count(node) == 0) {
133     MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
134     (void)Ref(node);
135     return;
136   }
137   AddInst(Instruction::kInput, Ref(node));
138   set_height(height_ + 1);
139 }
140 
141 // Call back effect in stack
Ret(int64_t nargs)142 void CompileGraph::Ret(int64_t nargs) { set_height(height_ - nargs); }
143 
PushParameters(const FuncGraphPtr & graph)144 void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
145   MS_EXCEPTION_IF_NULL(graph);
146   std::vector<AnfNodePtr> parameters = graph->parameters();
147   for (size_t i = parameters.size(); i != 0; i--) {
148     MS_EXCEPTION_IF_NULL(parameters[i - 1]);
149     Push(parameters[i - 1]);
150     MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << parameters[i - 1]->DebugString(true);
151   }
152 }
153 
LinConvert(const FuncGraphPtr & graph,const GraphSegmentPtr & segment,const std::string & target)154 int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
155   MS_EXCEPTION_IF_NULL(segment);
156   MS_LOG(DEBUG) << "LinConvert start";
157   LinConvertResult result;
158 
159   result = lin_convert_(segment, target);
160 
161   if (result.run == nullptr) {
162     MS_LOG(ERROR) << "LinConvert failed";
163     return RET_FAILED;
164   }
165 
166   if (!(*result.run)) {
167     if (result.inputs.size() != result.outputs.size()) {
168       MS_EXCEPTION_IF_NULL(graph);
169       MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfoStr(graph->debug_info());
170     } else {
171       size_t size = result.inputs.size();
172       for (size_t i = 0; i < size; i++) {
173         Tie(result.inputs[i], result.outputs[i]);
174       }
175       return RET_CONTINUE;
176     }
177   }
178   AddExternal(result);
179 
180   return RET_SUCCESS;
181 }
182 
InterpretNode(const FuncGraphPtr & graph,const CNodePtr & node)183 int64_t CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
184   MS_EXCEPTION_IF_NULL(node);
185   MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
186   std::vector<AnfNodePtr> node_inputs = node->inputs();
187   if (node_inputs.empty()) {
188     MS_LOG(EXCEPTION) << "The node->inputs() is empty";
189   }
190   AnfNodePtr fn = node_inputs[0];
191   if (IsValueNode<Primitive>(fn)) {
192     PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
193     MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
194     for (size_t i = node_inputs.size() - 1; i > 0; i--) {
195       AddInput(node->input(i));
196     }
197     if (IsPrimitive(fn, prim::kPrimReturn)) {
198       AddReturn(node);
199       return RET_BREAK;
200     }
201     if (IsPrimitive(fn, prim::kPrimPartial)) {
202       AddPartial(node);
203     } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
204       AddSwitch(node);
205     } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
206       AddSwitchLayer(node);
207     } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
208       AddMakeTuple(node);
209     } else {
210       AddPrimitive(node, value);
211     }
212   } else {
213     int64_t ret = AddCall(graph, node);
214     if (ret == RET_BREAK) {
215       return ret;
216     }
217   }
218   Push(node);
219   return RET_SUCCESS;
220 }
221 
Compile(const FuncGraphPtr & graph)222 bool CompileGraph::Compile(const FuncGraphPtr &graph) {
223   MS_LOG(DEBUG) << "Start split graph";
224   MS_EXCEPTION_IF_NULL(graph);
225   MS_EXCEPTION_IF_NULL(graph_partition_);
226   auto segments = graph_partition_->Partition(graph);
227 
228   MS_LOG(DEBUG) << "Split nodes size:" << segments.size();
229   for (auto &segment : segments) {
230     MS_EXCEPTION_IF_NULL(segment);
231     int64_t ret = RET_SUCCESS;
232     if (!segment->is_cut_) {
233       MS_LOG(DEBUG) << "Start a extern LinConvert";
234       if (!segment->nodes_.empty()) {
235         std::string cur_target = GetCNodeTarget(segment->nodes_[0]);
236         ret = LinConvert(graph, segment, cur_target);
237       } else {
238         ret = LinConvert(graph, segment);
239       }
240       MS_LOG(DEBUG) << "End a extern LinConvert";
241       if (ret == RET_FAILED) {
242         return false;
243       }
244       if (ret == RET_CONTINUE) {
245         continue;
246       }
247     } else if (!segment->nodes_.empty()) {
248       MS_LOG(DEBUG) << "Start a cut node";
249       auto &cut_node = segment->nodes_[0];
250       MS_EXCEPTION_IF_NULL(cut_node);
251       if (!cut_node->isa<CNode>()) {
252         MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfoStr(graph->debug_info());
253       }
254       auto node = cut_node->cast<CNodePtr>();
255       ret = InterpretNode(graph, node);
256       MS_LOG(DEBUG) << "End a cut node";
257       if (ret == RET_BREAK) {
258         break;
259       }
260     }
261   }
262   MS_LOG(DEBUG) << "End split graph";
263   return true;
264 }
265 
Run(const FuncGraphPtr & graph)266 InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
267   MS_EXCEPTION_IF_NULL(graph);
268 
269   Reset();
270   PushParameters(graph);
271 
272   int64_t param_height = height_;
273   MS_EXCEPTION_IF_NULL(graph->get_return());
274   MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
275 
276   if (!Compile(graph)) {
277     return inst_;
278   }
279 
280   AddPadStack(param_height);
281   auto ret = inst_;
282   Reset();
283   return ret;
284 }
285 
AddPadStack(int64_t param_height)286 void CompileGraph::AddPadStack(int64_t param_height) {
287   int64_t stack_sizes = max_height_ - param_height;
288   MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
289                 << " need_stack:" << stack_sizes;
290   if (stack_sizes > 0) {
291     VectorRef need_stacks({stack_sizes});
292     (void)inst_.insert(inst_.cbegin(), std::make_pair(Instruction::kPadStack, need_stacks));
293   }
294 }
295 
AddTailCall(const AnfNodePtr & fn,size_t size)296 void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
297   VectorRef args;
298   args.emplace_back(Ref(fn));
299   args.emplace_back(height_);
300   args.emplace_back(static_cast<int64_t>(size - 1));
301   MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
302   AddInst(Instruction::kTailCall, args);
303 }
304 
AddPartial(const CNodePtr & node)305 void CompileGraph::AddPartial(const CNodePtr &node) {
306   MS_EXCEPTION_IF_NULL(node);
307   auto inputs = node->inputs();
308   VectorRef args;
309   if (inputs.size() <= 1) {
310     MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
311   }
312   auto fn = inputs[1];
313   if (!IsValueNode<FuncGraph>(fn)) {
314     MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph, but got:" << fn->ToString();
315   }
316   for (size_t i = 1; i < inputs.size(); i++) {
317     args.emplace_back(Ref(inputs[i]));
318   }
319   AddInst(Instruction::kPartial, args);
320 }
321 
AddMakeTuple(const CNodePtr & node)322 void CompileGraph::AddMakeTuple(const CNodePtr &node) {
323   MS_EXCEPTION_IF_NULL(node);
324   auto inputs = node->inputs();
325   VectorRef args;
326   for (size_t i = 1; i < inputs.size(); i++) {
327     args.emplace_back(Ref(inputs[i]));
328   }
329   AddInst(Instruction::kTuple, args);
330 }
331 
AddSwitch(const CNodePtr & node)332 void CompileGraph::AddSwitch(const CNodePtr &node) {
333   MS_EXCEPTION_IF_NULL(node);
334   auto inputs = node->inputs();
335   if (inputs.size() < kSwitchInputSize) {
336     MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
337   }
338   VectorRef args;
339   args.emplace_back(Ref(inputs[kPartialGraphIndex]));
340   args.emplace_back(Ref(inputs[kSwitchTrueBranchIndex]));
341   args.emplace_back(Ref(inputs[kSwitchFalseBranchIndex]));
342   AddInst(Instruction::kSwitch, args);
343 }
344 
AddSwitchLayer(const CNodePtr & node)345 void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
346   MS_EXCEPTION_IF_NULL(node);
347   auto inputs = node->inputs();
348   if (inputs.size() != kSwitchLayerInputSize) {
349     MS_LOG(EXCEPTION) << "Switch layer must have index and branches.";
350   }
351   VectorRef args;
352   const size_t cond_index = 1;
353   const size_t tuple_index = 2;
354   args.emplace_back(Ref(inputs[cond_index]));
355   args.emplace_back(Ref(inputs[tuple_index]));
356   AddInst(Instruction::kSwitchLayer, args);
357 }
358 
AddReturn(const CNodePtr & node)359 void CompileGraph::AddReturn(const CNodePtr &node) {
360   MS_EXCEPTION_IF_NULL(node);
361   VectorRef args;
362   if (node->size() <= 1) {
363     MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
364   }
365   args.emplace_back(Ref(node->input(1)));
366   args.emplace_back(height_);
367   AddInst(Instruction::kReturn, args);
368 }
369 
AddPrimitive(const CNodePtr & node,const PrimitivePtr & prim)370 void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
371   MS_EXCEPTION_IF_NULL(node);
372   auto inputs = node->inputs();
373   VectorRef args;
374   args.push_back(prim);
375   for (size_t i = 1; i < inputs.size(); i++) {
376     args.emplace_back(Ref(inputs[i]));
377   }
378   AddInst(Instruction::kPrim, args);
379 }
380 
AddCall(const FuncGraphPtr & graph,const CNodePtr & node)381 int64_t CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
382   MS_EXCEPTION_IF_NULL(graph);
383   MS_EXCEPTION_IF_NULL(node);
384   auto inputs = node->inputs();
385   if (inputs.empty()) {
386     MS_LOG(EXCEPTION) << "The node->inputs() is empty.";
387   }
388   AnfNodePtr fn = inputs[0];
389   (void)Ref(fn);
390   size_t size = inputs.size();
391   for (size_t i = size - 1; i > 0; i--) {
392     AddInput(inputs[i]);
393   }
394   if (node == graph->output()) {
395     AddTailCall(fn, size);
396     return RET_BREAK;
397   }
398   MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
399   AddInst(Instruction::kCall, Ref(fn));
400   Ret(static_cast<int64_t>(size - 1));
401 
402   for (size_t i = size - 1; i > 0; i--) {
403     const auto iter = slots_.find(inputs[i]);
404     if (iter != slots_.end() && iter->second >= height_) {
405       (void)slots_.erase(inputs[i]);
406     }
407   }
408   return RET_SUCCESS;
409 }
410 
AddExternal(const LinConvertResult & result)411 void CompileGraph::AddExternal(const LinConvertResult &result) {
412   VectorRef args;
413   args.push_back(result.run);
414   args.push_back(result.simu_run);
415   size_t size = result.inputs.size();
416   for (size_t i = 0; i < size; i++) {
417     args.emplace_back(Ref(result.inputs[i]));
418   }
419   AddInst(Instruction::kExternal, args);
420   for (auto &out : result.outputs) {
421     Push(out);
422   }
423 }
424 
TraverseGraphMap(const FuncGraphManagerPtr & manager_ptr,FuncGraphTransaction * tr,const FuncGraphSet & fgs,const std::function<std::shared_ptr<FuncGraph> (const PrimitivePtr,const AbstractFunctionPtr)> & get_prim_graph)425 void TraverseGraphMap(
426   const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *tr, const FuncGraphSet &fgs,
427   const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
428   MS_EXCEPTION_IF_NULL(manager_ptr);
429   MS_EXCEPTION_IF_NULL(tr);
430   for (const auto &fg : fgs) {
431     MS_EXCEPTION_IF_NULL(fg);
432     for (const auto &ct_any : fg->value_nodes()) {
433       AnfNodePtr const_primitive_node = ct_any.first;
434       if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
435         auto users = manager_ptr->node_users()[const_primitive_node];
436         for (auto &use : users) {
437           CNodePtr node = use.first->cast<CNodePtr>();
438           MS_EXCEPTION_IF_NULL(node);
439           if (node->func_graph() != fg) {
440             continue;
441           }
442           int64_t key = use.second;
443           if (key != 0) {
444             MS_EXCEPTION_IF_NULL(node->input(0));
445             bool key_is_const = node->input(0)->isa<ValueNode>();
446             PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
447             if (value != nullptr) {
448               bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
449               bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
450               if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
451                 continue;
452               }
453             }
454             FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
455                                             dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
456             tr->SetEdge(node, key, NewValueNode(g));
457           }
458         }
459       }
460     }
461   }
462 }
463 
WrapPrimitives(const FuncGraphPtr & graph)464 FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
465   MS_EXCEPTION_IF_NULL(graph);
466   FuncGraphManagerPtr manager_ptr = graph->manager();
467   MS_EXCEPTION_IF_NULL(manager_ptr);
468   MapPrimTypeFuncGraph prim_graphs;
469   const auto &get_prim_graph = [&prim_graphs](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
470     PrimTypePair prim_type = std::make_pair(prim, type);
471     if (prim_graphs.end() == prim_graphs.find(prim_type)) {
472       FuncGraphPtr g = std::make_shared<FuncGraph>();
473       std::vector<AnfNodePtr> args;
474       ValueNodePtr prim_ct = NewValueNode(prim);
475       MS_EXCEPTION_IF_NULL(prim_ct);
476       prim_ct->set_abstract(type);
477       args.push_back(prim_ct);
478       MS_EXCEPTION_IF_NULL(type);
479       TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
480       if (tp == nullptr) {
481         MS_LOG(INTERNAL_EXCEPTION) << "Not TypedPrimitiveAbstractClosure, but got " << type->GetUnique()->ToString();
482       }
483       MS_EXCEPTION_IF_NULL(g);
484       for (const auto &t : tp->args_abs_list()) {
485         ParameterPtr p = g->add_parameter();
486         p->set_abstract(t);
487         args.push_back(p);
488       }
489       AnfNodePtr out = g->NewCNode(args);
490       out->set_abstract(tp->output());
491       g->set_output(out);
492       prim_graphs[prim_type] = g;
493     }
494 
495     return prim_graphs[prim_type];
496   };
497 
498   FuncGraphTransaction tr = manager_ptr->Transact();
499   auto &fgs = manager_ptr->func_graphs();
500   TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
501   tr.Commit();
502 
503   return graph;
504 }
505 
CompileGraphs(const BackendPtr & backend,const std::vector<PrimitivePtr> & cut_list)506 CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
507   MS_EXCEPTION_IF_NULL(backend);
508   MS_LOG(DEBUG) << "Start vm: " << backend->name();
509   transform_ = std::make_shared<CompileGraph>(backend, cut_list);
510   Reset();
511 }
512 
513 // Convert graphs to unlinked instructions.
Compile(const FuncGraphPtr & graph)514 void CompileGraphs::Compile(const FuncGraphPtr &graph) {
515   MS_LOG(DEBUG) << "Start";
516   mapping_[graph] = static_cast<int64_t>(insts_.size());
517   if (transform_ != nullptr) {
518     InstSet insts = transform_->Run(graph);
519     if (!insts.empty()) {
520       (void)insts_.insert(insts_.cend(), insts.cbegin(), insts.cend());
521     }
522   }
523   MS_LOG(DEBUG) << "End";
524 }
525 
526 // Link instructions from multiple function graphs together.
Link()527 FinalVMPtr CompileGraphs::Link() {
528   MS_LOG(DEBUG) << "Start";
529   for (std::size_t i = 0; i < insts_.size(); i++) {
530     InstType inst = insts_[i];
531     MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
532     if (Instruction::kGraph == inst.first) {
533       if (inst.second.empty()) {
534         MS_LOG(EXCEPTION) << "The second element of inst is empty";
535       }
536       FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
537       MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
538       insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
539     }
540   }
541 
542   FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
543   MS_LOG(DEBUG) << "End";
544   return rt;
545 }
546 
547 // Convert all graphs to unlinked instructions and link them.
CompileAndLink(const FuncGraphPtr & graph)548 FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
549   MS_EXCEPTION_IF_NULL(graph);
550   MS_LOG(DEBUG) << "Start";
551   Reset();
552   MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
553 
554   FuncGraphPtr prim_graph = WrapPrimitives(graph);
555   Compile(prim_graph);
556   MS_EXCEPTION_IF_NULL(prim_graph);
557   MS_EXCEPTION_IF_NULL(prim_graph->manager());
558   FuncGraphSet graphs = prim_graph->manager()->func_graphs();
559   for (const auto &g : graphs) {
560     if (g != graph && g != nullptr) {
561       Compile(g);
562     }
563   }
564 
565   FinalVMPtr rt = Link();
566   Reset();
567   MS_LOG(DEBUG) << "End";
568   return rt;
569 }
570 
CreateBackend()571 BackendPtr CreateBackend() {
572   auto context_ptr = MsContext::GetInstance();
573   MS_EXCEPTION_IF_NULL(context_ptr);
574   std::string name = context_ptr->backend_policy();
575   MS_LOG(INFO) << "CreateBackend is: " << name;
576   context_ptr->Refresh();
577 
578   if (name == kMsConvert || name == kGeVm) {
579     std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
580     uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
581     BackendPtr backend = nullptr;
582     // Create MindRTBackend or MsBackend according to whether mindrt is used.
583     if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
584       backend = std::make_shared<MindRTBackend>(name, target, device_id);
585     } else {
586       backend = std::make_shared<MsBackend>(name, target, device_id);
587     }
588     if (target == kAscendDevice && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
589       backend->set_is_multi_graph_sink(false);
590     }
591     return backend;
592   }
593 
594   return std::make_shared<Backend>(name);
595 }
596 
SetMindRTEnable()597 void SetMindRTEnable() {
598   auto context_ptr = MsContext::GetInstance();
599   MS_EXCEPTION_IF_NULL(context_ptr);
600   MS_LOG(DEBUG) << "Enable mindRT.";
601   context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, true);
602 }
603 }  // namespace compile
604 }  // namespace mindspore
605