• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include <map>
23 #include <utility>
24 #include <unordered_set>
25 #include "include/backend/device_address.h"
26 #include "ir/tensor.h"
27 #include "include/common/utils/convert_utils.h"
28 #ifdef ENABLE_DEBUGGER
29 #include "include/backend/debug/debugger/debugger.h"
30 #endif
31 #include "include/backend/kernel_graph.h"
32 #include "include/backend/anf_runtime_algorithm.h"
33 #include "include/common/utils/anfalgo.h"
34 #include "kernel/kernel.h"
35 #include "utils/ms_context.h"
36 #include "runtime/device/memory_manager.h"
37 #include "runtime/device/memory_scheduler.h"
38 #include "include/backend/visible.h"
39 
40 using mindspore::tensor::Tensor;
41 using std::vector;
42 using TensorPtr = std::shared_ptr<Tensor>;
43 using mindspore::kernel::AddressPtr;
44 using mindspore::kernel::AddressPtrList;
45 using mindspore::kernel::KernelLaunchInfo;
46 
47 namespace mindspore {
48 #ifndef ENABLE_DEBUGGER
49 class Debugger;
50 #endif
51 namespace device {
52 class BACKEND_EXPORT KernelRuntime {
53  public:
54   KernelRuntime() = default;
55   virtual ~KernelRuntime();
56   virtual bool Init() = 0;
57   virtual void AssignMemory(const session::KernelGraph &graph);
58   void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph,
59                          bool is_gradient_out,
60                          const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {});
61   void AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const;
62   void AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const;
63   void RunOpClearMemory(const session::KernelGraph &graph) const;
64   using TbeLaunchKernelModCallBack =
65     std::function<void(const AnfNodePtr &, const kernel::KernelMod *kernel_mod, std::vector<KernelTensor *> *)>;
tbe_call_setter(const TbeLaunchKernelModCallBack & call)66   static void tbe_call_setter(const TbeLaunchKernelModCallBack &call) { tbe_call_ = call; }
67 #ifdef ENABLE_DEBUGGER
68   BACKEND_EXPORT static bool DumpDataEnabled();
69   BACKEND_EXPORT static bool DumpDataEnabledIteration();
70 #endif
71   virtual bool LoadData(const session::KernelGraph &graph);
72   virtual bool Load(const session::KernelGraph &graph, bool is_task_sink);
73   virtual bool Run(const session::KernelGraph &graph, bool is_task_sink) = 0;
74   virtual bool RunDynamicKernelAsync(const session::KernelGraph &graph) = 0;
75   bool LaunchKernels(const session::KernelGraph &graph);
76   virtual void AssignStaticMemoryInput(const session::KernelGraph &graph);
77   virtual void AssignStaticMemoryValueNode(const session::KernelGraph &graph);
78 
79   virtual void ClearGraphRuntimeResource(uint32_t graph_id);
80   virtual bool SyncStream() = 0;
81   virtual bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind, void *stream) = 0;
ClearGlobalIdleMem()82   virtual void ClearGlobalIdleMem() {}
CreateContext()83   virtual void CreateContext() {}
SetContext()84   virtual void SetContext() {}
SetContextForce()85   virtual void SetContextForce() {}
ResetStreamAndCtx()86   virtual void ResetStreamAndCtx() {}
context()87   virtual const void *context() const { return nullptr; }
MallocMem(MemType type,size_t size,const DeviceAddressPtr & address)88   uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) {
89     return mem_manager_->MallocMem(type, size, address);
90   }
91   bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size,
92                                       const std::vector<size_t> &size_list, uint32_t stream_id = kDefaultStreamIndex) {
93     return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list, stream_id);
94   }
95   static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel,
96                             KernelLaunchInfo *kernel_launch_info);
97 
98   // for GPU and D to impl
ReleaseDeviceRes()99   virtual void ReleaseDeviceRes() {}
set_device_id(uint32_t device_id)100   void set_device_id(uint32_t device_id) { device_id_ = device_id; }
device_id()101   uint32_t device_id() const { return device_id_; }
102   static bool UseMemScheduler();
103   void SyncParameter(const session::KernelGraph &graph, const std::shared_ptr<MemScheduler> &mem_scheduler) const;
104 
105 #ifdef ENABLE_DEBUGGER
106   // set debugger
SetDebugger()107   void SetDebugger() {
108 #if !defined(_WIN32) && !defined(_WIN64)
109     debugger_ = Debugger::GetInstance();
110 #endif
111   }
112 #endif
113 
114 #ifndef ENABLE_SECURITY
PreInit()115   virtual void PreInit() {}
116 #endif
GetAvailableMemMaxSize()117   virtual uint64_t GetAvailableMemMaxSize() const { return 0; }
GetMsUsedHbmSize()118   virtual uint64_t GetMsUsedHbmSize() const { return 0; }
119   virtual void GenKernelEvents(const session::KernelGraph &graph);
CreateDeviceEvent()120   virtual std::shared_ptr<DeviceEvent> CreateDeviceEvent() { return nullptr; }
CreateDeviceTimeEvent()121   virtual std::shared_ptr<DeviceEvent> CreateDeviceTimeEvent() { return nullptr; }
122   virtual DeviceType GetTargetDeviceType() const = 0;
compute_stream()123   virtual void *compute_stream() const { return nullptr; }
copy_data_stream()124   virtual void *copy_data_stream() const { return nullptr; }
communication_stream()125   virtual void *communication_stream() const { return nullptr; }
126   void UpdateRefNodeOutputMem(const session::KernelGraph &graph) const;
127   void UpdateSingleRefNodeMem(const CNodePtr &kernel, const session::KernelGraph &graph, bool reverse) const;
128   virtual DeviceAddressPtr AssignExtraStaticMem(const TensorPtr &tensor, const AnfNodePtr &node, size_t index);
GetModelStream(uint32_t graph_id)129   virtual void *GetModelStream(uint32_t graph_id) const { return nullptr; }
GetInternalDeviceAddress(const session::KernelGraph &,const AnfNodePtr &)130   virtual DeviceAddressPtr GetInternalDeviceAddress(const session::KernelGraph &, const AnfNodePtr &) {
131     return nullptr;
132   }
GetShadowBackendNodeMap(const session::KernelGraph &,std::map<AnfNodePtr,AnfNodePtr> *)133   virtual void GetShadowBackendNodeMap(const session::KernelGraph &, std::map<AnfNodePtr, AnfNodePtr> *) { return; }
134 
135   // add for MindRT
GetMemoryManager()136   std::shared_ptr<MemoryManager> GetMemoryManager() { return mem_manager_; }
137   void AssignStaticMemoryOutput(const session::KernelGraph &graph);
138   void AssignDynamicMemory(const session::KernelGraph &graph);
139 
140   // lock runtime
141   static std::lock_guard<std::mutex> LockRuntime(const void *stream);
142 
143  protected:
144   virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
145                                                TypeId type_id) const = 0;
146   virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
147                                                TypeId type_id, const KernelWithIndex &node_index) const = 0;
148   virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index);
149   virtual bool KernelMemNotReuse(const AnfNodePtr &node);
150 
151   void AssignStaticMemory(const session::KernelGraph &graph);
152   void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index);
153   void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node);
154 
155   void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node);
156   void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node);
157   void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node);
158   bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name,
159                                          const KernelLaunchInfo &kernel_launch_address, void *stream);
160 
KernelLaunchProfiling(const std::string & kernel_name)161   virtual void KernelLaunchProfiling(const std::string &kernel_name) {}
GetKernelStream(const AnfNodePtr & kernel)162   virtual void *GetKernelStream(const AnfNodePtr &kernel) const { return nullptr; }
163   void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler,
164                              const session::KernelGraph &graph) const;
165 
166  private:
167   static TbeLaunchKernelModCallBack tbe_call_;
168   void GetDeviceAddress(const AnfNodePtr &item, const std::map<AnfNodePtr, AnfNodePtr> shadow_backend_node_map,
169                         size_t index, const session::KernelGraph &graph, DeviceAddressPtr *device_address);
170   void UseMemSchedulerIfNeeded(const session::KernelGraph &graph);
171   bool LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel,
172                     const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock = false);
173   void ResetNodeAddress(const session::KernelGraph &graph);
174   void AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel,
175                            KernelLaunchInfo *kernel_launch_address) const;
176   static void GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler,
177                                  const DeviceAddress *device_address, const kernel::KernelTensorPtr &kernel_tensor);
178   void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph,
179                              const AnfNodePtr &kernel);
180   void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output,
181                             const session::KernelGraph &graph);
182 
183   void AddCommunicationMemInfo(const session::KernelGraph &graph);
184   bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false);
185   void LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &run_events,
186                          const AnfNodePtr &node) const;
187   void DebugStreamSync(const CNodePtr &kernel);
188   static void GenKernelTensorLaunchArgs(const CNodePtr &cnode, std::vector<kernel::KernelTensor *> *kernel_inputs,
189                                         const std::shared_ptr<MemScheduler> &mem_schedule = nullptr);
190   void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph);
191   void RunOpAssignOutputMemory(const AnfNodePtr &kernel,
192                                const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
193                                bool is_gradient_out);
194   void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);
195   void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph) const;
196   void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx);
197   DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) const;
198   void GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list,
199                                  std::vector<size_t> *align_size_list) const;
200   void GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list,
201                                   std::vector<size_t> *align_size_list) const;
202   DeviceAddressPtr CreateDeviceAddressForStringValue(const ValuePtr &value, bool use_mem_pool, uint32_t graph_id);
203   bool MemSchedulerPreCompute(const AnfNodePtr &kernel, const std::shared_ptr<MemScheduler> &mem_scheduler,
204                               void *stream, bool mock, KernelLaunchInfo *kernel_launch_info);
205   bool MemSchedulerPostCompute(const session::KernelGraph &graph, const AnfNodePtr &kernel,
206                                const std::shared_ptr<MemScheduler> &mem_scheduler, void *stream, bool mock);
207 
208  protected:
209   uint32_t device_id_{0};
210   bool pynative_mode_profiling_flag_{false};
211 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64)
212   std::shared_ptr<Debugger> debugger_;
213 #endif
214   void *stream_{nullptr};
215   void *communication_stream_{nullptr};
216   void *copy_data_stream_{nullptr};
217   void *forward_send_stream_{nullptr};
218   void *backward_send_stream_{nullptr};
219   void *forward_recv_stream_{nullptr};
220   void *backward_recv_stream_{nullptr};
221   std::shared_ptr<MemoryManager> mem_manager_{nullptr};
222   std::map<uint32_t, std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>,
223                                std::map<AnfNodePtr, std::vector<std::function<void()>>>>>
224     graph_kernel_events_map_;
225   mindspore::HashMap<int64_t, std::pair<uint8_t *, uint8_t *>> reuse_communication_address_;
226   MemSchedulerManager mem_scheduler_manager_;
227 };
228 using KernelRuntimePtr = std::shared_ptr<KernelRuntime>;
229 }  // namespace device
230 }  // namespace mindspore
231 
232 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_
233