• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_VM_BACKEND_H_
17 #define MINDSPORE_CCSRC_VM_BACKEND_H_
18 
19 #include <list>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <map>
24 #include <utility>
25 #include <vector>
26 
27 #include "utils/contract.h"
28 #include "ir/anf.h"
29 #include "vm/segment_runner.h"
30 #include "vm/graph_partition.h"
31 #include "vm/vm.h"
32 #include "backend/session/session_basic.h"
33 #include "runtime/hardware/device_context.h"
34 #include "runtime/framework/graph_scheduler.h"
35 
36 namespace mindspore {
37 namespace compile {
38 using OpRunInfo = session::OpRunInfo;
39 using GraphOutputInfo = session::GraphOutputInfo;
40 using DeviceContext = device::DeviceContext;
41 using ActorInfo = runtime::ActorInfo;
42 using GraphCompiler = runtime::GraphCompiler;
43 using GraphCompilerInfo = runtime::GraphCompilerInfo;
44 using ControlNodeParser = runtime::ControlNodeParser;
45 using ControlNodeParserPtr = runtime::ControlNodeParserPtr;
46 using KernelWithIndex = session::KernelWithIndex;
47 
48 enum SwitchCondStatus {
49   kCondOk = 0,
50   kCondAlreadyRun,
51 };
52 
53 class Backend {
54  public:
55   explicit Backend(const std::string &name);
56 
57   virtual ~Backend() = default;
58 
convert_fn()59   LinkFuncType convert_fn() { return convert_fn_; }
name()60   std::string name() { return name_; }
61   virtual bool GetCond(const BaseRef &c, bool *value);
62   virtual bool GetIndex(const BaseRef &c, int64_t *value);
CompileGraph(NotNull<FuncGraphPtr> fg)63   virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; }
SetDebugger()64   virtual void SetDebugger() {}
65 
is_multi_graph_sink()66   bool is_multi_graph_sink() const { return is_multi_graph_sink_; }
set_is_multi_graph_sink(bool flag)67   void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; }
68 
69  protected:
70   std::string name_;
71   LinkFuncType convert_fn_;
72   bool is_multi_graph_sink_;
73 };
74 
75 class MsBackend : public Backend {
76  public:
77   MsBackend(const std::string &name, const std::string &target, uint32_t device_id);
78   ~MsBackend() override = default;
79 
80   LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = "");
81   virtual VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = "");
82 
83   VectorRef MsSimuRunGraph(const GraphId &g);
84   GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override;
85   VectorRef RunGraph(GraphId graph_id, const VectorRef &args);
86   void ClearSessionGraphs();
87   void CreateOtherSession(const std::string &target);
88 
89 #ifdef ENABLE_DEBUGGER
90   void SetDebugger() override;
91 #endif
92 
93  protected:
94   session::SessionPtr target_sess_;
95   session::SessionPtr other_sess_;
96   std::string target_device_;
97   std::string other_device_;
98   std::unordered_map<GraphId, LinConvertResult> graph_id_map_;
99 };
100 
101 class MindRTBackend : public Backend {
102  public:
103   MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id);
104   ~MindRTBackend() override = default;
105 
106   // The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, It will traverse
107   // all sub graphs to call CompileGraph.
108   const ActorInfo &CompileGraphs(const FuncGraphPtr &root_graph);
109 
110   // Compile single op kernel graph in the pyNative mode.
111   const ActorInfo &CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
112                                 const std::vector<int64_t> *tensors_mask,
113                                 std::vector<tensor::TensorPtr> *input_tensors);
114 
115   // Run Graph in the graph mode.
116   void RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs);
117 
118   // Run Graph in the pyNative mode.
119   void RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info, const std::vector<int64_t> *tensors_mask,
120                 const std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs);
121 #ifdef ENABLE_DEBUGGER
122   void SetDebuggerInit();
123 #endif
124 
125  private:
126   // The parameter func_graph is a graph, it can be either a root graph or a sub graph,
127   // The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_.
128   void CompileGraph(const FuncGraphPtr &func_graph);
129 
130   // Restore the outputs tuple by the origin funcGraph output node and output tensors.
131   void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
132                         size_t *output_position, VectorRef *outputs);
133 
134   // Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode.
135   std::unique_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph);
136 
137   // Construct the GraphCompilerInfo by the compilation results of graph, used in PyNative mode.
138   std::unique_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const ActorInfo &actor_info,
139                                                                 const std::vector<int64_t> *tensors_mask,
140                                                                 const std::vector<tensor::TensorPtr> *input_tensors,
141                                                                 bool need_erase);
142 
143   // In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing,
144   // so the latest single op cache should be erased when cache list size exceeds threshold value.
145   void EraseSingleOpCache(const ActorInfo &actor_info, const KernelGraphPtr &graph);
146 
147   // Split complete kernel graph to single op graph in PyNative back
148   // propagation, then compile and run single op graph.
149   void RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
150                           const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs);
151 
152   // When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several
153   // node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to
154   // the corresponding device_context.
155   std::map<GraphId, DeviceContext *> graph_id_to_device_context_;
156   std::map<GraphInfo, DeviceContext *> graph_info_to_device_context_;
157   std::vector<AnfNodePtr> control_nodes_;
158 
159   std::unordered_map<ActorInfo, std::unique_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_;
160 
161   // Cache output tensor ref count of kernels for back propagation graph in PyNative mode.
162   std::map<GraphId, std::map<KernelWithIndex, size_t>> cnode_ref_counts_;
163 
164   FuncGraph *root_graph_;
165   GraphPartitionPtr graph_partition_;
166   std::shared_ptr<GraphCompiler> graph_compiler_;
167   std::string device_name_;
168   uint32_t device_id_;
169   int ms_execution_mode_{kGraphMode};
170   int real_execution_mode_{kGraphMode};
171 };
172 }  // namespace compile
173 }  // namespace mindspore
174 #endif
175