• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_
17 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_
18 #include <dirent.h>
19 #include <memory>
20 #include <vector>
21 #include <string>
22 #include <map>
23 #include <utility>
24 #include <unordered_map>
25 #include <unordered_set>
26 #include "runtime/device/kernel_runtime.h"
27 #include "runtime/context.h"
28 #include "runtime/device/ascend/ge_runtime/davinci_model.h"
29 #include "runtime/device/kernel_runtime_manager.h"
30 #include "backend/session/session_basic.h"
31 #ifndef ENABLE_SECURITY
32 #include "runtime/device/ascend/dump/data_dumper.h"
33 #endif
34 
35 using std::unordered_map;
36 using std::vector;
37 namespace mindspore::device::ascend {
38 using ge::model_runner::TaskInfo;
39 class AscendKernelRuntime : public KernelRuntime {
40  public:
41   AscendKernelRuntime() = default;
42   ~AscendKernelRuntime() override;
43   bool Init() override;
44   bool LoadData(const session::KernelGraph &graph) override;
45   bool GenTask(const session::KernelGraph &graph);
46   void GenKernelEvents(const session::KernelGraph &graph) override;
47   void SetKernelModStream(const std::vector<CNodePtr> &kernels, std::vector<size_t> *last_stream_nodes);
48   void ProcessBoundaryEvent(const std::vector<CNodePtr> &kernels,
49                             std::vector<std::vector<std::function<void()>>> *kernel_run_events,
50                             const std::vector<size_t> &last_stream_nodes);
51   bool GenDynamicKernel(const session::KernelGraph &graph) override;
52   bool RunDynamicKernelAsync(const session::KernelGraph &graph) override;
53   bool LoadTask(const session::KernelGraph &graph);
54   bool RunTask(const session::KernelGraph &graph);
55   bool Load(const session::KernelGraph &graph, bool is_task_sink) override;
56   bool Run(const session::KernelGraph &graph, bool is_task_sink) override;
57   void ClearGraphRuntimeResource(uint32_t graph_id) override;
58   void ClearGlobalIdleMem() override;
59   bool SyncStream() override;
60   bool MemcpyAsync(void *dst, const void *src, uint64_t size, int32_t kind) override;
61   void SetContext() override;
62   void CreateContext() override;
context()63   const void *context() const override { return rt_context_; }
64 #ifndef ENABLE_SECURITY
65   void PreInit() override;
66 #endif
67   uint64_t GetAvailableMemMaxSize() const override;
GetTargetDeviceAddressType()68   DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; };
69   std::shared_ptr<DeviceEvent> CreateDeviceEvent() override;
70   std::shared_ptr<DeviceEvent> CreateDeviceTimeEvent() override;
compute_stream()71   void *compute_stream() const override { return stream_; }
communication_stream()72   void *communication_stream() const override { return communication_stream_; }
73   void *GetModelStream(uint32_t graph_id) const override;
74 
75  protected:
76   DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
77                                        TypeId type_id) const override;
78   DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id,
79                                        const KernelWithIndex &node_index) const override;
80   bool KernelMemNotReuse(const AnfNodePtr &node) override;
81 
82   void KernelLaunchProfiling(const std::string &kernel_name) override;
83 
84  private:
85   bool InitDevice();
86   bool ResetDevice(uint32_t device_id);
87   static bool HcclInit();
88   static bool NeedDestroyHccl();
89   static bool DestroyHccl();
90   void SetCurrentContext();
91 
92   void ClearGraphModelMap();
93   void ReleaseDeviceRes() override;
94   bool GraphWithEmptyTaskList(const session::KernelGraph &graph) const;
95   bool CheckGraphIdValid(GraphId graph_id) const;
96 #ifndef ENABLE_SECURITY
97   void DistributeDebugTask(const session::KernelGraph &graph, const NotNull<std::function<void *()>> &model_handle);
98   void LaunchDataDump(GraphId graph_id);
99   void ReportProfilingData();
100 #endif
101   static CNodePtr GetErrorNodeName(uint32_t streamid, uint32_t taskid);
102   static std::string GetDumpPath();
103 #ifndef ENABLE_SECURITY
104   static void DumpTaskExceptionInfo(const session::KernelGraph &graph);
105 #endif
106   static void TaskFailCallback(rtExceptionInfo *task_fail_info);
107   static bool DeleteDumpDir(const std::string &path);
108   static int DeleteDumpFile(std::string path);
109   static std::string GetRealPath(const std::string &path);
110 
111   rtContext_t rt_context_{nullptr};
112   bool initialized_{false};
113   unordered_map<GraphId, vector<std::shared_ptr<TaskInfo>>> task_map_;
114   unordered_map<GraphId, std::shared_ptr<ge::model_runner::DavinciModel>> graph_model_map_;
115 #ifndef ENABLE_SECURITY
116   unordered_map<GraphId, std::shared_ptr<DataDumper>> graph_data_dumper_;
117 #endif
118   std::map<std::pair<uint32_t, uint32_t>, std::string> stream_id_task_id_op_name_map_;
119   static std::map<std::string, uint32_t> overflow_tasks_;
120   static std::vector<rtExceptionInfo> task_fail_infoes_;
121   std::map<uint32_t, void *> stream_id_map_;
122   std::map<std::string, uint32_t> group_stream_id_map_;
123 };
124 
125 MS_REG_KERNEL_RUNTIME(kAscendDevice, AscendKernelRuntime);
126 }  // namespace mindspore::device::ascend
127 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_KERNEL_RUNTIME_H_
128