1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "backend/graph_compiler/vmimpl.h"
20
21 #include <algorithm>
22 #include <exception>
23 #include <vector>
24 #include <memory>
25
26 #include "mindspore/core/ops/sequence_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "frontend/operator/ops.h"
29 #include "ir/manager.h"
30 #include "ir/func_graph_cloner.h"
31 #include "include/common/utils/convert_utils.h"
32 #include "include/common/utils/primitive_utils.h"
33
34 namespace mindspore {
35 namespace compile {
36
37 // Indicate a call to a new frame.
38 struct CallWrap final : public Base {
CallWrapmindspore::compile::CallWrap39 explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {}
40 MS_DECLARE_PARENT(CallWrap, Base);
41 VMFramePtr frame{nullptr};
42 };
43 using CallWrapPtr = std::shared_ptr<CallWrap>;
44
45 // Indicates a return with its value.
46 struct ReturnWrap final : public Base {
ReturnWrapmindspore::compile::ReturnWrap47 explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {}
48 MS_DECLARE_PARENT(ReturnWrap, Base);
49 BaseRef value{BaseRef()};
50 };
51 using ReturnWrapPtr = std::shared_ptr<ReturnWrap>;
52
VMFrame(const AnfNodePtrList & nodes,const AnfNodePtrToBaseRefMap & values,const AnfNodePtrToBaseRefMap & closure)53 VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values,
54 const AnfNodePtrToBaseRefMap &closure)
55 : values_(values), todo_(nodes), closure_(closure) {
56 std::reverse(std::begin(todo_), std::end(todo_));
57 }
58
operator [](const AnfNodePtr & node)59 const BaseRef VMFrame::operator[](const AnfNodePtr &node) {
60 MS_EXCEPTION_IF_NULL(node);
61 auto ret = values_.find(node);
62 if (ret != values_.end()) {
63 return ret->second;
64 }
65
66 ret = closure_.find(node);
67 if (ret != closure_.end()) {
68 return ret->second;
69 }
70
71 if (node->isa<ValueNode>()) {
72 return GetValueNode(node);
73 }
74
75 MS_LOG(EXCEPTION) << "ValueError " << node->type_name();
76 }
77
Closure(const FuncGraphPtr & graph,const AnfNodePtrToBaseRefMap & values)78 Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values)
79 : func_graph_(graph), values_(values) {}
80
operator ()(const VectorRef & args)81 BaseRef Closure::operator()(const VectorRef &args) {
82 MS_LOG(DEBUG) << "Start closure";
83 MS_EXCEPTION_IF_NULL(vm_);
84 return vm_->Evaluate(func_graph_, args, values_);
85 }
86
Partial(const BaseRef & fn,const VectorRef & args,const VMPtr & vm)87 Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {}
88
operator ()(const VectorRef & nodes)89 BaseRef Partial::operator()(const VectorRef &nodes) {
90 VectorRef arglist;
91 (void)arglist.insert(arglist.end(), args_.begin(), args_.end());
92 (void)arglist.insert(arglist.end(), nodes.begin(), nodes.end());
93 MS_EXCEPTION_IF_NULL(vm_);
94 return vm_->Call(fn_, arglist);
95 }
96
ComputeFvs(const FuncGraphPtr & graph) const97 SetRef VM::ComputeFvs(const FuncGraphPtr &graph) const {
98 MS_EXCEPTION_IF_NULL(graph);
99 SetRef rval;
100 for (auto &fkv : graph->free_variables_total()) {
101 if (utils::isa<FuncGraphPtr>(fkv.first)) {
102 // Add all value_nodes of g that refer to a fv graph
103 auto g = utils::cast<FuncGraphPtr>(fkv.first);
104 for (auto &ctkv : g->value_nodes()) {
105 auto ct = ctkv.first;
106 if (GetValueNode(ct) == g) {
107 (void)rval.insert(ct);
108 }
109 }
110 } else {
111 // Add a normal fv
112 (void)rval.insert(fkv.first);
113 }
114 }
115
116 return rval;
117 }
118
AcquireGraph(const FuncGraphPtr & graph)119 void VM::AcquireGraph(const FuncGraphPtr &graph) {
120 // Already acquired
121 if (vars_.find(graph) != vars_.end()) {
122 return;
123 }
124 MS_EXCEPTION_IF_NULL(manager_);
125 // Add g to manager
126 manager_->AddFuncGraph(graph);
127 MS_EXCEPTION_IF_NULL(graph->manager());
128 // Compute fvs for all acquired graph
129 auto graphs = graph->manager()->func_graphs();
130 for (auto g = graphs.begin(); g != graphs.end(); ++g) {
131 vars_[*g] = ComputeFvs(*g);
132 }
133 }
134
ExportSequence(const VectorRef & seq)135 VectorRef VM::ExportSequence(const VectorRef &seq) {
136 std::vector<BaseRef> ret;
137 (void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret),
138 [&, this](const BaseRef &x) -> BaseRef { return Export(x); });
139 return VectorRef(ret);
140 }
141
ExportClosure(const ClosurePtr & clos)142 ClosurePtr VM::ExportClosure(const ClosurePtr &clos) {
143 MS_EXCEPTION_IF_NULL(clos);
144 clos->set_vm(shared_from_this());
145 return clos;
146 }
147
148 // transform graph to executable closure
ExportGraph(const FuncGraphPtr & g)149 ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) {
150 auto c = std::make_shared<Closure>(g, AnfNodePtrToBaseRefMap());
151 MS_EXCEPTION_IF_NULL(c);
152 c->set_vm(shared_from_this());
153 return c;
154 }
155
ExportObj(const BaseRef & obj) const156 BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; }
157
Export(const BaseRef & value)158 BaseRef VM::Export(const BaseRef &value) {
159 if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<FuncGraph>()) {
160 return ExportGraph(utils::cast<ValuePtr>(value)->cast<FuncGraphPtr>());
161 }
162
163 if (utils::isa<ValuePtr>(value) && utils::cast<ValuePtr>(value)->isa<Primitive>()) {
164 return ExportPrimitive(utils::cast<ValuePtr>(value)->cast<PrimitivePtr>());
165 }
166
167 if (utils::isa<FuncGraphPtr>(value)) {
168 return ExportGraph(utils::cast<FuncGraphPtr>(value));
169 }
170
171 if (utils::isa<ClosurePtr>(value)) {
172 return ExportClosure(utils::cast<ClosurePtr>(value));
173 }
174
175 if (utils::isa<PrimitivePtr>(value)) {
176 return ExportPrimitive(utils::cast<PrimitivePtr>(value));
177 }
178
179 if (utils::isa<VectorRef>(value)) {
180 return ExportSequence(utils::cast<VectorRef>(value));
181 }
182
183 return ExportObj(value);
184 }
185
186 // Run a graph.
187 // This will evaluate the passed-in graph and return the resulting value.
Evaluate(const FuncGraphPtr & graph,const VectorRef & args,const AnfNodePtrToBaseRefMap & closure)188 BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) {
189 MS_EXCEPTION_IF_NULL(graph);
190 AcquireGraph(graph);
191 MS_LOG(DEBUG) << "Evalue arg size: " << args.size();
192 if (args.size() != graph->parameters().size()) {
193 MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graph->parameters().size() << ", but got "
194 << args.size();
195 }
196
197 // toposort graph nodes, the order will be reversed by frame so that the dependent be computed first
198 auto nodes = TopoSort(graph->get_return(), SuccVm(graph));
199 // mapping parameters to args
200 AnfNodePtrToBaseRefMap values;
201 for (size_t i = 0; i < args.size(); i++) {
202 values[graph->parameters()[i]] = args[i];
203 }
204 // create top frame with params initialized
205 VMFramePtrList frames{std::make_shared<VMFrame>(nodes, values, closure)};
206 // execute frames starting from top frame
207 while (!frames.empty()) {
208 auto frame = frames[frames.size() - 1];
209 MS_EXCEPTION_IF_NULL(frame);
210 auto todo = frame->todo();
211 while (!todo.empty()) {
212 auto except = HandleNode(todo[todo.size() - 1], frame);
213 if (utils::isa<CallWrapPtr>(except)) {
214 if (todo.size() == 2) {
215 // The last element is always a return, replace the ret with call frame
216 frames[frames.size() - 1] = utils::cast<CallWrapPtr>(except)->frame;
217 } else {
218 frames.push_back(utils::cast<CallWrapPtr>(except)->frame);
219 }
220 break;
221 }
222 if (utils::isa<ReturnWrapPtr>(except)) {
223 (void)frames.erase(frames.cbegin() + (static_cast<ssize_t>(frames.size()) - 1));
224 if (frames.size() > 0) {
225 auto top = frames[frames.size() - 1];
226 MS_EXCEPTION_IF_NULL(top);
227 auto td = top->todo();
228 // set value for top frame's last evaluated node
229 if (td.empty()) {
230 MS_LOG(EXCEPTION) << "The td is empty";
231 }
232 top->values()[td[td.size() - 1]] = utils::cast<ReturnWrapPtr>(except)->value;
233 (void)td.erase(td.cbegin() + (static_cast<ssize_t>(td.size()) - 1));
234 } else {
235 return Export(utils::cast<ReturnWrapPtr>(except)->value);
236 }
237 break;
238 }
239 (void)todo.erase(todo.cbegin() + (static_cast<ssize_t>(todo.size()) - 1));
240 }
241 }
242 MS_LOG(EXCEPTION) << "VM Evaluate error";
243 }
244
SuccVm(const FuncGraphPtr & graph)245 SuccFunc VM::SuccVm(const FuncGraphPtr &graph) {
246 auto fn = [&, this](const AnfNodePtr &node) -> AnfNodeWeakPtrList {
247 MS_EXCEPTION_IF_NULL(node);
248 AnfNodeWeakPtrList res;
249
250 // Follow node.incoming
251 if (node->isa<CNode>()) {
252 auto &inputs = node->cast<CNodePtr>()->weak_inputs();
253 for (auto &weak_input : inputs) {
254 auto i = weak_input.lock();
255 MS_EXCEPTION_IF_NULL(i);
256 if (i->func_graph() == node->func_graph() ||
257 (IsValueNode<FuncGraph>(i) && GetValueNode<FuncGraphPtr>(i)->parent() == graph)) {
258 res.push_back(i);
259 }
260 }
261 }
262
263 // for subgraph input, add their fvs as succ nodes
264 if (IsValueNode<FuncGraph>(node) && GetValueNode<FuncGraphPtr>(node)->parent() == graph) {
265 auto fvs = utils::cast<SetRef>(vars_[GetValueNode<FuncGraphPtr>(node)]);
266 (void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(res),
267 [](const BaseRef &value) -> AnfNodePtr { return utils::cast<AnfNodePtr>(value); });
268 }
269
270 return res;
271 };
272 return fn;
273 }
274
Call(const BaseRef & fn,const VectorRef & args)275 BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) {
276 if (utils::isa<PrimitivePtr>(fn)) {
277 return RunOperation(utils::cast<PrimitivePtr>(fn), args);
278 }
279
280 if (utils::isa<FuncGraphPtr>(fn)) {
281 return Evaluate(utils::cast<FuncGraphPtr>(fn), args);
282 }
283
284 if (utils::isa<ClosurePtr>(fn)) {
285 auto clos = utils::cast<ClosurePtr>(fn);
286 return Evaluate(clos->func_graph(), args, clos->values());
287 }
288
289 MS_LOG(EXCEPTION) << "Can't call fn";
290 }
291
292 // make call frame for graph
_Call(const BaseRef & graph,const VectorRef & args)293 BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) {
294 AnfNodePtrToBaseRefMap clos;
295 auto func_graph = graph;
296 if (utils::isa<ClosurePtr>(func_graph)) {
297 clos = utils::cast<ClosurePtr>(func_graph)->values();
298 func_graph = utils::cast<ClosurePtr>(func_graph)->func_graph();
299 }
300 if (utils::isa<ValuePtr>(func_graph)) {
301 func_graph = utils::cast<ValuePtr>(func_graph)->cast<FuncGraphPtr>();
302 }
303
304 if (!utils::isa<FuncGraphPtr>(func_graph)) {
305 MS_LOG(EXCEPTION) << "Graph type error";
306 }
307
308 auto graphPtr = utils::cast<FuncGraphPtr>(func_graph);
309
310 if (vars_.find(graphPtr) == vars_.end()) {
311 AcquireGraph(graphPtr);
312 }
313
314 if (args.size() != graphPtr->parameters().size()) {
315 MS_LOG(EXCEPTION) << "Call with wrong number of arguments, expect " << graphPtr->parameters().size() << ", but got "
316 << args.size();
317 }
318
319 auto nodes = TopoSort(graphPtr->get_return(), SuccVm(graphPtr));
320 AnfNodePtrToBaseRefMap values;
321 for (size_t i = 0; i < args.size(); i++) {
322 values[graphPtr->parameters()[i]] = args[i];
323 }
324
325 return std::make_shared<CallWrap>(std::make_shared<VMFrame>(nodes, values, clos));
326 }
327
328 // make closure out of graph with fv values from frame
MakeClosure(const FuncGraphPtr & graph,const VMFramePtr & frame)329 ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) {
330 MS_EXCEPTION_IF_NULL(frame);
331 AnfNodePtrToBaseRefMap clos;
332
333 for (const auto &v : utils::cast<SetRef>(vars_[graph])) {
334 auto anf = utils::cast<AnfNodePtr>(v);
335 clos[anf] = (*frame)[anf];
336 }
337
338 return std::make_shared<Closure>(graph, clos);
339 }
340
DispatchCall(const AnfNodePtr & node,const VMFramePtr & frame,const BaseRef & fn,const VectorRef & args)341 BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) {
342 if (utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<Primitive>()) {
343 auto fnval = utils::cast<ValuePtr>(fn)->cast<PrimitivePtr>();
344 MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true);
345 if (args.empty()) {
346 MS_LOG(EXCEPTION) << "Args is empty";
347 }
348 if (fnval == prim::kPrimReturn) {
349 MS_LOG(DEBUG) << "Return args:" << args.size();
350 return std::make_shared<ReturnWrap>(args[0]);
351 }
352 MS_EXCEPTION_IF_NULL(frame);
353 if (fnval == prim::kPrimMakeTuple) {
354 frame->values()[node] = args;
355 return BaseRef();
356 }
357
358 if (fnval == prim::kPrimPartial) {
359 VectorRef partial_args(args.begin() + 1, args.end());
360 frame->values()[node] = (std::make_shared<Partial>(args[0], partial_args, shared_from_this()));
361 return BaseRef();
362 }
363
364 // call prim implementation
365 frame->values()[node] = RunOperation(fnval, args);
366 return BaseRef();
367 }
368
369 // partial args logic
370 if (utils::isa<PartialPtr>(fn)) {
371 auto fnPtr = utils::cast<PartialPtr>(fn);
372
373 VectorRef arglist;
374 (void)arglist.insert(arglist.end(), fnPtr->args().begin(), fnPtr->args().end());
375 (void)arglist.insert(arglist.end(), args.begin(), args.end());
376
377 auto ret = DispatchCall(node, frame, fnPtr->fn(), arglist);
378 if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
379 return ret;
380 }
381 }
382
383 // create frame for graph and closure
384 if ((utils::isa<ValuePtr>(fn) && utils::cast<ValuePtr>(fn)->isa<FuncGraph>()) || utils::isa<ClosurePtr>(fn)) {
385 auto ret = _Call(fn, args);
386 if (utils::isa<CallWrapPtr>(ret) || utils::isa<ReturnWrapPtr>(ret)) {
387 return ret;
388 }
389 }
390
391 MS_LOG(EXCEPTION) << "Invalid fn to call";
392 }
393
HandleNode(const AnfNodePtr & node,const VMFramePtr & frame)394 BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) {
395 MS_EXCEPTION_IF_NULL(node);
396 if (node->isa<Parameter>()) {
397 // pass
398 return BaseRef();
399 }
400
401 if (node->isa<ValueNode>()) {
402 // We only visit valuenode graphs
403 if (!IsValueNode<FuncGraph>(node)) {
404 MS_LOG(EXCEPTION) << "We only visit valuenode graphs ";
405 }
406 auto g = GetValueNode<FuncGraphPtr>(node);
407 MS_EXCEPTION_IF_NULL(frame);
408 // if g is a graph with fvs, we need to make a closure for it
409 auto iterG = vars_.find(g);
410 if (iterG != vars_.end() && utils::cast<SetRef>(iterG->second).size() != 0) {
411 frame->values()[node] = MakeClosure(g, frame);
412 }
413
414 return BaseRef();
415 }
416
417 if (node->isa<CNode>()) {
418 std::vector<BaseRef> fnArgs;
419 auto &inputs = node->cast<CNodePtr>()->inputs();
420 // set args' values in frame
421 (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs),
422 [&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; });
423 if (fnArgs.empty()) {
424 MS_LOG(EXCEPTION) << "Function arguments is empty";
425 } else {
426 auto args = VectorRef(fnArgs.begin() + 1, fnArgs.end());
427 auto except = DispatchCall(node, frame, fnArgs[0], args);
428 return except;
429 }
430 }
431
432 MS_LOG(EXCEPTION) << "Unknown node type";
433 }
434
RunGraph(const FuncGraphPtr & g,const VectorRef & args)435 VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
436 this->manager_ = Manage(g);
437
438 auto fn = utils::cast<ClosurePtr>(Export(g));
439 auto result = (*fn)(args);
440
441 if (utils::isa<VectorRef>(result)) {
442 return utils::cast<VectorRef>(result);
443 } else {
444 VectorRef ret({result});
445 return ret;
446 }
447 }
448
RunOperation(const PrimitivePtr & prim,const VectorRef & args)449 BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
450 MS_EXCEPTION_IF_NULL(prim);
451 MS_LOG(DEBUG) << "Operation start " << prim->name();
452 auto result = prim->RunComputeFunction(args);
453 if (result.is_null()) {
454 result = RunComputeFunctionWithoutPyObj(prim, args);
455 }
456 if (result.is_null()) {
457 return RunComputeFunction(prim, args);
458 }
459 return result;
460 }
461
462 } // namespace compile
463 } // namespace mindspore
464