1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_TASK_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_TASK_H_ 19 20 #include <functional> 21 #include <utility> 22 #include <vector> 23 #include <memory> 24 #include "runtime/pipeline/task/task.h" 25 #include "pipeline/pynative/base.h" 26 #include "backend/common/session/session_basic.h" 27 28 namespace mindspore { 29 namespace pynative { 30 class FrontendTask : public runtime::AsyncTask { 31 public: FrontendTask(std::function<void (const FrontendOpRunInfoPtr & op_run_info)> run_func,FrontendOpRunInfoPtr op_run_info)32 FrontendTask(std::function<void(const FrontendOpRunInfoPtr &op_run_info)> run_func, FrontendOpRunInfoPtr op_run_info) 33 : AsyncTask(runtime::kFrontendTask), run_func_(std::move(run_func)), op_run_info_(std::move(op_run_info)) {} 34 ~FrontendTask() override = default; 35 void Run() override; 36 void SetException(const std::exception_ptr &e) override; 37 38 private: 39 std::function<void(const FrontendOpRunInfoPtr &op_run_info)> run_func_; 40 FrontendOpRunInfoPtr op_run_info_; 41 }; 42 43 class PassthroughFrontendTask : public runtime::AsyncTask { 44 public: PassthroughFrontendTask(std::function<void (void)> run_func)45 explicit PassthroughFrontendTask(std::function<void(void)> run_func) 46 : AsyncTask(runtime::kFrontendTask), run_func_(std::move(run_func)) {} 47 ~PassthroughFrontendTask() override = default; 48 void Run() override; 49 50 private: 51 std::function<void(void)> run_func_; 52 }; 53 54 class SliceOpFrontendTask : public runtime::AsyncTask { 55 public: SliceOpFrontendTask(std::function<void (const std::vector<ValuePtr> & input_values,const std::vector<SliceOpInfoPtr> & slice_op_infos,bool requires_grad,const stub::StubNodePtr & stub_output)> run_func,std::vector<ValuePtr> input_values,std::vector<SliceOpInfoPtr> slice_op_infos,bool requires_grad,const stub::StubNodePtr & stub_output)56 SliceOpFrontendTask( 57 std::function<void(const std::vector<ValuePtr> &input_values, const std::vector<SliceOpInfoPtr> &slice_op_infos, 58 bool requires_grad, const stub::StubNodePtr &stub_output)> 59 run_func, 60 std::vector<ValuePtr> input_values, std::vector<SliceOpInfoPtr> slice_op_infos, bool requires_grad, 61 const stub::StubNodePtr &stub_output) 62 : AsyncTask(runtime::kFrontendTask), 63 run_func_(std::move(run_func)), 64 input_values_(std::move(input_values)), 65 slice_op_infos_(std::move(slice_op_infos)), 66 requires_grad_(requires_grad), 67 stub_output_(stub_output) {} 68 ~SliceOpFrontendTask() override = default; 69 void Run() override; 70 void SetException(const std::exception_ptr &e) override; 71 72 private: 73 std::function<void(const std::vector<ValuePtr> &input_values, const std::vector<SliceOpInfoPtr> &slice_op_infos, 74 bool requires_grad, const stub::StubNodePtr &stub_output)> 75 run_func_; 76 std::vector<ValuePtr> input_values_; 77 std::vector<SliceOpInfoPtr> slice_op_infos_; 78 bool requires_grad_{false}; 79 stub::StubNodePtr stub_output_; 80 }; 81 82 using BackendOpRunInfoPtr = session::BackendOpRunInfoPtr; 83 class BackendTask : public runtime::AsyncTask { 84 public: BackendTask(std::function<void (const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)> run_func,FrontendOpRunInfoPtr op_run_info,BackendOpRunInfoPtr backend_op_run_info)85 BackendTask( 86 std::function<void(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info)> 87 run_func, 88 FrontendOpRunInfoPtr op_run_info, BackendOpRunInfoPtr backend_op_run_info) 89 : AsyncTask(runtime::kBackendTask), 90 run_func_(std::move(run_func)), 91 op_run_info_(std::move(op_run_info)), 92 backend_op_run_info_(std::move(backend_op_run_info)) {} 93 ~BackendTask() override = default; 94 void Run() override; 95 96 private: 97 std::function<void(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info)> 98 run_func_; 99 FrontendOpRunInfoPtr op_run_info_; 100 BackendOpRunInfoPtr backend_op_run_info_; 101 }; 102 103 class ViewKernelBackendTask : public runtime::AsyncTask { 104 public: ViewKernelBackendTask(std::function<void (const FrontendOpRunInfoPtr & op_run_info,const runtime::KernelTaskType & task_type)> run_func,FrontendOpRunInfoPtr op_run_info,const runtime::KernelTaskType & task_type)105 ViewKernelBackendTask( 106 std::function<void(const FrontendOpRunInfoPtr &op_run_info, const runtime::KernelTaskType &task_type)> run_func, 107 FrontendOpRunInfoPtr op_run_info, const runtime::KernelTaskType &task_type) 108 : AsyncTask(runtime::kBackendTask), 109 run_func_(std::move(run_func)), 110 op_run_info_(std::move(op_run_info)), 111 task_type_(task_type) {} 112 ~ViewKernelBackendTask() override = default; 113 void Run() override; 114 115 private: 116 std::function<void(const FrontendOpRunInfoPtr &op_run_info, const runtime::KernelTaskType &task_type)> run_func_; 117 FrontendOpRunInfoPtr op_run_info_; 118 runtime::KernelTaskType task_type_; 119 }; 120 121 class AllocViewMemBackendTask : public runtime::AsyncTask { 122 public: AllocViewMemBackendTask(std::function<void (const FrontendOpRunInfoPtr & op_run_info,const tensor::TensorPtr & input_tensor,const size_t & input_idx,bool need_wait)> run_func,FrontendOpRunInfoPtr op_run_info,const tensor::TensorPtr & input_tensor,const size_t & input_idx,bool need_wait)123 AllocViewMemBackendTask( 124 std::function<void(const FrontendOpRunInfoPtr &op_run_info, const tensor::TensorPtr &input_tensor, 125 const size_t &input_idx, bool need_wait)> 126 run_func, 127 FrontendOpRunInfoPtr op_run_info, const tensor::TensorPtr &input_tensor, const size_t &input_idx, bool need_wait) 128 : AsyncTask(runtime::kBackendTask), 129 run_func_(std::move(run_func)), 130 op_run_info_(std::move(op_run_info)), 131 input_tensor_(input_tensor), 132 input_idx_(input_idx), 133 need_wait_(need_wait) {} 134 ~AllocViewMemBackendTask() override = default; 135 void Run() override; 136 void SetException(const std::exception_ptr &e) override; 137 138 private: 139 std::function<void(const FrontendOpRunInfoPtr &op_run_info, const tensor::TensorPtr &input_tensor, 140 const size_t &input_idx, bool need_wait)> 141 run_func_; 142 FrontendOpRunInfoPtr op_run_info_; 143 tensor::TensorPtr input_tensor_; 144 size_t input_idx_{0}; 145 bool need_wait_{false}; 146 }; 147 148 class ContiguousBackendTask : public runtime::AsyncTask { 149 public: ContiguousBackendTask(std::function<void (const tensor::TensorPtr & tensor)> run_func,const tensor::TensorPtr & tensor)150 ContiguousBackendTask(std::function<void(const tensor::TensorPtr &tensor)> run_func, const tensor::TensorPtr &tensor) 151 : AsyncTask(runtime::kBackendTask), run_func_(std::move(run_func)), tensor_(tensor) {} 152 ~ContiguousBackendTask() override = default; 153 void Run() override; 154 155 private: 156 std::function<void(const tensor::TensorPtr &tensor)> run_func_; 157 tensor::TensorPtr tensor_; 158 }; 159 } // namespace pynative 160 } // namespace mindspore 161 #endif // MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_TASK_H_ 162