• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "backend/graph_compiler/vmimpl.h"
20 
21 #include <algorithm>
22 #include <exception>
23 #include <vector>
24 #include <memory>
25 
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "frontend/operator/ops.h"
29 #include "ir/manager.h"
30 #include "ir/func_graph_cloner.h"
31 #include "include/common/utils/convert_utils.h"
32 #include "include/common/utils/primitive_utils.h"
33 
34 namespace mindspore {
35 namespace compile {
36 
37 // Indicate a call to a new frame.
38 struct CallWrap final : public Base {
CallWrapmindspore::compile::CallWrap39   explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {}
40   MS_DECLARE_PARENT(CallWrap, Base);
41   VMFramePtr frame{nullptr};
42 };
43 using CallWrapPtr = std::shared_ptr<CallWrap>;
44 
45 // Indicates a return with its value.
46 struct ReturnWrap final : public Base {
ReturnWrapmindspore::compile::ReturnWrap47   explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {}
48   MS_DECLARE_PARENT(ReturnWrap, Base);
49   BaseRef value{BaseRef()};
50 };
51 using ReturnWrapPtr = std::shared_ptr<ReturnWrap>;
52 
VMFrame(const AnfNodePtrList & nodes,const AnfNodePtrToBaseRefMap & values,const AnfNodePtrToBaseRefMap & closure)53 VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values,
54                  const AnfNodePtrToBaseRefMap &closure)
55     : values_(values), todo_(nodes), closure_(closure) {
56   std::reverse(std::begin(todo_), std::end(todo_));
57 }
58 
operator [](const AnfNodePtr & node)59 const BaseRef VMFrame::operator[](const AnfNodePtr &node) {
60   MS_EXCEPTION_IF_NULL(node);
61   auto ret = values_.find(node);
62   if (ret != values_.end()) {
63     return ret->second;
64   }
65 
66   ret = closure_.find(node);
67   if (ret != closure_.end()) {
68     return ret->second;
69   }
70 
71   if (node->isa<ValueNode>()) {
72     return GetValueNode(node);
73   }
74 
75   MS_LOG(EXCEPTION) << "ValueError " << node->type_name();
76 }
77 
Closure(const FuncGraphPtr & graph,const AnfNodePtrToBaseRefMap & values)78 Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values)
79     : func_graph_(graph), values_(values) {}
80 
operator ()(const VectorRef & args)81 BaseRef Closure::operator()(const VectorRef &args) {
82   MS_LOG(DEBUG) << "Start closure";
83   MS_EXCEPTION_IF_NULL(vm_);
84   return vm_->Evaluate(func_graph_, args, values_);
85 }
86 
Partial(const BaseRef & fn,const VectorRef & args,const VMPtr & vm)87 Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {}
88 
operator ()(const VectorRef & nodes)89 BaseRef Partial::operator()(const VectorRef &nodes) {
90   VectorRef arglist;
91   (void)arglist.insert(arglist.end(), args_.begin(), args_.end());
92   (void)arglist.insert(arglist.end(), nodes.begin(), nodes.end());
93   MS_EXCEPTION_IF_NULL(vm_);
94   return vm_->Call(fn_, arglist);
95 }
96 
ComputeFvs(const FuncGraphPtr & graph) const97 SetRef VM::ComputeFvs(const FuncGraphPtr &graph) const {
98   MS_EXCEPTION_IF_NULL(graph);
99   SetRef rval;
100   for (auto &fkv : graph->free_variables_total()) {
101     if (utils::isa<FuncGraphPtr>(fkv.first)) {
102       // Add all value_nodes of g that refer to a fv graph
103       auto g = utils::cast<FuncGraphPtr>(fkv.first);
104       for (auto &ctkv : g->value_nodes()) {
105         auto ct = ctkv.first;
106         if (GetValueNode(ct) == g) {
107           (void)rval.insert(ct);
108         }
109       }
110     } else {
111       // Add a normal fv
112       (void)rval.insert(fkv.first);
113     }
114   }
115 
116   return rval;
117 }
118 
AcquireGraph(const FuncGraphPtr & graph)119 void VM::AcquireGraph(const FuncGraphPtr &graph) {
120   // Already acquired
121   if (vars_.find(graph) != vars_.end()) {
122     return;
123   }
124   MS_EXCEPTION_IF_NULL(manager_);
125   // Add g to manager
126   manager_->AddFuncGraph(graph);
127   MS_EXCEPTION_IF_NULL(graph->manager());
128   // Compute fvs for all acquired graph
129   auto graphs = graph->manager()->func_graphs();
130   for (auto g = graphs.begin(); g != graphs.end(); ++g) {
131     vars_[*g] = ComputeFvs(*g);
132   }
133 }
134 
ExportSequence(const VectorRef & seq)135 VectorRef VM::ExportSequence(const VectorRef &seq) {
136   std::vector<BaseRef> ret;
137   (void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret),
138                        [&, this](const BaseRef &x) -> BaseRef { return Export(x); });
139   return VectorRef(ret);
140 }
141 
ExportClosure(const ClosurePtr & clos)142 ClosurePtr VM::ExportClosure(const ClosurePtr &clos) {
143   MS_EXCEPTION_IF_NULL(clos);
144   clos->set_vm(shared_from_this());
145   return clos;
146 }
147 
148 // transform graph to executable closure
ExportGraph(const FuncGraphPtr & g)149 ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) {
150   auto c = std::make_shared<Closure>(g, AnfNodePtrToBaseRefMap());
151   MS_EXCEPTION_IF_NULL(c);
152   c->set_vm(shared_from_this());
153   return c;
154 }
155 
ExportObj(const BaseRef & obj) const156 BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; }
157 
Export(const BaseRef & value)158 BaseRef VM::Export(const BaseRef &value) {
159   if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<FuncGraph>()) {
160     return ExportGraph(utils::cast<ValuePtr>(value)->cast<FuncGraphPtr>());
161   }
162 
163   if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<Primitive>()) {
164     return ExportPrimitive(utils::cast<ValuePtr>(value)->cast<PrimitivePtr>());
165   }
166 
167   if (utils::isa<FuncGraphPtr>(value)) {
168     return ExportGraph(utils::cast<FuncGraphPtr>(value));
169   }
170 
171   if (utils::isa<ClosurePtr>(value)) {
172     return ExportClosure(utils::cast<ClosurePtr>(value));
173   }
174 
175   if (utils::isa<PrimitivePtr>(value)) {
176     return ExportPrimitive(utils::cast<PrimitivePtr>(value));
177   }
178 
179   if (utils::isa<VectorRef>(value)) {
180     return ExportSequence(utils::cast<VectorRef>(value));
181   }
182 
183   return ExportObj(value);
184 }
185 
186 // Run a graph.
187 // This will evaluate the passed-in graph and return the resulting value.
Evaluate(const FuncGraphPtr & graph,const VectorRef & args,const AnfNodePtrToBaseRefMap & closure)188 BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) {
189   MS_EXCEPTION_IF_NULL(graph);
190   AcquireGraph(graph);
191   MS_LOG(DEBUG) << "Evalue arg size: " << args.size();
192   if (args.size() != graph->parameters().size()) {
193     MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graph->parameters().size() << ", but got "
194                       << args.size();
195   }
196 
197   // toposort graph nodes, the order will be reversed by frame so that the dependent be computed first
198   auto nodes = TopoSort(graph->get_return(), SuccVm(graph));
199   // mapping parameters to args
200   AnfNodePtrToBaseRefMap values;
201   for (size_t i = 0; i < args.size(); i++) {
202     values[graph->parameters()[i]] = args[i];
203   }
204   // create top frame with params initialized
205   VMFramePtrList frames{std::make_shared<VMFrame>(nodes, values, closure)};
206   // execute frames starting from top frame
207   while (!frames.empty()) {
208     auto frame = frames[frames.size() - 1];
209     MS_EXCEPTION_IF_NULL(frame);
210     auto todo = frame->todo();
211     while (!todo.empty()) {
212       auto except = HandleNode(todo[todo.size() - 1], frame);
213       if (utils::isa<CallWrapPtr>(except)) {
214         if (todo.size() == 2) {
215           // The last element is always a return, replace the ret with call frame
216           frames[frames.size() - 1] = utils::cast<CallWrapPtr>(except)->frame;
217         } else {
218           frames.push_back(utils::cast<CallWrapPtr>(except)->frame);
219         }
220         break;
221       }
222       if (utils::isa<ReturnWrapPtr>(except)) {
223         (void)frames.erase(frames.cbegin() + (static_cast<ssize_t>(frames.size()) - 1));
224         if (frames.size() > 0) {
225           auto top = frames[frames.size() - 1];
226           MS_EXCEPTION_IF_NULL(top);
227           auto td = top->todo();
228           // set value for top frame's last evaluated node
229           if (td.empty()) {
230             MS_LOG(EXCEPTION) << "The td is empty";
231           }
232           top->values()[td[td.size() - 1]] = utils::cast<ReturnWrapPtr>(except)->value;
233           (void)td.erase(td.cbegin() + (static_cast<ssize_t>(td.size()) - 1));
234         } else {
235           return Export(utils::cast<ReturnWrapPtr>(except)->value);
236         }
237         break;
238       }
239       (void)todo.erase(todo.cbegin() + (static_cast<ssize_t>(todo.size()) - 1));
240     }
241   }
242   MS_LOG(EXCEPTION) << "VM Evaluate error";
243 }
244 
SuccVm(const FuncGraphPtr & graph)245 SuccFunc VM::SuccVm(const FuncGraphPtr &graph) {
246   auto fn = [&, this](const AnfNodePtr &node) -> AnfNodeWeakPtrList {
247     MS_EXCEPTION_IF_NULL(node);
248     AnfNodeWeakPtrList res;
249 
250     // Follow node.incoming
251     if (node->isa<CNode>()) {
252       auto &inputs = node->cast<CNodePtr>()->weak_inputs();
253       for (auto &weak_input : inputs) {
254         auto i = weak_input.lock();
255         MS_EXCEPTION_IF_NULL(i);
256         if (i->func_graph() == node->func_graph() ||
257             (IsValueNode<FuncGraph>(i) && GetValueNode<FuncGraphPtr>(i)->parent() == graph)) {
258           res.push_back(i);
259         }
260       }
261     }
262 
263     // for subgraph input, add their fvs as succ nodes
264     if (IsValueNode<FuncGraph>(node) && GetValueNode<FuncGraphPtr>(node)->parent() == graph) {
265       auto fvs = utils::cast<SetRef>(vars_[GetValueNode<FuncGraphPtr>(node)]);
266       (void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(res),
267                            [](const BaseRef &value) -> AnfNodePtr { return utils::cast<AnfNodePtr>(value); });
268     }
269 
270     return res;
271   };
272   return fn;
273 }
274 
Call(const BaseRef & fn,const VectorRef & args)275 BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) {
276   if (utils::isa<PrimitivePtr>(fn)) {
277     return RunOperation(utils::cast<PrimitivePtr>(fn), args);
278   }
279 
280   if (utils::isa<FuncGraphPtr>(fn)) {
281     return Evaluate(utils::cast<FuncGraphPtr>(fn), args);
282   }
283 
284   if (utils::isa<ClosurePtr>(fn)) {
285     auto clos = utils::cast<ClosurePtr>(fn);
286     return Evaluate(clos->func_graph(), args, clos->values());
287   }
288 
289   MS_LOG(EXCEPTION) << "Can't call fn";
290 }
291 
292 // make call frame for graph
_Call(const BaseRef & graph,const VectorRef & args)293 BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) {
294   AnfNodePtrToBaseRefMap clos;
295   auto func_graph = graph;
296   if (utils::isa<ClosurePtr>(func_graph)) {
297     clos = utils::cast<ClosurePtr>(func_graph)->values();
298     func_graph = utils::cast<ClosurePtr>(func_graph)->func_graph();
299   }
300   if (utils::isa<ValuePtr>(func_graph)) {
301     func_graph = utils::cast<ValuePtr>(func_graph)->cast<FuncGraphPtr>();
302   }
303 
304   if (!utils::isa<FuncGraphPtr>(func_graph)) {
305     MS_LOG(EXCEPTION) << "Graph type error";
306   }
307 
308   auto graphPtr = utils::cast<FuncGraphPtr>(func_graph);
309 
310   if (vars_.find(graphPtr) == vars_.end()) {
311     AcquireGraph(graphPtr);
312   }
313 
314   if (args.size() != graphPtr->parameters().size()) {
315     MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graphPtr->parameters().size() << ", but got "
316                       << args.size();
317   }
318 
319   auto nodes = TopoSort(graphPtr->get_return(), SuccVm(graphPtr));
320   AnfNodePtrToBaseRefMap values;
321   for (size_t i = 0; i < args.size(); i++) {
322     values[graphPtr->parameters()[i]] = args[i];
323   }
324 
325   return std::make_shared<CallWrap>(std::make_shared<VMFrame>(nodes, values, clos));
326 }
327 
328 // make closure out of graph with fv values from frame
MakeClosure(const FuncGraphPtr & graph,const VMFramePtr & frame)329 ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) {
330   MS_EXCEPTION_IF_NULL(frame);
331   AnfNodePtrToBaseRefMap clos;
332 
333   for (const auto &v : utils::cast<SetRef>(vars_[graph])) {
334     auto anf = utils::cast<AnfNodePtr>(v);
335     clos[anf] = (*frame)[anf];
336   }
337 
338   return std::make_shared<Closure>(graph, clos);
339 }
340 
DispatchCall(const AnfNodePtr & node,const VMFramePtr & frame,const BaseRef & fn,const VectorRef & args)341 BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) {
342   if (utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<Primitive>()) {
343     auto fnval = utils::cast<ValuePtr>(fn)->cast<PrimitivePtr>();
344     MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true);
345     if (args.empty()) {
346       MS_LOG(EXCEPTION) << "Args is empty";
347     }
348     if (fnval == prim::kPrimReturn) {
349       MS_LOG(DEBUG) << "Return args:" << args.size();
350       return std::make_shared<ReturnWrap>(args[0]);
351     }
352     MS_EXCEPTION_IF_NULL(frame);
353     if (fnval == prim::kPrimMakeTuple) {
354       frame->values()[node] = args;
355       return BaseRef();
356     }
357 
358     if (fnval == prim::kPrimPartial) {
359       VectorRef partial_args(args.begin() + 1, args.end());
360       frame->values()[node] = (std::make_shared<Partial>(args[0], partial_args, shared_from_this()));
361       return BaseRef();
362     }
363 
364     // call prim implementation
365     frame->values()[node] = RunOperation(fnval, args);
366     return BaseRef();
367   }
368 
369   // partial args logic
370   if (utils::isa<PartialPtr>(fn)) {
371     auto fnPtr = utils::cast<PartialPtr>(fn);
372 
373     VectorRef arglist;
374     (void)arglist.insert(arglist.end(), fnPtr->args().begin(), fnPtr->args().end());
375     (void)arglist.insert(arglist.end(), args.begin(), args.end());
376 
377     auto ret = DispatchCall(node, frame, fnPtr->fn(), arglist);
378     if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
379       return ret;
380     }
381   }
382 
383   // create frame for graph and closure
384   if ((utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<FuncGraph>()) || utils::isa<ClosurePtr>(fn)) {
385     auto ret = _Call(fn, args);
386     if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
387       return ret;
388     }
389   }
390 
391   MS_LOG(EXCEPTION) << "Invalid fn to call";
392 }
393 
HandleNode(const AnfNodePtr & node,const VMFramePtr & frame)394 BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) {
395   MS_EXCEPTION_IF_NULL(node);
396   if (node->isa<Parameter>()) {
397     // pass
398     return BaseRef();
399   }
400 
401   if (node->isa<ValueNode>()) {
402     // We only visit valuenode graphs
403     if (!IsValueNode<FuncGraph>(node)) {
404       MS_LOG(EXCEPTION) << "We only visit valuenode graphs ";
405     }
406     auto g = GetValueNode<FuncGraphPtr>(node);
407     MS_EXCEPTION_IF_NULL(frame);
408     // if g is a graph with fvs, we need to make a closure for it
409     auto iterG = vars_.find(g);
410     if (iterG != vars_.end() && utils::cast<SetRef>(iterG->second).size() != 0) {
411       frame->values()[node] = MakeClosure(g, frame);
412     }
413 
414     return BaseRef();
415   }
416 
417   if (node->isa<CNode>()) {
418     std::vector<BaseRef> fnArgs;
419     auto &inputs = node->cast<CNodePtr>()->inputs();
420     // set args' values in frame
421     (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs),
422                          [&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; });
423     if (fnArgs.empty()) {
424       MS_LOG(EXCEPTION) << "Function arguments is empty";
425     } else {
426       auto args = VectorRef(fnArgs.begin() + 1, fnArgs.end());
427       auto except = DispatchCall(node, frame, fnArgs[0], args);
428       return except;
429     }
430   }
431 
432   MS_LOG(EXCEPTION) << "Unknown node type";
433 }
434 
RunGraph(const FuncGraphPtr & g,const VectorRef & args)435 VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
436   this->manager_ = Manage(g);
437 
438   auto fn = utils::cast<ClosurePtr>(Export(g));
439   auto result = (*fn)(args);
440 
441   if (utils::isa<VectorRef>(result)) {
442     return utils::cast<VectorRef>(result);
443   } else {
444     VectorRef ret({result});
445     return ret;
446   }
447 }
448 
RunOperation(const PrimitivePtr & prim,const VectorRef & args)449 BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
450   MS_EXCEPTION_IF_NULL(prim);
451   MS_LOG(DEBUG) << "Operation start " << prim->name();
452   auto result = prim->RunComputeFunction(args);
453   if (result.is_null()) {
454     result = RunComputeFunctionWithoutPyObj(prim, args);
455   }
456   if (result.is_null()) {
457     return RunComputeFunction(prim, args);
458   }
459   return result;
460 }
461 
462 }  // namespace compile
463 }  // namespace mindspore
464