• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_RUN_OP_HELPER_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_RUN_OP_HELPER_H_
19 
20 #include <vector>
21 #include <string>
22 #include "include/backend/kernel_graph.h"
23 #include "runtime/pynative/op_compiler.h"
24 #include "runtime/hardware/device_context.h"
25 
26 namespace mindspore::runtime {
27 struct OpRunnerInfo {
28   const PrimitivePtr &prim;
29   const std::string &device_target;
30   const vector<ValuePtr> &inputs;
31   const abstract::AbstractBasePtrList &inputs_abs;
32   const std::vector<InputType> &inputs_mask;
33   abstract::AbstractBasePtr output_abs;
34   ValueSimpleInfoPtr output_value_simple_info{nullptr};
35 };
36 
37 class OpRunner {
38  public:
39   // Update Tensor or input node DeviceAddress before PyNative async running.
40   static void UpdateDeviceAddress(const KernelGraphPtr &graph,
41                                   const std::vector<tensor::BaseTensorPtr> &tensors_without_value_mask,
42                                   const device::DeviceContext *device_context, bool is_sync);
43 
44   static void RunSingleOpGraph(const session::BackendOpRunInfoPtr &op_run_info,
45                                const OpCompilerInfoPtr &op_compiler_info,
46                                const std::vector<tensor::BaseTensorPtr> &input_tensors);
47 
48   static std::vector<tensor::BaseTensorPtr> GetTensorWithoutValueMask(const session::BackendOpRunInfoPtr &op_run_info);
49   static void LaunchKernelTask(const runtime::KernelTaskType &task_type, DeviceContext *device_context,
50                                const device::DeviceAddressPtrList &input_addr_list,
51                                const device::DeviceAddressPtrList &output_addr_list, size_t stream_id);
52   BACKEND_EXPORT static DeviceContext *GetDeviceContext(const std::string &device_type);
53   BACKEND_EXPORT static void ChildAfterFork();
54 };
55 
56 class DynamicOpRunner {
57  public:
58   static void UpdateInputDeviceAddress(const OpCompilerInfoPtr &op_compiler_info,
59                                        const std::vector<tensor::BaseTensorPtr> &input_tensors, bool is_sync);
60   static void RunSingleOpGraph(const session::BackendOpRunInfoPtr &op_run_info,
61                                const OpCompilerInfoPtr &op_compiler_info,
62                                const std::vector<tensor::BaseTensorPtr> &input_tensors);
63   static void CopyHostToDevice(const OpCompilerInfoPtr &op_compiler_info,
64                                const std::vector<tensor::BaseTensorPtr> &input_tensors);
65 };
66 }  // namespace mindspore::runtime
67 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_RUN_OP_RUN_OP_HELPER_H_
68