1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_VM_BACKEND_H_ 17 #define MINDSPORE_CCSRC_VM_BACKEND_H_ 18 19 #include <list> 20 #include <memory> 21 #include <string> 22 #include <unordered_map> 23 #include <map> 24 #include <utility> 25 #include <vector> 26 27 #include "utils/contract.h" 28 #include "ir/anf.h" 29 #include "vm/segment_runner.h" 30 #include "vm/graph_partition.h" 31 #include "vm/vm.h" 32 #include "backend/session/session_basic.h" 33 #include "runtime/hardware/device_context.h" 34 #include "runtime/framework/graph_scheduler.h" 35 36 namespace mindspore { 37 namespace compile { 38 using OpRunInfo = session::OpRunInfo; 39 using GraphOutputInfo = session::GraphOutputInfo; 40 using DeviceContext = device::DeviceContext; 41 using ActorInfo = runtime::ActorInfo; 42 using GraphCompiler = runtime::GraphCompiler; 43 using GraphCompilerInfo = runtime::GraphCompilerInfo; 44 using ControlNodeParser = runtime::ControlNodeParser; 45 using ControlNodeParserPtr = runtime::ControlNodeParserPtr; 46 using KernelWithIndex = session::KernelWithIndex; 47 48 enum SwitchCondStatus { 49 kCondOk = 0, 50 kCondAlreadyRun, 51 }; 52 53 class Backend { 54 public: 55 explicit Backend(const std::string &name); 56 57 virtual ~Backend() = default; 58 convert_fn()59 LinkFuncType convert_fn() { return convert_fn_; } name()60 std::string name() { return name_; } 61 virtual bool GetCond(const BaseRef &c, bool *value); 62 virtual bool GetIndex(const BaseRef &c, int64_t *value); CompileGraph(NotNull<FuncGraphPtr> fg)63 virtual GraphId CompileGraph(NotNull<FuncGraphPtr> fg) { return kInvalidGraphId; } SetDebugger()64 virtual void SetDebugger() {} 65 is_multi_graph_sink()66 bool is_multi_graph_sink() const { return is_multi_graph_sink_; } set_is_multi_graph_sink(bool flag)67 void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; } 68 69 protected: 70 std::string name_; 71 LinkFuncType convert_fn_; 72 bool is_multi_graph_sink_; 73 }; 74 75 class MsBackend : public Backend { 76 public: 77 MsBackend(const std::string &name, const std::string &target, uint32_t device_id); 78 ~MsBackend() override = default; 79 80 LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = ""); 81 virtual VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); 82 83 VectorRef MsSimuRunGraph(const GraphId &g); 84 GraphId CompileGraph(NotNull<FuncGraphPtr> fg) override; 85 VectorRef RunGraph(GraphId graph_id, const VectorRef &args); 86 void ClearSessionGraphs(); 87 void CreateOtherSession(const std::string &target); 88 89 #ifdef ENABLE_DEBUGGER 90 void SetDebugger() override; 91 #endif 92 93 protected: 94 session::SessionPtr target_sess_; 95 session::SessionPtr other_sess_; 96 std::string target_device_; 97 std::string other_device_; 98 std::unordered_map<GraphId, LinConvertResult> graph_id_map_; 99 }; 100 101 class MindRTBackend : public Backend { 102 public: 103 MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id); 104 ~MindRTBackend() override = default; 105 106 // The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, It will traverse 107 // all sub graphs to call CompileGraph. 108 const ActorInfo &CompileGraphs(const FuncGraphPtr &root_graph); 109 110 // Compile single op kernel graph in the pyNative mode. 111 const ActorInfo &CompileGraph(const OpRunInfo &op_run_info, const GraphInfo &graph_info, 112 const std::vector<int64_t> *tensors_mask, 113 std::vector<tensor::TensorPtr> *input_tensors); 114 115 // Run Graph in the graph mode. 116 void RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs); 117 118 // Run Graph in the pyNative mode. 119 void RunGraph(const ActorInfo &actor_info, OpRunInfo *op_run_info, const std::vector<int64_t> *tensors_mask, 120 const std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs); 121 #ifdef ENABLE_DEBUGGER 122 void SetDebuggerInit(); 123 #endif 124 125 private: 126 // The parameter func_graph is a graph, it can be either a root graph or a sub graph, 127 // The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_. 128 void CompileGraph(const FuncGraphPtr &func_graph); 129 130 // Restore the outputs tuple by the origin funcGraph output node and output tensors. 131 void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors, 132 size_t *output_position, VectorRef *outputs); 133 134 // Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode. 135 std::unique_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph); 136 137 // Construct the GraphCompilerInfo by the compilation results of graph, used in PyNative mode. 138 std::unique_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const ActorInfo &actor_info, 139 const std::vector<int64_t> *tensors_mask, 140 const std::vector<tensor::TensorPtr> *input_tensors, 141 bool need_erase); 142 143 // In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing, 144 // so the latest single op cache should be erased when cache list size exceeds threshold value. 145 void EraseSingleOpCache(const ActorInfo &actor_info, const KernelGraphPtr &graph); 146 147 // Split complete kernel graph to single op graph in PyNative back 148 // propagation, then compile and run single op graph. 149 void RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs, 150 const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs); 151 152 // When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several 153 // node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to 154 // the corresponding device_context. 155 std::map<GraphId, DeviceContext *> graph_id_to_device_context_; 156 std::map<GraphInfo, DeviceContext *> graph_info_to_device_context_; 157 std::vector<AnfNodePtr> control_nodes_; 158 159 std::unordered_map<ActorInfo, std::unique_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_; 160 161 // Cache output tensor ref count of kernels for back propagation graph in PyNative mode. 162 std::map<GraphId, std::map<KernelWithIndex, size_t>> cnode_ref_counts_; 163 164 FuncGraph *root_graph_; 165 GraphPartitionPtr graph_partition_; 166 std::shared_ptr<GraphCompiler> graph_compiler_; 167 std::string device_name_; 168 uint32_t device_id_; 169 int ms_execution_mode_{kGraphMode}; 170 int real_execution_mode_{kGraphMode}; 171 }; 172 } // namespace compile 173 } // namespace mindspore 174 #endif 175