1 /** 2 * Copyright 2022-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_DEVICE_TASK_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_DEVICE_TASK_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <memory> 23 #include <future> 24 25 #include "runtime/pipeline/task/task.h" 26 #include "backend/common/session/session_basic.h" 27 #include "runtime/pynative/op_compiler.h" 28 29 namespace mindspore { 30 namespace runtime { 31 class BACKEND_EXPORT OpTaskContext { 32 public: OpTaskContext(GraphId graph_id,KernelGraphPtr graph,session::BackendOpRunInfoPtr op_run_info,OpCompilerInfoPtr op_compiler_info,bool is_pynative_infer)33 OpTaskContext(GraphId graph_id, KernelGraphPtr graph, session::BackendOpRunInfoPtr op_run_info, 34 OpCompilerInfoPtr op_compiler_info, bool is_pynative_infer) 35 : graph_id_(graph_id), 36 graph_(std::move(graph)), 37 op_run_info_(std::move(op_run_info)), 38 op_compiler_info_(std::move(op_compiler_info)), 39 is_pyantive_infer_(is_pynative_infer) {} 40 ~OpTaskContext() = default; 41 graph_id()42 GraphId graph_id() const { return graph_id_; } graph()43 const KernelGraphPtr &graph() const { return graph_; } op_run_info()44 const session::BackendOpRunInfoPtr &op_run_info() const { return op_run_info_; } device_context()45 const device::DeviceContext *device_context() const { return op_compiler_info_->device_context_; } is_pynative_infer()46 bool is_pynative_infer() const { return is_pyantive_infer_; } op_compiler_info()47 const OpCompilerInfoPtr &op_compiler_info() const { return op_compiler_info_; } 48 49 private: 50 GraphId graph_id_; 51 KernelGraphPtr graph_; 52 session::BackendOpRunInfoPtr op_run_info_; 53 OpCompilerInfoPtr op_compiler_info_; 54 bool is_pyantive_infer_; 55 }; 56 57 class BACKEND_EXPORT DeviceOpTask : public AsyncTask { 58 public: DeviceOpTask(std::shared_ptr<OpTaskContext> context,TaskType task_type)59 DeviceOpTask(std::shared_ptr<OpTaskContext> context, TaskType task_type) 60 : AsyncTask(task_type), context_(std::move(context)) {} 61 ~DeviceOpTask() override = default; 62 Run()63 void Run() override {} 64 context()65 const std::shared_ptr<OpTaskContext> &context() { return context_; } 66 67 protected: 68 std::shared_ptr<OpTaskContext> context_; 69 }; 70 71 class BACKEND_EXPORT DeviceOpRunTask : public DeviceOpTask { 72 public: 73 DeviceOpRunTask(std::shared_ptr<OpTaskContext> context, 74 std::function<void(const std::shared_ptr<OpTaskContext> &context)> run_func); 75 ~DeviceOpRunTask() override; 76 void Run() override; 77 78 private: 79 std::function<void(const std::shared_ptr<OpTaskContext> &context)> run_func_; 80 }; 81 82 class BACKEND_EXPORT PyBoostDeviceTask : public AsyncTask { 83 public: PyBoostDeviceTask(std::function<void ()> run_func)84 explicit PyBoostDeviceTask(std::function<void()> run_func) : AsyncTask(kPyBoostOpTask), run_func_(run_func) {} 85 ~PyBoostDeviceTask() = default; 86 87 void Run() override; 88 89 private: 90 std::function<void()> run_func_; 91 }; 92 93 class BACKEND_EXPORT DeviceLaunchTask : public AsyncTask { 94 public: DeviceLaunchTask(std::function<void ()> run_func)95 explicit DeviceLaunchTask(std::function<void()> run_func) : AsyncTask(kKernelTask), run_func_(std::move(run_func)) {} 96 ~DeviceLaunchTask() = default; 97 98 void Run() override; 99 100 private: 101 std::function<void()> run_func_; 102 }; 103 104 class BACKEND_EXPORT PassthroughDeviceTask : public AsyncTask { 105 public: PassthroughDeviceTask(std::function<void (void)> run_func)106 explicit PassthroughDeviceTask(std::function<void(void)> run_func) 107 : AsyncTask(kDeviceOpTask), run_func_(std::move(run_func)) {} 108 ~PassthroughDeviceTask() override = default; 109 void Run() override; 110 111 private: 112 std::function<void(void)> run_func_; 113 }; 114 } // namespace runtime 115 } // namespace mindspore 116 #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_DEVICE_TASK_H_ 117