• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_KERNEL_TASK_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_KERNEL_TASK_H_
19 
20 #include <utility>
21 #include <vector>
22 #include <memory>
23 #include <future>
24 
25 #include "runtime/pipeline/task/task.h"
26 
27 namespace mindspore {
28 namespace runtime {
29 
30 class BACKEND_EXPORT KernelTaskContext {
31  public:
KernelTaskContext(const device::DeviceContext * device_context,device::DeviceAddressPtrList input_addr_list,device::DeviceAddressPtrList output_addr_list,void * stream)32   KernelTaskContext(const device::DeviceContext *device_context, device::DeviceAddressPtrList input_addr_list,
33                     device::DeviceAddressPtrList output_addr_list, void *stream)
34       : device_context_(device_context),
35         input_addr_list_(std::move(input_addr_list)),
36         output_addr_list_(std::move(output_addr_list)),
37         stream_(stream) {}
38   ~KernelTaskContext() = default;
39 
device_context()40   const device::DeviceContext *device_context() { return device_context_; }
stream()41   void *stream() { return stream_; }
42 
GetInputAddr(size_t idx)43   const device::DeviceAddressPtr GetInputAddr(size_t idx) {
44     if (idx >= input_addr_list_.size()) {
45       MS_LOG(EXCEPTION) << "input_addr_list size is invalid, size:" << input_addr_list_.size() << ", idx:" << idx;
46     }
47     auto addr = input_addr_list_[idx];
48     MS_EXCEPTION_IF_NULL(addr);
49     return addr;
50   }
51 
GetOutputAddr(size_t idx)52   const device::DeviceAddressPtr GetOutputAddr(size_t idx) {
53     if (idx >= output_addr_list_.size()) {
54       MS_LOG(EXCEPTION) << "output_addr_list_ size is invalid, size:" << output_addr_list_.size() << ", idx:" << idx;
55     }
56     auto addr = output_addr_list_[idx];
57     MS_EXCEPTION_IF_NULL(addr);
58     return addr;
59   }
60 
61  private:
62   const device::DeviceContext *device_context_;
63   device::DeviceAddressPtrList input_addr_list_;
64   device::DeviceAddressPtrList output_addr_list_;
65   void *stream_;
66 };
67 
68 class BACKEND_EXPORT KernelTask : public AsyncTask {
69  public:
KernelTask(std::shared_ptr<KernelTaskContext> context)70   explicit KernelTask(std::shared_ptr<KernelTaskContext> context)
71       : AsyncTask(kKernelTask), context_(std::move(context)) {}
72   ~KernelTask() override = default;
Run()73   void Run() override {}
74 
75  protected:
76   std::shared_ptr<KernelTaskContext> context_;
77 };
78 using KernelTaskPtr = std::shared_ptr<KernelTask>;
79 
80 }  // namespace runtime
81 }  // namespace mindspore
82 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_PYNATIVE_ASYNC_KERNEL_TASK_H_
83