1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2020 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "backend/graph_compiler/vm.h"
20 #include <algorithm>
21 #include "ops/nn_op_name.h"
22 #include "backend/graph_compiler/vmimpl.h"
23 #include "backend/graph_compiler/backend.h"
24 #include "pipeline/jit/ps/parse/data_converter.h"
25 #include "pybind_api/ir/base_ref_py.h"
26 #include "pybind_api/ir/primitive_py.h"
27
28 namespace mindspore {
29 namespace compile {
30
31 // Initialize StructPartial.
32 // Arguments:
33 // fn_: Callable function.
34 // args_: Sequence of function args.
35 // fg_: Graph of function.
StructPartial(int64_t fn,const VectorRef & args,const FuncGraphPtr & fg)36 StructPartial::StructPartial(int64_t fn, const VectorRef &args, const FuncGraphPtr &fg)
37 : fn_(fn), args_(args), fg_(fg) {}
38
operator <<(std::ostream & os,const StructPartial & other)39 std::ostream &operator<<(std::ostream &os, const StructPartial &other) {
40 os << "Partial(" << other.fn_ << ", " << other.args_.ToString() << ")";
41 return os;
42 }
43
operator ==(const StructPartial & lhs,const StructPartial & rhs)44 bool operator==(const StructPartial &lhs, const StructPartial &rhs) {
45 return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_ && lhs.fg_ == rhs.fg_);
46 }
47
StructSimuSwitch(const BaseRef & fn,const BaseRef & value)48 StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {}
49
operator <<(std::ostream & os,const StructSimuSwitch & other)50 std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other) {
51 os << "SimulSwitch(" << other.fn_.ToString() << ", " << other.value_.ToString() << ")";
52 return os;
53 }
54
operator ==(const StructSimuSwitch & lhs,const StructSimuSwitch & rhs)55 bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs) {
56 return (lhs.fn_ == rhs.fn_ && lhs.value_ == rhs.value_);
57 }
58
operator <<(std::ostream & os,const SwitchCondStatus & other)59 std::ostream &operator<<(std::ostream &os, const SwitchCondStatus &other) {
60 os << "SwitchCondStatus(" << static_cast<int64_t>(other) << ")";
61 return os;
62 }
63
64 // Follow the specified instructions to create a VM.
65 // Arguments:
66 // insts_: std::vector<std::map<std::string, VectorRef>>
67 // insts_stack_: The value stack.
68 // retp_: The call stack.
69 // pc_: program counter (next instruction)
70 // sp_: stack pointer (for the value stack)
FinalVM(const InstSet & insts,const BackendPtr & backend)71 FinalVM::FinalVM(const InstSet &insts, const BackendPtr &backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) {
72 MS_LOG(DEBUG) << "InstSet size:" << insts_.size();
73 insts_stack_.emplace_back(BaseRef());
74 retp_.push(-1);
75 }
76
Push(const BaseRef & v)77 void FinalVM::Push(const BaseRef &v) {
78 MS_LOG(DEBUG) << "Push " << v.ToString() << " sp_:" << sp_;
79 insts_stack_[IntToSize(sp_++)] = v;
80 }
81
Pop(int64_t n)82 void FinalVM::Pop(int64_t n) {
83 if (n > sp_) {
84 MS_LOG(EXCEPTION) << "Invalid value of n " << n << ", it should not be more than " << (sp_ - 1);
85 }
86 for (int64_t i = 0; i < n; i++) {
87 insts_stack_[IntToSize(sp_ - i - 1)] = BaseRef();
88 }
89 sp_ -= n;
90 }
91
MoveStack(int64_t nitems,int64_t height)92 void FinalVM::MoveStack(int64_t nitems, int64_t height) {
93 if (nitems > height || height > sp_) {
94 MS_LOG(EXCEPTION) << "MoveStack arg error: nitems=" << nitems << " height=" << height << " sp=" << sp_;
95 }
96 int64_t n = height - nitems;
97 int64_t src = sp_ - height;
98 int64_t dst = sp_ - nitems;
99 for (int64_t i = 0; i < nitems; i++) {
100 insts_stack_[IntToSize(src + i)] = insts_stack_[IntToSize(dst + i)];
101 }
102 Pop(n);
103 }
104
Ref(int64_t i)105 BaseRef FinalVM::Ref(int64_t i) {
106 MS_LOG(DEBUG) << "Ref i:" << i << " sp_:" << sp_;
107 size_t sp_next = LongToSize(sp_ + i);
108 if (sp_next < insts_stack_.size()) {
109 if (utils::isa<PyObjectRef>(insts_stack_[sp_next])) {
110 py::object value = utils::cast<PyObjectRef>(insts_stack_[sp_next]).object_;
111 MS_LOG(DEBUG) << "VM ref python:" << py::str(value);
112 return python_adapter::PyAdapterCallback::PyDataToValue(value);
113 }
114 MS_LOG(DEBUG) << "Ref not python :" << insts_stack_[sp_next].ToString();
115 return insts_stack_[sp_next];
116 }
117
118 MS_LOG(EXCEPTION) << "IndexError: index(" << sp_next << ") out of range [0, " << insts_stack_.size() << ").";
119 }
120
Pushp()121 void FinalVM::Pushp() { retp_.push(pc_); }
122
Popp()123 void FinalVM::Popp() {
124 if (retp_.empty()) {
125 MS_LOG(EXCEPTION) << "Stack retp_ is empty";
126 }
127 pc_ = retp_.top();
128 MS_LOG(DEBUG) << "Pop pc:" << pc_ << ", sp:" << sp_;
129 retp_.pop();
130 }
131
Pushsp()132 void FinalVM::Pushsp() { retsp_.push(sp_); }
133
Popsp()134 void FinalVM::Popsp() {
135 int64_t sp = retsp_.top();
136 MS_LOG(DEBUG) << "Current sp:" << sp_ << ", before sp:" << sp << ", " << sp_ - sp;
137 if (sp_ >= sp) {
138 Pop((sp_ - sp) + 1);
139 retsp_.pop();
140 } else {
141 MS_LOG(EXCEPTION) << "Stack point sp_:" << sp << " must be bigger than sp:" << sp_;
142 }
143 }
144
DoJmp(const BaseRef & jmp_orig)145 void FinalVM::DoJmp(const BaseRef &jmp_orig) {
146 MS_LOG(DEBUG) << "Start";
147
148 BaseRef jmp = jmp_orig;
149 if (utils::isa<StructPartial>(jmp)) { // need to inherit from Base
150 MS_LOG(DEBUG) << "Start jump StructPartial";
151 auto new_jmp = utils::cast<std::shared_ptr<StructPartial>>(jmp);
152 auto args = new_jmp->args_;
153 InstPadStack(VectorRef(std::vector<BaseRef>{static_cast<int64_t>(args.size())}));
154 auto iter = args.rbegin();
155 for (; iter != args.rend(); ++iter) {
156 Push(*iter);
157 }
158 pc_ = new_jmp->fn_;
159 return;
160 }
161
162 if (!utils::isa<int64_t>(jmp)) {
163 MS_LOG(EXCEPTION) << "Jmp inst should be an int64_t";
164 }
165 pc_ = utils::cast<int64_t>(jmp);
166 MS_LOG(DEBUG) << "End do jump pc_:" << pc_;
167 }
168
Eval(const VectorRef & args)169 BaseRef FinalVM::Eval(const VectorRef &args) {
170 MS_LOG(DEBUG) << "Start: " << args.size();
171 insts_stack_.clear();
172 insts_stack_.resize(args.size());
173 std::stack<int64_t>().swap(retp_);
174 retp_.push(-1);
175 pc_ = 0;
176 sp_ = 0;
177
178 auto riter = args.rbegin();
179 for (; riter != args.rend(); ++riter) {
180 if (utils::isa<PyObjectRef>(*riter)) {
181 PyObjectRef py_ref = utils::cast<PyObjectRef>(*riter);
182 py::object value = py_ref.object_;
183 if (py::isinstance<py::bool_>(value)) {
184 auto a = py::cast<bool>(value);
185 Push(static_cast<int64_t>(a));
186 continue;
187 }
188 }
189 Push(*riter);
190 }
191
192 while (pc_ >= 0) {
193 auto inst = insts_[IntToSize(pc_)];
194 MS_LOG(DEBUG) << "Loop " << insts_.size() << ", pc:" << pc_ << ", inst:" << inst_str[inst.first];
195 ++pc_;
196 auto iter = inst_function_map.find(inst.first);
197 if (iter != inst_function_map.end()) {
198 iter->second(inst.second);
199 } else {
200 MS_LOG(EXCEPTION) << "Unknown instruction {" << inst_str[inst.first] << "}";
201 }
202 }
203
204 MS_LOG(DEBUG) << "End";
205 return insts_stack_[0];
206 }
207
InstCall(const VectorRef & args)208 void FinalVM::InstCall(const VectorRef &args) {
209 MS_LOG(DEBUG) << "Start";
210 const size_t args_size = 1;
211 if (args.size() != args_size) {
212 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
213 << ".";
214 return;
215 }
216
217 int64_t jmp = utils::cast<int64_t>(args[0]);
218 MS_LOG(DEBUG) << "Call pushp:" << pc_ << ", jmp:" << jmp << ", sp:" << sp_;
219 Pushp();
220 DoJmp(Ref(jmp));
221 MS_LOG(DEBUG) << "Instcall end sp :" << sp_;
222 }
223
InstTailCall(const VectorRef & args)224 void FinalVM::InstTailCall(const VectorRef &args) {
225 MS_LOG(DEBUG) << "Start";
226 const size_t args_size = 3;
227 if (args.size() != args_size) {
228 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
229 << ".";
230 return;
231 }
232
233 const size_t jmp_index = 0;
234 const size_t height_index = 1;
235 const size_t nargs_index = 2;
236 int64_t jmp = utils::cast<int64_t>(args[jmp_index]);
237 int64_t height = utils::cast<int64_t>(args[height_index]);
238 int64_t nargs = utils::cast<int64_t>(args[nargs_index]);
239
240 auto new_jmp = Ref(jmp);
241 MoveStack(nargs, height);
242 MS_LOG(DEBUG) << "TailCall pushp:" << pc_ << ", jmp:" << jmp;
243 DoJmp(new_jmp);
244 MS_LOG(DEBUG) << "End";
245 }
246
InstSwitchReturn(const VectorRef & args)247 void FinalVM::InstSwitchReturn(const VectorRef &args) {
248 MS_LOG(DEBUG) << "Start";
249 if (args.size() != 1) {
250 MS_LOG(ERROR) << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << ".";
251 return;
252 }
253 Pop(1);
254 Popsp();
255 }
256
InstReturn(const VectorRef & args)257 void FinalVM::InstReturn(const VectorRef &args) {
258 MS_LOG(DEBUG) << "Start";
259 const size_t args_size = 2;
260 if (args.size() != args_size) {
261 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
262 << ".";
263 return;
264 }
265
266 int64_t rpos = utils::cast<int64_t>(args[0]);
267 int64_t height = utils::cast<int64_t>(args[1]);
268
269 auto rv = Ref(rpos);
270 Pop(height);
271 Push(rv);
272 Popp();
273 MS_LOG(DEBUG) << "End";
274 }
275
InstRealPartial(const VectorRef & args)276 void FinalVM::InstRealPartial(const VectorRef &args) {
277 const size_t args_size = 1;
278 if (args.size() < args_size) {
279 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
280 << args.size() << ".";
281 return;
282 }
283
284 int64_t fn_ = utils::cast<int64_t>(args[0]);
285 auto fn = utils::cast<int64_t>(Ref(fn_));
286 MS_LOG(DEBUG) << "Partial argssize:" << args.size();
287 std::vector<BaseRef> outs(args.size() - 1);
288 (void)std::transform(args.begin() + 1, args.end(), outs.begin(),
289 [&, this](const BaseRef &a) { return Ref(utils::cast<int64_t>(a)); });
290 Push(std::make_shared<StructPartial>(fn, VectorRef(outs)));
291 }
292
InstPartial(const VectorRef & args)293 void FinalVM::InstPartial(const VectorRef &args) {
294 MS_LOG(DEBUG) << "Start";
295 InstRealPartial(args);
296 MS_LOG(DEBUG) << "End";
297 }
298
InstRealSwitch(const VectorRef & args)299 void FinalVM::InstRealSwitch(const VectorRef &args) {
300 const size_t args_size = 3;
301 if (args.size() != args_size) {
302 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
303 << ".";
304 return;
305 }
306
307 const size_t cond_index = 0;
308 const size_t vtrue_index = 1;
309 const size_t vfalse_index = 2;
310 int64_t cond = utils::cast<int64_t>(args[cond_index]);
311 int64_t vtrue = utils::cast<int64_t>(args[vtrue_index]);
312 int64_t vfalse = utils::cast<int64_t>(args[vfalse_index]);
313
314 BaseRef c = Ref(cond);
315 MS_LOG(DEBUG) << vtrue << " false:" << vfalse << " InstSwitch: " << c.ToString();
316 bool bool_value = false;
317 MS_EXCEPTION_IF_NULL(backend_);
318 if (backend_->GetCond(c, &bool_value)) {
319 MS_LOG(DEBUG) << "Cond:" << bool_value;
320 if (bool_value) {
321 Push(Ref(vtrue));
322 } else {
323 Push(Ref(vfalse));
324 }
325 } else {
326 MS_LOG(EXCEPTION) << "Not supported type to be casted to bool";
327 }
328 }
329
InstSwitch(const VectorRef & args)330 void FinalVM::InstSwitch(const VectorRef &args) {
331 MS_LOG(DEBUG) << "Start";
332 InstRealSwitch(args);
333 MS_LOG(DEBUG) << "End";
334 }
335
InstSwitchLayer(const VectorRef & args)336 void FinalVM::InstSwitchLayer(const VectorRef &args) {
337 MS_LOG(DEBUG) << "Start";
338 const size_t args_size = 2;
339 if (args.size() != args_size) {
340 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size()
341 << ".";
342 return;
343 }
344
345 int64_t idx = utils::cast<int64_t>(args[0]);
346 VectorRef branches = utils::cast<VectorRef>(Ref(utils::cast<int64_t>(args[1])));
347 int64_t size = static_cast<int64_t>(branches.size());
348
349 BaseRef index = Ref(idx);
350 int64_t idx_value = 0;
351 MS_EXCEPTION_IF_NULL(backend_);
352 if (!backend_->GetIndex(index, &idx_value)) {
353 MS_LOG(EXCEPTION) << "Not supported type to be casted to int64_t.";
354 }
355 auto ori_value = idx_value;
356 if (idx_value < 0) {
357 // Add support negative index range [-size, -1].
358 idx_value += size;
359 }
360 if (idx_value < 0 || idx_value >= size) {
361 MS_EXCEPTION(IndexError) << __FUNCTION__ << " given index " << ori_value
362 << " out of range. Please make sure the value "
363 << "of index in [" << -size << ", " << size << "), and the type is int32.";
364 }
365 Push(branches[idx_value]);
366 MS_LOG(DEBUG) << "End";
367 }
368
InstTuple(const VectorRef & args)369 void FinalVM::InstTuple(const VectorRef &args) {
370 MS_LOG(DEBUG) << "Start";
371 VectorRef tuple;
372 auto iter = args.begin();
373 for (; iter != args.end(); ++iter) {
374 auto a = utils::cast<int64_t>(*iter);
375 tuple.push_back(Ref(a));
376 }
377 Push(tuple);
378 MS_LOG(DEBUG) << "End";
379 }
380
InstPush(const VectorRef & args)381 void FinalVM::InstPush(const VectorRef &args) {
382 MS_LOG(DEBUG) << "Start";
383 const size_t args_size = 1;
384 if (args.size() != args_size) {
385 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
386 << ".";
387 return;
388 }
389
390 auto v = args[0];
391 Push(v);
392 MS_LOG(DEBUG) << "End";
393 }
394
InstInput(const VectorRef & args)395 void FinalVM::InstInput(const VectorRef &args) {
396 MS_LOG(DEBUG) << "Start";
397 const size_t args_size = 1;
398 if (args.size() != args_size) {
399 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
400 << ".";
401 return;
402 }
403
404 int64_t rpos = utils::cast<int64_t>(args[0]);
405 Push(Ref(rpos));
406 MS_LOG(DEBUG) << "End";
407 }
408
InstPadStack(const VectorRef & args)409 void FinalVM::InstPadStack(const VectorRef &args) {
410 MS_LOG(DEBUG) << "Start";
411 const size_t args_size = 1;
412 if (args.size() != args_size) {
413 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameter, while the input size is " << args.size()
414 << ".";
415 return;
416 }
417
418 int64_t sz = utils::cast<int64_t>(args[0]);
419 MS_LOG(DEBUG) << insts_stack_.size() << " need padstack " << sz << " sp_ " << sp_;
420 size_t stack_size = insts_stack_.size();
421 int64_t need = sz - (static_cast<int64_t>(stack_size) - sp_);
422 if (need > 0) {
423 MS_LOG(DEBUG) << "InstPadStack resize: size:" << insts_stack_.size() << " need pad:" << need;
424 insts_stack_.resize(stack_size + IntToSize(need));
425 }
426 MS_LOG(DEBUG) << "End";
427 }
428
InstExternal(const VectorRef & args)429 void FinalVM::InstExternal(const VectorRef &args) {
430 MS_LOG(DEBUG) << "Start:" << args.size();
431
432 if (args.empty()) {
433 MS_LOG(EXCEPTION) << "Args is empty!";
434 }
435
436 VectorRef tuple;
437 RunFunctionRef run_ref = utils::cast<RunFunctionRef>(args[0]);
438 compile::RunFuncPtr fn = run_ref.func_;
439 const size_t arg_start_index = 2;
440 for (size_t i = arg_start_index; i < args.size(); ++i) {
441 auto index = utils::cast<int64_t>(args[i]);
442 tuple.push_back(Ref(index));
443 }
444
445 if (!fn) {
446 MS_LOG(EXCEPTION) << "Function not callable";
447 }
448
449 auto outs = (*fn)(tuple);
450 MS_LOG(DEBUG) << "The 'fn' out size:" << outs.size();
451 for (auto &o : outs) {
452 MS_LOG(DEBUG) << "InstExternal value:" << o.ToString();
453 Push(o);
454 }
455 MS_LOG(DEBUG) << "End";
456 }
457
InstPushPrim(const VectorRef & args)458 void FinalVM::InstPushPrim(const VectorRef &args) {
459 MS_LOG(DEBUG) << "Start: " << args.size();
460 const size_t args_size = 2;
461 if (args.size() < args_size) {
462 MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " or more parameters, while the input size is "
463 << args.size() << ".";
464 return;
465 }
466
467 auto prim = utils::cast<PrimitivePtr>(args[0]);
468 VectorRef tuple;
469 for (size_t i = 1; i < args.size(); ++i) {
470 auto index = utils::cast<int64_t>(args[i]);
471 tuple.push_back(Ref(index));
472 }
473
474 if (prim->name() == kBpropCutOpName) {
475 BaseRef outs = python_adapter::PyAdapterCallback::RunPrimitivePyHookFunction(prim, tuple);
476 Push(outs);
477 } else {
478 auto outs = RunOperation(prim, tuple);
479 Push(outs);
480 }
481
482 MS_LOG(DEBUG) << "End";
483 }
484 } // namespace compile
485 } // namespace mindspore
486