1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H 17 #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H 18 19 #include <condition_variable> 20 #include <list> 21 #include <map> 22 #include <memory> 23 #include <mutex> 24 #include <atomic> 25 #include <queue> 26 #include <string> 27 #include <thread> 28 #include <vector> 29 #include <utility> 30 #include "backend/common/session/session_basic.h" 31 #include "ir/anf.h" 32 #include "ir/tensor.h" 33 #include "utils/any.h" 34 #include "include/common/utils/comm_manager.h" 35 #include "include/common/utils/contract.h" 36 #include "include/backend/visible.h" 37 38 namespace mindspore::session { 39 enum TaskType { 40 kUnKnown, 41 kExit, 42 kCompileNodes, 43 kCompileGraph, 44 kBuildGraph, 45 kRunGraph, 46 kRunOp, 47 kCreateCommGroup, 48 kDestroyCommGroup, 49 kRunOpsInGraph 50 }; 51 52 class Task { 53 public: 54 Task() = default; 55 virtual ~Task() = default; 56 SessionPtr session_{nullptr}; 57 TaskType type_{kUnKnown}; 58 bool sync_run_{false}; Run()59 virtual void Run() {} 60 }; 61 62 class CompileNodesTask : public Task { 63 public: CompileNodesTask()64 CompileNodesTask() { type_ = kCompileNodes; } 65 ~CompileNodesTask() override = default; 66 void Run() override; 67 GraphSegmentPtr segment_; 68 AnfNodePtrList output_nodes_; 69 GraphId graph_id_{0}; 70 }; 71 72 class CompileGraphTask : public Task { 73 public: CompileGraphTask()74 CompileGraphTask() { type_ = kCompileGraph; } 75 ~CompileGraphTask() override = default; 76 void Run() override; 77 FuncGraphPtr func_graph_{nullptr}; 78 GraphId graph_id_{0}; 79 }; 80 81 class BuildGraphTask : public Task { 82 public: BuildGraphTask()83 BuildGraphTask() { type_ = kBuildGraph; } 84 ~BuildGraphTask() override = default; 85 void Run() override; 86 GraphId graph_id_{0}; 87 }; 88 89 class RunGraphTask : public Task { 90 public: RunGraphTask()91 RunGraphTask() { type_ = kRunGraph; } 92 ~RunGraphTask() override = default; 93 void Run() override; 94 std::vector<tensor::TensorPtr> input_tensors_; 95 VectorRef outputs_; 96 GraphId graph_id_{0}; 97 std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_; 98 KernelMapTensor node_to_tensor_; 99 }; 100 101 class CreateCommGroupTask : public Task { 102 public: CreateCommGroupTask()103 CreateCommGroupTask() { type_ = kCreateCommGroup; } 104 ~CreateCommGroupTask() override = default; 105 void Run() override; 106 std::string group_name_; 107 std::vector<uint32_t> ranks_; 108 bool result_{false}; 109 }; 110 111 class DestroyCommGroupTask : public Task { 112 public: DestroyCommGroupTask()113 DestroyCommGroupTask() { type_ = kDestroyCommGroup; } 114 ~DestroyCommGroupTask() override = default; 115 void Run() override; 116 std::string group_name_; 117 bool result_{false}; 118 }; 119 120 class ExitTask : public Task { 121 public: ExitTask()122 ExitTask() { type_ = kExit; } 123 ~ExitTask() override = default; 124 }; 125 126 enum class ExecutorEvent { kClear, kRunGraphFinished, kException }; 127 128 class BACKEND_EXPORT Executor { 129 public: Executor(std::string device_name,uint32_t device_id)130 Executor(std::string device_name, uint32_t device_id) : device_name_(std::move(device_name)), device_id_(device_id) { 131 worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this); 132 } 133 ~Executor(); 134 void WorkerLoop(); 135 void WorkerJoin(); 136 GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); 137 GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph); 138 void BuildGraph(const SessionPtr &session, GraphId graphId); 139 void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, 140 VectorRef *outputs); 141 void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, 142 VectorRef *outputs); 143 bool CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks); 144 bool DestroyCommGroup(const std::string &group_name); 145 void OnEvent(const ExecutorEvent &event); 146 void ClearDoneTasks(); 147 148 private: 149 void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false); 150 std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList(); 151 void OnWorkerExit(); 152 void OnClear(); 153 void OnRunGraphFinished(); 154 void OnException(); 155 156 std::string device_name_; 157 uint32_t device_id_; 158 std::mutex task_mutex_; 159 std::mutex done_task_mutex_; 160 std::mutex pending_task_mutex_; 161 std::mutex reenter_mutex_; 162 std::condition_variable task_cond_var_; 163 std::condition_variable sync_cond_var_; 164 std::condition_variable reenter_cond_var_; 165 std::queue<std::shared_ptr<Task>> ready_tasks_; 166 std::list<std::shared_ptr<RunGraphTask>> pending_tasks_; 167 std::vector<std::shared_ptr<Task>> done_tasks_; 168 std::shared_ptr<std::thread> worker_; 169 bool sync_run_task_finished_{false}; 170 }; 171 } // namespace mindspore::session 172 #endif // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H 173