1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ 17 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ 18 #include <dirent.h> 19 #include <memory> 20 #include <vector> 21 #include <string> 22 #include <map> 23 #include <utility> 24 #include <unordered_map> 25 #include <unordered_set> 26 #include "runtime/device/kernel_runtime.h" 27 #include "runtime/context.h" 28 #include "runtime/device/ascend/ge_runtime/davinci_model.h" 29 #include "runtime/device/kernel_runtime_manager.h" 30 #include "backend/session/session_basic.h" 31 #ifndef ENABLE_SECURITY 32 #include "runtime/device/ascend/dump/data_dumper.h" 33 #endif 34 35 using std::unordered_map; 36 using std::vector; 37 namespace mindspore::device::ascend { 38 using ge::model_runner::TaskInfo; 39 class AscendKernelRuntime : public KernelRuntime { 40 public: 41 AscendKernelRuntime() = default; 42 ~AscendKernelRuntime() override; 43 bool Init() override; 44 bool LoadData(const session::KernelGraph &graph) override; 45 bool GenTask(const session::KernelGraph &graph); 46 void GenKernelEvents(const session::KernelGraph &graph) override; 47 void SetKernelModStream(const std::vector<CNodePtr> &kernels, std::vector<size_t> *last_stream_nodes); 48 void ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels, 49 std::vector<std::vector<std::function<void()>>> *kernel_run_events, 50 const std::vector<size_t> &last_stream_nodes); 51 bool GenDynamicKernel(const session::KernelGraph &graph) override; 52 bool RunDynamicKernelAsync(const session::KernelGraph &graph) override; 53 bool LoadTask(const session::KernelGraph &graph); 54 bool RunTask(const session::KernelGraph &graph); 55 bool Load(const session::KernelGraph &graph, bool is_task_sink) override; 56 bool Run(const session::KernelGraph &graph, bool is_task_sink) override; 57 void ClearGraphRuntimeResource(uint32_t graph_id) override; 58 void ClearGlobalIdleMem() override; 59 bool SyncStream() override; 60 bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override; 61 void SetContext() override; 62 void CreateContext() override; context()63 const void *context() const override { return rt_context_; } 64 #ifndef ENABLE_SECURITY 65 void PreInit() override; 66 #endif 67 uint64_t GetAvailableMemMaxSize() const override; GetTargetDeviceAddressType()68 DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; }; 69 std::shared_ptr<DeviceEvent> CreateDeviceEvent() override; 70 std::shared_ptr<DeviceEvent> CreateDeviceTimeEvent() override; compute_stream()71 void *compute_stream() const override { return stream_; } communication_stream()72 void *communication_stream() const override { return communication_stream_; } 73 void *GetModelStream(uint32_t graph_id) const override; 74 75 protected: 76 DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, 77 TypeId type_id) const override; 78 DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id, 79 const KernelWithIndex &node_index) const override; 80 bool KernelMemNotReuse(const AnfNodePtr &node) override; 81 82 void KernelLaunchProfiling(const std::string &kernel_name) override; 83 84 private: 85 bool InitDevice(); 86 bool ResetDevice(uint32_t device_id); 87 static bool HcclInit(); 88 static bool NeedDestroyHccl(); 89 static bool DestroyHccl(); 90 void SetCurrentContext(); 91 92 void ClearGraphModelMap(); 93 void ReleaseDeviceRes() override; 94 bool GraphWithEmptyTaskList(const session::KernelGraph &graph) const; 95 bool CheckGraphIdValid(GraphId graph_id) const; 96 #ifndef ENABLE_SECURITY 97 void DistributeDebugTask(const session::KernelGraph &graph, const NotNull<std::function<void *()>> &model_handle); 98 void LaunchDataDump(GraphId graph_id); 99 void ReportProfilingData(); 100 #endif 101 static CNodePtr GetErrorNodeName(uint32_t streamid, uint32_t taskid); 102 static std::string GetDumpPath(); 103 #ifndef ENABLE_SECURITY 104 static void DumpTaskExceptionInfo(const session::KernelGraph &graph); 105 #endif 106 static void TaskFailCallback(rtExceptionInfo *task_fail_info); 107 static bool DeleteDumpDir(const std::string &path); 108 static int DeleteDumpFile(std::string path); 109 static std::string GetRealPath(const std::string &path); 110 111 rtContext_t rt_context_{nullptr}; 112 bool initialized_{false}; 113 unordered_map<GraphId, vector<std::shared_ptr<TaskInfo>>> task_map_; 114 unordered_map<GraphId, std::shared_ptr<ge::model_runner::DavinciModel>> graph_model_map_; 115 #ifndef ENABLE_SECURITY 116 unordered_map<GraphId, std::shared_ptr<DataDumper>> graph_data_dumper_; 117 #endif 118 std::map<std::pair<uint32_t, uint32_t>, std::string> stream_id_task_id_op_name_map_; 119 static std::map<std::string, uint32_t> overflow_tasks_; 120 static std::vector<rtExceptionInfo> task_fail_infoes_; 121 std::map<uint32_t, void *> stream_id_map_; 122 std::map<std::string, uint32_t> group_stream_id_map_; 123 }; 124 125 MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime); 126 } // namespace mindspore::device::ascend 127 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_ 128