• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H
17 #define MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H
18 
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include <string>
23 #include <map>
24 #include "backend/session/session_basic.h"
25 #include "backend/session/kernel_graph.h"
26 #include "backend/session/session_factory.h"
27 using KernelGraph = mindspore::session::KernelGraph;
28 
29 namespace mindspore {
30 namespace session {
31 namespace gpu {
32 class GPUSession : public SessionBasic {
33  public:
34   GPUSession() = default;
35   ~GPUSession() override = default;
36   void Init(uint32_t device_id) override;
37   void SyncStream() const override;
38 
39  protected:
UnifyMindIR(const KernelGraphPtr & graph)40   void UnifyMindIR(const KernelGraphPtr &graph) override { SessionBasic::UnifyMindIR(graph); }
41   GraphId CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
42   GraphId CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) override;
43   void PreExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
44                        VectorRef *const outputs) override;
45   void PostExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
46                         VectorRef *const outputs) override;
47   void ExecuteGraph(const std::shared_ptr<KernelGraph> &kernel_graph) override;
48   KernelGraphPtr BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
49                              const std::vector<tensor::TensorPtr> &input_tensors,
50                              const std::vector<int64_t> &tensors_mask) override;
51   void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector<tensor::TensorPtr> *input_tensors,
52                  VectorRef *outputs, const std::vector<int64_t> &tensors_mask) override;
53   void RunOpImplOrigin(const GraphInfo &graph_info, OpRunInfo *op_run_info,
54                        std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
55                        const std::vector<int64_t> &tensors_mask) override;
56   std::shared_ptr<device::Bucket> CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override;
GetCommWorldGroup()57   std::string GetCommWorldGroup() override { return kNcclWorldGroup; }
58   void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
59                      const std::vector<tensor::TensorPtr> &inputs_const) const override;
60   void UpdateOutputTensors(const VectorRef *outputs,
61                            const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
62                            std::map<DeviceAddressPtr, DeviceAddressPtr> *new_to_old_device_address) override;
63 
64  private:
65   void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
66 
67   void StartKernelRT() const;
68 
69   void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
70 
71   void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
72 
73   void RunOpOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
74 
75   void RunOpHardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
76 
77   void GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_graph);
78 
79   void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
80 
81   void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
82 
83   void AllocateMemory(const KernelGraph *kernel_graph) const;
84 
85   void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, const KernelGraph *kernel_graph) const;
86 
87   void RunOpClearMemory(const KernelGraph *kernel_graph) const;
88 
89   void RunOpGenKernelEvent(const KernelGraph *graph) const;
90 
91   void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
92 
93 #ifdef ENABLE_DEBUGGER
94   void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
95 
96   void DumpSetup(const std::shared_ptr<KernelGraph> &kernel_graph) const;
97 
98   bool DumpDataEnabledIteration() const;
99 #endif
100 
101   GraphId CompileGraphImpl(const KernelGraphPtr &kernel_graph);
102 };
103 using GPUSessionPtr = std::shared_ptr<GPUSession>;
104 MS_REG_SESSION(kGPUDevice, GPUSession);
105 }  // namespace gpu
106 }  // namespace session
107 }  // namespace mindspore
108 #endif  // MINDSPORE_CCSRC_BACKEND_SESSION_GPU_SESSION_H
109