• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2020 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_VM_VM_H_
20 #define MINDSPORE_CCSRC_VM_VM_H_
21 
22 #include <map>
23 #include <memory>
24 #include <stack>
25 #include <string>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29 #include <deque>
30 #include <unordered_map>
31 
32 #include "pybind11/pybind11.h"
33 
34 #include "ir/anf.h"
35 #include "base/base_ref.h"
36 
37 namespace py = pybind11;
38 
39 namespace mindspore {
40 namespace compile {
41 
42 class Backend;
43 using BackendPtr = std::shared_ptr<Backend>;
44 
45 enum Instruction {
46   kCall = 0,
47   kTailCall,
48   kReturn,
49   kPartial,
50   kSwitch,
51   kSwitchReturn,
52   kTuple,
53   kInput,
54   kExternal,
55   kPush,
56   kPrim,
57   kGraph,
58   kPadStack,
59   kSwitchLayer
60 };
61 
62 using InstType = std::pair<Instruction, VectorRef>;
63 using InstSet = std::vector<InstType>;
64 using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>;
65 
66 const std::vector<std::string> inst_str{"call",          "tail_call", "Return",    "partial",     "Switch",
67                                         "switch_return", "tuple",     "input",     "external",    "push",
68                                         "primitive",     "graph",     "pad_stack", "switch_layer"};
69 class StructPartial : public Base {
70  public:
71   // Initialize StructPartial.
72   StructPartial(int64_t fn, const VectorRef &args, const FuncGraphPtr &fg = nullptr);
73 
74   virtual ~StructPartial() = default;
75   MS_DECLARE_PARENT(StructPartial, Base)
76 
77   int64_t fn_;
78   VectorRef args_;
79   FuncGraphPtr fg_;
80 };
81 
82 std::ostream &operator<<(std::ostream &os, const StructPartial &other);
83 bool operator==(const StructPartial &lhs, const StructPartial &rhs);
84 
85 class StructSimuSwitch : public Base {
86  public:
87   StructSimuSwitch(const BaseRef &fn, const BaseRef &value);
88 
89   virtual ~StructSimuSwitch() = default;
90   MS_DECLARE_PARENT(StructSimuSwitch, Base)
91 
92   BaseRef fn_;
93   BaseRef value_;
94 };
95 
96 std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other);
97 bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs);
98 
99 class FinalVM {
100  public:
101   // Create a VM with the specified instructions and backend.
102   explicit FinalVM(const InstSet &insts, const BackendPtr &backend);
103   virtual ~FinalVM() = default;
104 
105   BaseRef Eval(const VectorRef &args);
106   void InstCall(const VectorRef &args);
107   void InstTailCall(const VectorRef &args);
108   void InstReturn(const VectorRef &args);
109   void InstPartial(const VectorRef &args);
110   void InstRealPartial(const VectorRef &args);
111   void InstSwitch(const VectorRef &args);
112   void InstRealSwitch(const VectorRef &args);
113   void InstTuple(const VectorRef &args);
114   void InstPush(const VectorRef &args);
115   void InstInput(const VectorRef &args);
116   void InstPadStack(const VectorRef &args);
117   void InstExternal(const VectorRef &args);
118   void InstPushPrim(const VectorRef &args);
119   void InstSwitchReturn(const VectorRef &args);
120   void InstSwitchLayer(const VectorRef &args);
121   void set_insts(const InstSet &value) { insts_ = value; }
122   BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg);
123 
124  protected:
125   BaseRef Ref(int64_t i);
126   void Push(const BaseRef &v);
127   void Pop(int64_t n = 1);
128   void MoveStack(int64_t nitems, int64_t height);
129   void Pushp();
130   void Popp();
131   void Pushsp();
132   void Popsp();
133   void DoJmp(const BaseRef &jmp);
134   void SyncData(const py::object &args);
135 
136  private:
137   InstSet insts_;
138   std::deque<BaseRef> insts_stack_;
139   std::stack<int64_t> retp_;
140   std::stack<int64_t> retsp_;
141   int64_t pc_;
142   int64_t sp_;
143   BackendPtr backend_;
144   const InstFunctionMap inst_function_map = {
145     {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
146     {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }},
147     {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }},
148     {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }},
149     {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }},
150     {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }},
151     {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }},
152     {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }},
153     {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }},
154     {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }},
155     {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }},
156     {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }},
157     {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}};
158 };
159 
160 using FinalVMPtr = std::shared_ptr<FinalVM>;
161 
162 }  // namespace compile
163 }  // namespace mindspore
164 
165 #endif  // MINDSPORE_CCSRC_VM_VM_H_
166