• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_DEVICE_TASK_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_DEVICE_TASK_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <memory>
23 #include <future>
24 
25 #include "runtime/pipeline/task/task.h"
26 #include "backend/common/session/session_basic.h"
27 #include "runtime/pynative/op_compiler.h"
28 
29 namespace mindspore {
30 namespace runtime {
31 class BACKEND_EXPORT OpTaskContext {
32  public:
OpTaskContext(GraphId graph_id,KernelGraphPtr graph,session::BackendOpRunInfoPtr op_run_info,OpCompilerInfoPtr op_compiler_info,bool is_pynative_infer)33   OpTaskContext(GraphId graph_id, KernelGraphPtr graph, session::BackendOpRunInfoPtr op_run_info,
34                 OpCompilerInfoPtr op_compiler_info, bool is_pynative_infer)
35       : graph_id_(graph_id),
36         graph_(std::move(graph)),
37         op_run_info_(std::move(op_run_info)),
38         op_compiler_info_(std::move(op_compiler_info)),
39         is_pyantive_infer_(is_pynative_infer) {}
40   ~OpTaskContext() = default;
41 
graph_id()42   GraphId graph_id() const { return graph_id_; }
graph()43   const KernelGraphPtr &graph() const { return graph_; }
op_run_info()44   const session::BackendOpRunInfoPtr &op_run_info() const { return op_run_info_; }
device_context()45   const device::DeviceContext *device_context() const { return op_compiler_info_->device_context_; }
is_pynative_infer()46   bool is_pynative_infer() const { return is_pyantive_infer_; }
op_compiler_info()47   const OpCompilerInfoPtr &op_compiler_info() const { return op_compiler_info_; }
48 
49  private:
50   GraphId graph_id_;
51   KernelGraphPtr graph_;
52   session::BackendOpRunInfoPtr op_run_info_;
53   OpCompilerInfoPtr op_compiler_info_;
54   bool is_pyantive_infer_;
55 };
56 
57 class BACKEND_EXPORT DeviceOpTask : public AsyncTask {
58  public:
DeviceOpTask(std::shared_ptr<OpTaskContext> context,TaskType task_type)59   DeviceOpTask(std::shared_ptr<OpTaskContext> context, TaskType task_type)
60       : AsyncTask(task_type), context_(std::move(context)) {}
61   ~DeviceOpTask() override = default;
62 
Run()63   void Run() override {}
64 
context()65   const std::shared_ptr<OpTaskContext> &context() { return context_; }
66 
67  protected:
68   std::shared_ptr<OpTaskContext> context_;
69 };
70 
71 class BACKEND_EXPORT DeviceOpRunTask : public DeviceOpTask {
72  public:
73   DeviceOpRunTask(std::shared_ptr<OpTaskContext> context,
74                   std::function<void(const std::shared_ptr<OpTaskContext> &context)> run_func);
75   ~DeviceOpRunTask() override;
76   void Run() override;
77 
78  private:
79   std::function<void(const std::shared_ptr<OpTaskContext> &context)> run_func_;
80 };
81 
82 class BACKEND_EXPORT PyBoostDeviceTask : public AsyncTask {
83  public:
PyBoostDeviceTask(std::function<void ()> run_func)84   explicit PyBoostDeviceTask(std::function<void()> run_func) : AsyncTask(kPyBoostOpTask), run_func_(run_func) {}
85   ~PyBoostDeviceTask() = default;
86 
87   void Run() override;
88 
89  private:
90   std::function<void()> run_func_;
91 };
92 
93 class BACKEND_EXPORT DeviceLaunchTask : public AsyncTask {
94  public:
DeviceLaunchTask(std::function<void ()> run_func)95   explicit DeviceLaunchTask(std::function<void()> run_func) : AsyncTask(kKernelTask), run_func_(std::move(run_func)) {}
96   ~DeviceLaunchTask() = default;
97 
98   void Run() override;
99 
100  private:
101   std::function<void()> run_func_;
102 };
103 
104 class BACKEND_EXPORT PassthroughDeviceTask : public AsyncTask {
105  public:
PassthroughDeviceTask(std::function<void (void)> run_func)106   explicit PassthroughDeviceTask(std::function<void(void)> run_func)
107       : AsyncTask(kDeviceOpTask), run_func_(std::move(run_func)) {}
108   ~PassthroughDeviceTask() override = default;
109   void Run() override;
110 
111  private:
112   std::function<void(void)> run_func_;
113 };
114 }  // namespace runtime
115 }  // namespace mindspore
116 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_DEVICE_TASK_H_
117