• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "backend/graph_compiler/vm.h"
20 #include <algorithm>
21 #include "ops/nn_op_name.h"
22 #include "backend/graph_compiler/vmimpl.h"
23 #include "backend/graph_compiler/backend.h"
24 #include "pipeline/jit/ps/parse/data_converter.h"
25 #include "pybind_api/ir/base_ref_py.h"
26 #include "pybind_api/ir/primitive_py.h"
27 
28 namespace mindspore {
29 namespace compile {
30 
31 // Initialize StructPartial.
32 // Arguments:
33 //   fn_: Callable function.
34 //   args_: Sequence of function args.
35 //   fg_: Graph of function.
StructPartial(int64_t fn,const VectorRef & args,const FuncGraphPtr & fg)36 StructPartial::StructPartial(int64_t fn, const VectorRef &args, const FuncGraphPtr &fg)
37     : fn_(fn), args_(args), fg_(fg) {}
38 
operator <<(std::ostream & os,const StructPartial & other)39 std::ostream &operator<<(std::ostream &os, const StructPartial &other) {
40   os << "Partial(" << other.fn_ << ", " << other.args_.ToString() << ")";
41   return os;
42 }
43 
operator ==(const StructPartial & lhs,const StructPartial & rhs)44 bool operator==(const StructPartial &lhs, const StructPartial &rhs) {
45   return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_ && lhs.fg_ == rhs.fg_);
46 }
47 
StructSimuSwitch(const BaseRef & fn,const BaseRef & value)48 StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {}
49 
operator <<(std::ostream & os,const StructSimuSwitch & other)50 std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other) {
51   os << "SimulSwitch(" << other.fn_.ToString() << ", " << other.value_.ToString() << ")";
52   return os;
53 }
54 
operator ==(const StructSimuSwitch & lhs,const StructSimuSwitch & rhs)55 bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs) {
56   return (lhs.fn_ == rhs.fn_ && lhs.value_ == rhs.value_);
57 }
58 
operator <<(std::ostream & os,const SwitchCondStatus & other)59 std::ostream &operator<<(std::ostream &os, const SwitchCondStatus &other) {
60   os << "SwitchCondStatus(" << static_cast<int64_t>(other) << ")";
61   return os;
62 }
63 
64 // Follow the specified instructions to create a VM.
65 // Arguments:
66 //   insts_: std::vector<std::map<std::string, VectorRef>>
67 //   insts_stack_: The value stack.
68 //   retp_: The call stack.
69 //   pc_: program counter (next instruction)
70 //   sp_: stack pointer (for the value stack)
FinalVM(const InstSet & insts,const BackendPtr & backend)71 FinalVM::FinalVM(const InstSet &insts, const BackendPtr &backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) {
72   MS_LOG(DEBUG) << "InstSet size:" << insts_.size();
73   insts_stack_.emplace_back(BaseRef());
74   retp_.push(-1);
75 }
76 
Push(const BaseRef & v)77 void FinalVM::Push(const BaseRef &v) {
78   MS_LOG(DEBUG) << "Push " << v.ToString() << " sp_:" << sp_;
79   insts_stack_[IntToSize(sp_++)] = v;
80 }
81 
Pop(int64_t n)82 void FinalVM::Pop(int64_t n) {
83   if (n > sp_) {
84     MS_LOG(EXCEPTION) << "Invalid value of n " << n << ", it should not be more than " << (sp_ - 1);
85   }
86   for (int64_t i = 0; i < n; i++) {
87     insts_stack_[IntToSize(sp_ - i - 1)] = BaseRef();
88   }
89   sp_ -= n;
90 }
91 
MoveStack(int64_t nitems,int64_t height)92 void FinalVM::MoveStack(int64_t nitems, int64_t height) {
93   if (nitems > height || height > sp_) {
94     MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height << " sp=" << sp_;
95   }
96   int64_t n = height - nitems;
97   int64_t src = sp_ - height;
98   int64_t dst = sp_ - nitems;
99   for (int64_t i = 0; i < nitems; i++) {
100     insts_stack_[IntToSize(src + i)] = insts_stack_[IntToSize(dst + i)];
101   }
102   Pop(n);
103 }
104 
Ref(int64_t i)105 BaseRef FinalVM::Ref(int64_t i) {
106   MS_LOG(DEBUG) << "Ref i:" << i << " sp_:" << sp_;
107   size_t sp_next = LongToSize(sp_ + i);
108   if (sp_next < insts_stack_.size()) {
109     if (utils::isa<PyObjectRef>(insts_stack_[sp_next])) {
110       py::object value = utils::cast<PyObjectRef>(insts_stack_[sp_next]).object_;
111       MS_LOG(DEBUG) << "VM ref python:" << py::str(value);
112       return python_adapter::PyAdapterCallback::PyDataToValue(value);
113     }
114     MS_LOG(DEBUG) << "Ref not python :" << insts_stack_[sp_next].ToString();
115     return insts_stack_[sp_next];
116   }
117 
118   MS_LOG(EXCEPTION) << "IndexError: index(" << sp_next << ") out of range [0, " << insts_stack_.size() << ").";
119 }
120 
Pushp()121 void FinalVM::Pushp() { retp_.push(pc_); }
122 
Popp()123 void FinalVM::Popp() {
124   if (retp_.empty()) {
125     MS_LOG(EXCEPTION) << "Stack retp_ is empty";
126   }
127   pc_ = retp_.top();
128   MS_LOG(DEBUG) << "Pop pc:" << pc_ << ", sp:" << sp_;
129   retp_.pop();
130 }
131 
Pushsp()132 void FinalVM::Pushsp() { retsp_.push(sp_); }
133 
Popsp()134 void FinalVM::Popsp() {
135   int64_t sp = retsp_.top();
136   MS_LOG(DEBUG) << "Current sp:" << sp_ << ", before sp:" << sp << ", " << sp_ - sp;
137   if (sp_ >= sp) {
138     Pop((sp_ - sp) + 1);
139     retsp_.pop();
140   } else {
141     MS_LOG(EXCEPTION) << "Stack point sp_:" << sp << " must be bigger than sp:" << sp_;
142   }
143 }
144 
DoJmp(const BaseRef & jmp_orig)145 void FinalVM::DoJmp(const BaseRef &jmp_orig) {
146   MS_LOG(DEBUG) << "Start";
147 
148   BaseRef jmp = jmp_orig;
149   if (utils::isa<StructPartial>(jmp)) {  // need to inherit from Base
150     MS_LOG(DEBUG) << "Start jump StructPartial";
151     auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
152     auto args = new_jmp->args_;
153     InstPadStack(VectorRef(std::vector<BaseRef>{static_cast<int64_t>(args.size())}));
154     auto iter = args.rbegin();
155     for (; iter != args.rend(); ++iter) {
156       Push(*iter);
157     }
158     pc_ = new_jmp->fn_;
159     return;
160   }
161 
162   if (!utils::isa<int64_t>(jmp)) {
163     MS_LOG(EXCEPTION) << "Jmp inst should be an int64_t";
164   }
165   pc_ = utils::cast<int64_t>(jmp);
166   MS_LOG(DEBUG) << "End do jump pc_:" << pc_;
167 }
168 
Eval(const VectorRef & args)169 BaseRef FinalVM::Eval(const VectorRef &args) {
170   MS_LOG(DEBUG) << "Start: " << args.size();
171   insts_stack_.clear();
172   insts_stack_.resize(args.size());
173   std::stack<int64_t>().swap(retp_);
174   retp_.push(-1);
175   pc_ = 0;
176   sp_ = 0;
177 
178   auto riter = args.rbegin();
179   for (; riter != args.rend(); ++riter) {
180     if (utils::isa<PyObjectRef>(*riter)) {
181       PyObjectRef py_ref = utils::cast<PyObjectRef>(*riter);
182       py::object value = py_ref.object_;
183       if (py::isinstance<py::bool_>(value)) {
184         auto a = py::cast<bool>(value);
185         Push(static_cast<int64_t>(a));
186         continue;
187       }
188     }
189     Push(*riter);
190   }
191 
192   while (pc_ >= 0) {
193     auto inst = insts_[IntToSize(pc_)];
194     MS_LOG(DEBUG) << "Loop " << insts_.size() << ", pc:" << pc_ << ", inst:" << inst_str[inst.first];
195     ++pc_;
196     auto iter = inst_function_map.find(inst.first);
197     if (iter != inst_function_map.end()) {
198       iter->second(inst.second);
199     } else {
200       MS_LOG(EXCEPTION) << "Unknown instruction {" << inst_str[inst.first] << "}";
201     }
202   }
203 
204   MS_LOG(DEBUG) << "End";
205   return insts_stack_[0];
206 }
207 
InstCall(const VectorRef & args)208 void FinalVM::InstCall(const VectorRef &args) {
209   MS_LOG(DEBUG) << "Start";
210   const size_t args_size = 1;
211   if (args.size() != args_size) {
212     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
213                   << ".";
214     return;
215   }
216 
217   int64_t jmp = utils::cast<int64_t>(args[0]);
218   MS_LOG(DEBUG) << "Call pushp:" << pc_ << ", jmp:" << jmp << ", sp:" << sp_;
219   Pushp();
220   DoJmp(Ref(jmp));
221   MS_LOG(DEBUG) << "Instcall end sp :" << sp_;
222 }
223 
InstTailCall(const VectorRef & args)224 void FinalVM::InstTailCall(const VectorRef &args) {
225   MS_LOG(DEBUG) << "Start";
226   const size_t args_size = 3;
227   if (args.size() != args_size) {
228     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
229                   << ".";
230     return;
231   }
232 
233   const size_t jmp_index = 0;
234   const size_t height_index = 1;
235   const size_t nargs_index = 2;
236   int64_t jmp = utils::cast<int64_t>(args[jmp_index]);
237   int64_t height = utils::cast<int64_t>(args[height_index]);
238   int64_t nargs = utils::cast<int64_t>(args[nargs_index]);
239 
240   auto new_jmp = Ref(jmp);
241   MoveStack(nargs, height);
242   MS_LOG(DEBUG) << "TailCall pushp:" << pc_ << ", jmp:" << jmp;
243   DoJmp(new_jmp);
244   MS_LOG(DEBUG) << "End";
245 }
246 
InstSwitchReturn(const VectorRef & args)247 void FinalVM::InstSwitchReturn(const VectorRef &args) {
248   MS_LOG(DEBUG) << "Start";
249   if (args.size() != 1) {
250     MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
251     return;
252   }
253   Pop(1);
254   Popsp();
255 }
256 
InstReturn(const VectorRef & args)257 void FinalVM::InstReturn(const VectorRef &args) {
258   MS_LOG(DEBUG) << "Start";
259   const size_t args_size = 2;
260   if (args.size() != args_size) {
261     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
262                   << ".";
263     return;
264   }
265 
266   int64_t rpos = utils::cast<int64_t>(args[0]);
267   int64_t height = utils::cast<int64_t>(args[1]);
268 
269   auto rv = Ref(rpos);
270   Pop(height);
271   Push(rv);
272   Popp();
273   MS_LOG(DEBUG) << "End";
274 }
275 
InstRealPartial(const VectorRef & args)276 void FinalVM::InstRealPartial(const VectorRef &args) {
277   const size_t args_size = 1;
278   if (args.size() < args_size) {
279     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
280                   << args.size() << ".";
281     return;
282   }
283 
284   int64_t fn_ = utils::cast<int64_t>(args[0]);
285   auto fn = utils::cast<int64_t>(Ref(fn_));
286   MS_LOG(DEBUG) << "Partial argssize:" << args.size();
287   std::vector<BaseRef> outs(args.size() - 1);
288   (void)std::transform(args.begin() + 1, args.end(), outs.begin(),
289                        [&, this](const BaseRef &a) { return Ref(utils::cast<int64_t>(a)); });
290   Push(std::make_shared<StructPartial>(fn, VectorRef(outs)));
291 }
292 
InstPartial(const VectorRef & args)293 void FinalVM::InstPartial(const VectorRef &args) {
294   MS_LOG(DEBUG) << "Start";
295   InstRealPartial(args);
296   MS_LOG(DEBUG) << "End";
297 }
298 
InstRealSwitch(const VectorRef & args)299 void FinalVM::InstRealSwitch(const VectorRef &args) {
300   const size_t args_size = 3;
301   if (args.size() != args_size) {
302     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
303                   << ".";
304     return;
305   }
306 
307   const size_t cond_index = 0;
308   const size_t vtrue_index = 1;
309   const size_t vfalse_index = 2;
310   int64_t cond = utils::cast<int64_t>(args[cond_index]);
311   int64_t vtrue = utils::cast<int64_t>(args[vtrue_index]);
312   int64_t vfalse = utils::cast<int64_t>(args[vfalse_index]);
313 
314   BaseRef c = Ref(cond);
315   MS_LOG(DEBUG) << vtrue << " false:" << vfalse << " InstSwitch: " << c.ToString();
316   bool bool_value = false;
317   MS_EXCEPTION_IF_NULL(backend_);
318   if (backend_->GetCond(c, &bool_value)) {
319     MS_LOG(DEBUG) << "Cond:" << bool_value;
320     if (bool_value) {
321       Push(Ref(vtrue));
322     } else {
323       Push(Ref(vfalse));
324     }
325   } else {
326     MS_LOG(EXCEPTION) << "Not supported type to be casted to bool";
327   }
328 }
329 
InstSwitch(const VectorRef & args)330 void FinalVM::InstSwitch(const VectorRef &args) {
331   MS_LOG(DEBUG) << "Start";
332   InstRealSwitch(args);
333   MS_LOG(DEBUG) << "End";
334 }
335 
InstSwitchLayer(const VectorRef & args)336 void FinalVM::InstSwitchLayer(const VectorRef &args) {
337   MS_LOG(DEBUG) << "Start";
338   const size_t args_size = 2;
339   if (args.size() != args_size) {
340     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
341                   << ".";
342     return;
343   }
344 
345   int64_t idx = utils::cast<int64_t>(args[0]);
346   VectorRef branches = utils::cast<VectorRef>(Ref(utils::cast<int64_t>(args[1])));
347   int64_t size = static_cast<int64_t>(branches.size());
348 
349   BaseRef index = Ref(idx);
350   int64_t idx_value = 0;
351   MS_EXCEPTION_IF_NULL(backend_);
352   if (!backend_->GetIndex(index, &idx_value)) {
353     MS_LOG(EXCEPTION) << "Not supported type to be casted to int64_t.";
354   }
355   auto ori_value = idx_value;
356   if (idx_value < 0) {
357     // Add support negative index range [-size, -1].
358     idx_value += size;
359   }
360   if (idx_value < 0 || idx_value >= size) {
361     MS_EXCEPTION(IndexError) << __FUNCTION__ << " given index " << ori_value
362                              << " out of range. Please make sure the value "
363                              << "of index in [" << -size << ", " << size << "), and the type is int32.";
364   }
365   Push(branches[idx_value]);
366   MS_LOG(DEBUG) << "End";
367 }
368 
InstTuple(const VectorRef & args)369 void FinalVM::InstTuple(const VectorRef &args) {
370   MS_LOG(DEBUG) << "Start";
371   VectorRef tuple;
372   auto iter = args.begin();
373   for (; iter != args.end(); ++iter) {
374     auto a = utils::cast<int64_t>(*iter);
375     tuple.push_back(Ref(a));
376   }
377   Push(tuple);
378   MS_LOG(DEBUG) << "End";
379 }
380 
InstPush(const VectorRef & args)381 void FinalVM::InstPush(const VectorRef &args) {
382   MS_LOG(DEBUG) << "Start";
383   const size_t args_size = 1;
384   if (args.size() != args_size) {
385     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
386                   << ".";
387     return;
388   }
389 
390   auto v = args[0];
391   Push(v);
392   MS_LOG(DEBUG) << "End";
393 }
394 
InstInput(const VectorRef & args)395 void FinalVM::InstInput(const VectorRef &args) {
396   MS_LOG(DEBUG) << "Start";
397   const size_t args_size = 1;
398   if (args.size() != args_size) {
399     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
400                   << ".";
401     return;
402   }
403 
404   int64_t rpos = utils::cast<int64_t>(args[0]);
405   Push(Ref(rpos));
406   MS_LOG(DEBUG) << "End";
407 }
408 
InstPadStack(const VectorRef & args)409 void FinalVM::InstPadStack(const VectorRef &args) {
410   MS_LOG(DEBUG) << "Start";
411   const size_t args_size = 1;
412   if (args.size() != args_size) {
413     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
414                   << ".";
415     return;
416   }
417 
418   int64_t sz = utils::cast<int64_t>(args[0]);
419   MS_LOG(DEBUG) << insts_stack_.size() << " need padstack " << sz << " sp_ " << sp_;
420   size_t stack_size = insts_stack_.size();
421   int64_t need = sz - (static_cast<int64_t>(stack_size) - sp_);
422   if (need > 0) {
423     MS_LOG(DEBUG) << "InstPadStack resize: size:" << insts_stack_.size() << " need pad:" << need;
424     insts_stack_.resize(stack_size + IntToSize(need));
425   }
426   MS_LOG(DEBUG) << "End";
427 }
428 
InstExternal(const VectorRef & args)429 void FinalVM::InstExternal(const VectorRef &args) {
430   MS_LOG(DEBUG) << "Start:" << args.size();
431 
432   if (args.empty()) {
433     MS_LOG(EXCEPTION) << "Args is empty!";
434   }
435 
436   VectorRef tuple;
437   RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
438   compile::RunFuncPtr fn = run_ref.func_;
439   const size_t arg_start_index = 2;
440   for (size_t i = arg_start_index; i < args.size(); ++i) {
441     auto index = utils::cast<int64_t>(args[i]);
442     tuple.push_back(Ref(index));
443   }
444 
445   if (!fn) {
446     MS_LOG(EXCEPTION) << "Function not callable";
447   }
448 
449   auto outs = (*fn)(tuple);
450   MS_LOG(DEBUG) << "The 'fn' out size:" << outs.size();
451   for (auto &o : outs) {
452     MS_LOG(DEBUG) << "InstExternal value:" << o.ToString();
453     Push(o);
454   }
455   MS_LOG(DEBUG) << "End";
456 }
457 
InstPushPrim(const VectorRef & args)458 void FinalVM::InstPushPrim(const VectorRef &args) {
459   MS_LOG(DEBUG) << "Start: " << args.size();
460   const size_t args_size = 2;
461   if (args.size() < args_size) {
462     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
463                   << args.size() << ".";
464     return;
465   }
466 
467   auto prim = utils::cast<PrimitivePtr>(args[0]);
468   VectorRef tuple;
469   for (size_t i = 1; i < args.size(); ++i) {
470     auto index = utils::cast<int64_t>(args[i]);
471     tuple.push_back(Ref(index));
472   }
473 
474   if (prim->name() == kBpropCutOpName) {
475     BaseRef outs = python_adapter::PyAdapterCallback::RunPrimitivePyHookFunction(prim, tuple);
476     Push(outs);
477   } else {
478     auto outs = RunOperation(prim, tuple);
479     Push(outs);
480   }
481 
482   MS_LOG(DEBUG) << "End";
483 }
484 }  // namespace compile
485 }  // namespace mindspore
486