1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H 17 #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H 18 19 #include <condition_variable> 20 #include <list> 21 #include <map> 22 #include <memory> 23 #include <mutex> 24 #include <atomic> 25 #include <queue> 26 #include <string> 27 #include <thread> 28 #include <utility> 29 #include <vector> 30 #include "backend/session/session_basic.h" 31 #include "ir/anf.h" 32 #include "ir/tensor.h" 33 #include "utils/any.h" 34 #include "utils/comm_manager.h" 35 #include "utils/contract.h" 36 37 namespace mindspore { 38 namespace session { 39 enum TaskType { 40 kUnKnown, 41 kExit, 42 kCompileNodes, 43 kCompileGraph, 44 kBuildGraph, 45 kRunGraph, 46 kRunOp, 47 kCreateCommGroup, 48 kDestroyCommGroup, 49 kRunOpsInGraph 50 }; 51 52 class Task { 53 public: 54 Task() = default; 55 virtual ~Task() = default; 56 SessionPtr session_{nullptr}; 57 TaskType type_{kUnKnown}; 58 bool sync_run_{false}; Run()59 virtual void Run() {} 60 }; 61 62 class CompileNodesTask : public Task { 63 public: CompileNodesTask()64 CompileNodesTask() { type_ = kCompileNodes; } 65 ~CompileNodesTask() override = default; 66 void Run() override; 67 GraphSegmentPtr segment_; 68 AnfNodePtrList output_nodes_; 69 GraphId graph_id_{0}; 70 }; 71 72 class CompileGraphTask : public Task { 73 public: CompileGraphTask()74 CompileGraphTask() { type_ = kCompileGraph; } 75 ~CompileGraphTask() override = default; 76 void Run() override; 77 FuncGraphPtr func_graph_{nullptr}; 78 GraphId graph_id_{0}; 79 }; 80 81 class BuildGraphTask : public Task { 82 public: BuildGraphTask()83 BuildGraphTask() { type_ = kBuildGraph; } 84 ~BuildGraphTask() override = default; 85 void Run() override; 86 GraphId graph_id_{0}; 87 }; 88 89 class RunGraphTask : public Task { 90 public: RunGraphTask()91 RunGraphTask() { type_ = kRunGraph; } 92 ~RunGraphTask() override = default; 93 void Run() override; 94 std::vector<tensor::TensorPtr> input_tensors_; 95 std::vector<tensor::TensorPtr> input_need_wait_tensors_; 96 std::vector<tensor::TensorPtr> input_need_lock_tensors_; 97 VectorRef outputs_; 98 GraphId graph_id_{0}; 99 std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_; 100 }; 101 102 class RunOpsInGraphTask : public Task { 103 public: RunOpsInGraphTask()104 RunOpsInGraphTask() { type_ = kRunOpsInGraph; } 105 ~RunOpsInGraphTask() override = default; 106 void Run() override; 107 std::vector<tensor::TensorPtr> input_tensors_; 108 VectorRef outputs_; 109 GraphId graph_id_{0}; 110 }; 111 112 class RunOpTask : public Task { 113 public: RunOpTask()114 RunOpTask() { type_ = kRunOp; } 115 ~RunOpTask() override = default; 116 void Run() override; 117 OpRunInfo *op_run_info_{nullptr}; 118 GraphInfo graph_info_; 119 std::vector<tensor::TensorPtr> *input_tensors_{nullptr}; 120 VectorRef outputs_; 121 std::vector<int64_t> tensors_mask_; 122 }; 123 124 class CreateCommGroupTask : public Task { 125 public: CreateCommGroupTask()126 CreateCommGroupTask() { type_ = kCreateCommGroup; } 127 ~CreateCommGroupTask() override = default; 128 void Run() override; 129 std::string group_name_; 130 std::vector<uint32_t> ranks_; 131 bool result_{false}; 132 }; 133 134 class DestroyCommGroupTask : public Task { 135 public: DestroyCommGroupTask()136 DestroyCommGroupTask() { type_ = kDestroyCommGroup; } 137 ~DestroyCommGroupTask() override = default; 138 void Run() override; 139 std::string group_name_; 140 bool result_{false}; 141 }; 142 143 class ExitTask : public Task { 144 public: ExitTask()145 ExitTask() { type_ = kExit; } 146 ~ExitTask() override = default; 147 }; 148 149 enum class ExecutorEvent { kClear, kRunGraphFinished, kException }; 150 151 class Executor { 152 public: 153 Executor(const std::string &device_name, uint32_t device_id); 154 ~Executor(); 155 void WorkerLoop(); 156 void WorkerJoin(); 157 GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); 158 GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph); 159 void BuildGraph(const SessionPtr &session, GraphId graphId); 160 void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, 161 VectorRef *outputs); 162 void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, 163 VectorRef *outputs); 164 void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, 165 std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs, 166 const std::vector<int64_t> &tensors_mask); 167 void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, 168 VectorRef *outputs); 169 bool CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks); 170 bool DestroyCommGroup(const std::string &group_name); 171 void OnEvent(const ExecutorEvent &event); 172 173 private: 174 void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false); 175 std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList(); 176 void OnWorkerExit(); 177 void OnClear(); 178 void OnRunGraphFinished(); 179 void OnException(); 180 void ClearDoneTasks(); 181 182 uint32_t device_id_; 183 std::string device_name_; 184 std::mutex task_mutex_; 185 std::mutex done_task_mutex_; 186 std::mutex pending_task_mutex_; 187 std::mutex reenter_mutex_; 188 std::condition_variable task_cond_var_; 189 std::condition_variable sync_cond_var_; 190 std::condition_variable reenter_cond_var_; 191 std::queue<std::shared_ptr<Task>> ready_tasks_; 192 std::list<std::shared_ptr<RunGraphTask>> pending_tasks_; 193 std::vector<std::shared_ptr<Task>> done_tasks_; 194 std::shared_ptr<std::thread> worker_; 195 bool sync_run_task_finished_{false}; 196 }; 197 } // namespace session 198 } // namespace mindspore 199 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H 200