• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_VM_TRANSFORM_H_
20 #define MINDSPORE_CCSRC_VM_TRANSFORM_H_
21 
22 #include <string>
23 #include <memory>
24 #include <functional>
25 #include <utility>
26 #include <unordered_map>
27 #include <vector>
28 
29 #include "vm/vm.h"
30 #include "ir/anf.h"
31 #include "frontend/operator/ops.h"
32 #include "vm/segment_runner.h"
33 #include "vm/backend.h"
34 #include "vm/graph_partition.h"
35 
36 // mindspore namespace is the top level namespace of MindSpore project.
37 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
38 namespace mindspore {
39 extern const char kMsVm[];
40 extern const char kGeVm[];
41 
42 // compile namespace
43 // A sub namespace in ME to support compile related definition.
44 namespace compile {
45 extern std::vector<PrimitivePtr> nonlinear_ops;
46 extern std::vector<PrimitivePtr> control_ops;
47 const std::vector<PrimitivePtr> &GetMsNonlinearOps();
48 FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph);
49 using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;
50 using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>;
51 
52 class CompileGraph {
53  public:
54   explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
55 
56   virtual ~CompileGraph() = default;
57 
58   InstSet Run(const FuncGraphPtr &func_graph);
59   bool IsCut(const AnfNodePtr &node);
60   void Push(const AnfNodePtr &node);
Tie(const AnfNodePtr & n1,const AnfNodePtr & n2)61   void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
62   void Ret(int64_t nargs);
63   virtual int64_t Ref(const AnfNodePtr &node);
64 
set_height(int64_t h)65   void set_height(int64_t h) {
66     height_ = h;
67     if (height_ > max_height_) {
68       max_height_ = height_;
69     }
70   }
71 
Reset()72   void Reset() {
73     height_ = 0;
74     max_height_ = 0;
75     slots_.clear();
76     inst_.clear();
77   }
78 
79  protected:
80   virtual void PushParameters(const FuncGraphPtr &func_graph);
81   bool Compile(const FuncGraphPtr &func_graph);
82   int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = "");
83   int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node);
84   virtual int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
85   void AddPadStack(int64_t param_height);
86   void AddTailCall(const AnfNodePtr &fn, size_t size);
87   virtual void AddPartial(const CNodePtr &node);
88   void AddMakeTuple(const CNodePtr &node);
89   void AddSwitch(const CNodePtr &node);
90   void AddSwitchLayer(const CNodePtr &node);
91   void AddReturn(const CNodePtr &node);
92   void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
93   virtual void AddInput(const AnfNodePtr &node);
94   virtual void AddExternal(const LinConvertResult &result);
95   void AddInst(const Instruction &inst, const int64_t &arg);
96   void AddInst(const Instruction &inst, const ValuePtr &arg);
97   void AddInst(const Instruction &inst, const VectorRef &args);
98 
99   BackendPtr backend_;
100   GraphPartitionPtr graph_partition_;
101   LinkFuncType lin_convert_;
102 
103   int64_t height_{0};
104   int64_t max_height_{0};
105 
106   std::unordered_map<AnfNodePtr, int64_t> slots_;
107   InstSet inst_;
108 };
109 
110 using CompileGraphPtr = std::shared_ptr<CompileGraph>;
111 
112 // CompileGraphs is used to Convert a graph cluster into instruction lists.
113 class CompileGraphs {
114  public:
115   explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops);
116 
117   virtual ~CompileGraphs() = default;
118 
Reset()119   void Reset() {
120     insts_.clear();
121     mapping_.clear();
122   }
123 
124   void Compile(const FuncGraphPtr &func_graph);
125   FinalVMPtr Link();
126   FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph);
127 
128  protected:
129   InstSet insts_;
130   std::unordered_map<FuncGraphPtr, int64_t> mapping_;
131   CompileGraphPtr transform_;
132   BackendPtr backend_;
133 };
134 
135 BackendPtr CreateBackend();
136 
137 // Set mindRT whether enable. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
138 void SetMindRTEnable();
139 
140 }  // namespace compile
141 }  // namespace mindspore
142 
143 #endif  // MINDSPORE_CCSRC_VM_TRANSFORM_H_
144