• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2021 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #ifndef MINDSPORE_CCSRC_VM_TRANSFORM_H_
20 #define MINDSPORE_CCSRC_VM_TRANSFORM_H_
21 
22 #include <string>
23 #include <memory>
24 #include <functional>
25 #include <utility>
26 #include <vector>
27 
28 #include "utils/hash_map.h"
29 #include "backend/graph_compiler/vm.h"
30 #include "ir/anf.h"
31 #include "frontend/operator/ops.h"
32 #include "backend/graph_compiler/segment_runner.h"
33 #include "backend/graph_compiler/backend.h"
34 #include "backend/graph_compiler/graph_partition.h"
35 #include "include/backend/visible.h"
36 
37 // mindspore namespace is the top level namespace of MindSpore project.
38 // Other namespace should be a sub namespace of mindspore namespace in the ME project.
39 namespace mindspore {
40 // compile namespace
41 // A sub namespace in ME to support compile related definition.
42 namespace compile {
43 BACKEND_EXPORT const std::vector<PrimitivePtr> &GetNonlinearOps();
44 BACKEND_EXPORT const std::vector<PrimitivePtr> &GetControlOps();
45 BACKEND_EXPORT const std::vector<PrimitivePtr> &GetMsNonlinearOps();
46 FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph);
47 using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;
48 using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>;
49 
50 class BACKEND_EXPORT CompileGraph {
51  public:
52   explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = GetNonlinearOps());
53 
54   virtual ~CompileGraph() = default;
55 
56   InstSet Run(const FuncGraphPtr &graph);
57   bool IsCut(const AnfNodePtr &node);
58   void Push(const AnfNodePtr &node);
Tie(const AnfNodePtr & n1,const AnfNodePtr & n2)59   void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; }
60   void Ret(int64_t nargs);
61   virtual int64_t Ref(const AnfNodePtr &node);
62 
set_height(int64_t h)63   void set_height(int64_t h) {
64     height_ = h;
65     if (height_ > max_height_) {
66       max_height_ = height_;
67     }
68   }
69 
Reset()70   void Reset() {
71     height_ = 0;
72     max_height_ = 0;
73     slots_.clear();
74     inst_.clear();
75   }
76 
77  protected:
78   virtual void PushParameters(const FuncGraphPtr &graph);
79   bool Compile(const FuncGraphPtr &graph);
80   int64_t LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target = "");
81   int64_t InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node);
82   virtual int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node);
83   void AddPadStack(int64_t param_height);
84   void AddTailCall(const AnfNodePtr &fn, size_t size);
85   virtual void AddPartial(const CNodePtr &node);
86   void AddMakeTuple(const CNodePtr &node);
87   void AddSwitch(const CNodePtr &node);
88   void AddSwitchLayer(const CNodePtr &node);
89   void AddReturn(const CNodePtr &node);
90   void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim);
91   virtual void AddInput(const AnfNodePtr &node);
92   virtual void AddExternal(const LinConvertResult &result);
93   void AddInst(const Instruction &inst, const int64_t &arg);
94   void AddInst(const Instruction &inst, const ValuePtr &arg);
95   void AddInst(const Instruction &inst, const VectorRef &args);
96 
97   BackendPtr backend_;
98   GraphPartitionPtr graph_partition_;
99   LinkFuncType lin_convert_;
100 
101   int64_t height_{0};
102   int64_t max_height_{0};
103 
104   mindspore::HashMap<AnfNodePtr, int64_t> slots_;
105   InstSet inst_;
106 };
107 
108 using CompileGraphPtr = std::shared_ptr<CompileGraph>;
109 
110 // CompileGraphs is used to Convert a graph cluster into instruction lists.
111 class BACKEND_EXPORT CompileGraphs {
112  public:
113   explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = GetNonlinearOps());
114 
115   virtual ~CompileGraphs() = default;
116 
Reset()117   void Reset() {
118     insts_.clear();
119     mapping_.clear();
120   }
121 
122   void Compile(const FuncGraphPtr &graph);
123   FinalVMPtr Link();
124   FinalVMPtr CompileAndLink(const FuncGraphPtr &graph);
125 
126  protected:
127   InstSet insts_;
128   mindspore::HashMap<FuncGraphPtr, int64_t> mapping_;
129   CompileGraphPtr transform_;
130   BackendPtr backend_;
131 };
132 
133 BACKEND_EXPORT BackendPtr CreateBackend();
134 
135 // Set mindRT whether enable. GPU and CPU use mindRT currently, and other hardwares will use it in the future.
136 BACKEND_EXPORT void SetMindRTEnable();
137 
138 }  // namespace compile
139 }  // namespace mindspore
140 
141 #endif  // MINDSPORE_CCSRC_VM_TRANSFORM_H_
142