1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_KERNEL_TASK_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_KERNEL_TASK_H_ 19 20 #include <utility> 21 #include <vector> 22 #include <memory> 23 #include <future> 24 25 #include "runtime/pipeline/task/task.h" 26 27 namespace mindspore { 28 namespace runtime { 29 30 class BACKEND_EXPORT KernelTaskContext { 31 public: KernelTaskContext(const device::DeviceContext * device_context,device::DeviceAddressPtrList input_addr_list,device::DeviceAddressPtrList output_addr_list,void * stream)32 KernelTaskContext(const device::DeviceContext *device_context, device::DeviceAddressPtrList input_addr_list, 33 device::DeviceAddressPtrList output_addr_list, void *stream) 34 : device_context_(device_context), 35 input_addr_list_(std::move(input_addr_list)), 36 output_addr_list_(std::move(output_addr_list)), 37 stream_(stream) {} 38 ~KernelTaskContext() = default; 39 device_context()40 const device::DeviceContext *device_context() { return device_context_; } stream()41 void *stream() { return stream_; } 42 GetInputAddr(size_t idx)43 const device::DeviceAddressPtr GetInputAddr(size_t idx) { 44 if (idx >= input_addr_list_.size()) { 45 MS_LOG(EXCEPTION) << "input_addr_list size is invalid, size:" << input_addr_list_.size() << ", idx:" << idx; 46 } 47 auto addr = input_addr_list_[idx]; 48 MS_EXCEPTION_IF_NULL(addr); 49 return addr; 50 } 51 GetOutputAddr(size_t idx)52 const device::DeviceAddressPtr GetOutputAddr(size_t idx) { 53 if (idx >= output_addr_list_.size()) { 54 MS_LOG(EXCEPTION) << "output_addr_list_ size is invalid, size:" << output_addr_list_.size() << ", idx:" << idx; 55 } 56 auto addr = output_addr_list_[idx]; 57 MS_EXCEPTION_IF_NULL(addr); 58 return addr; 59 } 60 61 private: 62 const device::DeviceContext *device_context_; 63 device::DeviceAddressPtrList input_addr_list_; 64 device::DeviceAddressPtrList output_addr_list_; 65 void *stream_; 66 }; 67 68 class BACKEND_EXPORT KernelTask : public AsyncTask { 69 public: KernelTask(std::shared_ptr<KernelTaskContext> context)70 explicit KernelTask(std::shared_ptr<KernelTaskContext> context) 71 : AsyncTask(kKernelTask), context_(std::move(context)) {} 72 ~KernelTask() override = default; Run()73 void Run() override {} 74 75 protected: 76 std::shared_ptr<KernelTaskContext> context_; 77 }; 78 using KernelTaskPtr = std::shared_ptr<KernelTask>; 79 80 } // namespace runtime 81 } // namespace mindspore 82 #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_KERNEL_TASK_H_ 83