1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H 17 #define MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H 18 19 #include <vector> 20 #include <memory> 21 #include <algorithm> 22 #include <string> 23 #include <map> 24 #include "backend/session/session_basic.h" 25 #include "backend/session/kernel_graph.h" 26 #include "backend/session/session_factory.h" 27 using KernelGraph = mindspore::session::KernelGraph; 28 29 namespace mindspore { 30 namespace session { 31 namespace gpu { 32 class GPUSession : public SessionBasic { 33 public: 34 GPUSession() = default; 35 ~GPUSession() override = default; 36 void Init(uint32_t device_id) override; 37 void SyncStream() const override; 38 39 protected: UnifyMindIR(const KernelGraphPtr & graph)40 void UnifyMindIR(const KernelGraphPtr &graph) override { SessionBasic::UnifyMindIR(graph); } 41 GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; 42 GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override; 43 void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs, 44 VectorRef *const outputs) override; 45 void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs, 46 VectorRef *const outputs) override; 47 void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override; 48 KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, 49 const std::vector<tensor::TensorPtr> &input_tensors, 50 const std::vector<int64_t> &tensors_mask) override; 51 void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors, 52 VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override; 53 void RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_run_info, 54 std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, 55 const std::vector<int64_t> &tensors_mask) override; 56 std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override; GetCommWorldGroup()57 std::string GetCommWorldGroup() override { return kNcclWorldGroup; } 58 void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, 59 const std::vector<tensor::TensorPtr> &inputs_const) const override; 60 void UpdateOutputTensors(const VectorRef *outputs, 61 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, 62 std::map<DeviceAddressPtr, DeviceAddressPtr> *new_to_old_device_address) override; 63 64 private: 65 void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; 66 67 void StartKernelRT() const; 68 69 void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph); 70 71 void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph); 72 73 void RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph); 74 75 void RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph); 76 77 void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph); 78 79 void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph); 80 81 void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; 82 83 void AllocateMemory(const KernelGraph *kernel_graph) const; 84 85 void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, const KernelGraph *kernel_graph) const; 86 87 void RunOpClearMemory(const KernelGraph *kernel_graph) const; 88 89 void RunOpGenKernelEvent(const KernelGraph *graph) const; 90 91 void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const; 92 93 #ifdef ENABLE_DEBUGGER 94 void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const; 95 96 void DumpSetup(const std::shared_ptr<KernelGraph> &kernel_graph) const; 97 98 bool DumpDataEnabledIteration() const; 99 #endif 100 101 GraphId CompileGraphImpl(const KernelGraphPtr &kernel_graph); 102 }; 103 using GPUSessionPtr = std::shared_ptr<GPUSession>; 104 MS_REG_SESSION(kGPUDevice, GPUSession); 105 } // namespace gpu 106 } // namespace session 107 } // namespace mindspore 108 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H 109