• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_TASK_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_TASK_H_
19 
20 #include <functional>
21 #include <utility>
22 #include <vector>
23 #include <memory>
24 #include "runtime/pipeline/task/task.h"
25 #include "pipeline/pynative/base.h"
26 #include "backend/common/session/session_basic.h"
27 
28 namespace mindspore {
29 namespace pynative {
30 class FrontendTask : public runtime::AsyncTask {
31  public:
FrontendTask(std::function<void (const FrontendOpRunInfoPtr & op_run_info)> run_func,FrontendOpRunInfoPtr op_run_info)32   FrontendTask(std::function<void(const FrontendOpRunInfoPtr &op_run_info)> run_func, FrontendOpRunInfoPtr op_run_info)
33       : AsyncTask(runtime::kFrontendTask), run_func_(std::move(run_func)), op_run_info_(std::move(op_run_info)) {}
34   ~FrontendTask() override = default;
35   void Run() override;
36   void SetException(const std::exception_ptr &e) override;
37 
38  private:
39   std::function<void(const FrontendOpRunInfoPtr &op_run_info)> run_func_;
40   FrontendOpRunInfoPtr op_run_info_;
41 };
42 
43 class PassthroughFrontendTask : public runtime::AsyncTask {
44  public:
PassthroughFrontendTask(std::function<void (void)> run_func)45   explicit PassthroughFrontendTask(std::function<void(void)> run_func)
46       : AsyncTask(runtime::kFrontendTask), run_func_(std::move(run_func)) {}
47   ~PassthroughFrontendTask() override = default;
48   void Run() override;
49 
50  private:
51   std::function<void(void)> run_func_;
52 };
53 
54 class SliceOpFrontendTask : public runtime::AsyncTask {
55  public:
SliceOpFrontendTask(std::function<void (const std::vector<ValuePtr> & input_values,const std::vector<SliceOpInfoPtr> & slice_op_infos,bool requires_grad,const stub::StubNodePtr & stub_output)> run_func,std::vector<ValuePtr> input_values,std::vector<SliceOpInfoPtr> slice_op_infos,bool requires_grad,const stub::StubNodePtr & stub_output)56   SliceOpFrontendTask(
57     std::function<void(const std::vector<ValuePtr> &input_values, const std::vector<SliceOpInfoPtr> &slice_op_infos,
58                        bool requires_grad, const stub::StubNodePtr &stub_output)>
59       run_func,
60     std::vector<ValuePtr> input_values, std::vector<SliceOpInfoPtr> slice_op_infos, bool requires_grad,
61     const stub::StubNodePtr &stub_output)
62       : AsyncTask(runtime::kFrontendTask),
63         run_func_(std::move(run_func)),
64         input_values_(std::move(input_values)),
65         slice_op_infos_(std::move(slice_op_infos)),
66         requires_grad_(requires_grad),
67         stub_output_(stub_output) {}
68   ~SliceOpFrontendTask() override = default;
69   void Run() override;
70   void SetException(const std::exception_ptr &e) override;
71 
72  private:
73   std::function<void(const std::vector<ValuePtr> &input_values, const std::vector<SliceOpInfoPtr> &slice_op_infos,
74                      bool requires_grad, const stub::StubNodePtr &stub_output)>
75     run_func_;
76   std::vector<ValuePtr> input_values_;
77   std::vector<SliceOpInfoPtr> slice_op_infos_;
78   bool requires_grad_{false};
79   stub::StubNodePtr stub_output_;
80 };
81 
82 using BackendOpRunInfoPtr = session::BackendOpRunInfoPtr;
83 class BackendTask : public runtime::AsyncTask {
84  public:
BackendTask(std::function<void (const FrontendOpRunInfoPtr & op_run_info,const BackendOpRunInfoPtr & backend_op_run_info)> run_func,FrontendOpRunInfoPtr op_run_info,BackendOpRunInfoPtr backend_op_run_info)85   BackendTask(
86     std::function<void(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info)>
87       run_func,
88     FrontendOpRunInfoPtr op_run_info, BackendOpRunInfoPtr backend_op_run_info)
89       : AsyncTask(runtime::kBackendTask),
90         run_func_(std::move(run_func)),
91         op_run_info_(std::move(op_run_info)),
92         backend_op_run_info_(std::move(backend_op_run_info)) {}
93   ~BackendTask() override = default;
94   void Run() override;
95 
96  private:
97   std::function<void(const FrontendOpRunInfoPtr &op_run_info, const BackendOpRunInfoPtr &backend_op_run_info)>
98     run_func_;
99   FrontendOpRunInfoPtr op_run_info_;
100   BackendOpRunInfoPtr backend_op_run_info_;
101 };
102 
103 class ViewKernelBackendTask : public runtime::AsyncTask {
104  public:
ViewKernelBackendTask(std::function<void (const FrontendOpRunInfoPtr & op_run_info,const runtime::KernelTaskType & task_type)> run_func,FrontendOpRunInfoPtr op_run_info,const runtime::KernelTaskType & task_type)105   ViewKernelBackendTask(
106     std::function<void(const FrontendOpRunInfoPtr &op_run_info, const runtime::KernelTaskType &task_type)> run_func,
107     FrontendOpRunInfoPtr op_run_info, const runtime::KernelTaskType &task_type)
108       : AsyncTask(runtime::kBackendTask),
109         run_func_(std::move(run_func)),
110         op_run_info_(std::move(op_run_info)),
111         task_type_(task_type) {}
112   ~ViewKernelBackendTask() override = default;
113   void Run() override;
114 
115  private:
116   std::function<void(const FrontendOpRunInfoPtr &op_run_info, const runtime::KernelTaskType &task_type)> run_func_;
117   FrontendOpRunInfoPtr op_run_info_;
118   runtime::KernelTaskType task_type_;
119 };
120 
121 class AllocViewMemBackendTask : public runtime::AsyncTask {
122  public:
AllocViewMemBackendTask(std::function<void (const FrontendOpRunInfoPtr & op_run_info,const tensor::TensorPtr & input_tensor,const size_t & input_idx,bool need_wait)> run_func,FrontendOpRunInfoPtr op_run_info,const tensor::TensorPtr & input_tensor,const size_t & input_idx,bool need_wait)123   AllocViewMemBackendTask(
124     std::function<void(const FrontendOpRunInfoPtr &op_run_info, const tensor::TensorPtr &input_tensor,
125                        const size_t &input_idx, bool need_wait)>
126       run_func,
127     FrontendOpRunInfoPtr op_run_info, const tensor::TensorPtr &input_tensor, const size_t &input_idx, bool need_wait)
128       : AsyncTask(runtime::kBackendTask),
129         run_func_(std::move(run_func)),
130         op_run_info_(std::move(op_run_info)),
131         input_tensor_(input_tensor),
132         input_idx_(input_idx),
133         need_wait_(need_wait) {}
134   ~AllocViewMemBackendTask() override = default;
135   void Run() override;
136   void SetException(const std::exception_ptr &e) override;
137 
138  private:
139   std::function<void(const FrontendOpRunInfoPtr &op_run_info, const tensor::TensorPtr &input_tensor,
140                      const size_t &input_idx, bool need_wait)>
141     run_func_;
142   FrontendOpRunInfoPtr op_run_info_;
143   tensor::TensorPtr input_tensor_;
144   size_t input_idx_{0};
145   bool need_wait_{false};
146 };
147 
148 class ContiguousBackendTask : public runtime::AsyncTask {
149  public:
ContiguousBackendTask(std::function<void (const tensor::TensorPtr & tensor)> run_func,const tensor::TensorPtr & tensor)150   ContiguousBackendTask(std::function<void(const tensor::TensorPtr &tensor)> run_func, const tensor::TensorPtr &tensor)
151       : AsyncTask(runtime::kBackendTask), run_func_(std::move(run_func)), tensor_(tensor) {}
152   ~ContiguousBackendTask() override = default;
153   void Run() override;
154 
155  private:
156   std::function<void(const tensor::TensorPtr &tensor)> run_func_;
157   tensor::TensorPtr tensor_;
158 };
159 }  // namespace pynative
160 }  // namespace mindspore
161 #endif  // MINDSPORE_MINDSPORE_CCSRC_PIPELINE_PYNATIVE_FORWARD_FORWARD_TASK_H_
162