1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_VM_BACKENDBASE_H_ 17 #define MINDSPORE_CCSRC_VM_BACKENDBASE_H_ 18 19 #include <list> 20 #include <memory> 21 #include <string> 22 #include <map> 23 #include <set> 24 #include <utility> 25 #include <vector> 26 27 #include "utils/hash_map.h" 28 #include "ir/anf.h" 29 #include "backend/common/session/session_basic.h" 30 #include "runtime/hardware/device_context.h" 31 #include "backend/graph_compiler/segment_runner.h" 32 #include "runtime/graph_scheduler/actor/actor_set.h" 33 #include "include/common/profiler.h" 34 #include "include/backend/py_execute_utils.h" 35 36 namespace mindspore { 37 namespace compile { 38 using GraphOutputInfo = session::GraphOutputInfo; 39 using DeviceContext = device::DeviceContext; 40 using ActorInfo = runtime::ActorInfo; 41 using GraphCompiler = runtime::GraphCompiler; 42 using GraphCompilerInfo = runtime::GraphCompilerInfo; 43 using ControlNodeParser = runtime::ControlNodeParser; 44 using FuncGraphToKernelGraphGroup = runtime::FuncGraphToKernelGraphGroup; 45 using ControlNodeParserPtr = runtime::ControlNodeParserPtr; 46 using KernelWithIndex = session::KernelWithIndex; 47 48 enum SwitchCondStatus { 49 kCondOk = 0, 50 kCondAlreadyRun, 51 }; 52 53 class BACKEND_EXPORT Backend { 54 public: 55 explicit Backend(const std::string &name); 56 57 virtual ~Backend() = default; 58 convert_fn()59 LinkFuncType convert_fn() { return convert_fn_; } name()60 std::string name() { return name_; } 61 virtual bool GetCond(const BaseRef &c, bool *value); 62 virtual bool GetIndex(const BaseRef &c, int64_t *value); CompileGraph(const NotNull<FuncGraphPtr> & fg)63 virtual GraphId CompileGraph(const NotNull<FuncGraphPtr> &fg) { return kInvalidGraphId; } SetDebugger()64 virtual void SetDebugger() {} 65 is_multi_graph_sink()66 bool is_multi_graph_sink() const { return is_multi_graph_sink_; } set_is_multi_graph_sink(bool flag)67 void set_is_multi_graph_sink(bool flag) { is_multi_graph_sink_ = flag; } 68 69 protected: 70 std::string name_; 71 LinkFuncType convert_fn_; 72 bool is_multi_graph_sink_; 73 }; 74 75 BACKEND_EXPORT void set_pydata_converter(const pyexecute::PyDataConverter &pydata_converter); 76 77 std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info, 78 const VectorRef &args); 79 runtime::KernelMapPosition FetchOriginOutputOrder(const AnfNodePtr &root_output); 80 81 class BACKEND_EXPORT MindRTBackendBase : public Backend { 82 public: 83 MindRTBackendBase(const std::string &backend_name, const std::string &device_name, uint32_t device_id); 84 ~MindRTBackendBase() override = default; 85 86 // The parameter root_graph is a root graph, and the root graph maybe contain multiple sub graphs, It will traverse 87 // all sub graphs to call CompileGraph. 88 const ActorInfo &CompileGraphs(const FuncGraphPtr &func_graph); 89 90 // Run Graph in the graph mode. 91 void RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs); 92 93 #ifdef ENABLE_DEBUGGER 94 void SetDebuggerInit() const; 95 #endif 96 97 // Get serialized random status of all random kernels in this graph 98 std::string GetRandomStatus(const ActorInfo &actor_info); 99 100 // Get the device target. GetDeviceTarget()101 std::string GetDeviceTarget() { return device_name_; } 102 WaitTaskFinish()103 virtual void WaitTaskFinish() const {} RunGraphByCondition(const ActorInfo & actor_info,const GraphCompilerInfo & graph_compiler_info,const VectorRef & args,VectorRef * outputs)104 virtual void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, 105 const VectorRef &args, VectorRef *outputs) {} 106 107 protected: 108 // Convert the nodes which are not supported in the backend. 109 void UnifyMindIR(const FuncGraphPtr &func_graph) const; 110 111 // The parameter func_graph is a graph, it can be either a root graph or a sub graph, 112 // The result of graph compiler is stored in graph_id_to_device_context_ and control_nodes_. 113 void CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode); 114 115 // Compile the kernel graph by the segment which is from the function graph partition. 116 void CompileGraphFromSegment(const GraphSegmentPtr &segment, device::RunMode run_mode); 117 118 // Compile the kernel graph which generated directly from front end(PyNative), and no need do graph partition. 119 void CompileKernelGraph(const KernelGraphPtr &kernel_graph, const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes, 120 DeviceContext *device_context, device::RunMode run_mode); 121 122 void CacheFuncGraphWithKernelGraphId(const FuncGraphPtr &func_graph, const GraphId &graph_id, 123 DeviceContext *device_context); 124 125 void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph); 126 127 // Restore the outputs tuple by the origin funcGraph output node and output tensors. 128 void ConstructOutputs(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors, 129 size_t *output_position, VectorRef *outputs, std::vector<tensor::TensorPtr> *tuple_tensors); 130 // Spit the tuple tensor to multi tensors for restoring the tuple output. 131 void ConstructOutputByTupleTensor(tensor::TensorPtr output_tensor, const abstract::SequenceShapePtr &tensor_shape, 132 VectorRef *outputs, std::vector<tensor::TensorPtr> *tuple_tensors) const; 133 // In the control flow, the output of the call node needs to be created by abstract. 134 BaseRef ConstructOutputByAbstract(const abstract::AbstractBasePtr &abstract, 135 const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position, 136 std::vector<tensor::TensorPtr> *tuple_tensors); 137 // Construct the GraphCompilerInfo by the compilation results of graph, used in Graph mode. 138 std::shared_ptr<GraphCompilerInfo> ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph); 139 140 void ParseControlNodes(const GraphCompilerInfo &graph_compile_info); 141 142 void UpdateGraphCompilerInfo(const ActorInfo &actor_info); 143 144 void ContiguousArgs(const VectorRef &args, const GraphCompilerInfo &graph_compiler_info); 145 146 // Wait multi stream finish. 147 void WaitMultiStream(const GraphCompilerInfo &graph_compiler_info); 148 149 // When compiling FuncGraph, it is divided according to the control nodes, and obtain the control nodes and several 150 // node segments. Node segments will be compiled into kernelGraphs which are expressed as GraphId and bound to 151 // the corresponding device_context. 152 std::map<GraphId, DeviceContext *> graph_id_to_device_context_; 153 // Funcgraph will be cut into multiple kernel graphs, and the map is used to save the correspondence. 154 // The kernel graphs which not cut by control flow are placed in the same group. 155 std::map<FuncGraphPtr, std::vector<std::vector<GraphId>>> func_graph_to_kernel_graph_ids_; 156 std::map<GraphInfo, DeviceContext *> graph_info_to_device_context_; 157 std::vector<AnfNodePtr> control_nodes_; 158 159 mindspore::HashMap<ActorInfo, std::shared_ptr<GraphCompilerInfo>> actor_to_graph_compiler_info_; 160 161 // Save the mapping between cell id and actor info. 162 FuncGraphPtr root_graph_; 163 AnfNodePtr output_node_; 164 GraphPartitionPtr graph_partition_; 165 std::shared_ptr<GraphCompiler> graph_compiler_; 166 std::string device_name_; 167 uint32_t device_id_; 168 int ms_execution_mode_{kGraphMode}; 169 void CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode = device::RunMode::kUnknown); 170 void ProcessNotSupportCnode(const FuncGraphPtr &func_graph, const device::DeviceType &old_target, 171 const device::DeviceType &new_target) const; 172 }; 173 } // namespace compile 174 } // namespace mindspore 175 #endif 176