1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H 17 #define MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H 18 19 #include <vector> 20 #include <string> 21 #include <utility> 22 #include <memory> 23 #include <map> 24 #include <set> 25 #include "utils/hash_map.h" 26 #include "backend/common/session/kernel_graph_mgr.h" 27 #include "backend/common/session/session_context.h" 28 #include "include/backend/kernel_graph.h" 29 #include "include/backend/anf_runtime_algorithm.h" 30 #include "include/common/utils/anfalgo.h" 31 #include "include/common/utils/tensor_future.h" 32 #include "ir/anf.h" 33 #include "ir/tensor.h" 34 #include "utils/any.h" 35 #include "include/common/utils/contract.h" 36 #include "include/backend/kernel_info.h" 37 #include "utils/ms_context.h" 38 #include "pipeline/pynative/base.h" 39 40 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64) 41 #include "include/backend/debug/debugger/debugger.h" 42 #endif 43 #include "mindspore/ccsrc/debug/summary/summary.h" 44 #include "runtime/hardware/device_context.h" 45 #include "include/backend/visible.h" 46 47 namespace mindspore { 48 namespace runtime { 49 class GraphCompiler; 50 } // namespace runtime 51 } // namespace mindspore 52 53 namespace mindspore { 54 const char kSessionBasic[] = "SessionBasic"; 55 56 namespace session { 57 using mindspore::debug::CallBackFunc; 58 #ifndef ENABLE_SECURITY 59 using mindspore::debug::Summary; 60 #endif 61 62 using AnyList = std::vector<Any>; 63 using AnyListPtr = std::shared_ptr<AnyList>; 64 65 struct BackendOpRunInfo { 66 ~BackendOpRunInfo() = default; BackendOpRunInfoBackendOpRunInfo67 BackendOpRunInfo(pynative::BaseOpRunInfo base_op_run_info, PrimitivePtr prim, bool is_infer, bool is_gradient_out) 68 : base_op_run_info(std::move(base_op_run_info)), 69 op_prim(std::move(prim)), 70 is_infer(is_infer), 71 is_gradient_out(is_gradient_out) {} 72 73 pynative::BaseOpRunInfo base_op_run_info; 74 PrimitivePtr op_prim; 75 bool is_infer = false; 76 bool is_gradient_out = false; 77 }; 78 using BackendOpRunInfoPtr = std::shared_ptr<BackendOpRunInfo>; 79 80 struct InputInfo { 81 std::vector<ValuePtr> input_values; 82 std::vector<InputType> input_types; 83 std::set<KernelWithIndex> input_kernel; 84 abstract::AbstractBasePtrList input_abs; 85 }; 86 87 struct OutputTensorInfo { 88 tensor::TensorPtr output_stub_tensor; 89 bool is_weight; 90 }; 91 92 struct GraphOutputInfo { 93 VectorRef *graph_outputs; 94 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> output_indexes; 95 std::vector<tensor::BaseTensorPtr> graph_output_tensors; 96 }; 97 98 class Executor; 99 100 class BACKEND_EXPORT SessionBasic : public KernelGraphMgr, public std::enable_shared_from_this<SessionBasic> { 101 public: 102 using KernelGraphMgr::ConstructKernelGraph; SessionBasic()103 SessionBasic() : context_(nullptr), summary_callback_(nullptr), device_id_(0) { 104 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64) 105 debugger_ = nullptr; 106 #endif 107 } 108 Init(uint32_t device_id)109 virtual void Init(uint32_t device_id) { device_id_ = device_id; } 110 void InitExecutor(const std::string &device_name, uint32_t device_id); SyncStream()111 virtual void SyncStream() const {} ~SessionBasic()112 virtual ~SessionBasic() { summary_callback_ = nullptr; } 113 114 GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); 115 GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph); 116 void BuildGraph(GraphId graphId); 117 void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); 118 void RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); 119 120 #ifndef ENABLE_SECURITY 121 virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); 122 #endif GetFinalRunGraph()123 virtual GraphId GetFinalRunGraph() const { return kInvalidGraphId; } 124 bool IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name) const; CheckModelInputs(uint32_t graph_id,const std::vector<tensor::TensorPtr> & inputs,std::string * error_msg)125 virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<tensor::TensorPtr> &inputs, 126 std::string *error_msg) const { 127 return true; 128 } 129 void GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs, 130 std::vector<std::string> *inputs_name) const; 131 void GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs, 132 std::vector<std::string> *output_names) const; 133 // create a single run op graph 134 std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const BackendOpRunInfoPtr &op_run_info, 135 const std::vector<ValuePtr> &input_values, 136 const std::vector<InputType> &input_type); 137 void EraseValueNodeTensor(const std::vector<InputType> &input_types, 138 std::vector<tensor::TensorPtr> *input_tensors) const; 139 void RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const; 140 static void RunOpHideNopNode(const KernelGraphPtr &kernel_graph); ReportWarningMessage()141 virtual void ReportWarningMessage() {} ReportErrorMessage()142 virtual void ReportErrorMessage() {} SetThreadContext()143 virtual void SetThreadContext() {} 144 #ifdef ENABLE_DEBUGGER 145 // set debugger SetDebugger()146 void SetDebugger() { 147 debugger_ = Debugger::GetInstance(); 148 auto ms_context = MsContext::GetInstance(); 149 MS_EXCEPTION_IF_NULL(ms_context); 150 MS_EXCEPTION_IF_NULL(debugger_); 151 debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET)); 152 } 153 #endif 154 static BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph, 155 const std::vector<tensor::TensorPtr> &input_tensors, 156 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, 157 KernelMapTensor *node_to_tensor); 158 159 private: 160 void GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, 161 std::map<AnfNodePtr, size_t> *parameter_index) const; 162 void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors, 163 VectorRef *const outputs, 164 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) const; 165 void GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) const; 166 // Cut op not flatten, so we need calculate maketuple input ref count. 167 void CalculateRefCount(const AnfNodePtr &node, std::map<KernelWithIndex, size_t> *ref_count) const; 168 void GetForwardOpOutputRefCount(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, 169 std::map<std::string, size_t> *forward_op_output_tensor_id, 170 const std::map<AnfNodePtr, size_t> ¶meter_index) const; 171 void ReleaseForwardOpOutput(const std::vector<ValuePtr> &input_tensors, 172 std::map<std::string, size_t> *forward_op_output_tensor_id) const; 173 void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count, 174 std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map) const; 175 176 void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs, 177 const std::map<KernelWithIndex, size_t> &ref_count, 178 std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map, 179 GraphOutputInfo *const graph_output_info) const; 180 181 protected: 182 friend class Executor; 183 friend class CompileNodesTask; 184 friend class CompileGraphTask; 185 friend class BuildGraphTask; 186 friend class RunGraphTask; 187 friend class mindspore::runtime::GraphCompiler; IsSupportSummary()188 virtual bool IsSupportSummary() { return true; } 189 virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, 190 VectorRef *outputs, 191 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node, 192 KernelMapTensor *node_to_tensor); 193 // When the device address of the node is used as the output of the graph, the device address will be passed 194 // to the output tensor, and the output node will recreate a new device address. This third parameter records 195 // the relationship between the new and old device address. 196 virtual void UpdateOutputTensors(const VectorRef *outputs, 197 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, 198 std::map<DeviceAddressPtr, DeviceAddressPtr> *); 199 virtual void FinalOptimize(const KernelGraphPtr &graph) const; CompileGraphImpl(const AnfNodePtrList & lst,const AnfNodePtrList & outputs)200 virtual GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { return 0; } CompileGraphImpl(NotNull<FuncGraphPtr>)201 virtual GraphId CompileGraphImpl(NotNull<FuncGraphPtr>) { return kInvalidGraphId; } BuildGraphImpl(GraphId)202 virtual void BuildGraphImpl(GraphId) {} PreExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)203 virtual void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, 204 const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) { 205 MS_EXCEPTION_IF_NULL(kernel_graph); 206 MS_EXCEPTION_IF_NULL(outputs); 207 MS_LOG(INFO) << "Call default PreExecuteGraph with input size: " << inputs.size(); 208 } 209 PostExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)210 virtual void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, 211 const std::vector<tensor::TensorPtr> &inputs, VectorRef *const outputs) { 212 MS_EXCEPTION_IF_NULL(kernel_graph); 213 MS_EXCEPTION_IF_NULL(outputs); 214 MS_LOG(INFO) << "Call default PostExecuteGraph with input size: " << inputs.size(); 215 } 216 ExecuteGraph(const std::shared_ptr<KernelGraph> & kernel_graph)217 virtual void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); } 218 219 void RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); 220 221 void ProcessInputTensorsForHeterogeneous(const std::string &cur_target, 222 const std::vector<tensor::TensorPtr> &input_tensors) const; 223 #ifndef ENABLE_SECURITY 224 virtual void SetSummaryNodes(KernelGraph *graph); 225 void RecurseSetSummaryNodesForAllGraphs(KernelGraph *graph); 226 #endif 227 LoadInputs(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs_const)228 void LoadInputs(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs_const) const { 229 MS_LOG(INFO) << "Status record: start load input. graph id: " << graph_id; 230 auto kernel_graph = GetGraph(graph_id); 231 MS_EXCEPTION_IF_NULL(kernel_graph); 232 if (!kernel_graph->executable()) { 233 return; 234 } 235 LoadInputData(kernel_graph, inputs_const); 236 MS_LOG(INFO) << "Status record: end load input. graph id: " << graph_id; 237 } 238 LoadInputData(const std::shared_ptr<KernelGraph> & kernel_graph,const std::vector<tensor::TensorPtr> & inputs_const)239 virtual void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, 240 const std::vector<tensor::TensorPtr> &inputs_const) const { 241 MS_EXCEPTION_IF_NULL(kernel_graph); 242 MS_LOG(INFO) << "Call default LoadInputData with input size: " << inputs_const.size(); 243 } 244 245 void UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs, 246 const std::vector<tensor::TensorPtr> &input_tensors, 247 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const; 248 #ifndef ENABLE_SECURITY 249 void Summary(KernelGraph *graph); 250 #endif 251 // create graph output for RunOp 252 void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) const; 253 254 BackendOpRunInfoPtr GetSingleOpRunInfo(const CNodePtr &cnode, const InputInfo &input_info, 255 const GraphOutputInfo *const graph_output_info) const; 256 ValuePtr GetValueNodeOutput(const AnfNodePtr &node, size_t output_index) const; 257 tensor::TensorPtr GetParameterOutputTensor(const AnfNodePtr &node, 258 const std::map<AnfNodePtr, size_t> ¶meter_index, 259 const std::vector<tensor::TensorPtr> &graph_inputs) const; 260 tensor::BaseTensorPtr GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index, 261 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output) const; 262 void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output, 263 const std::map<AnfNodePtr, size_t> ¶meter_index, 264 const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info) const; 265 void GetOpInputTensorsFromCNode(const CNodePtr &cnode, 266 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output, 267 const std::map<AnfNodePtr, size_t> ¶meter_index, 268 const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info) const; 269 tensor::BaseTensorPtr GetOpInputTensorByIndex(const CNodePtr &cnode, 270 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output, 271 const std::map<AnfNodePtr, size_t> ¶meter_index, 272 const std::vector<tensor::TensorPtr> &graph_inputs, 273 InputInfo *input_info, size_t input_index) const; 274 275 AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) const; 276 std::vector<uint32_t> GetAllReduceSplitIndex(); GetCommWorldGroup()277 virtual std::string GetCommWorldGroup() { return std::string(); } 278 void DumpGraphs(const std::vector<KernelGraphPtr> &graphs) const; 279 void GetConstValueDepend(const CNodePtr &cnode, std::set<int64_t> *const_input_attr_index) const; 280 mindspore::HashMap<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_; 281 std::shared_ptr<Context> context_; 282 CallBackFunc summary_callback_; 283 uint32_t device_id_; 284 // rank id of physical device 285 uint32_t rank_id_{0}; 286 std::shared_ptr<Executor> executor_; 287 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64) 288 std::shared_ptr<Debugger> debugger_; 289 #endif 290 }; 291 292 using SessionPtr = std::shared_ptr<session::SessionBasic>; 293 using NamedSummaryOutputs = std::map<std::string, std::pair<AnfNodePtr, int>>; 294 } // namespace session 295 BACKEND_EXPORT void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir, 296 const std::vector<CNodePtr> &execution_order); 297 BACKEND_EXPORT uint32_t GetRankId(); 298 } // namespace mindspore 299 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_SESSION_BASIC_H 300