• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "vm/vm.h"
20 #include <algorithm>
21 #include "vm/vmimpl.h"
22 #include "vm/backend.h"
23 #include "pipeline/jit/parse/data_converter.h"
24 #include "pybind_api/ir/base_ref_py.h"
25 
26 namespace mindspore {
27 namespace compile {
28 
29 // Initialize StructPartial.
30 // Arguments:
31 //   fn_: Callable function.
32 //   args_: Sequence of function args.
33 //   fg_: Graph of function.
StructPartial(int64_t fn,const VectorRef & args,const FuncGraphPtr & fg)34 StructPartial::StructPartial(int64_t fn, const VectorRef &args, const FuncGraphPtr &fg)
35     : fn_(fn), args_(args), fg_(fg) {}
36 
operator <<(std::ostream & os,const StructPartial & other)37 std::ostream &operator<<(std::ostream &os, const StructPartial &other) {
38   os << "Partial(" << other.fn_ << ", " << other.args_.ToString() << ")";
39   return os;
40 }
41 
operator ==(const StructPartial & lhs,const StructPartial & rhs)42 bool operator==(const StructPartial &lhs, const StructPartial &rhs) {
43   return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_ && lhs.fg_ == rhs.fg_);
44 }
45 
StructSimuSwitch(const BaseRef & fn,const BaseRef & value)46 StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {}
47 
operator <<(std::ostream & os,const StructSimuSwitch & other)48 std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other) {
49   os << "SimulSwitch(" << other.fn_.ToString() << ", " << other.value_.ToString() << ")";
50   return os;
51 }
52 
operator ==(const StructSimuSwitch & lhs,const StructSimuSwitch & rhs)53 bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs) {
54   return (lhs.fn_ == rhs.fn_ && lhs.value_ == rhs.value_);
55 }
56 
operator <<(std::ostream & os,const SwitchCondStatus & other)57 std::ostream &operator<<(std::ostream &os, const SwitchCondStatus &other) {
58   os << "SwitchCondStatus(" << static_cast<int64_t>(other) << ")";
59   return os;
60 }
61 
62 // Follow the specified instructions to create a VM.
63 // Arguments:
64 //   insts_: std::vector<std::map<std::string, VectorRef>>
65 //   insts_stack_: The value stack.
66 //   retp_: The call stack.
67 //   pc_: program counter (next instruction)
68 //   sp_: stack pointer (for the value stack)
FinalVM(const InstSet & insts,const BackendPtr & backend)69 FinalVM::FinalVM(const InstSet &insts, const BackendPtr &backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) {
70   MS_LOG(DEBUG) << "InstSet size:" << insts_.size();
71   insts_stack_.emplace_back(BaseRef());
72   retp_.push(-1);
73 }
74 
Push(const BaseRef & v)75 void FinalVM::Push(const BaseRef &v) {
76   MS_LOG(DEBUG) << "Push " << v.ToString() << " sp_:" << sp_;
77   insts_stack_[IntToSize(sp_++)] = v;
78 }
79 
Pop(int64_t n)80 void FinalVM::Pop(int64_t n) {
81   if (n > sp_) {
82     MS_LOG(EXCEPTION) << "Invalid value of n " << n << ", it should be not more than " << (sp_ - 1);
83   }
84   for (int64_t i = 0; i < n; i++) {
85     insts_stack_[IntToSize(sp_ - i - 1)] = BaseRef();
86   }
87   sp_ -= n;
88 }
89 
MoveStack(int64_t nitems,int64_t height)90 void FinalVM::MoveStack(int64_t nitems, int64_t height) {
91   if (nitems > height || height > sp_) {
92     MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height << " sp=" << sp_;
93   }
94   int64_t n = height - nitems;
95   int64_t src = sp_ - height;
96   int64_t dst = sp_ - nitems;
97   for (int64_t i = 0; i < nitems; i++) {
98     insts_stack_[IntToSize(src + i)] = insts_stack_[IntToSize(dst + i)];
99   }
100   Pop(n);
101 }
102 
Ref(int64_t i)103 BaseRef FinalVM::Ref(int64_t i) {
104   MS_LOG(DEBUG) << "Ref i:" << i << " sp_:" << sp_;
105   size_t sp_next = LongToSize(sp_ + i);
106   if (sp_next < insts_stack_.size()) {
107     if (utils::isa<PyObjectRef>(insts_stack_[sp_next])) {
108       py::object value = utils::cast<PyObjectRef>(insts_stack_[sp_next]).object_;
109       MS_LOG(DEBUG) << "VM ref python:" << py::str(value);
110       return parse::data_converter::PyDataToValue(value);
111     }
112     MS_LOG(DEBUG) << "Ref not python :" << insts_stack_[sp_next].ToString();
113     return insts_stack_[sp_next];
114   }
115 
116   MS_LOG(EXCEPTION) << "IndexError: index(" << sp_next << ") out of range [0, " << insts_stack_.size() << ").";
117 }
118 
Pushp()119 void FinalVM::Pushp() { retp_.push(pc_); }
120 
Popp()121 void FinalVM::Popp() {
122   if (retp_.empty()) {
123     MS_LOG(EXCEPTION) << "Stack retp_ is empty";
124   }
125   pc_ = retp_.top();
126   MS_LOG(DEBUG) << "Pop pc:" << pc_ << ", sp:" << sp_;
127   retp_.pop();
128 }
129 
Pushsp()130 void FinalVM::Pushsp() { retsp_.push(sp_); }
131 
Popsp()132 void FinalVM::Popsp() {
133   int64_t sp = retsp_.top();
134   MS_LOG(DEBUG) << "Current sp:" << sp_ << ", before sp:" << sp << ", " << sp_ - sp;
135   if (sp_ >= sp) {
136     Pop(sp_ - sp + 1);
137     retsp_.pop();
138   } else {
139     MS_LOG(EXCEPTION) << "Stack point sp_:" << sp << " must biger than sp:" << sp_;
140   }
141 }
142 
DoJmp(const BaseRef & jmp_orig)143 void FinalVM::DoJmp(const BaseRef &jmp_orig) {
144   MS_LOG(DEBUG) << "Start";
145 
146   BaseRef jmp = jmp_orig;
147   if (utils::isa<StructPartial>(jmp)) {  // need to inherit from Base
148     MS_LOG(DEBUG) << "Start jump StructPartial";
149     auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
150     auto args = new_jmp->args_;
151     InstPadStack(VectorRef(std::vector<BaseRef>{static_cast<int64_t>(args.size())}));
152     auto iter = args.rbegin();
153     for (; iter != args.rend(); ++iter) {
154       Push(*iter);
155     }
156     pc_ = new_jmp->fn_;
157     return;
158   }
159 
160   if (!utils::isa<int64_t>(jmp)) {
161     MS_LOG(EXCEPTION) << "Jmp inst should be a int64_t";
162   }
163   pc_ = utils::cast<int64_t>(jmp);
164   MS_LOG(DEBUG) << "End do jump pc_:" << pc_;
165 }
166 
Eval(const VectorRef & args)167 BaseRef FinalVM::Eval(const VectorRef &args) {
168   MS_LOG(DEBUG) << "Start: " << args.size();
169   insts_stack_.clear();
170   insts_stack_.resize(args.size());
171   std::stack<int64_t>().swap(retp_);
172   retp_.push(-1);
173   pc_ = 0;
174   sp_ = 0;
175 
176   auto riter = args.rbegin();
177   for (; riter != args.rend(); ++riter) {
178     if (utils::isa<PyObjectRef>(*riter)) {
179       PyObjectRef py_ref = utils::cast<PyObjectRef>(*riter);
180       py::object value = py_ref.object_;
181       if (py::isinstance<py::bool_>(value)) {
182         auto a = py::cast<bool>(value);
183         Push(static_cast<int64_t>(a));
184         continue;
185       }
186     }
187     Push(*riter);
188   }
189 
190   while (pc_ >= 0) {
191     auto inst = insts_[IntToSize(pc_)];
192     MS_LOG(DEBUG) << "Loop " << insts_.size() << ", pc:" << pc_ << ", inst:" << inst_str[inst.first];
193     ++pc_;
194     auto iter = inst_function_map.find(inst.first);
195     if (iter != inst_function_map.end()) {
196       iter->second(inst.second);
197     } else {
198       MS_LOG(EXCEPTION) << "Unknown instruction {" << inst_str[inst.first] << "}";
199     }
200   }
201 
202   MS_LOG(DEBUG) << "End";
203   return insts_stack_[0];
204 }
205 
InstCall(const VectorRef & args)206 void FinalVM::InstCall(const VectorRef &args) {
207   MS_LOG(DEBUG) << "Start";
208   const size_t args_size = 1;
209   if (args.size() != args_size) {
210     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
211                   << ".";
212     return;
213   }
214 
215   int64_t jmp = utils::cast<int64_t>(args[0]);
216   MS_LOG(DEBUG) << "Call pushp:" << pc_ << ", jmp:" << jmp << ", sp:" << sp_;
217   Pushp();
218   DoJmp(Ref(jmp));
219   MS_LOG(DEBUG) << "Instcall end sp :" << sp_;
220 }
221 
InstTailCall(const VectorRef & args)222 void FinalVM::InstTailCall(const VectorRef &args) {
223   MS_LOG(DEBUG) << "Start";
224   const size_t args_size = 3;
225   if (args.size() != args_size) {
226     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
227                   << ".";
228     return;
229   }
230 
231   const size_t jmp_index = 0;
232   const size_t height_index = 1;
233   const size_t nargs_index = 2;
234   int64_t jmp = utils::cast<int64_t>(args[jmp_index]);
235   int64_t height = utils::cast<int64_t>(args[height_index]);
236   int64_t nargs = utils::cast<int64_t>(args[nargs_index]);
237 
238   auto new_jmp = Ref(jmp);
239   MoveStack(nargs, height);
240   MS_LOG(DEBUG) << "TailCall pushp:" << pc_ << ", jmp:" << jmp;
241   DoJmp(new_jmp);
242   MS_LOG(DEBUG) << "End";
243 }
244 
InstSwitchReturn(const VectorRef & args)245 void FinalVM::InstSwitchReturn(const VectorRef &args) {
246   MS_LOG(DEBUG) << "Start";
247   if (args.size() != 1) {
248     MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
249     return;
250   }
251   Pop(1);
252   Popsp();
253 }
254 
InstReturn(const VectorRef & args)255 void FinalVM::InstReturn(const VectorRef &args) {
256   MS_LOG(DEBUG) << "Start";
257   const size_t args_size = 2;
258   if (args.size() != args_size) {
259     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
260                   << ".";
261     return;
262   }
263 
264   int64_t rpos = utils::cast<int64_t>(args[0]);
265   int64_t height = utils::cast<int64_t>(args[1]);
266 
267   auto rv = Ref(rpos);
268   Pop(height);
269   Push(rv);
270   Popp();
271   MS_LOG(DEBUG) << "End";
272 }
273 
InstRealPartial(const VectorRef & args)274 void FinalVM::InstRealPartial(const VectorRef &args) {
275   const size_t args_size = 1;
276   if (args.size() < args_size) {
277     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
278                   << args.size() << ".";
279     return;
280   }
281 
282   int64_t fn_ = utils::cast<int64_t>(args[0]);
283   auto fn = utils::cast<int64_t>(Ref(fn_));
284   MS_LOG(DEBUG) << "Partial argssize:" << args.size();
285   std::vector<BaseRef> outs(args.size() - 1);
286   (void)std::transform(args.begin() + 1, args.end(), outs.begin(),
287                        [&, this](const BaseRef &a) { return Ref(utils::cast<int64_t>(a)); });
288   Push(std::make_shared<StructPartial>(fn, VectorRef(outs)));
289 }
290 
InstPartial(const VectorRef & args)291 void FinalVM::InstPartial(const VectorRef &args) {
292   MS_LOG(DEBUG) << "Start";
293   InstRealPartial(args);
294   MS_LOG(DEBUG) << "End";
295 }
296 
InstRealSwitch(const VectorRef & args)297 void FinalVM::InstRealSwitch(const VectorRef &args) {
298   const size_t args_size = 3;
299   if (args.size() != args_size) {
300     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
301                   << ".";
302     return;
303   }
304 
305   const size_t cond_index = 0;
306   const size_t vtrue_index = 1;
307   const size_t vfalse_index = 2;
308   int64_t cond = utils::cast<int64_t>(args[cond_index]);
309   int64_t vtrue = utils::cast<int64_t>(args[vtrue_index]);
310   int64_t vfalse = utils::cast<int64_t>(args[vfalse_index]);
311 
312   BaseRef c = Ref(cond);
313   MS_LOG(DEBUG) << vtrue << " false:" << vfalse << " InstSwitch: " << c.ToString();
314   bool bool_value = false;
315   MS_EXCEPTION_IF_NULL(backend_);
316   if (backend_->GetCond(c, &bool_value)) {
317     MS_LOG(DEBUG) << "Cond:" << bool_value;
318     if (bool_value) {
319       Push(Ref(vtrue));
320     } else {
321       Push(Ref(vfalse));
322     }
323   } else {
324     MS_LOG(EXCEPTION) << "Not supported type to be casted to bool";
325   }
326 }
327 
InstSwitch(const VectorRef & args)328 void FinalVM::InstSwitch(const VectorRef &args) {
329   MS_LOG(DEBUG) << "Start";
330   InstRealSwitch(args);
331   MS_LOG(DEBUG) << "End";
332 }
333 
InstSwitchLayer(const VectorRef & args)334 void FinalVM::InstSwitchLayer(const VectorRef &args) {
335   MS_LOG(DEBUG) << "Start";
336   const size_t args_size = 2;
337   if (args.size() != args_size) {
338     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
339                   << ".";
340     return;
341   }
342 
343   int64_t idx = utils::cast<int64_t>(args[0]);
344   VectorRef branches = utils::cast<VectorRef>(Ref(utils::cast<int64_t>(args[1])));
345   int64_t size = static_cast<int64_t>(branches.size());
346 
347   BaseRef index = Ref(idx);
348   int64_t idx_value = 0;
349   MS_EXCEPTION_IF_NULL(backend_);
350   if (!backend_->GetIndex(index, &idx_value)) {
351     MS_LOG(EXCEPTION) << "Not supported type to be casted to int64_t.";
352   }
353   auto ori_value = idx_value;
354   if (idx_value < 0) {
355     // Add support negative index range [-size, -1].
356     idx_value += size;
357   }
358   if (idx_value < 0 || idx_value >= size) {
359     MS_EXCEPTION(IndexError) << __FUNCTION__ << " given index " << ori_value
360                              << " out of range. Please make sure the value "
361                              << "of index in [" << -size << ", " << size << "), and the type is int32.";
362   }
363   Push(branches[idx_value]);
364   MS_LOG(DEBUG) << "End";
365 }
366 
InstTuple(const VectorRef & args)367 void FinalVM::InstTuple(const VectorRef &args) {
368   MS_LOG(DEBUG) << "Start";
369   VectorRef tuple;
370   auto iter = args.begin();
371   for (; iter != args.end(); ++iter) {
372     auto a = utils::cast<int64_t>(*iter);
373     tuple.push_back(Ref(a));
374   }
375   Push(tuple);
376   MS_LOG(DEBUG) << "End";
377 }
378 
InstPush(const VectorRef & args)379 void FinalVM::InstPush(const VectorRef &args) {
380   MS_LOG(DEBUG) << "Start";
381   const size_t args_size = 1;
382   if (args.size() != args_size) {
383     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
384                   << ".";
385     return;
386   }
387 
388   auto v = args[0];
389   Push(v);
390   MS_LOG(DEBUG) << "End";
391 }
392 
InstInput(const VectorRef & args)393 void FinalVM::InstInput(const VectorRef &args) {
394   MS_LOG(DEBUG) << "Start";
395   const size_t args_size = 1;
396   if (args.size() != args_size) {
397     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
398                   << ".";
399     return;
400   }
401 
402   int64_t rpos = utils::cast<int64_t>(args[0]);
403   Push(Ref(rpos));
404   MS_LOG(DEBUG) << "End";
405 }
406 
InstPadStack(const VectorRef & args)407 void FinalVM::InstPadStack(const VectorRef &args) {
408   MS_LOG(DEBUG) << "Start";
409   const size_t args_size = 1;
410   if (args.size() != args_size) {
411     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
412                   << ".";
413     return;
414   }
415 
416   int64_t sz = utils::cast<int64_t>(args[0]);
417   MS_LOG(DEBUG) << insts_stack_.size() << " need padstack " << sz << " sp_ " << sp_;
418   size_t stack_size = insts_stack_.size();
419   int64_t need = sz - (static_cast<int64_t>(stack_size) - sp_);
420   if (need > 0) {
421     MS_LOG(DEBUG) << "InstPadStack resize: size:" << insts_stack_.size() << " need pad:" << need;
422     insts_stack_.resize(stack_size + IntToSize(need));
423   }
424   MS_LOG(DEBUG) << "End";
425 }
426 
InstExternal(const VectorRef & args)427 void FinalVM::InstExternal(const VectorRef &args) {
428   MS_LOG(DEBUG) << "Start:" << args.size();
429 
430   if (args.empty()) {
431     MS_LOG(EXCEPTION) << "Args is empty!";
432   }
433 
434   VectorRef tuple;
435   RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
436   compile::RunFuncPtr fn = run_ref.func_;
437   const size_t arg_start_index = 2;
438   for (size_t i = arg_start_index; i < args.size(); ++i) {
439     auto index = utils::cast<int64_t>(args[i]);
440     tuple.push_back(Ref(index));
441   }
442 
443   if (!fn) {
444     MS_LOG(EXCEPTION) << "Function not callable";
445   }
446 
447   auto outs = (*fn)(tuple);
448   MS_LOG(DEBUG) << "The 'fn' out size:" << outs.size();
449   for (auto &o : outs) {
450     MS_LOG(DEBUG) << "InstExternal value:" << o.ToString();
451     Push(o);
452   }
453   MS_LOG(DEBUG) << "End";
454 }
455 
InstPushPrim(const VectorRef & args)456 void FinalVM::InstPushPrim(const VectorRef &args) {
457   MS_LOG(DEBUG) << "Start: " << args.size();
458   const size_t args_size = 2;
459   if (args.size() < args_size) {
460     MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
461                   << args.size() << ".";
462     return;
463   }
464 
465   auto prim = utils::cast<PrimitivePtr>(args[0]);
466   VectorRef tuple;
467   for (size_t i = 1; i < args.size(); ++i) {
468     auto index = utils::cast<int64_t>(args[i]);
469     tuple.push_back(Ref(index));
470   }
471 
472   if (prim->name() == "bprop_cut") {
473     auto outs = RunHook(prim, tuple);
474     Push(outs);
475   } else {
476     auto outs = RunOperation(prim, tuple);
477     Push(outs);
478   }
479 
480   MS_LOG(DEBUG) << "End";
481 }
482 
SyncData(const py::object & arg)483 void FinalVM::SyncData(const py::object &arg) {
484   if (py::isinstance<py::tuple>(arg)) {
485     auto arg_list = py::cast<py::tuple>(arg);
486     for (size_t i = 0; i < arg_list.size(); i++) {
487       SyncData(arg_list[i]);
488     }
489   }
490   if (py::isinstance<tensor::Tensor>(arg)) {
491     auto tensor = py::cast<tensor::TensorPtr>(arg);
492     tensor->data_sync();
493   }
494 }
495 
RunHook(const PrimitivePtr & prim,const VectorRef & args)496 BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
497   MS_LOG(DEBUG) << "Input for operation:";
498   MS_EXCEPTION_IF_NULL(prim);
499   return prim->RunHookFunction(args);
500 }
501 }  // namespace compile
502 }  // namespace mindspore
503