1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_PYNATIVE_TASK_MANAGER_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_PYNATIVE_TASK_MANAGER_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <queue> 23 #include <map> 24 #include <string> 25 #include <utility> 26 #include "backend/session/kernel_graph.h" 27 #include "backend/session/anf_runtime_algorithm.h" 28 29 namespace mindspore { 30 namespace session { 31 class RunOpContext { 32 public: RunOpContext(std::string graph_info,bool is_dynamic_shape,KernelGraphPtr graph,std::vector<int64_t> tensors_mask,std::vector<tensor::TensorPtr> input_tensors,std::map<tensor::TensorPtr,KernelWithIndex> tensor_to_node)33 RunOpContext(std::string graph_info, bool is_dynamic_shape, KernelGraphPtr graph, std::vector<int64_t> tensors_mask, 34 std::vector<tensor::TensorPtr> input_tensors, 35 std::map<tensor::TensorPtr, KernelWithIndex> tensor_to_node) 36 : graph_info_(std::move(graph_info)), 37 is_dynamic_shape_(is_dynamic_shape), 38 graph_(std::move(graph)), 39 tensors_mask_(std::move(tensors_mask)), 40 input_tensors_(std::move(input_tensors)), 41 tensor_to_node_(std::move(tensor_to_node)) {} 42 ~RunOpContext() = default; 43 graph()44 const KernelGraphPtr &graph() const { return graph_; } is_dynamic_shape()45 bool is_dynamic_shape() const { return is_dynamic_shape_; } tensor_mask()46 const std::vector<int64_t> &tensor_mask() const { return tensors_mask_; } input_tensors()47 const std::vector<tensor::TensorPtr> &input_tensors() const { return input_tensors_; } tensor_to_node()48 const std::map<tensor::TensorPtr, KernelWithIndex> &tensor_to_node() const { return tensor_to_node_; } 49 50 private: 51 std::string graph_info_; 52 bool is_dynamic_shape_; 53 KernelGraphPtr graph_; 54 std::vector<int64_t> tensors_mask_; 55 std::vector<tensor::TensorPtr> input_tensors_; 56 std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_; 57 }; 58 59 enum SessionTaskType { 60 kUnknowTask = 0, 61 kBuildTask, 62 kLaunchTask, 63 }; 64 65 class SessionTask { 66 public: SessionTask(SessionTaskType type,std::shared_ptr<RunOpContext> context)67 explicit SessionTask(SessionTaskType type, std::shared_ptr<RunOpContext> context) 68 : type_(type), context_(std::move(context)) {} 69 virtual ~SessionTask() = default; 70 virtual void Run() = 0; context()71 const std::shared_ptr<RunOpContext> &context() { return context_; } 72 73 protected: 74 SessionTaskType type_; 75 std::shared_ptr<RunOpContext> context_; 76 }; 77 78 class BuildTask : public SessionTask { 79 public: BuildTask(std::shared_ptr<RunOpContext> context)80 explicit BuildTask(std::shared_ptr<RunOpContext> context) 81 : SessionTask(SessionTaskType::kBuildTask, std::move(context)) {} 82 ~BuildTask() override = default; 83 // Parallel build Run()84 void Run() override {} 85 }; 86 87 class LaunchTask : public SessionTask { 88 public: LaunchTask(std::shared_ptr<RunOpContext> context)89 explicit LaunchTask(std::shared_ptr<RunOpContext> context) 90 : SessionTask(SessionTaskType::kLaunchTask, std::move(context)) {} 91 ~LaunchTask() override = default; Run()92 void Run() override {} 93 }; 94 95 class PynativeTaskManager { 96 public: GetInstance()97 static PynativeTaskManager &GetInstance() { 98 static PynativeTaskManager instance; 99 return instance; 100 } 101 102 class ExecuteGuard { 103 public: ExecuteGuard()104 ExecuteGuard() { PynativeTaskManager::GetInstance().executing_ = true; } ~ExecuteGuard()105 ~ExecuteGuard() { PynativeTaskManager::GetInstance().executing_ = false; } 106 }; 107 Init(const std::function<void ()> & execute_all)108 void Init(const std::function<void()> &execute_all) { 109 execute_all_ = execute_all; 110 inited_ = true; 111 } GetAllBuildTasks()112 const std::vector<std::shared_ptr<SessionTask>> &GetAllBuildTasks() const { return build_tasks_; } GetAllLaunchTasks()113 const std::queue<std::shared_ptr<SessionTask>> &GetAllLaunchTasks() const { return launch_tasks_; } PopLaunchTask()114 void PopLaunchTask() { launch_tasks_.pop(); } ClearAllBuildTasks()115 void ClearAllBuildTasks() { build_tasks_.clear(); } Reset()116 void Reset() { 117 ClearAllResources(); 118 execute_all_ = nullptr; 119 inited_ = false; 120 } ClearAllResources()121 void ClearAllResources() { 122 build_tasks_.clear(); 123 std::queue<std::shared_ptr<SessionTask>> empty; 124 std::swap(launch_tasks_, empty); 125 } ExecuteRemainingTasks()126 void ExecuteRemainingTasks() { 127 if (!executing_) { 128 ExecuteGuard guard; 129 if (execute_all_ != nullptr) { 130 execute_all_(); 131 } 132 } 133 } 134 PushBuildTask(const std::shared_ptr<SessionTask> & build_task)135 void PushBuildTask(const std::shared_ptr<SessionTask> &build_task) { build_tasks_.push_back(build_task); } PushLaunchTask(const std::shared_ptr<SessionTask> & launch_task)136 void PushLaunchTask(const std::shared_ptr<SessionTask> &launch_task) { launch_tasks_.push(launch_task); } QueueEmpty()137 [[nodiscard]] bool QueueEmpty() const { return launch_tasks_.empty() && build_tasks_.empty(); } QueueFull()138 [[nodiscard]] bool QueueFull() const { 139 return build_tasks_.size() > kMaxQueueSize || launch_tasks_.size() > kMaxQueueSize; 140 } inited()141 [[nodiscard]] bool inited() const { return inited_; } 142 143 private: 144 std::vector<std::shared_ptr<SessionTask>> build_tasks_; 145 std::queue<std::shared_ptr<SessionTask>> launch_tasks_; 146 std::function<void()> execute_all_{nullptr}; 147 inline static size_t kMaxQueueSize = 100; 148 bool executing_{false}; 149 bool inited_{false}; 150 }; 151 } // namespace session 152 } // namespace mindspore 153 #endif // MINDSPORE_MINDSPORE_CCSRC_BACKEND_SESSION_PYNATIVE_TASK_MANAGER_H_ 154