• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
17 #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
18 
19 #include <condition_variable>
20 #include <list>
21 #include <map>
22 #include <memory>
23 #include <mutex>
24 #include <atomic>
25 #include <queue>
26 #include <string>
27 #include <thread>
28 #include <utility>
29 #include <vector>
30 #include "backend/session/session_basic.h"
31 #include "ir/anf.h"
32 #include "ir/tensor.h"
33 #include "utils/any.h"
34 #include "utils/comm_manager.h"
35 #include "utils/contract.h"
36 
37 namespace mindspore {
38 namespace session {
39 enum TaskType {
40   kUnKnown,
41   kExit,
42   kCompileNodes,
43   kCompileGraph,
44   kBuildGraph,
45   kRunGraph,
46   kRunOp,
47   kCreateCommGroup,
48   kDestroyCommGroup,
49   kRunOpsInGraph
50 };
51 
52 class Task {
53  public:
54   Task() = default;
55   virtual ~Task() = default;
56   SessionPtr session_{nullptr};
57   TaskType type_{kUnKnown};
58   bool sync_run_{false};
Run()59   virtual void Run() {}
60 };
61 
62 class CompileNodesTask : public Task {
63  public:
CompileNodesTask()64   CompileNodesTask() { type_ = kCompileNodes; }
65   ~CompileNodesTask() override = default;
66   void Run() override;
67   GraphSegmentPtr segment_;
68   AnfNodePtrList output_nodes_;
69   GraphId graph_id_{0};
70 };
71 
72 class CompileGraphTask : public Task {
73  public:
CompileGraphTask()74   CompileGraphTask() { type_ = kCompileGraph; }
75   ~CompileGraphTask() override = default;
76   void Run() override;
77   FuncGraphPtr func_graph_{nullptr};
78   GraphId graph_id_{0};
79 };
80 
81 class BuildGraphTask : public Task {
82  public:
BuildGraphTask()83   BuildGraphTask() { type_ = kBuildGraph; }
84   ~BuildGraphTask() override = default;
85   void Run() override;
86   GraphId graph_id_{0};
87 };
88 
89 class RunGraphTask : public Task {
90  public:
RunGraphTask()91   RunGraphTask() { type_ = kRunGraph; }
92   ~RunGraphTask() override = default;
93   void Run() override;
94   std::vector<tensor::TensorPtr> input_tensors_;
95   std::vector<tensor::TensorPtr> input_need_wait_tensors_;
96   std::vector<tensor::TensorPtr> input_need_lock_tensors_;
97   VectorRef outputs_;
98   GraphId graph_id_{0};
99   std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
100 };
101 
102 class RunOpsInGraphTask : public Task {
103  public:
RunOpsInGraphTask()104   RunOpsInGraphTask() { type_ = kRunOpsInGraph; }
105   ~RunOpsInGraphTask() override = default;
106   void Run() override;
107   std::vector<tensor::TensorPtr> input_tensors_;
108   VectorRef outputs_;
109   GraphId graph_id_{0};
110 };
111 
112 class RunOpTask : public Task {
113  public:
RunOpTask()114   RunOpTask() { type_ = kRunOp; }
115   ~RunOpTask() override = default;
116   void Run() override;
117   OpRunInfo *op_run_info_{nullptr};
118   GraphInfo graph_info_;
119   std::vector<tensor::TensorPtr> *input_tensors_{nullptr};
120   VectorRef outputs_;
121   std::vector<int64_t> tensors_mask_;
122 };
123 
124 class CreateCommGroupTask : public Task {
125  public:
CreateCommGroupTask()126   CreateCommGroupTask() { type_ = kCreateCommGroup; }
127   ~CreateCommGroupTask() override = default;
128   void Run() override;
129   std::string group_name_;
130   std::vector<uint32_t> ranks_;
131   bool result_{false};
132 };
133 
134 class DestroyCommGroupTask : public Task {
135  public:
DestroyCommGroupTask()136   DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
137   ~DestroyCommGroupTask() override = default;
138   void Run() override;
139   std::string group_name_;
140   bool result_{false};
141 };
142 
143 class ExitTask : public Task {
144  public:
ExitTask()145   ExitTask() { type_ = kExit; }
146   ~ExitTask() override = default;
147 };
148 
149 enum class ExecutorEvent { kClear, kRunGraphFinished, kException };
150 
151 class Executor {
152  public:
153   Executor(const std::string &device_name, uint32_t device_id);
154   ~Executor();
155   void WorkerLoop();
156   void WorkerJoin();
157   GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
158   GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph);
159   void BuildGraph(const SessionPtr &session, GraphId graphId);
160   void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
161                 VectorRef *outputs);
162   void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
163                      VectorRef *outputs);
164   void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
165              std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
166              const std::vector<int64_t> &tensors_mask);
167   void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
168                      VectorRef *outputs);
169   bool CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks);
170   bool DestroyCommGroup(const std::string &group_name);
171   void OnEvent(const ExecutorEvent &event);
172 
173  private:
174   void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
175   std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList();
176   void OnWorkerExit();
177   void OnClear();
178   void OnRunGraphFinished();
179   void OnException();
180   void ClearDoneTasks();
181 
182   uint32_t device_id_;
183   std::string device_name_;
184   std::mutex task_mutex_;
185   std::mutex done_task_mutex_;
186   std::mutex pending_task_mutex_;
187   std::mutex reenter_mutex_;
188   std::condition_variable task_cond_var_;
189   std::condition_variable sync_cond_var_;
190   std::condition_variable reenter_cond_var_;
191   std::queue<std::shared_ptr<Task>> ready_tasks_;
192   std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
193   std::vector<std::shared_ptr<Task>> done_tasks_;
194   std::shared_ptr<std::thread> worker_;
195   bool sync_run_task_finished_{false};
196 };
197 }  // namespace session
198 }  // namespace mindspore
199 #endif  // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
200