• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
17 #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
18 
19 #include <condition_variable>
20 #include <list>
21 #include <map>
22 #include <memory>
23 #include <mutex>
24 #include <atomic>
25 #include <queue>
26 #include <string>
27 #include <thread>
28 #include <vector>
29 #include <utility>
30 #include "backend/common/session/session_basic.h"
31 #include "ir/anf.h"
32 #include "ir/tensor.h"
33 #include "utils/any.h"
34 #include "include/common/utils/comm_manager.h"
35 #include "include/common/utils/contract.h"
36 #include "include/backend/visible.h"
37 
38 namespace mindspore::session {
39 enum TaskType {
40   kUnKnown,
41   kExit,
42   kCompileNodes,
43   kCompileGraph,
44   kBuildGraph,
45   kRunGraph,
46   kRunOp,
47   kCreateCommGroup,
48   kDestroyCommGroup,
49   kRunOpsInGraph
50 };
51 
52 class Task {
53  public:
54   Task() = default;
55   virtual ~Task() = default;
56   SessionPtr session_{nullptr};
57   TaskType type_{kUnKnown};
58   bool sync_run_{false};
Run()59   virtual void Run() {}
60 };
61 
62 class CompileNodesTask : public Task {
63  public:
CompileNodesTask()64   CompileNodesTask() { type_ = kCompileNodes; }
65   ~CompileNodesTask() override = default;
66   void Run() override;
67   GraphSegmentPtr segment_;
68   AnfNodePtrList output_nodes_;
69   GraphId graph_id_{0};
70 };
71 
72 class CompileGraphTask : public Task {
73  public:
CompileGraphTask()74   CompileGraphTask() { type_ = kCompileGraph; }
75   ~CompileGraphTask() override = default;
76   void Run() override;
77   FuncGraphPtr func_graph_{nullptr};
78   GraphId graph_id_{0};
79 };
80 
81 class BuildGraphTask : public Task {
82  public:
BuildGraphTask()83   BuildGraphTask() { type_ = kBuildGraph; }
84   ~BuildGraphTask() override = default;
85   void Run() override;
86   GraphId graph_id_{0};
87 };
88 
89 class RunGraphTask : public Task {
90  public:
RunGraphTask()91   RunGraphTask() { type_ = kRunGraph; }
92   ~RunGraphTask() override = default;
93   void Run() override;
94   std::vector<tensor::TensorPtr> input_tensors_;
95   VectorRef outputs_;
96   GraphId graph_id_{0};
97   std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
98   KernelMapTensor node_to_tensor_;
99 };
100 
101 class CreateCommGroupTask : public Task {
102  public:
CreateCommGroupTask()103   CreateCommGroupTask() { type_ = kCreateCommGroup; }
104   ~CreateCommGroupTask() override = default;
105   void Run() override;
106   std::string group_name_;
107   std::vector<uint32_t> ranks_;
108   bool result_{false};
109 };
110 
111 class DestroyCommGroupTask : public Task {
112  public:
DestroyCommGroupTask()113   DestroyCommGroupTask() { type_ = kDestroyCommGroup; }
114   ~DestroyCommGroupTask() override = default;
115   void Run() override;
116   std::string group_name_;
117   bool result_{false};
118 };
119 
120 class ExitTask : public Task {
121  public:
ExitTask()122   ExitTask() { type_ = kExit; }
123   ~ExitTask() override = default;
124 };
125 
126 enum class ExecutorEvent { kClear, kRunGraphFinished, kException };
127 
128 class BACKEND_EXPORT Executor {
129  public:
Executor(std::string device_name,uint32_t device_id)130   Executor(std::string device_name, uint32_t device_id) : device_name_(std::move(device_name)), device_id_(device_id) {
131     worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
132   }
133   ~Executor();
134   void WorkerLoop();
135   void WorkerJoin();
136   GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs);
137   GraphId CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph);
138   void BuildGraph(const SessionPtr &session, GraphId graphId);
139   void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
140                 VectorRef *outputs);
141   void RunGraphAsync(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
142                      VectorRef *outputs);
143   bool CreateCommGroup(const std::string &group_name, const std::vector<uint32_t> &ranks);
144   bool DestroyCommGroup(const std::string &group_name);
145   void OnEvent(const ExecutorEvent &event);
146   void ClearDoneTasks();
147 
148  private:
149   void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
150   std::vector<std::shared_ptr<RunGraphTask>> GetReadyTasksFromPendingList();
151   void OnWorkerExit();
152   void OnClear();
153   void OnRunGraphFinished();
154   void OnException();
155 
156   std::string device_name_;
157   uint32_t device_id_;
158   std::mutex task_mutex_;
159   std::mutex done_task_mutex_;
160   std::mutex pending_task_mutex_;
161   std::mutex reenter_mutex_;
162   std::condition_variable task_cond_var_;
163   std::condition_variable sync_cond_var_;
164   std::condition_variable reenter_cond_var_;
165   std::queue<std::shared_ptr<Task>> ready_tasks_;
166   std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
167   std::vector<std::shared_ptr<Task>> done_tasks_;
168   std::shared_ptr<std::thread> worker_;
169   bool sync_run_task_finished_{false};
170 };
171 }  // namespace mindspore::session
172 #endif  // MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H
173