• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_PYNATIVE_TASK_MANAGER_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_PYNATIVE_TASK_MANAGER_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <queue>
23 #include <map>
24 #include <string>
25 #include <utility>
26 #include "backend/session/kernel_graph.h"
27 #include "backend/session/anf_runtime_algorithm.h"
28 
29 namespace mindspore {
30 namespace session {
31 class RunOpContext {
32  public:
RunOpContext(std::string graph_info,bool is_dynamic_shape,KernelGraphPtr graph,std::vector<int64_t> tensors_mask,std::vector<tensor::TensorPtr> input_tensors,std::map<tensor::TensorPtr,KernelWithIndex> tensor_to_node)33   RunOpContext(std::string graph_info, bool is_dynamic_shape, KernelGraphPtr graph, std::vector<int64_t> tensors_mask,
34                std::vector<tensor::TensorPtr> input_tensors,
35                std::map<tensor::TensorPtr, KernelWithIndex> tensor_to_node)
36       : graph_info_(std::move(graph_info)),
37         is_dynamic_shape_(is_dynamic_shape),
38         graph_(std::move(graph)),
39         tensors_mask_(std::move(tensors_mask)),
40         input_tensors_(std::move(input_tensors)),
41         tensor_to_node_(std::move(tensor_to_node)) {}
42   ~RunOpContext() = default;
43 
graph()44   const KernelGraphPtr &graph() const { return graph_; }
is_dynamic_shape()45   bool is_dynamic_shape() const { return is_dynamic_shape_; }
tensor_mask()46   const std::vector<int64_t> &tensor_mask() const { return tensors_mask_; }
input_tensors()47   const std::vector<tensor::TensorPtr> &input_tensors() const { return input_tensors_; }
tensor_to_node()48   const std::map<tensor::TensorPtr, KernelWithIndex> &tensor_to_node() const { return tensor_to_node_; }
49 
50  private:
51   std::string graph_info_;
52   bool is_dynamic_shape_;
53   KernelGraphPtr graph_;
54   std::vector<int64_t> tensors_mask_;
55   std::vector<tensor::TensorPtr> input_tensors_;
56   std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_;
57 };
58 
59 enum SessionTaskType {
60   kUnknowTask = 0,
61   kBuildTask,
62   kLaunchTask,
63 };
64 
65 class SessionTask {
66  public:
SessionTask(SessionTaskType type,std::shared_ptr<RunOpContext> context)67   explicit SessionTask(SessionTaskType type, std::shared_ptr<RunOpContext> context)
68       : type_(type), context_(std::move(context)) {}
69   virtual ~SessionTask() = default;
70   virtual void Run() = 0;
context()71   const std::shared_ptr<RunOpContext> &context() { return context_; }
72 
73  protected:
74   SessionTaskType type_;
75   std::shared_ptr<RunOpContext> context_;
76 };
77 
78 class BuildTask : public SessionTask {
79  public:
BuildTask(std::shared_ptr<RunOpContext> context)80   explicit BuildTask(std::shared_ptr<RunOpContext> context)
81       : SessionTask(SessionTaskType::kBuildTask, std::move(context)) {}
82   ~BuildTask() override = default;
83   // Parallel build
Run()84   void Run() override {}
85 };
86 
87 class LaunchTask : public SessionTask {
88  public:
LaunchTask(std::shared_ptr<RunOpContext> context)89   explicit LaunchTask(std::shared_ptr<RunOpContext> context)
90       : SessionTask(SessionTaskType::kLaunchTask, std::move(context)) {}
91   ~LaunchTask() override = default;
Run()92   void Run() override {}
93 };
94 
95 class PynativeTaskManager {
96  public:
GetInstance()97   static PynativeTaskManager &GetInstance() {
98     static PynativeTaskManager instance;
99     return instance;
100   }
101 
102   class ExecuteGuard {
103    public:
ExecuteGuard()104     ExecuteGuard() { PynativeTaskManager::GetInstance().executing_ = true; }
~ExecuteGuard()105     ~ExecuteGuard() { PynativeTaskManager::GetInstance().executing_ = false; }
106   };
107 
Init(const std::function<void ()> & execute_all)108   void Init(const std::function<void()> &execute_all) {
109     execute_all_ = execute_all;
110     inited_ = true;
111   }
GetAllBuildTasks()112   const std::vector<std::shared_ptr<SessionTask>> &GetAllBuildTasks() const { return build_tasks_; }
GetAllLaunchTasks()113   const std::queue<std::shared_ptr<SessionTask>> &GetAllLaunchTasks() const { return launch_tasks_; }
PopLaunchTask()114   void PopLaunchTask() { launch_tasks_.pop(); }
ClearAllBuildTasks()115   void ClearAllBuildTasks() { build_tasks_.clear(); }
Reset()116   void Reset() {
117     ClearAllResources();
118     execute_all_ = nullptr;
119     inited_ = false;
120   }
ClearAllResources()121   void ClearAllResources() {
122     build_tasks_.clear();
123     std::queue<std::shared_ptr<SessionTask>> empty;
124     std::swap(launch_tasks_, empty);
125   }
ExecuteRemainingTasks()126   void ExecuteRemainingTasks() {
127     if (!executing_) {
128       ExecuteGuard guard;
129       if (execute_all_ != nullptr) {
130         execute_all_();
131       }
132     }
133   }
134 
PushBuildTask(const std::shared_ptr<SessionTask> & build_task)135   void PushBuildTask(const std::shared_ptr<SessionTask> &build_task) { build_tasks_.push_back(build_task); }
PushLaunchTask(const std::shared_ptr<SessionTask> & launch_task)136   void PushLaunchTask(const std::shared_ptr<SessionTask> &launch_task) { launch_tasks_.push(launch_task); }
QueueEmpty()137   [[nodiscard]] bool QueueEmpty() const { return launch_tasks_.empty() && build_tasks_.empty(); }
QueueFull()138   [[nodiscard]] bool QueueFull() const {
139     return build_tasks_.size() > kMaxQueueSize || launch_tasks_.size() > kMaxQueueSize;
140   }
inited()141   [[nodiscard]] bool inited() const { return inited_; }
142 
143  private:
144   std::vector<std::shared_ptr<SessionTask>> build_tasks_;
145   std::queue<std::shared_ptr<SessionTask>> launch_tasks_;
146   std::function<void()> execute_all_{nullptr};
147   inline static size_t kMaxQueueSize = 100;
148   bool executing_{false};
149   bool inited_{false};
150 };
151 }  // namespace session
152 }  // namespace mindspore
153 #endif  // MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_PYNATIVE_TASK_MANAGER_H_
154