1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ 19 #include <vector> 20 #include <memory> 21 #include <string> 22 #include <map> 23 #include <utility> 24 #include <unordered_set> 25 #include "runtime/device/device_address.h" 26 #include "ir/tensor.h" 27 #include "utils/convert_utils.h" 28 #ifdef ENABLE_DEBUGGER 29 #include "debug/debugger/debugger.h" 30 #endif 31 #include "backend/session/kernel_graph.h" 32 #include "backend/session/anf_runtime_algorithm.h" 33 #include "backend/kernel_compiler/kernel.h" 34 #include "utils/ms_context.h" 35 #include "runtime/device/memory_manager.h" 36 #include "runtime/device/memory_scheduler.h" 37 #include "runtime/device/executor/dynamic_kernel.h" 38 #include "ir/device_event.h" 39 40 using mindspore::tensor::Tensor; 41 using std::vector; 42 using TensorPtr = std::shared_ptr<Tensor>; 43 using mindspore::kernel::AddressPtr; 44 using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>; 45 46 namespace mindspore { 47 #ifndef ENABLE_DEBUGGER 48 class Debugger; 49 #endif 50 namespace device { 51 class KernelRuntime { 52 public: 53 KernelRuntime() = default; 54 virtual ~KernelRuntime(); 55 virtual bool Init() = 0; 56 virtual void AssignMemory(const session::KernelGraph &graph); 57 void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph, 58 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {}); 59 void RunOpAssignCommunicationOutput(const AnfNodePtr &node) const; 60 void RunOpAssignCommunicationInput(const AnfNodePtr &node) const; 61 void RunOpClearMemory(const session::KernelGraph &graph) const; 62 void RunOpMallocPre(const session::KernelGraph &graph, const std::vector<tensor::TensorPtr> &input_tensors); 63 #ifdef ENABLE_DEBUGGER 64 static bool DumpDataEnabled(); 65 static bool DumpDataEnabledIteration(); 66 #endif 67 virtual bool LoadData(const session::KernelGraph &graph); 68 virtual bool Load(const session::KernelGraph &graph, bool is_task_sink); 69 virtual bool Run(const session::KernelGraph &graph, bool is_task_sink) = 0; 70 virtual bool GenDynamicKernel(const session::KernelGraph &graph) = 0; 71 virtual bool RunDynamicKernelAsync(const session::KernelGraph &graph) = 0; 72 bool LaunchKernels(const session::KernelGraph &graph); 73 virtual void AssignStaticMemoryInput(const session::KernelGraph &graph); 74 virtual void AssignStaticMemoryValueNode(const session::KernelGraph &graph); 75 76 virtual void ClearGraphRuntimeResource(uint32_t graph_id); 77 virtual bool SyncStream() = 0; 78 virtual bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) = 0; ClearGlobalIdleMem()79 virtual void ClearGlobalIdleMem() {} CreateContext()80 virtual void CreateContext() {} SetContext()81 virtual void SetContext() {} context()82 virtual const void *context() const { return nullptr; } MallocMem(MemType type,size_t size,const DeviceAddressPtr & address)83 uint8_t *MallocMem(MemType type, size_t size, const DeviceAddressPtr &address) { 84 return mem_manager_->MallocMem(type, size, address); 85 } MallocCommunicationMemFromMemPool(size_t size)86 uint8_t *MallocCommunicationMemFromMemPool(size_t size) { 87 return mem_manager_->MallocCommunicationMemFromMemPool(size); 88 } 89 static void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, 90 AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, 91 AddressPtrList *kernel_outputs); 92 93 // for GPU and D to impl ReleaseDeviceRes()94 virtual void ReleaseDeviceRes() {} set_device_id(uint32_t device_id)95 void set_device_id(uint32_t device_id) { device_id_ = device_id; } device_id()96 uint32_t device_id() { return device_id_; } 97 98 #ifdef ENABLE_DEBUGGER 99 // set debugger SetDebugger()100 void SetDebugger() { 101 #if !defined(_WIN32) && !defined(_WIN64) 102 debugger_ = Debugger::GetInstance(); 103 #endif 104 } 105 #endif 106 107 #ifndef ENABLE_SECURITY PreInit()108 virtual void PreInit() {} 109 #endif GetAvailableMemMaxSize()110 virtual uint64_t GetAvailableMemMaxSize() const { return 0; } 111 virtual void GenKernelEvents(const session::KernelGraph &graph); CreateDeviceEvent()112 virtual std::shared_ptr<DeviceEvent> CreateDeviceEvent() { return nullptr; } CreateDeviceTimeEvent()113 virtual std::shared_ptr<DeviceEvent> CreateDeviceTimeEvent() { return nullptr; } 114 virtual DeviceAddressType GetTargetDeviceAddressType() const = 0; compute_stream()115 virtual void *compute_stream() const { return nullptr; } communication_stream()116 virtual void *communication_stream() const { return nullptr; } 117 void UpdateRefNodeOutputMem(const session::KernelGraph &graph); 118 virtual DeviceAddressPtr AssignExtraStaticMem(const TensorPtr &tensor, const AnfNodePtr &node, size_t index); GetModelStream(uint32_t graph_id)119 virtual void *GetModelStream(uint32_t graph_id) const { return nullptr; } 120 121 protected: 122 virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, 123 TypeId type_id) const = 0; 124 virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, 125 TypeId type_id, const KernelWithIndex &node_index) const = 0; 126 virtual bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index); 127 virtual bool KernelMemNotReuse(const AnfNodePtr &node); 128 129 void AssignStaticMemory(const session::KernelGraph &graph); 130 void AssignDynamicMemory(const session::KernelGraph &graph); 131 void AssignNodeOutputMem(MemType type, const AnfNodePtr &node, int index); 132 void AssignWorkSpaceMem(MemType type, const AnfNodePtr &node); 133 134 void AssignCommunicationNodeOutputMem(MemType type, const AnfNodePtr &node); 135 void AssignCommunicationNodeInputMem(MemType type, const AnfNodePtr &node); 136 void AssignCommunicationNodeMem(MemType type, const AnfNodePtr &node); 137 bool LaunchKernelWithPynativeProfiling(kernel::KernelMod *kernel_mod, const std::string &op_name, 138 const std::vector<AddressPtr> &inputs, 139 const std::vector<AddressPtr> &workspace, 140 const std::vector<AddressPtr> &outputs, void *stream); 141 KernelLaunchProfiling(const std::string & kernel_name)142 virtual void KernelLaunchProfiling(const std::string &kernel_name) {} 143 144 private: 145 void UseMemSchedulerIfNeeded(const session::KernelGraph &graph); 146 bool LaunchKernel(const session::KernelGraph &graph, const AnfNodePtr &kernel, 147 const std::shared_ptr<MemScheduler> &mem_scheduler, bool mock = false); 148 void ResetNodeAddress(const session::KernelGraph &graph); 149 void AssignKernelAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, const AnfNodePtr &kernel, 150 AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, 151 AddressPtrList *kernel_outputs); 152 static void GetOrMallocAddress(const std::shared_ptr<MemScheduler> &mem_scheduler, 153 const DeviceAddress *device_address, const kernel::AddressPtr &kernel_addr); 154 void InitGraphInputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph); 155 void SyncNodeOutputTensors(const std::shared_ptr<MemScheduler> &mem_scheduler, const session::KernelGraph &graph, 156 const AnfNodePtr &kernel, bool mock); 157 void AssignStaticMemoryOutput(const session::KernelGraph &graph); 158 bool LaunchKernelMod(const session::KernelGraph &graph, bool mock = false); 159 void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index) const; 160 void DebugStreamSync(const CNodePtr &kernel); 161 static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs, 162 const std::shared_ptr<MemScheduler> &mem_schedule = nullptr); 163 void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph &graph); 164 void RunOpAssignOutputMemory(const AnfNodePtr &kernel, 165 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node = {}); 166 void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); 167 void RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, const session::KernelGraph &graph); 168 void AssignValueNodeTensor(const ValueNodePtr &value_node, const ValuePtr &node_value, size_t output_idx); 169 DeviceAddressPtr PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) const; 170 #if ((defined ENABLE_CPU) && (!defined _WIN32)) 171 void GetFirstPSEmbeddingCache(const session::KernelGraph &graph, AnfNodePtr *const first_cache_input_index, 172 size_t *const first_cache_size); 173 void CheckIfSupportPSEmbeddingCache(const session::KernelGraph &graph); 174 void CheckSparsePSEmbeddingCache(const CNodePtr &node); 175 #endif 176 void RunOpGetCommunicationInputInfo(const AnfNodePtr &node, size_t *total_size, 177 std::vector<DeviceAddressPtr> *address_list, 178 std::vector<size_t> *align_size_list) const; 179 void RunOpGetCommunicationOutputInfo(const AnfNodePtr &node, size_t *total_size, std::vector<size_t> *align_size_list, 180 std::vector<DeviceAddressPtr> *device_address_list) const; 181 182 protected: 183 uint32_t device_id_{0}; 184 bool pynative_mode_profiling_flag_{false}; 185 #if defined(ENABLE_DEBUGGER) && !defined(_WIN32) && !defined(_WIN64) 186 std::shared_ptr<Debugger> debugger_; 187 #endif 188 void *stream_{nullptr}; 189 void *independent_stream_{nullptr}; 190 void *communication_stream_{nullptr}; 191 std::shared_ptr<MemoryManager> mem_manager_{nullptr}; 192 std::map<uint32_t, std::vector<DynamicKernelPtr>> graph_dynamic_kernel_map_; 193 std::map<uint32_t, 194 std::pair<std::vector<std::vector<std::function<void()>>>, std::vector<std::vector<std::function<void()>>>>> 195 graph_kernel_events_map_; 196 MemSchedulerManager mem_scheduler_manager_; 197 }; 198 using KernelRuntimePtr = std::shared_ptr<KernelRuntime>; 199 } // namespace device 200 } // namespace mindspore 201 202 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_KERNEL_RUNTIME_H_ 203