• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_VM_VM_H_
20 #define MINDSPORE_CCSRC_VM_VM_H_
21 
22 #include <map>
23 #include <memory>
24 #include <stack>
25 #include <string>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29 #include <deque>
30 
31 #include "utils/hash_map.h"
32 #include "mindspore/core/ops/framework_ops.h"
33 #include "pybind11/pybind11.h"
34 
35 #include "ir/anf.h"
36 #include "base/base_ref.h"
37 #include "include/backend/visible.h"
38 
39 namespace py = pybind11;
40 
41 namespace mindspore {
42 namespace compile {
43 
44 class Backend;
45 using BackendPtr = std::shared_ptr<Backend>;
46 
47 enum Instruction {
48   kCall = 0,
49   kTailCall,
50   kReturn,
51   kPartial,
52   kSwitch,
53   kSwitchReturn,
54   kTuple,
55   kInput,
56   kExternal,
57   kPush,
58   kPrim,
59   kGraph,
60   kPadStack,
61   kSwitchLayer
62 };
63 
64 using InstType = std::pair<Instruction, VectorRef>;
65 using InstSet = std::vector<InstType>;
66 using InstFunctionMap = std::map<Instruction, std::function<void(const VectorRef &)>>;
67 
68 const std::vector<std::string> inst_str{"call",          "tail_call", "Return",    "partial",     "Switch",
69                                         "switch_return", "tuple",     "input",     "external",    "push",
70                                         "primitive",     "graph",     "pad_stack", "switch_layer"};
71 class StructPartial : public Base {
72  public:
73   // Initialize StructPartial.
74   StructPartial(int64_t fn, const VectorRef &args, const FuncGraphPtr &fg = nullptr);
75 
76   virtual ~StructPartial() = default;
77   MS_DECLARE_PARENT(StructPartial, Base)
78 
79   int64_t fn_;
80   VectorRef args_;
81   FuncGraphPtr fg_;
82 };
83 
84 std::ostream &operator<<(std::ostream &os, const StructPartial &other);
85 bool operator==(const StructPartial &lhs, const StructPartial &rhs);
86 
87 class StructSimuSwitch : public Base {
88  public:
89   StructSimuSwitch(const BaseRef &fn, const BaseRef &value);
90 
91   virtual ~StructSimuSwitch() = default;
92   MS_DECLARE_PARENT(StructSimuSwitch, Base)
93 
94   BaseRef fn_;
95   BaseRef value_;
96 };
97 
98 std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other);
99 bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs);
100 
101 class BACKEND_EXPORT FinalVM {
102  public:
103   // Create a VM with the specified instructions and backend.
104   explicit FinalVM(const InstSet &insts, const BackendPtr &backend);
105   virtual ~FinalVM() = default;
106 
107   BaseRef Eval(const VectorRef &args);
108   void InstCall(const VectorRef &args);
109   void InstTailCall(const VectorRef &args);
110   void InstReturn(const VectorRef &args);
111   void InstPartial(const VectorRef &args);
112   void InstRealPartial(const VectorRef &args);
113   void InstSwitch(const VectorRef &args);
114   void InstRealSwitch(const VectorRef &args);
115   void InstTuple(const VectorRef &args);
116   void InstPush(const VectorRef &args);
117   void InstInput(const VectorRef &args);
118   void InstPadStack(const VectorRef &args);
119   void InstExternal(const VectorRef &args);
120   void InstPushPrim(const VectorRef &args);
121   void InstSwitchReturn(const VectorRef &args);
122   void InstSwitchLayer(const VectorRef &args);
123   void set_insts(const InstSet &value) { insts_ = value; }
124 
125  protected:
126   BaseRef Ref(int64_t i);
127   void Push(const BaseRef &v);
128   void Pop(int64_t n = 1);
129   void MoveStack(int64_t nitems, int64_t height);
130   void Pushp();
131   void Popp();
132   void Pushsp();
133   void Popsp();
134   void DoJmp(const BaseRef &jmp_orig);
135 
136  private:
137   InstSet insts_;
138   std::deque<BaseRef> insts_stack_;
139   std::stack<int64_t> retp_;
140   std::stack<int64_t> retsp_;
141   int64_t pc_;
142   int64_t sp_;
143   BackendPtr backend_;
144   const InstFunctionMap inst_function_map = {
145     {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }},
146     {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }},
147     {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }},
148     {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }},
149     {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }},
150     {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }},
151     {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }},
152     {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }},
153     {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }},
154     {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }},
155     {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }},
156     {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }},
157     {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}};
158 };
159 
160 using FinalVMPtr = std::shared_ptr<FinalVM>;
161 
162 }  // namespace compile
163 }  // namespace mindspore
164 
165 #endif  // MINDSPORE_CCSRC_VM_VM_H_
166