1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ 19 #include <vector> 20 #include <memory> 21 #include <string> 22 #include <map> 23 #include <utility> 24 #include <unordered_set> 25 #include "include/backend/device_address.h" 26 #include "ir/tensor.h" 27 #include "include/common/utils/convert_utils.h" 28 #ifdef ENABLE_DEBUGGER 29 #include "include/backend/debug/debugger/debugger.h" 30 #endif 31 #include "include/backend/kernel_graph.h" 32 #include "include/backend/anf_runtime_algorithm.h" 33 #include "include/common/utils/anfalgo.h" 34 #include "kernel/kernel.h" 35 #include "utils/ms_context.h" 36 #include "runtime/device/memory_manager.h" 37 #include "runtime/device/memory_scheduler.h" 38 #include "include/backend/visible.h" 39 40 using mindspore::tensor::Tensor; 41 using std::vector; 42 using TensorPtr = std::shared_ptr<Tensor>; 43 using mindspore::kernel::AddressPtr; 44 using mindspore::kernel::AddressPtrList; 45 using mindspore::kernel::KernelLaunchInfo; 46 47 namespace mindspore { 48 #ifndef ENABLE_DEBUGGER 49 class Debugger; 50 #endif 51 namespace device { 52 class BACKEND_EXPORT KernelRuntime { 53 public: 54 KernelRuntime() = default; 55 virtual ~KernelRuntime(); 56 virtual bool Init() = 0; 57 virtual void AssignMemory(const session::KernelGraph &graph); 58 void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph, 59 bool is_gradient_out, 60 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {}); 61 void AssignCommunicationOutputFromMemoryPool(const AnfNodePtr &node) const; 62 void AssignCommunicationInputFromMemoryPool(const AnfNodePtr &node) const; 63 void RunOpClearMemory(const session::KernelGraph &graph) const; 64 using TbeLaunchKernelModCallBack = 65 std::function<void(const AnfNodePtr &, const kernel::KernelMod *kernel_mod, std::vector<KernelTensor *> *)>; tbe_call_setter(const TbeLaunchKernelModCallBack & call)66 static void tbe_call_setter(const TbeLaunchKernelModCallBack &call) { tbe_call_ = call; } 67 #ifdef ENABLE_DEBUGGER 68 BACKEND_EXPORT static bool DumpDataEnabled(); 69 BACKEND_EXPORT static bool DumpDataEnabledIteration(); 70 #endif 71 virtual bool LoadData(const session::KernelGraph &graph); 72 virtual bool Load(const session::KernelGraph &graph, bool is_task_sink); 73 virtual bool Run(const session::KernelGraph &graph, bool is_task_sink) = 0; 74 virtual bool RunDynamicKernelAsync(const session::KernelGraph &graph) = 0; 75 bool LaunchKernels(const session::KernelGraph &graph); 76 virtual void AssignStaticMemoryInput(const session::KernelGraph &graph); 77 virtual void AssignStaticMemoryValueNode(const session::KernelGraph &graph); 78 79 virtual void ClearGraphRuntimeResource(uint32_t graph_id); 80 virtual bool SyncStream() = 0; 81 virtual bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind, void *stream) = 0; ClearGlobalIdleMem()82 virtual void ClearGlobalIdleMem() {} CreateContext()83 virtual void CreateContext() {} SetContext()84 virtual void SetContext() {} SetContextForce()85 virtual void SetContextForce() {} ResetStreamAndCtx()86 virtual void ResetStreamAndCtx() {} context()87 virtual const void *context() const { return nullptr; } MallocMem(MemType type,size_t size,const DeviceAddressPtr & address)88 uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { 89 return mem_manager_->MallocMem(type, size, address); 90 } 91 bool MallocContinuousMemFromMemPool(const DeviceAddressPtrList &addr_list, size_t total_size, 92 const std::vector<size_t> &size_list, uint32_t stream_id = kDefaultStreamIndex) { 93 return mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list, stream_id); 94 } 95 static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, 96 KernelLaunchInfo *kernel_launch_info); 97 98 // for GPU and D to impl ReleaseDeviceRes()99 virtual void ReleaseDeviceRes() {} set_device_id(uint32_t device_id)100 void set_device_id(uint32_t device_id) { device_id_ = device_id; } device_id()101 uint32_t device_id() const { return device_id_; } 102 static bool UseMemScheduler(); 103 void SyncParameter(const session::KernelGraph &graph, const std::shared_ptr<MemScheduler> &mem_scheduler) const; 104 105 #ifdef ENABLE_DEBUGGER 106 // set debugger SetDebugger()107 void SetDebugger() { 108 #if !defined(_WIN32) && !defined(_WIN64) 109 debugger_ = Debugger::GetInstance(); 110 #endif 111 } 112 #endif 113 114 #ifndef ENABLE_SECURITY PreInit()115 virtual void PreInit() {} 116 #endif GetAvailableMemMaxSize()117 virtual uint64_t GetAvailableMemMaxSize() const { return 0; } GetMsUsedHbmSize()118 virtual uint64_t GetMsUsedHbmSize() const { return 0; } 119 virtual void GenKernelEvents(const session::KernelGraph &graph); CreateDeviceEvent()120 virtual std::shared_ptr<DeviceEvent> CreateDeviceEvent() { return nullptr; } CreateDeviceTimeEvent()121 virtual std::shared_ptr<DeviceEvent> CreateDeviceTimeEvent() { return nullptr; } 122 virtual DeviceType GetTargetDeviceType() const = 0; compute_stream()123 virtual void *compute_stream() const { return nullptr; } copy_data_stream()124 virtual void *copy_data_stream() const { return nullptr; } communication_stream()125 virtual void *communication_stream() const { return nullptr; } 126 void UpdateRefNodeOutputMem(const session::KernelGraph &graph) const; 127 void UpdateSingleRefNodeMem(const CNodePtr &kernel, const session::KernelGraph &graph, bool reverse) const; 128 virtual DeviceAddressPtr AssignExtraStaticMem(const TensorPtr &tensor, const AnfNodePtr &node, size_t index); GetModelStream(uint32_t graph_id)129 virtual void *GetModelStream(uint32_t graph_id) const { return nullptr; } GetInternalDeviceAddress(const session::KernelGraph &,const AnfNodePtr &)130 virtual DeviceAddressPtr GetInternalDeviceAddress(const session::KernelGraph &, const AnfNodePtr &) { 131 return nullptr; 132 } GetShadowBackendNodeMap(const session::KernelGraph &,std::map<AnfNodePtr,AnfNodePtr> *)133 virtual void GetShadowBackendNodeMap(const session::KernelGraph &, std::map<AnfNodePtr, AnfNodePtr> *) { return; } 134 135 // add for MindRT GetMemoryManager()136 std::shared_ptr<MemoryManager> GetMemoryManager() { return mem_manager_; } 137 void AssignStaticMemoryOutput(const session::KernelGraph &graph); 138 void AssignDynamicMemory(const session::KernelGraph &graph); 139 140 // lock runtime 141 static std::lock_guard<std::mutex> LockRuntime(const void *stream); 142 143 protected: 144 virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, 145 TypeId type_id) const = 0; 146 virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, 147 TypeId type_id, const KernelWithIndex &node_index) const = 0; 148 virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); 149 virtual bool KernelMemNotReuse(const AnfNodePtr &node); 150 151 void AssignStaticMemory(const session::KernelGraph &graph); 152 void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index); 153 void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node); 154 155 void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node); 156 void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node); 157 void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node); 158 bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name, 159 const KernelLaunchInfo &kernel_launch_address, void *stream); 160 KernelLaunchProfiling(const std::string & kernel_name)161 virtual void KernelLaunchProfiling(const std::string &kernel_name) {} GetKernelStream(const AnfNodePtr & kernel)162 virtual void *GetKernelStream(const AnfNodePtr &kernel) const { return nullptr; } 163 void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, 164 const session::KernelGraph &graph) const; 165 166 private: 167 static TbeLaunchKernelModCallBack tbe_call_; 168 void GetDeviceAddress(const AnfNodePtr &item, const std::map<AnfNodePtr, AnfNodePtr> shadow_backend_node_map, 169 size_t index, const session::KernelGraph &graph, DeviceAddressPtr *device_address); 170 void UseMemSchedulerIfNeeded(const session::KernelGraph &graph); 171 bool LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel, 172 const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock = false); 173 void ResetNodeAddress(const session::KernelGraph &graph); 174 void AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel, 175 KernelLaunchInfo *kernel_launch_address) const; 176 static void GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, 177 const DeviceAddress *device_address, const kernel::KernelTensorPtr &kernel_tensor); 178 void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph, 179 const AnfNodePtr &kernel); 180 void SyncNodeOutputTensor(const std::shared_ptr<MemScheduler> &mem_scheduler, const KernelWithIndex &output, 181 const session::KernelGraph &graph); 182 183 void AddCommunicationMemInfo(const session::KernelGraph &graph); 184 bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); 185 void LaunchKernelEvent(const std::map<AnfNodePtr, std::vector<std::function<void()>>> &run_events, 186 const AnfNodePtr &node) const; 187 void DebugStreamSync(const CNodePtr &kernel); 188 static void GenKernelTensorLaunchArgs(const CNodePtr &cnode, std::vector<kernel::KernelTensor *> *kernel_inputs, 189 const std::shared_ptr<MemScheduler> &mem_schedule = nullptr); 190 void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph); 191 void RunOpAssignOutputMemory(const AnfNodePtr &kernel, 192 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node, 193 bool is_gradient_out); 194 void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); 195 void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph) const; 196 void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); 197 DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) const; 198 void GetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list, 199 std::vector<size_t> *align_size_list) const; 200 void GetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, DeviceAddressPtrList *address_list, 201 std::vector<size_t> *align_size_list) const; 202 DeviceAddressPtr CreateDeviceAddressForStringValue(const ValuePtr &value, bool use_mem_pool, uint32_t graph_id); 203 bool MemSchedulerPreCompute(const AnfNodePtr &kernel, const std::shared_ptr<MemScheduler> &mem_scheduler, 204 void *stream, bool mock, KernelLaunchInfo *kernel_launch_info); 205 bool MemSchedulerPostCompute(const session::KernelGraph &graph, const AnfNodePtr &kernel, 206 const std::shared_ptr<MemScheduler> &mem_scheduler, void *stream, bool mock); 207 208 protected: 209 uint32_t device_id_{0}; 210 bool pynative_mode_profiling_flag_{false}; 211 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64) 212 std::shared_ptr<Debugger> debugger_; 213 #endif 214 void *stream_{nullptr}; 215 void *communication_stream_{nullptr}; 216 void *copy_data_stream_{nullptr}; 217 void *forward_send_stream_{nullptr}; 218 void *backward_send_stream_{nullptr}; 219 void *forward_recv_stream_{nullptr}; 220 void *backward_recv_stream_{nullptr}; 221 std::shared_ptr<MemoryManager> mem_manager_{nullptr}; 222 std::map<uint32_t, std::pair<std::map<AnfNodePtr, std::vector<std::function<void()>>>, 223 std::map<AnfNodePtr, std::vector<std::function<void()>>>>> 224 graph_kernel_events_map_; 225 mindspore::HashMap<int64_t, std::pair<uint8_t *, uint8_t *>> reuse_communication_address_; 226 MemSchedulerManager mem_scheduler_manager_; 227 }; 228 using KernelRuntimePtr = std::shared_ptr<KernelRuntime>; 229 } // namespace device 230 } // namespace mindspore 231 232 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ 233