1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2021 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_VM_VM_H_ 20 #define MINDSPORE_CCSRC_VM_VM_H_ 21 22 #include <map> 23 #include <memory> 24 #include <stack> 25 #include <string> 26 #include <tuple> 27 #include <utility> 28 #include <vector> 29 #include <deque> 30 31 #include "utils/hash_map.h" 32 #include "mindspore/core/ops/framework_ops.h" 33 #include "pybind11/pybind11.h" 34 35 #include "ir/anf.h" 36 #include "base/base_ref.h" 37 #include "include/backend/visible.h" 38 39 namespace py = pybind11; 40 41 namespace mindspore { 42 namespace compile { 43 44 class Backend; 45 using BackendPtr = std::shared_ptr<Backend>; 46 47 enum Instruction { 48 kCall = 0, 49 kTailCall, 50 kReturn, 51 kPartial, 52 kSwitch, 53 kSwitchReturn, 54 kTuple, 55 kInput, 56 kExternal, 57 kPush, 58 kPrim, 59 kGraph, 60 kPadStack, 61 kSwitchLayer 62 }; 63 64 using InstType = std::pair<Instruction, VectorRef>; 65 using InstSet = std::vector<InstType>; 66 using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>; 67 68 const std::vector<std::string> inst_str{"call", "tail_call", "Return", "partial", "Switch", 69 "switch_return", "tuple", "input", "external", "push", 70 "primitive", "graph", "pad_stack", "switch_layer"}; 71 class StructPartial : public Base { 72 public: 73 // Initialize StructPartial. 74 StructPartial(int64_t fn, const VectorRef &args, const FuncGraphPtr &fg = nullptr); 75 76 virtual ~StructPartial() = default; 77 MS_DECLARE_PARENT(StructPartial, Base) 78 79 int64_t fn_; 80 VectorRef args_; 81 FuncGraphPtr fg_; 82 }; 83 84 std::ostream &operator<<(std::ostream &os, const StructPartial &other); 85 bool operator==(const StructPartial &lhs, const StructPartial &rhs); 86 87 class StructSimuSwitch : public Base { 88 public: 89 StructSimuSwitch(const BaseRef &fn, const BaseRef &value); 90 91 virtual ~StructSimuSwitch() = default; 92 MS_DECLARE_PARENT(StructSimuSwitch, Base) 93 94 BaseRef fn_; 95 BaseRef value_; 96 }; 97 98 std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other); 99 bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs); 100 101 class BACKEND_EXPORT FinalVM { 102 public: 103 // Create a VM with the specified instructions and backend. 104 explicit FinalVM(const InstSet &insts, const BackendPtr &backend); 105 virtual ~FinalVM() = default; 106 107 BaseRef Eval(const VectorRef &args); 108 void InstCall(const VectorRef &args); 109 void InstTailCall(const VectorRef &args); 110 void InstReturn(const VectorRef &args); 111 void InstPartial(const VectorRef &args); 112 void InstRealPartial(const VectorRef &args); 113 void InstSwitch(const VectorRef &args); 114 void InstRealSwitch(const VectorRef &args); 115 void InstTuple(const VectorRef &args); 116 void InstPush(const VectorRef &args); 117 void InstInput(const VectorRef &args); 118 void InstPadStack(const VectorRef &args); 119 void InstExternal(const VectorRef &args); 120 void InstPushPrim(const VectorRef &args); 121 void InstSwitchReturn(const VectorRef &args); 122 void InstSwitchLayer(const VectorRef &args); 123 void set_insts(const InstSet &value) { insts_ = value; } 124 125 protected: 126 BaseRef Ref(int64_t i); 127 void Push(const BaseRef &v); 128 void Pop(int64_t n = 1); 129 void MoveStack(int64_t nitems, int64_t height); 130 void Pushp(); 131 void Popp(); 132 void Pushsp(); 133 void Popsp(); 134 void DoJmp(const BaseRef &jmp_orig); 135 136 private: 137 InstSet insts_; 138 std::deque<BaseRef> insts_stack_; 139 std::stack<int64_t> retp_; 140 std::stack<int64_t> retsp_; 141 int64_t pc_; 142 int64_t sp_; 143 BackendPtr backend_; 144 const InstFunctionMap inst_function_map = { 145 {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, 146 {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }}, 147 {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }}, 148 {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }}, 149 {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }}, 150 {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }}, 151 {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }}, 152 {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }}, 153 {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }}, 154 {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, 155 {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, 156 {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, 157 {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; 158 }; 159 160 using FinalVMPtr = std::shared_ptr<FinalVM>; 161 162 } // namespace compile 163 } // namespace mindspore 164 165 #endif // MINDSPORE_CCSRC_VM_VM_H_ 166