1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "backend/graph_compiler/transform.h"
20
21 #include <algorithm>
22 #include <map>
23 #include <queue>
24 #include <string>
25 #include <vector>
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/nn_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "abstract/abstract_value.h"
31 #include "abstract/abstract_function.h"
32 #include "ir/graph_utils.h"
33 #include "utils/ms_context.h"
34 #include "utils/trace_base.h"
35 #if defined(__linux__) && defined(WITH_BACKEND)
36 #include "include/backend/distributed/ps/ps_context.h"
37 #endif
38
39 namespace mindspore {
40 namespace compile {
41 using mindspore::abstract::AbstractFunction;
42 using mindspore::abstract::AbstractFunctionPtr;
43 using PrimTypePair = std::pair<PrimitivePtr, AbstractFunctionPtr>;
44 using MapPrimTypeFuncGraph = std::map<PrimTypePair, FuncGraphPtr>;
45 using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiveAbstractClosure>;
46
GetNonlinearOps()47 const std::vector<PrimitivePtr> &GetNonlinearOps() {
48 static std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
49 prim::kPrimMakeTuple, prim::kPrimBpropCut};
50 return nonlinear_ops;
51 }
52
GetControlOps()53 const std::vector<PrimitivePtr> &GetControlOps() {
54 static std::vector<PrimitivePtr> control_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
55 prim::kPrimMakeTuple, prim::kPrimSwitchLayer};
56 return control_ops;
57 }
58
GetMsNonlinearOps()59 const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
60 static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial,
61 prim::kPrimSwitch, prim::kPrimMakeTuple,
62 prim::kPrimBpropCut, prim::kPrimSwitchLayer};
63 return ms_nonlinear_ops;
64 }
65
CompileGraph(const BackendPtr & backend,const std::vector<PrimitivePtr> & cut_list)66 CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
67 MS_EXCEPTION_IF_NULL(backend_);
68 lin_convert_ = backend_->convert_fn();
69 if (lin_convert_ == nullptr) {
70 MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name();
71 }
72 graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend->name());
73 }
74
75 // Push the value node on the stack.
Push(const AnfNodePtr & node)76 void CompileGraph::Push(const AnfNodePtr &node) {
77 MS_EXCEPTION_IF_NULL(node);
78 if (slots_.count(node) > 0) {
79 MS_LOG(WARNING) << "Push failed node in slots:" << node->DebugString()
80 << " NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
81 return;
82 }
83 MS_LOG(DEBUG) << "Push node: " << node->DebugString(true) << " height_: " << height_
84 << " is parameter: " << node->isa<Parameter>();
85 slots_[node] = height_;
86 set_height(height_ + 1);
87 }
88
AddInst(const Instruction & inst,const int64_t & arg)89 void CompileGraph::AddInst(const Instruction &inst, const int64_t &arg) {
90 VectorRef args;
91 args.push_back(arg);
92 AddInst(inst, args);
93 }
94
AddInst(const Instruction & inst,const ValuePtr & arg)95 void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) {
96 VectorRef args;
97 args.push_back(arg);
98 AddInst(inst, args);
99 }
100
AddInst(const Instruction & inst,const VectorRef & args)101 void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) {
102 inst_.push_back(std::make_pair(inst, args));
103 }
104
105 // Gets the stack reference for the node value. If the node is a constant,
106 // it may actually cause the push in to not be mentioned before.
Ref(const AnfNodePtr & node)107 int64_t CompileGraph::Ref(const AnfNodePtr &node) {
108 MS_EXCEPTION_IF_NULL(node);
109 MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_;
110 if (slots_.count(node) == 0 && node->isa<ValueNode>()) {
111 if (IsValueNode<FuncGraph>(node)) {
112 MS_LOG(DEBUG) << "Push graph.";
113 AddInst(Instruction::kGraph, GetValueNode(node));
114 } else {
115 MS_LOG(DEBUG) << "Push.";
116 if (IsValueNode<Primitive>(node)) {
117 MS_LOG(EXCEPTION) << "must not be primitive in here NodeInfo: " << trace::GetDebugInfoStr(node->debug_info());
118 } else {
119 AddInst(Instruction::kPush, GetValueNode(node));
120 }
121 }
122 Push(node);
123 }
124 MS_LOG(DEBUG) << "End Ref node end height_: " << height_ << ", slots: " << slots_[node]
125 << ", return: " << slots_[node] - height_;
126 return slots_[node] - height_;
127 }
128
129 // Make sure the value of node is at the top of the stack.
AddInput(const AnfNodePtr & node)130 void CompileGraph::AddInput(const AnfNodePtr &node) {
131 MS_EXCEPTION_IF_NULL(node);
132 if (slots_.count(node) == 0) {
133 MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true);
134 (void)Ref(node);
135 return;
136 }
137 AddInst(Instruction::kInput, Ref(node));
138 set_height(height_ + 1);
139 }
140
141 // Call back effect in stack
Ret(int64_t nargs)142 void CompileGraph::Ret(int64_t nargs) { set_height(height_ - nargs); }
143
PushParameters(const FuncGraphPtr & graph)144 void CompileGraph::PushParameters(const FuncGraphPtr &graph) {
145 MS_EXCEPTION_IF_NULL(graph);
146 std::vector<AnfNodePtr> parameters = graph->parameters();
147 for (size_t i = parameters.size(); i != 0; i--) {
148 MS_EXCEPTION_IF_NULL(parameters[i - 1]);
149 Push(parameters[i - 1]);
150 MS_LOG(DEBUG) << "Push parameter " << (i - 1) << ": " << parameters[i - 1]->DebugString(true);
151 }
152 }
153
LinConvert(const FuncGraphPtr & graph,const GraphSegmentPtr & segment,const std::string & target)154 int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) {
155 MS_EXCEPTION_IF_NULL(segment);
156 MS_LOG(DEBUG) << "LinConvert start";
157 LinConvertResult result;
158
159 result = lin_convert_(segment, target);
160
161 if (result.run == nullptr) {
162 MS_LOG(ERROR) << "LinConvert failed";
163 return RET_FAILED;
164 }
165
166 if (!(*result.run)) {
167 if (result.inputs.size() != result.outputs.size()) {
168 MS_EXCEPTION_IF_NULL(graph);
169 MS_LOG(EXCEPTION) << "must inputs equal outputs NodeInfo: " << trace::GetDebugInfoStr(graph->debug_info());
170 } else {
171 size_t size = result.inputs.size();
172 for (size_t i = 0; i < size; i++) {
173 Tie(result.inputs[i], result.outputs[i]);
174 }
175 return RET_CONTINUE;
176 }
177 }
178 AddExternal(result);
179
180 return RET_SUCCESS;
181 }
182
InterpretNode(const FuncGraphPtr & graph,const CNodePtr & node)183 int64_t CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) {
184 MS_EXCEPTION_IF_NULL(node);
185 MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true);
186 std::vector<AnfNodePtr> node_inputs = node->inputs();
187 if (node_inputs.empty()) {
188 MS_LOG(EXCEPTION) << "The node->inputs() is empty";
189 }
190 AnfNodePtr fn = node_inputs[0];
191 if (IsValueNode<Primitive>(fn)) {
192 PrimitivePtr value = GetValueNode<PrimitivePtr>(fn);
193 MS_LOG(DEBUG) << "The fn is primitive " << (*value).name();
194 for (size_t i = node_inputs.size() - 1; i > 0; i--) {
195 AddInput(node->input(i));
196 }
197 if (IsPrimitive(fn, prim::kPrimReturn)) {
198 AddReturn(node);
199 return RET_BREAK;
200 }
201 if (IsPrimitive(fn, prim::kPrimPartial)) {
202 AddPartial(node);
203 } else if (IsPrimitive(fn, prim::kPrimSwitch)) {
204 AddSwitch(node);
205 } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) {
206 AddSwitchLayer(node);
207 } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) {
208 AddMakeTuple(node);
209 } else {
210 AddPrimitive(node, value);
211 }
212 } else {
213 int64_t ret = AddCall(graph, node);
214 if (ret == RET_BREAK) {
215 return ret;
216 }
217 }
218 Push(node);
219 return RET_SUCCESS;
220 }
221
Compile(const FuncGraphPtr & graph)222 bool CompileGraph::Compile(const FuncGraphPtr &graph) {
223 MS_LOG(DEBUG) << "Start split graph";
224 MS_EXCEPTION_IF_NULL(graph);
225 MS_EXCEPTION_IF_NULL(graph_partition_);
226 auto segments = graph_partition_->Partition(graph);
227
228 MS_LOG(DEBUG) << "Split nodes size:" << segments.size();
229 for (auto &segment : segments) {
230 MS_EXCEPTION_IF_NULL(segment);
231 int64_t ret = RET_SUCCESS;
232 if (!segment->is_cut_) {
233 MS_LOG(DEBUG) << "Start a extern LinConvert";
234 if (!segment->nodes_.empty()) {
235 std::string cur_target = GetCNodeTarget(segment->nodes_[0]);
236 ret = LinConvert(graph, segment, cur_target);
237 } else {
238 ret = LinConvert(graph, segment);
239 }
240 MS_LOG(DEBUG) << "End a extern LinConvert";
241 if (ret == RET_FAILED) {
242 return false;
243 }
244 if (ret == RET_CONTINUE) {
245 continue;
246 }
247 } else if (!segment->nodes_.empty()) {
248 MS_LOG(DEBUG) << "Start a cut node";
249 auto &cut_node = segment->nodes_[0];
250 MS_EXCEPTION_IF_NULL(cut_node);
251 if (!cut_node->isa<CNode>()) {
252 MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfoStr(graph->debug_info());
253 }
254 auto node = cut_node->cast<CNodePtr>();
255 ret = InterpretNode(graph, node);
256 MS_LOG(DEBUG) << "End a cut node";
257 if (ret == RET_BREAK) {
258 break;
259 }
260 }
261 }
262 MS_LOG(DEBUG) << "End split graph";
263 return true;
264 }
265
Run(const FuncGraphPtr & graph)266 InstSet CompileGraph::Run(const FuncGraphPtr &graph) {
267 MS_EXCEPTION_IF_NULL(graph);
268
269 Reset();
270 PushParameters(graph);
271
272 int64_t param_height = height_;
273 MS_EXCEPTION_IF_NULL(graph->get_return());
274 MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true);
275
276 if (!Compile(graph)) {
277 return inst_;
278 }
279
280 AddPadStack(param_height);
281 auto ret = inst_;
282 Reset();
283 return ret;
284 }
285
AddPadStack(int64_t param_height)286 void CompileGraph::AddPadStack(int64_t param_height) {
287 int64_t stack_sizes = max_height_ - param_height;
288 MS_LOG(DEBUG) << "Pad stack max_height_:" << max_height_ << " param:" << param_height
289 << " need_stack:" << stack_sizes;
290 if (stack_sizes > 0) {
291 VectorRef need_stacks({stack_sizes});
292 (void)inst_.insert(inst_.cbegin(), std::make_pair(Instruction::kPadStack, need_stacks));
293 }
294 }
295
AddTailCall(const AnfNodePtr & fn,size_t size)296 void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) {
297 VectorRef args;
298 args.emplace_back(Ref(fn));
299 args.emplace_back(height_);
300 args.emplace_back(static_cast<int64_t>(size - 1));
301 MS_LOG(DEBUG) << "Tail call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
302 AddInst(Instruction::kTailCall, args);
303 }
304
AddPartial(const CNodePtr & node)305 void CompileGraph::AddPartial(const CNodePtr &node) {
306 MS_EXCEPTION_IF_NULL(node);
307 auto inputs = node->inputs();
308 VectorRef args;
309 if (inputs.size() <= 1) {
310 MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
311 }
312 auto fn = inputs[1];
313 if (!IsValueNode<FuncGraph>(fn)) {
314 MS_LOG(EXCEPTION) << "The type of 1st input of node must be FuncGraph, but got:" << fn->ToString();
315 }
316 for (size_t i = 1; i < inputs.size(); i++) {
317 args.emplace_back(Ref(inputs[i]));
318 }
319 AddInst(Instruction::kPartial, args);
320 }
321
AddMakeTuple(const CNodePtr & node)322 void CompileGraph::AddMakeTuple(const CNodePtr &node) {
323 MS_EXCEPTION_IF_NULL(node);
324 auto inputs = node->inputs();
325 VectorRef args;
326 for (size_t i = 1; i < inputs.size(); i++) {
327 args.emplace_back(Ref(inputs[i]));
328 }
329 AddInst(Instruction::kTuple, args);
330 }
331
AddSwitch(const CNodePtr & node)332 void CompileGraph::AddSwitch(const CNodePtr &node) {
333 MS_EXCEPTION_IF_NULL(node);
334 auto inputs = node->inputs();
335 if (inputs.size() < kSwitchInputSize) {
336 MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4";
337 }
338 VectorRef args;
339 args.emplace_back(Ref(inputs[kPartialGraphIndex]));
340 args.emplace_back(Ref(inputs[kSwitchTrueBranchIndex]));
341 args.emplace_back(Ref(inputs[kSwitchFalseBranchIndex]));
342 AddInst(Instruction::kSwitch, args);
343 }
344
AddSwitchLayer(const CNodePtr & node)345 void CompileGraph::AddSwitchLayer(const CNodePtr &node) {
346 MS_EXCEPTION_IF_NULL(node);
347 auto inputs = node->inputs();
348 if (inputs.size() != kSwitchLayerInputSize) {
349 MS_LOG(EXCEPTION) << "Switch layer must have index and branches.";
350 }
351 VectorRef args;
352 const size_t cond_index = 1;
353 const size_t tuple_index = 2;
354 args.emplace_back(Ref(inputs[cond_index]));
355 args.emplace_back(Ref(inputs[tuple_index]));
356 AddInst(Instruction::kSwitchLayer, args);
357 }
358
AddReturn(const CNodePtr & node)359 void CompileGraph::AddReturn(const CNodePtr &node) {
360 MS_EXCEPTION_IF_NULL(node);
361 VectorRef args;
362 if (node->size() <= 1) {
363 MS_LOG(EXCEPTION) << "The node:" << node->DebugString() << "do not have two input.";
364 }
365 args.emplace_back(Ref(node->input(1)));
366 args.emplace_back(height_);
367 AddInst(Instruction::kReturn, args);
368 }
369
AddPrimitive(const CNodePtr & node,const PrimitivePtr & prim)370 void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) {
371 MS_EXCEPTION_IF_NULL(node);
372 auto inputs = node->inputs();
373 VectorRef args;
374 args.push_back(prim);
375 for (size_t i = 1; i < inputs.size(); i++) {
376 args.emplace_back(Ref(inputs[i]));
377 }
378 AddInst(Instruction::kPrim, args);
379 }
380
AddCall(const FuncGraphPtr & graph,const CNodePtr & node)381 int64_t CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) {
382 MS_EXCEPTION_IF_NULL(graph);
383 MS_EXCEPTION_IF_NULL(node);
384 auto inputs = node->inputs();
385 if (inputs.empty()) {
386 MS_LOG(EXCEPTION) << "The node->inputs() is empty.";
387 }
388 AnfNodePtr fn = inputs[0];
389 (void)Ref(fn);
390 size_t size = inputs.size();
391 for (size_t i = size - 1; i > 0; i--) {
392 AddInput(inputs[i]);
393 }
394 if (node == graph->output()) {
395 AddTailCall(fn, size);
396 return RET_BREAK;
397 }
398 MS_LOG(DEBUG) << "Call:" << Ref(fn) << ", " << height_ << ", " << (size - 1);
399 AddInst(Instruction::kCall, Ref(fn));
400 Ret(static_cast<int64_t>(size - 1));
401
402 for (size_t i = size - 1; i > 0; i--) {
403 const auto iter = slots_.find(inputs[i]);
404 if (iter != slots_.end() && iter->second >= height_) {
405 (void)slots_.erase(inputs[i]);
406 }
407 }
408 return RET_SUCCESS;
409 }
410
AddExternal(const LinConvertResult & result)411 void CompileGraph::AddExternal(const LinConvertResult &result) {
412 VectorRef args;
413 args.push_back(result.run);
414 args.push_back(result.simu_run);
415 size_t size = result.inputs.size();
416 for (size_t i = 0; i < size; i++) {
417 args.emplace_back(Ref(result.inputs[i]));
418 }
419 AddInst(Instruction::kExternal, args);
420 for (auto &out : result.outputs) {
421 Push(out);
422 }
423 }
424
TraverseGraphMap(const FuncGraphManagerPtr & manager_ptr,FuncGraphTransaction * tr,const FuncGraphSet & fgs,const std::function<std::shared_ptr<FuncGraph> (const PrimitivePtr,const AbstractFunctionPtr)> & get_prim_graph)425 void TraverseGraphMap(
426 const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *tr, const FuncGraphSet &fgs,
427 const std::function<std::shared_ptr<FuncGraph>(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) {
428 MS_EXCEPTION_IF_NULL(manager_ptr);
429 MS_EXCEPTION_IF_NULL(tr);
430 for (const auto &fg : fgs) {
431 MS_EXCEPTION_IF_NULL(fg);
432 for (const auto &ct_any : fg->value_nodes()) {
433 AnfNodePtr const_primitive_node = ct_any.first;
434 if (const_primitive_node != nullptr && IsValueNode<Primitive>(const_primitive_node)) {
435 auto users = manager_ptr->node_users()[const_primitive_node];
436 for (auto &use : users) {
437 CNodePtr node = use.first->cast<CNodePtr>();
438 MS_EXCEPTION_IF_NULL(node);
439 if (node->func_graph() != fg) {
440 continue;
441 }
442 int64_t key = use.second;
443 if (key != 0) {
444 MS_EXCEPTION_IF_NULL(node->input(0));
445 bool key_is_const = node->input(0)->isa<ValueNode>();
446 PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0));
447 if (value != nullptr) {
448 bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name()));
449 bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name()));
450 if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) {
451 continue;
452 }
453 }
454 FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node),
455 dyn_cast<AbstractFunction>(const_primitive_node->abstract()));
456 tr->SetEdge(node, key, NewValueNode(g));
457 }
458 }
459 }
460 }
461 }
462 }
463
WrapPrimitives(const FuncGraphPtr & graph)464 FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) {
465 MS_EXCEPTION_IF_NULL(graph);
466 FuncGraphManagerPtr manager_ptr = graph->manager();
467 MS_EXCEPTION_IF_NULL(manager_ptr);
468 MapPrimTypeFuncGraph prim_graphs;
469 const auto &get_prim_graph = [&prim_graphs](const PrimitivePtr &prim, const AbstractFunctionPtr &type) {
470 PrimTypePair prim_type = std::make_pair(prim, type);
471 if (prim_graphs.end() == prim_graphs.find(prim_type)) {
472 FuncGraphPtr g = std::make_shared<FuncGraph>();
473 std::vector<AnfNodePtr> args;
474 ValueNodePtr prim_ct = NewValueNode(prim);
475 MS_EXCEPTION_IF_NULL(prim_ct);
476 prim_ct->set_abstract(type);
477 args.push_back(prim_ct);
478 MS_EXCEPTION_IF_NULL(type);
479 TypedPrimitiveAbstractClosurePtr tp = dyn_cast<abstract::TypedPrimitiveAbstractClosure>(type->GetUnique());
480 if (tp == nullptr) {
481 MS_LOG(INTERNAL_EXCEPTION) << "Not TypedPrimitiveAbstractClosure, but got " << type->GetUnique()->ToString();
482 }
483 MS_EXCEPTION_IF_NULL(g);
484 for (const auto &t : tp->args_abs_list()) {
485 ParameterPtr p = g->add_parameter();
486 p->set_abstract(t);
487 args.push_back(p);
488 }
489 AnfNodePtr out = g->NewCNode(args);
490 out->set_abstract(tp->output());
491 g->set_output(out);
492 prim_graphs[prim_type] = g;
493 }
494
495 return prim_graphs[prim_type];
496 };
497
498 FuncGraphTransaction tr = manager_ptr->Transact();
499 auto &fgs = manager_ptr->func_graphs();
500 TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph);
501 tr.Commit();
502
503 return graph;
504 }
505
CompileGraphs(const BackendPtr & backend,const std::vector<PrimitivePtr> & cut_list)506 CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list) : backend_(backend) {
507 MS_EXCEPTION_IF_NULL(backend);
508 MS_LOG(DEBUG) << "Start vm: " << backend->name();
509 transform_ = std::make_shared<CompileGraph>(backend, cut_list);
510 Reset();
511 }
512
513 // Convert graphs to unlinked instructions.
Compile(const FuncGraphPtr & graph)514 void CompileGraphs::Compile(const FuncGraphPtr &graph) {
515 MS_LOG(DEBUG) << "Start";
516 mapping_[graph] = static_cast<int64_t>(insts_.size());
517 if (transform_ != nullptr) {
518 InstSet insts = transform_->Run(graph);
519 if (!insts.empty()) {
520 (void)insts_.insert(insts_.cend(), insts.cbegin(), insts.cend());
521 }
522 }
523 MS_LOG(DEBUG) << "End";
524 }
525
526 // Link instructions from multiple function graphs together.
Link()527 FinalVMPtr CompileGraphs::Link() {
528 MS_LOG(DEBUG) << "Start";
529 for (std::size_t i = 0; i < insts_.size(); i++) {
530 InstType inst = insts_[i];
531 MS_LOG(DEBUG) << "Link point:" << inst_str[inst.first];
532 if (Instruction::kGraph == inst.first) {
533 if (inst.second.empty()) {
534 MS_LOG(EXCEPTION) << "The second element of inst is empty";
535 }
536 FuncGraphPtr func_graph = utils::cast<ValuePtr>(inst.second[0])->cast<FuncGraphPtr>();
537 MS_LOG(DEBUG) << "Link graph:" << func_graph->ToString();
538 insts_[i] = std::make_pair(Instruction::kPush, VectorRef(std::vector<BaseRef>{mapping_[func_graph]}));
539 }
540 }
541
542 FinalVMPtr rt = std::make_shared<FinalVM>(insts_, backend_);
543 MS_LOG(DEBUG) << "End";
544 return rt;
545 }
546
547 // Convert all graphs to unlinked instructions and link them.
CompileAndLink(const FuncGraphPtr & graph)548 FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) {
549 MS_EXCEPTION_IF_NULL(graph);
550 MS_LOG(DEBUG) << "Start";
551 Reset();
552 MS_LOG(DEBUG) << "Begin parameter:" << graph->parameters().size();
553
554 FuncGraphPtr prim_graph = WrapPrimitives(graph);
555 Compile(prim_graph);
556 MS_EXCEPTION_IF_NULL(prim_graph);
557 MS_EXCEPTION_IF_NULL(prim_graph->manager());
558 FuncGraphSet graphs = prim_graph->manager()->func_graphs();
559 for (const auto &g : graphs) {
560 if (g != graph && g != nullptr) {
561 Compile(g);
562 }
563 }
564
565 FinalVMPtr rt = Link();
566 Reset();
567 MS_LOG(DEBUG) << "End";
568 return rt;
569 }
570
CreateBackend()571 BackendPtr CreateBackend() {
572 auto context_ptr = MsContext::GetInstance();
573 MS_EXCEPTION_IF_NULL(context_ptr);
574 std::string name = context_ptr->backend_policy();
575 MS_LOG(INFO) << "CreateBackend is: " << name;
576 context_ptr->Refresh();
577
578 if (name == kMsConvert || name == kGeVm) {
579 std::string target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
580 uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
581 BackendPtr backend = nullptr;
582 // Create MindRTBackend or MsBackend according to whether mindrt is used.
583 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
584 backend = std::make_shared<MindRTBackend>(name, target, device_id);
585 } else {
586 backend = std::make_shared<MsBackend>(name, target, device_id);
587 }
588 if (target == kAscendDevice && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
589 backend->set_is_multi_graph_sink(false);
590 }
591 return backend;
592 }
593
594 return std::make_shared<Backend>(name);
595 }
596
SetMindRTEnable()597 void SetMindRTEnable() {
598 auto context_ptr = MsContext::GetInstance();
599 MS_EXCEPTION_IF_NULL(context_ptr);
600 MS_LOG(DEBUG) << "Enable mindRT.";
601 context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, true);
602 }
603 } // namespace compile
604 } // namespace mindspore
605