1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2020 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "vm/vmimpl.h"
20
21 #include <algorithm>
22 #include <exception>
23 #include <vector>
24 #include <memory>
25
26 #include "frontend/operator/ops.h"
27 #include "ir/manager.h"
28 #include "ir/func_graph_cloner.h"
29 #include "utils/convert_utils.h"
30 #include "utils/primitive_utils.h"
31
32 namespace mindspore {
33 namespace compile {
34
35 // Indicate a call to a new frame.
36 struct CallWrap : public Base {
CallWrapmindspore::compile::CallWrap37 explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {}
38 VMFramePtr frame{nullptr};
39 };
40 using CallWrapPtr = std::shared_ptr<CallWrap>;
41
42 // Indicates a return with its value.
43 struct ReturnWrap : public Base {
ReturnWrapmindspore::compile::ReturnWrap44 explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {}
45 BaseRef value{BaseRef()};
46 };
47 using ReturnWrapPtr = std::shared_ptr<ReturnWrap>;
48
VMFrame(const AnfNodePtrList & nodes,const AnfNodePtrToBaseRefMap & values,const AnfNodePtrToBaseRefMap & closure)49 VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values,
50 const AnfNodePtrToBaseRefMap &closure)
51 : values_(values), todo_(nodes), closure_(closure) {
52 std::reverse(std::begin(todo_), std::end(todo_));
53 }
54
operator [](const AnfNodePtr & node)55 const BaseRef VMFrame::operator[](const AnfNodePtr &node) {
56 MS_EXCEPTION_IF_NULL(node);
57 auto ret = values_.find(node);
58 if (ret != values_.end()) {
59 return ret->second;
60 }
61
62 ret = closure_.find(node);
63 if (ret != closure_.end()) {
64 return ret->second;
65 }
66
67 if (node->isa<ValueNode>()) {
68 return GetValueNode(node);
69 }
70
71 MS_LOG(EXCEPTION) << "ValueError " << node->type_name();
72 }
73
Closure(const FuncGraphPtr & graph,const AnfNodePtrToBaseRefMap & values)74 Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values)
75 : func_graph_(graph), values_(values) {}
76
operator ()(const VectorRef & args)77 BaseRef Closure::operator()(const VectorRef &args) {
78 MS_LOG(DEBUG) << "Start closure";
79 MS_EXCEPTION_IF_NULL(vm_);
80 return vm_->Evaluate(func_graph_, args, values_);
81 }
82
Partial(const BaseRef & fn,const VectorRef & args,const VMPtr & vm)83 Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {}
84
operator ()(const VectorRef & nodes)85 BaseRef Partial::operator()(const VectorRef &nodes) {
86 VectorRef arglist;
87 (void)arglist.insert(arglist.end(), args_.begin(), args_.end());
88 (void)arglist.insert(arglist.end(), nodes.begin(), nodes.end());
89 MS_EXCEPTION_IF_NULL(vm_);
90 return vm_->Call(fn_, arglist);
91 }
92
ComputeFvs(const FuncGraphPtr & graph)93 SetRef VM::ComputeFvs(const FuncGraphPtr &graph) {
94 MS_EXCEPTION_IF_NULL(graph);
95 SetRef rval;
96 for (auto &fkv : graph->free_variables_total()) {
97 if (utils::isa<FuncGraphPtr>(fkv.first)) {
98 // Add all value_nodes of g that refer to a fv graph
99 auto g = utils::cast<FuncGraphPtr>(fkv.first);
100 for (auto &ctkv : g->value_nodes()) {
101 auto ct = ctkv.first;
102 if (GetValueNode(ct) == g) {
103 (void)rval.insert(ct);
104 }
105 }
106 } else {
107 // Add a normal fv
108 (void)rval.insert(fkv.first);
109 }
110 }
111
112 return rval;
113 }
114
AcquireGraph(const FuncGraphPtr & graph)115 void VM::AcquireGraph(const FuncGraphPtr &graph) {
116 // Already acquired
117 if (vars_.find(graph) != vars_.end()) {
118 return;
119 }
120 MS_EXCEPTION_IF_NULL(manager_);
121 // Add g to manager
122 manager_->AddFuncGraph(graph);
123 MS_EXCEPTION_IF_NULL(graph->manager());
124 // Compute fvs for all acquired graph
125 auto graphs = graph->manager()->func_graphs();
126 for (auto g = graphs.begin(); g != graphs.end(); ++g) {
127 vars_[*g] = ComputeFvs(*g);
128 }
129 }
130
ExportSequence(const VectorRef & seq)131 VectorRef VM::ExportSequence(const VectorRef &seq) {
132 std::vector<BaseRef> ret;
133 (void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret),
134 [&, this](const BaseRef &x) -> BaseRef { return Export(x); });
135 return VectorRef(ret);
136 }
137
ExportClosure(const ClosurePtr & clos)138 ClosurePtr VM::ExportClosure(const ClosurePtr &clos) {
139 MS_EXCEPTION_IF_NULL(clos);
140 clos->set_vm(shared_from_this());
141 return clos;
142 }
143
144 // transform graph to executable closure
ExportGraph(const FuncGraphPtr & g)145 ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) {
146 auto c = std::make_shared<Closure>(g, AnfNodePtrToBaseRefMap());
147 MS_EXCEPTION_IF_NULL(c);
148 c->set_vm(shared_from_this());
149 return c;
150 }
151
ExportObj(const BaseRef & obj) const152 BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; }
153
Export(const BaseRef & value)154 BaseRef VM::Export(const BaseRef &value) {
155 if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<FuncGraph>()) {
156 return ExportGraph(utils::cast<ValuePtr>(value)->cast<FuncGraphPtr>());
157 }
158
159 if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<Primitive>()) {
160 return ExportPrimitive(utils::cast<ValuePtr>(value)->cast<PrimitivePtr>());
161 }
162
163 if (utils::isa<FuncGraphPtr>(value)) {
164 return ExportGraph(utils::cast<FuncGraphPtr>(value));
165 }
166
167 if (utils::isa<ClosurePtr>(value)) {
168 return ExportClosure(utils::cast<ClosurePtr>(value));
169 }
170
171 if (utils::isa<PrimitivePtr>(value)) {
172 return ExportPrimitive(utils::cast<PrimitivePtr>(value));
173 }
174
175 if (utils::isa<VectorRef>(value)) {
176 return ExportSequence(utils::cast<VectorRef>(value));
177 }
178
179 return ExportObj(value);
180 }
181
182 // Run a graph.
183 // This will evaluate the passed-in graph and return the resulting value.
Evaluate(const FuncGraphPtr & graph,const VectorRef & args,const AnfNodePtrToBaseRefMap & closure)184 BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) {
185 MS_EXCEPTION_IF_NULL(graph);
186 AcquireGraph(graph);
187 MS_LOG(DEBUG) << "Evalue arg size: " << args.size();
188 if (args.size() != graph->parameters().size()) {
189 MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graph->parameters().size() << ", but got "
190 << args.size();
191 }
192
193 // toposort graph nodes, the order will be reversed by frame so that the dependent be computed first
194 auto nodes = TopoSort(graph->get_return(), SuccVm(graph));
195 // mapping parameters to args
196 AnfNodePtrToBaseRefMap values;
197 for (size_t i = 0; i < args.size(); i++) {
198 values[graph->parameters()[i]] = args[i];
199 }
200 // create top frame with params initialized
201 VMFramePtrList frames{std::make_shared<VMFrame>(nodes, values, closure)};
202 // execute frames starting from top frame
203 while (!frames.empty()) {
204 auto frame = frames[frames.size() - 1];
205 MS_EXCEPTION_IF_NULL(frame);
206 auto todo = frame->todo();
207 while (!todo.empty()) {
208 auto except = HandleNode(todo[todo.size() - 1], frame);
209 if (utils::isa<CallWrapPtr>(except)) {
210 if (todo.size() == 2) {
211 // The last element is always a return, replace the ret with call frame
212 frames[frames.size() - 1] = utils::cast<CallWrapPtr>(except)->frame;
213 } else {
214 frames.push_back(utils::cast<CallWrapPtr>(except)->frame);
215 }
216 break;
217 }
218 if (utils::isa<ReturnWrapPtr>(except)) {
219 (void)frames.erase(frames.begin() + (static_cast<ssize_t>(frames.size()) - 1));
220 if (frames.size() > 0) {
221 auto top = frames[frames.size() - 1];
222 MS_EXCEPTION_IF_NULL(top);
223 auto td = top->todo();
224 // set value for top frame's last evaluated node
225 if (td.empty()) {
226 MS_LOG(EXCEPTION) << "The td is empty";
227 }
228 top->values()[td[td.size() - 1]] = utils::cast<ReturnWrapPtr>(except)->value;
229 (void)td.erase(td.begin() + (static_cast<ssize_t>(td.size()) - 1));
230 } else {
231 return Export(utils::cast<ReturnWrapPtr>(except)->value);
232 }
233 break;
234 }
235 (void)todo.erase(todo.begin() + (static_cast<ssize_t>(todo.size()) - 1));
236 }
237 }
238 MS_LOG(EXCEPTION) << "VM Evaluate error";
239 }
240
SuccVm(const FuncGraphPtr & graph)241 SuccFunc VM::SuccVm(const FuncGraphPtr &graph) {
242 auto fn = [&, this](const AnfNodePtr &node) -> AnfNodePtrList {
243 MS_EXCEPTION_IF_NULL(node);
244 AnfNodePtrList ret;
245
246 // Follow node.incoming
247 if (node->isa<CNode>()) {
248 auto &inputs = node->cast<CNodePtr>()->inputs();
249 for (auto &i : inputs) {
250 if (i->func_graph() == node->func_graph() ||
251 (IsValueNode<FuncGraph>(i) && GetValueNode<FuncGraphPtr>(i)->parent() == graph)) {
252 ret.push_back(i);
253 }
254 }
255 }
256
257 // for subgraph input, add their fvs as succ nodes
258 if (IsValueNode<FuncGraph>(node) && GetValueNode<FuncGraphPtr>(node)->parent() == graph) {
259 auto fvs = utils::cast<SetRef>(vars_[GetValueNode<FuncGraphPtr>(node)]);
260 (void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(ret),
261 [](const BaseRef &value) -> AnfNodePtr { return utils::cast<AnfNodePtr>(value); });
262 }
263
264 return ret;
265 };
266 return fn;
267 }
268
Call(const BaseRef & fn,const VectorRef & args)269 BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) {
270 if (utils::isa<PrimitivePtr>(fn)) {
271 return RunOperation(utils::cast<PrimitivePtr>(fn), args);
272 }
273
274 if (utils::isa<FuncGraphPtr>(fn)) {
275 return Evaluate(utils::cast<FuncGraphPtr>(fn), args);
276 }
277
278 if (utils::isa<ClosurePtr>(fn)) {
279 auto clos = utils::cast<ClosurePtr>(fn);
280 return Evaluate(clos->func_graph(), args, clos->values());
281 }
282
283 MS_LOG(EXCEPTION) << "Can't call fn";
284 }
285
286 // make call frame for graph
_Call(const BaseRef & graph,const VectorRef & args)287 BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) {
288 AnfNodePtrToBaseRefMap clos;
289 auto func_graph = graph;
290 if (utils::isa<ClosurePtr>(func_graph)) {
291 clos = utils::cast<ClosurePtr>(func_graph)->values();
292 func_graph = utils::cast<ClosurePtr>(func_graph)->func_graph();
293 }
294 if (utils::isa<ValuePtr>(func_graph)) {
295 func_graph = utils::cast<ValuePtr>(func_graph)->cast<FuncGraphPtr>();
296 }
297
298 if (!utils::isa<FuncGraphPtr>(func_graph)) {
299 MS_LOG(EXCEPTION) << "Graph type error";
300 }
301
302 auto graphPtr = utils::cast<FuncGraphPtr>(func_graph);
303
304 if (vars_.find(graphPtr) == vars_.end()) {
305 AcquireGraph(graphPtr);
306 }
307
308 if (args.size() != graphPtr->parameters().size()) {
309 MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graphPtr->parameters().size() << ", but got "
310 << args.size();
311 }
312
313 auto nodes = TopoSort(graphPtr->get_return(), SuccVm(graphPtr));
314 AnfNodePtrToBaseRefMap values;
315 for (size_t i = 0; i < args.size(); i++) {
316 values[graphPtr->parameters()[i]] = args[i];
317 }
318
319 return std::make_shared<CallWrap>(std::make_shared<VMFrame>(nodes, values, clos));
320 }
321
322 // make closure out of graph with fv values from frame
MakeClosure(const FuncGraphPtr & graph,const VMFramePtr & frame)323 ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) {
324 MS_EXCEPTION_IF_NULL(frame);
325 AnfNodePtrToBaseRefMap clos;
326
327 for (auto &v : utils::cast<SetRef>(vars_[graph])) {
328 auto anf = utils::cast<AnfNodePtr>(v);
329 clos[anf] = (*frame)[anf];
330 }
331
332 return std::make_shared<Closure>(graph, clos);
333 }
334
DispatchCall(const AnfNodePtr & node,const VMFramePtr & frame,const BaseRef & fn,const VectorRef & args)335 BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) {
336 if (utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<Primitive>()) {
337 auto fnval = utils::cast<ValuePtr>(fn)->cast<PrimitivePtr>();
338 MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true);
339 if (args.empty()) {
340 MS_LOG(EXCEPTION) << "Args is empty";
341 }
342 if (fnval == prim::kPrimReturn) {
343 MS_LOG(DEBUG) << "Return args:" << args.size();
344 return std::make_shared<ReturnWrap>(args[0]);
345 }
346 MS_EXCEPTION_IF_NULL(frame);
347 if (fnval == prim::kPrimMakeTuple) {
348 frame->values()[node] = args;
349 return BaseRef();
350 }
351
352 if (fnval == prim::kPrimPartial) {
353 VectorRef partial_args(args.begin() + 1, args.end());
354 frame->values()[node] = (std::make_shared<Partial>(args[0], partial_args, shared_from_this()));
355 return BaseRef();
356 }
357
358 // call prim implementation
359 frame->values()[node] = RunOperation(fnval, args);
360 return BaseRef();
361 }
362
363 // partial args logic
364 if (utils::isa<PartialPtr>(fn)) {
365 auto fnPtr = utils::cast<PartialPtr>(fn);
366
367 VectorRef arglist;
368 (void)arglist.insert(arglist.end(), fnPtr->args().begin(), fnPtr->args().end());
369 (void)arglist.insert(arglist.end(), args.begin(), args.end());
370
371 auto ret = DispatchCall(node, frame, fnPtr->fn(), arglist);
372 if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
373 return ret;
374 }
375 }
376
377 // create frame for graph and closure
378 if ((utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<FuncGraph>()) || utils::isa<ClosurePtr>(fn)) {
379 auto ret = _Call(fn, args);
380 if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
381 return ret;
382 }
383 }
384
385 MS_LOG(EXCEPTION) << "Invalid fn to call";
386 }
387
HandleNode(const AnfNodePtr & node,const VMFramePtr & frame)388 BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) {
389 MS_EXCEPTION_IF_NULL(node);
390 if (node->isa<Parameter>()) {
391 // pass
392 return BaseRef();
393 }
394
395 if (node->isa<ValueNode>()) {
396 // We only visit valuenode graphs
397 if (!IsValueNode<FuncGraph>(node)) {
398 MS_LOG(EXCEPTION) << "We only visit valuenode graphs ";
399 }
400 auto g = GetValueNode<FuncGraphPtr>(node);
401 MS_EXCEPTION_IF_NULL(frame);
402 // if g is a graph with fvs, we need to make a closure for it
403 auto iterG = vars_.find(g);
404 if (iterG != vars_.end() && utils::cast<SetRef>(iterG->second).size() != 0) {
405 frame->values()[node] = MakeClosure(g, frame);
406 }
407
408 return BaseRef();
409 }
410
411 if (node->isa<CNode>()) {
412 std::vector<BaseRef> fnArgs;
413 auto &inputs = node->cast<CNodePtr>()->inputs();
414 // set args' values in frame
415 (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs),
416 [&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; });
417 if (fnArgs.empty()) {
418 MS_LOG(EXCEPTION) << "Function arguments is empty";
419 } else {
420 auto args = VectorRef(fnArgs.begin() + 1, fnArgs.end());
421 auto except = DispatchCall(node, frame, fnArgs[0], args);
422 return except;
423 }
424 }
425
426 MS_LOG(EXCEPTION) << "Unknown node type";
427 }
428
RunGraph(const FuncGraphPtr & g,const VectorRef & args)429 VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
430 this->manager_ = Manage(g);
431
432 auto fn = utils::cast<ClosurePtr>(Export(g));
433 auto result = (*fn)(args);
434
435 if (utils::isa<VectorRef>(result)) {
436 return utils::cast<VectorRef>(result);
437 } else {
438 VectorRef ret({result});
439 return ret;
440 }
441 }
442
RunOperation(const PrimitivePtr & prim,const VectorRef & args)443 BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
444 MS_LOG(DEBUG) << "Operation start " << prim->name();
445 MS_EXCEPTION_IF_NULL(prim);
446 auto result = prim->RunComputeFunction(args);
447 if (result.is_null()) {
448 result = RunComputeFunctionWithoutPyObj(prim, args);
449 }
450 if (result.is_null()) {
451 return RunComputeFunction(prim, args);
452 }
453 return result;
454 }
455
456 } // namespace compile
457 } // namespace mindspore
458