1 /** 2 * Copyright 2022-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_ 18 #define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_ 19 20 #include <vector> 21 #include <string> 22 #include <memory> 23 #include <utility> 24 #include "runtime/hardware/device_context.h" 25 #include "runtime/pynative/op_compiler.h" 26 #include "runtime/device/multi_stream_controller.h" 27 #include "kernel/kernel.h" 28 #include "mindapi/base/type_traits.h" 29 30 template <typename T> 31 struct is_optional : public std::false_type {}; 32 template <typename T> 33 struct is_optional<std::optional<T>> : public std::true_type {}; 34 35 namespace mindspore { 36 using device::DeviceContext; 37 namespace runtime { 38 // Extract the methods related to DeviceAddress in GraphCompiler to the DeviceAddressUtils class. 39 class BACKEND_EXPORT DeviceAddressUtils { 40 public: 41 static void CreateKernelTensor(const device::DeviceAddressPtr &device_address, const tensor::BaseTensorPtr &tensor); 42 static void CreateKernelTensor(const device::DeviceAddressPtr &device_address, const AbstractBasePtr &abs); 43 static void CreateKernelTensor(const ValuePtr &input_value); 44 static void CreateKernelTensor(const tensor::TensorPtr &input_tensor); 45 static void CopyNoneTensorDataToDevice(const device::DeviceContext *device_context, 46 const device::DeviceAddressPtr &device_address, const ShapeVector &shape = {}); 47 static void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph); 48 static device::DeviceAddressPtrList CreateDeviceAddressForTensorValue(const DeviceContext *device_context, 49 const ValuePtr &node_value, size_t output_idx, 50 const ValueNodePtr &value_node); 51 static void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph); 52 static void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph, 53 bool is_gradient_out); 54 55 static vector<device::DeviceAddressPtr> CreateGraphOutputDeviceAddress(const OpCompilerInfoPtr &op_compiler_info, 56 const abstract::AbstractBasePtr &out_abstract, 57 size_t stream_id); 58 59 static void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph); 60 static void CreateDeviceAddressByMapTensorNode(const DeviceContext *device_context, const AnfNodePtr &node, 61 size_t index); 62 static void UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph); 63 static void UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair, 64 const session::AnfWithOutIndex &origin_pair); 65 static void UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph); 66 static device::DeviceAddressPtr CloneEmptyDeviceAddress(const device::DeviceAddressPtr &old_device_address, 67 const DeviceContext *device_context); 68 static void CreateGraphOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph); 69 static size_t GetTensorDeviceSize(const DeviceContext *device_context, const AnfNodePtr &node, 70 const ShapeVector &shape, const string &format, TypeId dtype, size_t output_index); 71 72 // Overloading 73 static void CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index, 74 const tensor::BaseTensorPtr &tensor); 75 static void MallocForInput(const DeviceContext *device_context, const tensor::BaseTensorPtr &tensor, bool is_view); 76 static void MallocForInput(const DeviceContext *device_context, const std::optional<tensor::BaseTensorPtr> &val, 77 bool is_view); 78 static void MallocForInput(const DeviceContext *device_context, const std::vector<tensor::BaseTensorPtr> &tensors, 79 bool is_view); 80 static void CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index, 81 const std::optional<tensor::BaseTensorPtr> &val); 82 template <typename T> 83 static void CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index, 84 const std::vector<T> &inputs) { 85 for (size_t i = 0; i < inputs.size(); ++i) { 86 CreateInputTensorAddress(device_context, stream_id, index, inputs[i]); 87 } 88 } 89 90 static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id, 91 const abstract::AbstractBasePtr &abs, size_t index, 92 const tensor::BaseTensorPtr &tensor); 93 static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id, 94 const abstract::AbstractBasePtr &abs, size_t index, 95 const std::optional<tensor::BaseTensorPtr> &val); 96 static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id, 97 const abstract::AbstractBasePtr &abs, size_t index, 98 const ScalarPtr &scalar_value); 99 static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id, 100 const abstract::AbstractBasePtr &abs, size_t index, 101 const StringImmPtr &string_imm); 102 static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id, 103 const abstract::AbstractBasePtr &abs, size_t index, 104 const TypePtr &type_ptr); 105 template <typename T> 106 static device::DeviceAddressPtr CreateInputAddress(const DeviceContext *device_context, size_t stream_id, 107 const abstract::AbstractBasePtr &abs, size_t index, const T &t) { 108 MS_EXCEPTION_IF_NULL(device_context); 109 auto tmp_abs = abs; 110 if (abs == nullptr) { 111 tmp_abs = t->ToAbstract()->Broaden(); 112 } 113 auto shape = tmp_abs->GetShape(); 114 auto type = tmp_abs->GetType(); 115 auto value = tmp_abs->GetValue(); 116 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(shape, type, value); 117 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor); 118 device_address->set_from_persistent_mem(true); 119 120 if (device_address->GetPtr() == nullptr) { 121 CopyNoneTensorDataToDevice(device_context, device_address); 122 } 123 MS_LOG(DEBUG) << "Create input " << tmp_abs->ToString() << " device address for " << index 124 << "th input, Shape: " << shape->ToString() << ", Type: " << type->ToString() 125 << ", Value: " << (value ? value->ToString() : "nullptr") << " device address:" << device_address; 126 return device_address; 127 } 128 129 static void CreateOutputTensorAddress(const DeviceContext *device_context, size_t stream_id, 130 const std::vector<tensor::BaseTensorPtr> &outputs); 131 static void CreateOutputTensorAddress(const DeviceContext *device_context, size_t stream_id, 132 const tensor::BaseTensorPtr &output_tensor, size_t size); 133 134 static void MallocForOutputs(const DeviceContext *device_context, const std::vector<tensor::BaseTensorPtr> &outputs); 135 136 static device::DeviceAddressPtr CreateWorkspaceAddressWithoutKernelTensor(const DeviceContext *device_context, 137 size_t stream_id, 138 const size_t &workspace_size); 139 140 static device::DeviceAddressPtr CreateWorkspaceAddress(const DeviceContext *device_context, size_t stream_id, 141 const size_t &workspace_size); 142 143 static void UpdateDeviceAddressHostInfoByNode(const device::DeviceAddressPtr &addr, const AnfNodePtr &node, 144 size_t output_idx); 145 static device::DeviceAddressPtr CreateDeviceAddress(const DeviceContext *device_context, 146 const tensor::BaseTensorPtr &tensor, 147 const ShapeVector &real_shape, const size_t &stream_id); 148 149 // Convert tensor to contiguous tensor. 150 static void ConvertContiguousTensorSync(const tensor::BaseTensorPtr &tensor); 151 152 // Convert old_device_address to contiguous device address. 153 static device::DeviceAddressPtr ConvertContiguousDeviceAddress(const DeviceContext *device_context, 154 const device::DeviceAddressPtr &old_device_address, 155 bool is_sync); 156 157 template <typename... T> 158 static void ProcessCrossStreamAddress(const std::string &op_name, const DeviceContext *device_context, 159 size_t op_stream_id, const T &... args) { 160 // memory_stream_addresses pair : memory_stream_id, address. 161 std::vector<std::pair<uint32_t, void *>> cross_stream_addresses; 162 (GetCrossStreamAddressInfo(op_stream_id, &cross_stream_addresses, args), ...); 163 if (cross_stream_addresses.empty()) { 164 return; 165 } 166 167 device::MultiStreamController::GetInstance()->Refresh(device_context); 168 auto task_id_on_stream = 169 device::MultiStreamController::GetInstance()->LaunchTaskIdOnStream(device_context, op_stream_id); 170 MS_LOG(DEBUG) << "Launch stream_id:" << op_stream_id << ", task id:" << task_id_on_stream << ", op_name:" << op_name 171 << ", cross_stream_addresses size:" << cross_stream_addresses.size(); 172 device::MultiStreamController::GetInstance()->RecordEvent(device_context, task_id_on_stream, op_stream_id, 173 cross_stream_addresses); 174 } 175 176 private: 177 // Whether device address of anf node is valid and device address type 178 // is consistent with device type, for example, device address type 179 // DeviceType::kGPU should be used on GPU device 180 static bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &node, size_t index); 181 182 static void GetCrossStreamAddressInfoFromInput(size_t op_stream_id, 183 std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses, 184 const tensor::BaseTensorPtr &tensor); 185 186 static void GetCrossStreamAddressInfoFromInput(size_t op_stream_id, 187 std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses, 188 const mindspore::kernel::KernelTensor *tensor); 189 190 static void GetCrossStreamAddressInfoFromInput(size_t op_stream_id, 191 std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses, 192 const device::DeviceAddressPtr &device_address); 193 194 template <typename T> 195 static void GetCrossStreamAddressInfo(size_t op_stream_id, 196 std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses, 197 const std::optional<T> &opt) { 198 if (opt.has_value()) { 199 return GetCrossStreamAddressInfo(op_stream_id, cross_stream_addresses, opt.value()); 200 } 201 } 202 203 template <typename T> 204 static void GetCrossStreamAddressInfo(size_t op_stream_id, 205 std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses, 206 const std::vector<T> &inputs) { 207 if constexpr (!std::is_same_v<T, tensor::BaseTensorPtr> && !std::is_same_v<T, tensor::TensorPtr> && 208 !std::is_same_v<T, mindspore::kernel::KernelTensor *> && 209 !std::is_same_v<T, device::DeviceAddressPtr>) { 210 return; 211 } 212 for_each(inputs.begin(), inputs.end(), [op_stream_id, cross_stream_addresses](auto item) { 213 GetCrossStreamAddressInfo(op_stream_id, cross_stream_addresses, item); 214 }); 215 } 216 217 template <typename T, typename = typename std::enable_if_t<!is_vector<T>::value && !is_optional<T>::value, T>> 218 static void GetCrossStreamAddressInfo(size_t op_stream_id, 219 std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses, 220 const T &input) { 221 if constexpr (std::is_same_v<T, tensor::BaseTensorPtr> || std::is_same_v<T, tensor::TensorPtr> || 222 std::is_same_v<T, mindspore::kernel::KernelTensor *> || std::is_same_v<T, device::DeviceAddressPtr>) { 223 GetCrossStreamAddressInfoFromInput(op_stream_id, cross_stream_addresses, input); 224 } 225 } 226 }; 227 } // namespace runtime 228 } // namespace mindspore 229 #endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_GRAPH_SCHEDULER_COMMON_UTILS_H_ 230