1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2020 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_VM_VM_H_ 20 #define MINDSPORE_CCSRC_VM_VM_H_ 21 22 #include <map> 23 #include <memory> 24 #include <stack> 25 #include <string> 26 #include <tuple> 27 #include <utility> 28 #include <vector> 29 #include <deque> 30 #include <unordered_map> 31 32 #include "pybind11/pybind11.h" 33 34 #include "ir/anf.h" 35 #include "base/base_ref.h" 36 37 namespace py = pybind11; 38 39 namespace mindspore { 40 namespace compile { 41 42 class Backend; 43 using BackendPtr = std::shared_ptr<Backend>; 44 45 enum Instruction { 46 kCall = 0, 47 kTailCall, 48 kReturn, 49 kPartial, 50 kSwitch, 51 kSwitchReturn, 52 kTuple, 53 kInput, 54 kExternal, 55 kPush, 56 kPrim, 57 kGraph, 58 kPadStack, 59 kSwitchLayer 60 }; 61 62 using InstType = std::pair<Instruction, VectorRef>; 63 using InstSet = std::vector<InstType>; 64 using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>; 65 66 const std::vector<std::string> inst_str{"call", "tail_call", "Return", "partial", "Switch", 67 "switch_return", "tuple", "input", "external", "push", 68 "primitive", "graph", "pad_stack", "switch_layer"}; 69 class StructPartial : public Base { 70 public: 71 // Initialize StructPartial. 72 StructPartial(int64_t fn, const VectorRef &args, const FuncGraphPtr &fg = nullptr); 73 74 virtual ~StructPartial() = default; 75 MS_DECLARE_PARENT(StructPartial, Base) 76 77 int64_t fn_; 78 VectorRef args_; 79 FuncGraphPtr fg_; 80 }; 81 82 std::ostream &operator<<(std::ostream &os, const StructPartial &other); 83 bool operator==(const StructPartial &lhs, const StructPartial &rhs); 84 85 class StructSimuSwitch : public Base { 86 public: 87 StructSimuSwitch(const BaseRef &fn, const BaseRef &value); 88 89 virtual ~StructSimuSwitch() = default; 90 MS_DECLARE_PARENT(StructSimuSwitch, Base) 91 92 BaseRef fn_; 93 BaseRef value_; 94 }; 95 96 std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other); 97 bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs); 98 99 class FinalVM { 100 public: 101 // Create a VM with the specified instructions and backend. 102 explicit FinalVM(const InstSet &insts, const BackendPtr &backend); 103 virtual ~FinalVM() = default; 104 105 BaseRef Eval(const VectorRef &args); 106 void InstCall(const VectorRef &args); 107 void InstTailCall(const VectorRef &args); 108 void InstReturn(const VectorRef &args); 109 void InstPartial(const VectorRef &args); 110 void InstRealPartial(const VectorRef &args); 111 void InstSwitch(const VectorRef &args); 112 void InstRealSwitch(const VectorRef &args); 113 void InstTuple(const VectorRef &args); 114 void InstPush(const VectorRef &args); 115 void InstInput(const VectorRef &args); 116 void InstPadStack(const VectorRef &args); 117 void InstExternal(const VectorRef &args); 118 void InstPushPrim(const VectorRef &args); 119 void InstSwitchReturn(const VectorRef &args); 120 void InstSwitchLayer(const VectorRef &args); 121 void set_insts(const InstSet &value) { insts_ = value; } 122 BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); 123 124 protected: 125 BaseRef Ref(int64_t i); 126 void Push(const BaseRef &v); 127 void Pop(int64_t n = 1); 128 void MoveStack(int64_t nitems, int64_t height); 129 void Pushp(); 130 void Popp(); 131 void Pushsp(); 132 void Popsp(); 133 void DoJmp(const BaseRef &jmp); 134 void SyncData(const py::object &args); 135 136 private: 137 InstSet insts_; 138 std::deque<BaseRef> insts_stack_; 139 std::stack<int64_t> retp_; 140 std::stack<int64_t> retsp_; 141 int64_t pc_; 142 int64_t sp_; 143 BackendPtr backend_; 144 const InstFunctionMap inst_function_map = { 145 {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, 146 {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }}, 147 {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }}, 148 {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }}, 149 {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }}, 150 {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }}, 151 {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }}, 152 {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }}, 153 {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }}, 154 {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, 155 {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, 156 {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, 157 {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; 158 }; 159 160 using FinalVMPtr = std::shared_ptr<FinalVM>; 161 162 } // namespace compile 163 } // namespace mindspore 164 165 #endif // MINDSPORE_CCSRC_VM_VM_H_ 166