• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_
18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_
19 
20 #include <vector>
21 #include <string>
22 #include <memory>
23 #include <utility>
24 #include "runtime/hardware/device_context.h"
25 #include "runtime/pynative/op_compiler.h"
26 #include "runtime/device/multi_stream_controller.h"
27 #include "kernel/kernel.h"
28 #include "mindapi/base/type_traits.h"
29 
30 template <typename T>
31 struct is_optional : public std::false_type {};
32 template <typename T>
33 struct is_optional<std::optional<T>> : public std::true_type {};
34 
35 namespace mindspore {
36 using device::DeviceContext;
37 namespace runtime {
38 // Extract the methods related to DeviceAddress in GraphCompiler to the DeviceAddressUtils class.
39 class BACKEND_EXPORT DeviceAddressUtils {
40  public:
41   static void CreateKernelTensor(const device::DeviceAddressPtr &device_address, const tensor::BaseTensorPtr &tensor);
42   static void CreateKernelTensor(const device::DeviceAddressPtr &device_address, const AbstractBasePtr &abs);
43   static void CreateKernelTensor(const ValuePtr &input_value);
44   static void CreateKernelTensor(const tensor::TensorPtr &input_tensor);
45   static void CopyNoneTensorDataToDevice(const device::DeviceContext *device_context,
46                                          const device::DeviceAddressPtr &device_address, const ShapeVector &shape = {});
47   static void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
48   static device::DeviceAddressPtrList CreateDeviceAddressForTensorValue(const DeviceContext *device_context,
49                                                                         const ValuePtr &node_value, size_t output_idx,
50                                                                         const ValueNodePtr &value_node);
51   static void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
52   static void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph,
53                                               bool is_gradient_out);
54 
55   static vector<device::DeviceAddressPtr> CreateGraphOutputDeviceAddress(const OpCompilerInfoPtr &op_compiler_info,
56                                                                          const abstract::AbstractBasePtr &out_abstract,
57                                                                          size_t stream_id);
58 
59   static void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
60   static void CreateDeviceAddressByMapTensorNode(const DeviceContext *device_context, const AnfNodePtr &node,
61                                                  size_t index);
62   static void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph);
63   static void UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair,
64                                   const session::AnfWithOutIndex &origin_pair);
65   static void UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph);
66   static device::DeviceAddressPtr CloneEmptyDeviceAddress(const device::DeviceAddressPtr &old_device_address,
67                                                           const DeviceContext *device_context);
68   static void CreateGraphOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph);
69   static size_t GetTensorDeviceSize(const DeviceContext *device_context, const AnfNodePtr &node,
70                                     const ShapeVector &shape, const string &format, TypeId dtype, size_t output_index);
71 
72   // Overloading
73   static void CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index,
74                                        const tensor::BaseTensorPtr &tensor);
75   static void MallocForInput(const DeviceContext *device_context, const tensor::BaseTensorPtr &tensor, bool is_view);
76   static void MallocForInput(const DeviceContext *device_context, const std::optional<tensor::BaseTensorPtr> &val,
77                              bool is_view);
78   static void MallocForInput(const DeviceContext *device_context, const std::vector<tensor::BaseTensorPtr> &tensors,
79                              bool is_view);
80   static void CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index,
81                                        const std::optional<tensor::BaseTensorPtr> &val);
82   template <typename T>
83   static void CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index,
84                                        const std::vector<T> &inputs) {
85     for (size_t i = 0; i < inputs.size(); ++i) {
86       CreateInputTensorAddress(device_context, stream_id, index, inputs[i]);
87     }
88   }
89 
90   static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
91                                                      const abstract::AbstractBasePtr &abs, size_t index,
92                                                      const tensor::BaseTensorPtr &tensor);
93   static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
94                                                      const abstract::AbstractBasePtr &abs, size_t index,
95                                                      const std::optional<tensor::BaseTensorPtr> &val);
96   static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
97                                                      const abstract::AbstractBasePtr &abs, size_t index,
98                                                      const ScalarPtr &scalar_value);
99   static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
100                                                      const abstract::AbstractBasePtr &abs, size_t index,
101                                                      const StringImmPtr &string_imm);
102   static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
103                                                      const abstract::AbstractBasePtr &abs, size_t index,
104                                                      const TypePtr &type_ptr);
105   template <typename T>
106   static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
107                                                      const abstract::AbstractBasePtr &abs, size_t index, const T &t) {
108     MS_EXCEPTION_IF_NULL(device_context);
109     auto tmp_abs = abs;
110     if (abs == nullptr) {
111       tmp_abs = t->ToAbstract()->Broaden();
112     }
113     auto shape = tmp_abs->GetShape();
114     auto type = tmp_abs->GetType();
115     auto value = tmp_abs->GetValue();
116     auto kernel_tensor = std::make_shared<kernel::KernelTensor>(shape, type, value);
117     auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
118     device_address->set_from_persistent_mem(true);
119 
120     if (device_address->GetPtr() == nullptr) {
121       CopyNoneTensorDataToDevice(device_context, device_address);
122     }
123     MS_LOG(DEBUG) << "Create input " << tmp_abs->ToString() << " device address for " << index
124                   << "th input, Shape: " << shape->ToString() << ", Type: " << type->ToString()
125                   << ", Value: " << (value ? value->ToString() : "nullptr") << " device address:" << device_address;
126     return device_address;
127   }
128 
129   static void CreateOutputTensorAddress(const DeviceContext *device_context, size_t stream_id,
130                                         const std::vector<tensor::BaseTensorPtr> &outputs);
131   static void CreateOutputTensorAddress(const DeviceContext *device_context, size_t stream_id,
132                                         const tensor::BaseTensorPtr &output_tensor, size_t size);
133 
134   static void MallocForOutputs(const DeviceContext *device_context, const std::vector<tensor::BaseTensorPtr> &outputs);
135 
136   static device::DeviceAddressPtr CreateWorkspaceAddressWithoutKernelTensor(const DeviceContext *device_context,
137                                                                             size_t stream_id,
138                                                                             const size_t &workspace_size);
139 
140   static device::DeviceAddressPtr CreateWorkspaceAddress(const DeviceContext *device_context, size_t stream_id,
141                                                          const size_t &workspace_size);
142 
143   static void UpdateDeviceAddressHostInfoByNode(const device::DeviceAddressPtr &addr, const AnfNodePtr &node,
144                                                 size_t output_idx);
145   static device::DeviceAddressPtr CreateDeviceAddress(const DeviceContext *device_context,
146                                                       const tensor::BaseTensorPtr &tensor,
147                                                       const ShapeVector &real_shape, const size_t &stream_id);
148 
149   // Convert tensor to contiguous tensor.
150   static void ConvertContiguousTensorSync(const tensor::BaseTensorPtr &tensor);
151 
152   // Convert old_device_address to contiguous device address.
153   static device::DeviceAddressPtr ConvertContiguousDeviceAddress(const DeviceContext *device_context,
154                                                                  const device::DeviceAddressPtr &old_device_address,
155                                                                  bool is_sync);
156 
157   template <typename... T>
158   static void ProcessCrossStreamAddress(const std::string &op_name, const DeviceContext *device_context,
159                                         size_t op_stream_id, const T &... args) {
160     // memory_stream_addresses pair : memory_stream_id, address.
161     std::vector<std::pair<uint32_t, void *>> cross_stream_addresses;
162     (GetCrossStreamAddressInfo(op_stream_id, &cross_stream_addresses, args), ...);
163     if (cross_stream_addresses.empty()) {
164       return;
165     }
166 
167     device::MultiStreamController::GetInstance()->Refresh(device_context);
168     auto task_id_on_stream =
169       device::MultiStreamController::GetInstance()->LaunchTaskIdOnStream(device_context, op_stream_id);
170     MS_LOG(DEBUG) << "Launch stream_id:" << op_stream_id << ", task id:" << task_id_on_stream << ", op_name:" << op_name
171                   << ", cross_stream_addresses size:" << cross_stream_addresses.size();
172     device::MultiStreamController::GetInstance()->RecordEvent(device_context, task_id_on_stream, op_stream_id,
173                                                               cross_stream_addresses);
174   }
175 
176  private:
177   // Whether device address of anf node is valid and device address type
178   // is consistent with device type, for example, device address type
179   // DeviceType::kGPU should be used on GPU device
180   static bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &node, size_t index);
181 
182   static void GetCrossStreamAddressInfoFromInput(size_t op_stream_id,
183                                                  std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
184                                                  const tensor::BaseTensorPtr &tensor);
185 
186   static void GetCrossStreamAddressInfoFromInput(size_t op_stream_id,
187                                                  std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
188                                                  const mindspore::kernel::KernelTensor *tensor);
189 
190   static void GetCrossStreamAddressInfoFromInput(size_t op_stream_id,
191                                                  std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
192                                                  const device::DeviceAddressPtr &device_address);
193 
194   template <typename T>
195   static void GetCrossStreamAddressInfo(size_t op_stream_id,
196                                         std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
197                                         const std::optional<T> &opt) {
198     if (opt.has_value()) {
199       return GetCrossStreamAddressInfo(op_stream_id, cross_stream_addresses, opt.value());
200     }
201   }
202 
203   template <typename T>
204   static void GetCrossStreamAddressInfo(size_t op_stream_id,
205                                         std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
206                                         const std::vector<T> &inputs) {
207     if constexpr (!std::is_same_v<T, tensor::BaseTensorPtr> && !std::is_same_v<T, tensor::TensorPtr> &&
208                   !std::is_same_v<T, mindspore::kernel::KernelTensor *> &&
209                   !std::is_same_v<T, device::DeviceAddressPtr>) {
210       return;
211     }
212     for_each(inputs.begin(), inputs.end(), [op_stream_id, cross_stream_addresses](auto item) {
213       GetCrossStreamAddressInfo(op_stream_id, cross_stream_addresses, item);
214     });
215   }
216 
217   template <typename T, typename = typename std::enable_if_t<!is_vector<T>::value && !is_optional<T>::value, T>>
218   static void GetCrossStreamAddressInfo(size_t op_stream_id,
219                                         std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
220                                         const T &input) {
221     if constexpr (std::is_same_v<T, tensor::BaseTensorPtr> || std::is_same_v<T, tensor::TensorPtr> ||
222                   std::is_same_v<T, mindspore::kernel::KernelTensor *> || std::is_same_v<T, device::DeviceAddressPtr>) {
223       GetCrossStreamAddressInfoFromInput(op_stream_id, cross_stream_addresses, input);
224     }
225   }
226 };
227 }  // namespace runtime
228 }  // namespace mindspore
229 #endif  // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_
230