• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "vm/vmimpl.h"
20 
21 #include <algorithm>
22 #include <exception>
23 #include <vector>
24 #include <memory>
25 
26 #include "frontend/operator/ops.h"
27 #include "ir/manager.h"
28 #include "ir/func_graph_cloner.h"
29 #include "utils/convert_utils.h"
30 #include "utils/primitive_utils.h"
31 
32 namespace mindspore {
33 namespace compile {
34 
35 // Indicate a call to a new frame.
36 struct CallWrap : public Base {
CallWrapmindspore::compile::CallWrap37   explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {}
38   VMFramePtr frame{nullptr};
39 };
40 using CallWrapPtr = std::shared_ptr<CallWrap>;
41 
42 // Indicates a return with its value.
43 struct ReturnWrap : public Base {
ReturnWrapmindspore::compile::ReturnWrap44   explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {}
45   BaseRef value{BaseRef()};
46 };
47 using ReturnWrapPtr = std::shared_ptr<ReturnWrap>;
48 
VMFrame(const AnfNodePtrList & nodes,const AnfNodePtrToBaseRefMap & values,const AnfNodePtrToBaseRefMap & closure)49 VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values,
50                  const AnfNodePtrToBaseRefMap &closure)
51     : values_(values), todo_(nodes), closure_(closure) {
52   std::reverse(std::begin(todo_), std::end(todo_));
53 }
54 
operator [](const AnfNodePtr & node)55 const BaseRef VMFrame::operator[](const AnfNodePtr &node) {
56   MS_EXCEPTION_IF_NULL(node);
57   auto ret = values_.find(node);
58   if (ret != values_.end()) {
59     return ret->second;
60   }
61 
62   ret = closure_.find(node);
63   if (ret != closure_.end()) {
64     return ret->second;
65   }
66 
67   if (node->isa<ValueNode>()) {
68     return GetValueNode(node);
69   }
70 
71   MS_LOG(EXCEPTION) << "ValueError " << node->type_name();
72 }
73 
Closure(const FuncGraphPtr & graph,const AnfNodePtrToBaseRefMap & values)74 Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values)
75     : func_graph_(graph), values_(values) {}
76 
operator ()(const VectorRef & args)77 BaseRef Closure::operator()(const VectorRef &args) {
78   MS_LOG(DEBUG) << "Start closure";
79   MS_EXCEPTION_IF_NULL(vm_);
80   return vm_->Evaluate(func_graph_, args, values_);
81 }
82 
Partial(const BaseRef & fn,const VectorRef & args,const VMPtr & vm)83 Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {}
84 
operator ()(const VectorRef & nodes)85 BaseRef Partial::operator()(const VectorRef &nodes) {
86   VectorRef arglist;
87   (void)arglist.insert(arglist.end(), args_.begin(), args_.end());
88   (void)arglist.insert(arglist.end(), nodes.begin(), nodes.end());
89   MS_EXCEPTION_IF_NULL(vm_);
90   return vm_->Call(fn_, arglist);
91 }
92 
ComputeFvs(const FuncGraphPtr & graph)93 SetRef VM::ComputeFvs(const FuncGraphPtr &graph) {
94   MS_EXCEPTION_IF_NULL(graph);
95   SetRef rval;
96   for (auto &fkv : graph->free_variables_total()) {
97     if (utils::isa<FuncGraphPtr>(fkv.first)) {
98       // Add all value_nodes of g that refer to a fv graph
99       auto g = utils::cast<FuncGraphPtr>(fkv.first);
100       for (auto &ctkv : g->value_nodes()) {
101         auto ct = ctkv.first;
102         if (GetValueNode(ct) == g) {
103           (void)rval.insert(ct);
104         }
105       }
106     } else {
107       // Add a normal fv
108       (void)rval.insert(fkv.first);
109     }
110   }
111 
112   return rval;
113 }
114 
AcquireGraph(const FuncGraphPtr & graph)115 void VM::AcquireGraph(const FuncGraphPtr &graph) {
116   // Already acquired
117   if (vars_.find(graph) != vars_.end()) {
118     return;
119   }
120   MS_EXCEPTION_IF_NULL(manager_);
121   // Add g to manager
122   manager_->AddFuncGraph(graph);
123   MS_EXCEPTION_IF_NULL(graph->manager());
124   // Compute fvs for all acquired graph
125   auto graphs = graph->manager()->func_graphs();
126   for (auto g = graphs.begin(); g != graphs.end(); ++g) {
127     vars_[*g] = ComputeFvs(*g);
128   }
129 }
130 
ExportSequence(const VectorRef & seq)131 VectorRef VM::ExportSequence(const VectorRef &seq) {
132   std::vector<BaseRef> ret;
133   (void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret),
134                        [&, this](const BaseRef &x) -> BaseRef { return Export(x); });
135   return VectorRef(ret);
136 }
137 
ExportClosure(const ClosurePtr & clos)138 ClosurePtr VM::ExportClosure(const ClosurePtr &clos) {
139   MS_EXCEPTION_IF_NULL(clos);
140   clos->set_vm(shared_from_this());
141   return clos;
142 }
143 
144 // transform graph to executable closure
ExportGraph(const FuncGraphPtr & g)145 ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) {
146   auto c = std::make_shared<Closure>(g, AnfNodePtrToBaseRefMap());
147   MS_EXCEPTION_IF_NULL(c);
148   c->set_vm(shared_from_this());
149   return c;
150 }
151 
ExportObj(const BaseRef & obj) const152 BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; }
153 
Export(const BaseRef & value)154 BaseRef VM::Export(const BaseRef &value) {
155   if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<FuncGraph>()) {
156     return ExportGraph(utils::cast<ValuePtr>(value)->cast<FuncGraphPtr>());
157   }
158 
159   if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<Primitive>()) {
160     return ExportPrimitive(utils::cast<ValuePtr>(value)->cast<PrimitivePtr>());
161   }
162 
163   if (utils::isa<FuncGraphPtr>(value)) {
164     return ExportGraph(utils::cast<FuncGraphPtr>(value));
165   }
166 
167   if (utils::isa<ClosurePtr>(value)) {
168     return ExportClosure(utils::cast<ClosurePtr>(value));
169   }
170 
171   if (utils::isa<PrimitivePtr>(value)) {
172     return ExportPrimitive(utils::cast<PrimitivePtr>(value));
173   }
174 
175   if (utils::isa<VectorRef>(value)) {
176     return ExportSequence(utils::cast<VectorRef>(value));
177   }
178 
179   return ExportObj(value);
180 }
181 
182 // Run a graph.
183 // This will evaluate the passed-in graph and return the resulting value.
Evaluate(const FuncGraphPtr & graph,const VectorRef & args,const AnfNodePtrToBaseRefMap & closure)184 BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) {
185   MS_EXCEPTION_IF_NULL(graph);
186   AcquireGraph(graph);
187   MS_LOG(DEBUG) << "Evalue arg size: " << args.size();
188   if (args.size() != graph->parameters().size()) {
189     MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graph->parameters().size() << ", but got "
190                       << args.size();
191   }
192 
193   // toposort graph nodes, the order will be reversed by frame so that the dependent be computed first
194   auto nodes = TopoSort(graph->get_return(), SuccVm(graph));
195   // mapping parameters to args
196   AnfNodePtrToBaseRefMap values;
197   for (size_t i = 0; i < args.size(); i++) {
198     values[graph->parameters()[i]] = args[i];
199   }
200   // create top frame with params initialized
201   VMFramePtrList frames{std::make_shared<VMFrame>(nodes, values, closure)};
202   // execute frames starting from top frame
203   while (!frames.empty()) {
204     auto frame = frames[frames.size() - 1];
205     MS_EXCEPTION_IF_NULL(frame);
206     auto todo = frame->todo();
207     while (!todo.empty()) {
208       auto except = HandleNode(todo[todo.size() - 1], frame);
209       if (utils::isa<CallWrapPtr>(except)) {
210         if (todo.size() == 2) {
211           // The last element is always a return, replace the ret with call frame
212           frames[frames.size() - 1] = utils::cast<CallWrapPtr>(except)->frame;
213         } else {
214           frames.push_back(utils::cast<CallWrapPtr>(except)->frame);
215         }
216         break;
217       }
218       if (utils::isa<ReturnWrapPtr>(except)) {
219         (void)frames.erase(frames.begin() + (static_cast<ssize_t>(frames.size()) - 1));
220         if (frames.size() > 0) {
221           auto top = frames[frames.size() - 1];
222           MS_EXCEPTION_IF_NULL(top);
223           auto td = top->todo();
224           // set value for top frame's last evaluated node
225           if (td.empty()) {
226             MS_LOG(EXCEPTION) << "The td is empty";
227           }
228           top->values()[td[td.size() - 1]] = utils::cast<ReturnWrapPtr>(except)->value;
229           (void)td.erase(td.begin() + (static_cast<ssize_t>(td.size()) - 1));
230         } else {
231           return Export(utils::cast<ReturnWrapPtr>(except)->value);
232         }
233         break;
234       }
235       (void)todo.erase(todo.begin() + (static_cast<ssize_t>(todo.size()) - 1));
236     }
237   }
238   MS_LOG(EXCEPTION) << "VM Evaluate error";
239 }
240 
SuccVm(const FuncGraphPtr & graph)241 SuccFunc VM::SuccVm(const FuncGraphPtr &graph) {
242   auto fn = [&, this](const AnfNodePtr &node) -> AnfNodePtrList {
243     MS_EXCEPTION_IF_NULL(node);
244     AnfNodePtrList ret;
245 
246     // Follow node.incoming
247     if (node->isa<CNode>()) {
248       auto &inputs = node->cast<CNodePtr>()->inputs();
249       for (auto &i : inputs) {
250         if (i->func_graph() == node->func_graph() ||
251             (IsValueNode<FuncGraph>(i) && GetValueNode<FuncGraphPtr>(i)->parent() == graph)) {
252           ret.push_back(i);
253         }
254       }
255     }
256 
257     // for subgraph input, add their fvs as succ nodes
258     if (IsValueNode<FuncGraph>(node) && GetValueNode<FuncGraphPtr>(node)->parent() == graph) {
259       auto fvs = utils::cast<SetRef>(vars_[GetValueNode<FuncGraphPtr>(node)]);
260       (void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(ret),
261                            [](const BaseRef &value) -> AnfNodePtr { return utils::cast<AnfNodePtr>(value); });
262     }
263 
264     return ret;
265   };
266   return fn;
267 }
268 
Call(const BaseRef & fn,const VectorRef & args)269 BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) {
270   if (utils::isa<PrimitivePtr>(fn)) {
271     return RunOperation(utils::cast<PrimitivePtr>(fn), args);
272   }
273 
274   if (utils::isa<FuncGraphPtr>(fn)) {
275     return Evaluate(utils::cast<FuncGraphPtr>(fn), args);
276   }
277 
278   if (utils::isa<ClosurePtr>(fn)) {
279     auto clos = utils::cast<ClosurePtr>(fn);
280     return Evaluate(clos->func_graph(), args, clos->values());
281   }
282 
283   MS_LOG(EXCEPTION) << "Can't call fn";
284 }
285 
286 // make call frame for graph
_Call(const BaseRef & graph,const VectorRef & args)287 BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) {
288   AnfNodePtrToBaseRefMap clos;
289   auto func_graph = graph;
290   if (utils::isa<ClosurePtr>(func_graph)) {
291     clos = utils::cast<ClosurePtr>(func_graph)->values();
292     func_graph = utils::cast<ClosurePtr>(func_graph)->func_graph();
293   }
294   if (utils::isa<ValuePtr>(func_graph)) {
295     func_graph = utils::cast<ValuePtr>(func_graph)->cast<FuncGraphPtr>();
296   }
297 
298   if (!utils::isa<FuncGraphPtr>(func_graph)) {
299     MS_LOG(EXCEPTION) << "Graph type error";
300   }
301 
302   auto graphPtr = utils::cast<FuncGraphPtr>(func_graph);
303 
304   if (vars_.find(graphPtr) == vars_.end()) {
305     AcquireGraph(graphPtr);
306   }
307 
308   if (args.size() != graphPtr->parameters().size()) {
309     MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graphPtr->parameters().size() << ", but got "
310                       << args.size();
311   }
312 
313   auto nodes = TopoSort(graphPtr->get_return(), SuccVm(graphPtr));
314   AnfNodePtrToBaseRefMap values;
315   for (size_t i = 0; i < args.size(); i++) {
316     values[graphPtr->parameters()[i]] = args[i];
317   }
318 
319   return std::make_shared<CallWrap>(std::make_shared<VMFrame>(nodes, values, clos));
320 }
321 
322 // make closure out of graph with fv values from frame
MakeClosure(const FuncGraphPtr & graph,const VMFramePtr & frame)323 ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) {
324   MS_EXCEPTION_IF_NULL(frame);
325   AnfNodePtrToBaseRefMap clos;
326 
327   for (auto &v : utils::cast<SetRef>(vars_[graph])) {
328     auto anf = utils::cast<AnfNodePtr>(v);
329     clos[anf] = (*frame)[anf];
330   }
331 
332   return std::make_shared<Closure>(graph, clos);
333 }
334 
DispatchCall(const AnfNodePtr & node,const VMFramePtr & frame,const BaseRef & fn,const VectorRef & args)335 BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) {
336   if (utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<Primitive>()) {
337     auto fnval = utils::cast<ValuePtr>(fn)->cast<PrimitivePtr>();
338     MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true);
339     if (args.empty()) {
340       MS_LOG(EXCEPTION) << "Args is empty";
341     }
342     if (fnval == prim::kPrimReturn) {
343       MS_LOG(DEBUG) << "Return args:" << args.size();
344       return std::make_shared<ReturnWrap>(args[0]);
345     }
346     MS_EXCEPTION_IF_NULL(frame);
347     if (fnval == prim::kPrimMakeTuple) {
348       frame->values()[node] = args;
349       return BaseRef();
350     }
351 
352     if (fnval == prim::kPrimPartial) {
353       VectorRef partial_args(args.begin() + 1, args.end());
354       frame->values()[node] = (std::make_shared<Partial>(args[0], partial_args, shared_from_this()));
355       return BaseRef();
356     }
357 
358     // call prim implementation
359     frame->values()[node] = RunOperation(fnval, args);
360     return BaseRef();
361   }
362 
363   // partial args logic
364   if (utils::isa<PartialPtr>(fn)) {
365     auto fnPtr = utils::cast<PartialPtr>(fn);
366 
367     VectorRef arglist;
368     (void)arglist.insert(arglist.end(), fnPtr->args().begin(), fnPtr->args().end());
369     (void)arglist.insert(arglist.end(), args.begin(), args.end());
370 
371     auto ret = DispatchCall(node, frame, fnPtr->fn(), arglist);
372     if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
373       return ret;
374     }
375   }
376 
377   // create frame for graph and closure
378   if ((utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<FuncGraph>()) || utils::isa<ClosurePtr>(fn)) {
379     auto ret = _Call(fn, args);
380     if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
381       return ret;
382     }
383   }
384 
385   MS_LOG(EXCEPTION) << "Invalid fn to call";
386 }
387 
HandleNode(const AnfNodePtr & node,const VMFramePtr & frame)388 BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) {
389   MS_EXCEPTION_IF_NULL(node);
390   if (node->isa<Parameter>()) {
391     // pass
392     return BaseRef();
393   }
394 
395   if (node->isa<ValueNode>()) {
396     // We only visit valuenode graphs
397     if (!IsValueNode<FuncGraph>(node)) {
398       MS_LOG(EXCEPTION) << "We only visit valuenode graphs ";
399     }
400     auto g = GetValueNode<FuncGraphPtr>(node);
401     MS_EXCEPTION_IF_NULL(frame);
402     // if g is a graph with fvs, we need to make a closure for it
403     auto iterG = vars_.find(g);
404     if (iterG != vars_.end() && utils::cast<SetRef>(iterG->second).size() != 0) {
405       frame->values()[node] = MakeClosure(g, frame);
406     }
407 
408     return BaseRef();
409   }
410 
411   if (node->isa<CNode>()) {
412     std::vector<BaseRef> fnArgs;
413     auto &inputs = node->cast<CNodePtr>()->inputs();
414     // set args' values in frame
415     (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs),
416                          [&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; });
417     if (fnArgs.empty()) {
418       MS_LOG(EXCEPTION) << "Function arguments is empty";
419     } else {
420       auto args = VectorRef(fnArgs.begin() + 1, fnArgs.end());
421       auto except = DispatchCall(node, frame, fnArgs[0], args);
422       return except;
423     }
424   }
425 
426   MS_LOG(EXCEPTION) << "Unknown node type";
427 }
428 
RunGraph(const FuncGraphPtr & g,const VectorRef & args)429 VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
430   this->manager_ = Manage(g);
431 
432   auto fn = utils::cast<ClosurePtr>(Export(g));
433   auto result = (*fn)(args);
434 
435   if (utils::isa<VectorRef>(result)) {
436     return utils::cast<VectorRef>(result);
437   } else {
438     VectorRef ret({result});
439     return ret;
440   }
441 }
442 
RunOperation(const PrimitivePtr & prim,const VectorRef & args)443 BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
444   MS_LOG(DEBUG) << "Operation start " << prim->name();
445   MS_EXCEPTION_IF_NULL(prim);
446   auto result = prim->RunComputeFunction(args);
447   if (result.is_null()) {
448     result = RunComputeFunctionWithoutPyObj(prim, args);
449   }
450   if (result.is_null()) {
451     return RunComputeFunction(prim, args);
452   }
453   return result;
454 }
455 
456 }  // namespace compile
457 }  // namespace mindspore
458