1 /**
2 * Copyright 2022-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/device_address_utils.h"
18
19 #include <algorithm>
20 #include <string>
21 #include <map>
22 #include <vector>
23 #include <memory>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/op_def.h"
26 #include "ir/tensor.h"
27 #include "include/backend/device_address.h"
28 #include "include/backend/kernel_info.h"
29 #include "include/backend/py_execute_utils.h"
30 #include "runtime/device/hash_table.h"
31 #include "runtime/device/ms_device_shape_transfer.h"
32 #include "runtime/hardware/device_context_manager.h"
33 #include "runtime/pynative/op_runner.h"
34 #include "runtime/pynative/op_executor.h"
35 #include "pybind_api/gil_scoped_long_running.h"
36 #include "include/backend/mem_reuse/mem_tracker.h"
37 #ifdef ENABLE_DEBUGGER
38 #include "include/backend/debug/debugger/debugger.h"
39 #include "include/backend/debug/data_dump/dump_json_parser.h"
40 #include "include/backend/device_type.h"
41 #endif
42
43 namespace mindspore {
44 using tensor::TensorPtr;
45 namespace runtime {
46 namespace {
CreateDeviceAddressForScalarAndString(const DeviceContext * device_context,const ValueNodePtr & value_node)47 device::DeviceAddressPtr CreateDeviceAddressForScalarAndString(const DeviceContext *device_context,
48 const ValueNodePtr &value_node) {
49 device::DeviceAddressPtr address = nullptr;
50 const auto &node_value = value_node->value();
51 MS_EXCEPTION_IF_NULL(node_value);
52 if (node_value->isa<StringImm>()) {
53 auto value = GetValue<std::string>(node_value);
54 // Allocate one more byte to '/0'
55 size_t tensor_size = value.size() + 1;
56 if (device_context->device_context_key().device_name_ == kAscendDevice) {
57 // size of ge::StringHead which defined in Ascend/latest.aarch64-linux/include/types.h
58 constexpr size_t GE_STRING_HEAD_SIZE = 16;
59 // NOTE: on Ascend, string type need a head of type ge::StringHead
60 tensor_size += GE_STRING_HEAD_SIZE;
61 }
62 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
63 {value_node, 0}, nullptr, tensor_size, kOpFormat_DEFAULT, kObjectTypeString, ShapeVector(),
64 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
65 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
66 address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
67 } else if (node_value->isa<Scalar>()) {
68 auto scalar_value = node_value->cast<ScalarPtr>();
69 MS_EXCEPTION_IF_NULL(scalar_value);
70 TypePtr data_type = scalar_value->type();
71 MS_EXCEPTION_IF_NULL(data_type);
72 TypeId type_id = data_type->type_id();
73 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
74 {value_node, 0}, nullptr, GetTypeByte(TypeIdToType(type_id)), kOpFormat_DEFAULT, type_id, ShapeVector(),
75 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
76 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
77 address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
78 } else if (node_value->isa<None>()) {
79 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
80 {value_node, 0}, nullptr, 0, kOpFormat_DEFAULT, kTypeNone->type_id(), ShapeVector(),
81 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
82 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
83 address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
84 }
85
86 return address;
87 }
88
GetFormatByTensorShape(const DeviceContext * device_context,const ShapeVector & tensor_shape)89 Format GetFormatByTensorShape(const DeviceContext *device_context, const ShapeVector &tensor_shape) {
90 if (device_context->device_context_key().device_name_ != kAscendDevice) {
91 return Format::DEFAULT_FORMAT;
92 }
93
94 switch (tensor_shape.size()) {
95 case kShape4dDims:
96 return Format::NCHW;
97 case kShape5dDims:
98 return Format::NCDHW;
99 default:
100 return Format::ND;
101 }
102 }
103 } // namespace
104
NodeDeviceAddressExist(const DeviceContext * device_context,const AnfNodePtr & node,size_t index)105 bool DeviceAddressUtils::NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &node,
106 size_t index) {
107 MS_EXCEPTION_IF_NULL(node);
108 MS_EXCEPTION_IF_NULL(device_context);
109 if (AnfAlgo::OutputAddrExist(node, index)) {
110 const auto address = AnfAlgo::GetMutableOutputAddr(node, index, false);
111 MS_EXCEPTION_IF_NULL(address);
112 CreateKernelTensor(address, session::AnfRuntimeAlgorithm::GetNodeAbstractByIndex(node, index));
113 return address->GetDeviceType() == device_context->GetDeviceType();
114 }
115 return false;
116 }
117
CopyNoneTensorDataToDevice(const device::DeviceContext * device_context,const device::DeviceAddressPtr & device_address,const ShapeVector & shape)118 void DeviceAddressUtils::CopyNoneTensorDataToDevice(const device::DeviceContext *device_context,
119 const device::DeviceAddressPtr &device_address,
120 const ShapeVector &shape) {
121 MS_EXCEPTION_IF_NULL(device_address);
122 // Break copy data to device address if has the device_address has flag ignore.
123 if (TEST_FLAG(device_address->flag(), device::kDeviceAddressFlagIgnoreDevicePtr)) {
124 MS_LOG(DEBUG) << "Address " << device_address << " has flag ignore device address, so skip copy tensor to device";
125 return;
126 }
127
128 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kConstantValue,
129 device_address->GetSize(), device_address.get());
130 MS_EXCEPTION_IF_NULL(device_context);
131 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
132 if ((device_address->GetPtr() == nullptr) &&
133 (!device_context->device_res_manager_->AllocateMemory(device_address.get()))) {
134 MS_LOG(EXCEPTION) << "Allocate memory failed";
135 }
136
137 // Copy data from host to device.
138 const auto &kernel_tensor = device_address->kernel_tensor();
139 MS_EXCEPTION_IF_NULL(kernel_tensor);
140 auto data_size = kernel_tensor->size();
141 if (data_size == 0) {
142 MS_LOG(INFO) << "Constant size is zero.";
143 return;
144 }
145 const void *node_value = kernel_tensor->GetValuePtr();
146 MS_EXCEPTION_IF_NULL(node_value);
147 auto data_type_id = kernel_tensor->dtype_id();
148 auto format = kernel_tensor->GetStringFormat();
149 if (!device_address->SyncHostToDevice(shape, data_size, data_type_id, node_value, format)) {
150 MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
151 }
152 }
153
CreateDeviceAddressByMapTensorNode(const DeviceContext * device_context,const AnfNodePtr & node,size_t index)154 void DeviceAddressUtils::CreateDeviceAddressByMapTensorNode(const DeviceContext *device_context, const AnfNodePtr &node,
155 size_t index) {
156 MS_EXCEPTION_IF_NULL(node);
157 const auto &abstract_base = AnfAlgo::GetNodeAbstractByIndex(node, index);
158 if (!abstract_base->isa<abstract::AbstractMapTensor>()) {
159 MS_LOG(EXCEPTION) << "Parameter:" << node->DebugString() << " is not a map tensor type.";
160 }
161
162 const auto &abstract = abstract_base->cast<abstract::AbstractMapTensorPtr>();
163 MS_EXCEPTION_IF_NULL(abstract);
164
165 // Parse attrs for user data by abstract.
166 const auto &value_shape = abstract->value_shape();
167 MS_EXCEPTION_IF_NULL(value_shape);
168 const auto &shape_vector = value_shape->shape();
169 const auto &map_tensor_type = abstract->map_tensor_type();
170 MS_EXCEPTION_IF_NULL(map_tensor_type);
171 MS_EXCEPTION_IF_NULL(map_tensor_type->key_dtype());
172 MS_EXCEPTION_IF_NULL(map_tensor_type->value_dtype());
173
174 auto user_data = std::make_shared<UserData>();
175 user_data->set(kUserDataType, std::make_shared<UserDataType>(UserDataType::kUserTypeHashTable));
176 user_data->set(kHashTableKeyType, std::make_shared<TypeId>(map_tensor_type->key_dtype()->type_id()));
177 user_data->set(kHashTableValueType, std::make_shared<TypeId>(map_tensor_type->value_dtype()->type_id()));
178 user_data->set(kHashTableShapeVector, std::make_shared<ShapeVector>(shape_vector));
179 user_data->set(kHashTableDefaultValue, abstract->default_value());
180 user_data->set(kHashTablePermitFilter, abstract->permit_filter_value());
181 user_data->set(kHashTableEvictFilter, abstract->evict_filter_value());
182 // Create device for map tensor node and the ptr size is 1 byte.
183 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
184 {node, index}, nullptr, 1, kOpFormat_DEFAULT, TypeId::kObjectTypeMapTensorType, ShapeVector(),
185 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, user_data);
186 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(node));
187 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
188 MS_LOG(DEBUG) << "Create device tensor:" << device_address << " type:" << device_address->type_id();
189 AnfAlgo::SetOutputAddr(device_address, index, node.get());
190 }
191
CreateParameterDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)192 void DeviceAddressUtils::CreateParameterDeviceAddress(const DeviceContext *device_context,
193 const KernelGraphPtr &graph) {
194 MS_EXCEPTION_IF_NULL(device_context);
195 MS_EXCEPTION_IF_NULL(graph);
196 std::vector<AnfNodePtr> graph_inputs = graph->inputs();
197 const std::vector<bool> &graph_valid_input = graph->valid_inputs();
198 (void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
199
200 // Anf nodes which need create device address.
201 std::vector<AnfNodePtr> nodes_list;
202 for (size_t i = 0; i < graph_inputs.size(); ++i) {
203 AnfNodePtr item = graph_inputs[i];
204 MS_EXCEPTION_IF_NULL(item);
205 if (i < graph_valid_input.size() && !graph_valid_input[i]) {
206 continue;
207 }
208
209 const auto &real_device_context = device::FetchRealDeviceContext(item, device_context);
210 MS_EXCEPTION_IF_NULL(real_device_context);
211 if (common::AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
212 std::vector<AnfNodePtr> outs = common::AnfAlgo::GetAllOutput(item);
213 for (const auto &out : outs) {
214 MS_EXCEPTION_IF_NULL(out);
215 if (!out->isa<Parameter>() || NodeDeviceAddressExist(real_device_context, out, 0)) {
216 continue;
217 }
218 nodes_list.push_back(out);
219 }
220 }
221 if (!item->isa<Parameter>() || NodeDeviceAddressExist(real_device_context, item, 0)) {
222 continue;
223 }
224 nodes_list.push_back(item);
225 }
226
227 // Create device address for anf node in nodes_list
228 for (const auto &item : nodes_list) {
229 MS_EXCEPTION_IF_NULL(item);
230 const auto &real_device_context = device::FetchRealDeviceContext(item, device_context);
231 MS_EXCEPTION_IF_NULL(real_device_context);
232 auto output_size = AnfAlgo::GetOutputTensorNum(item);
233 for (size_t index = 0; index < output_size; index++) {
234 const auto &abstract = AnfAlgo::GetNodeAbstractByIndex(item, index);
235 if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
236 CreateDeviceAddressByMapTensorNode(real_device_context, item, index);
237 continue;
238 }
239
240 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
241 if (output_type_id == kTypeUnknown) {
242 output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
243 }
244
245 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
246 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
247 {item, index}, nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
248 trans::GetRuntimePaddingShape(item, index), real_device_context->device_context_key().device_name_,
249 real_device_context->device_context_key().device_id_);
250 MS_EXCEPTION_IF_NULL(kernel_tensor);
251 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(item));
252 auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
253 MS_EXCEPTION_IF_NULL(device_address);
254 MS_LOG(DEBUG) << "Create device address:" << device_address << " for item:" << item->DebugString();
255 // Set the flag of no user parameter.
256 if (item->isa<Parameter>()) {
257 auto input_param = item->cast<ParameterPtr>();
258 MS_EXCEPTION_IF_NULL(input_param);
259 // Unused address will not alloc memory, which is easy to cause problems for weight node, so skip weight node.
260 if (!common::AnfAlgo::IsParameterWeight(input_param) &&
261 !input_param->IsUsedByRealKernelInGraph(graph->graph_id())) {
262 MS_LOG(INFO) << "Node:" << item->fullname_with_scope() << " debug name:" << item->DebugString()
263 << " is not used in the graph " << graph->graph_id();
264 device_address->UpdateFlag(device::kDeviceAddressFlagNotUsed);
265 }
266 }
267 device_address->SetNodeIndex(item, index);
268 device_address->set_from_persistent_mem(item->isa<Parameter>());
269 MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(item)
270 << " addr:" << device_address << " type:" << device_address->type_id();
271 AnfAlgo::SetOutputAddr(device_address, index, item.get());
272 }
273 }
274 }
275
UpdateDeviceAddressHostInfoByNode(const device::DeviceAddressPtr & addr,const AnfNodePtr & node,size_t output_idx)276 void DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(const device::DeviceAddressPtr &addr, const AnfNodePtr &node,
277 size_t output_idx) {
278 MS_EXCEPTION_IF_NULL(addr);
279 CreateKernelTensor(addr, session::AnfRuntimeAlgorithm::GetNodeAbstractByIndex(node, output_idx));
280 }
281
CreateDeviceAddressForTensorValue(const DeviceContext * device_context,const ValuePtr & node_value,size_t output_idx,const ValueNodePtr & value_node)282 device::DeviceAddressPtrList DeviceAddressUtils::CreateDeviceAddressForTensorValue(const DeviceContext *device_context,
283 const ValuePtr &node_value,
284 size_t output_idx,
285 const ValueNodePtr &value_node) {
286 MS_EXCEPTION_IF_NULL(device_context);
287 MS_EXCEPTION_IF_NULL(node_value);
288 MS_EXCEPTION_IF_NULL(value_node);
289 const auto &ms_context = MsContext::GetInstance();
290 MS_EXCEPTION_IF_NULL(ms_context);
291
292 device::DeviceAddressPtrList address_list;
293 if (node_value->isa<tensor::BaseTensor>()) {
294 auto tensor = node_value->cast<tensor::BaseTensorPtr>();
295 MS_EXCEPTION_IF_NULL(tensor);
296 auto output_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
297 if (output_address != nullptr) {
298 if (output_address->GetDeviceType() == device_context->GetDeviceType()) {
299 // We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
300 // in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
301 UpdateDeviceAddressHostInfoByNode(output_address, value_node, output_idx);
302 AnfAlgo::SetOutputAddr(std::static_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
303 value_node.get());
304 (void)address_list.emplace_back(output_address);
305 return address_list;
306 }
307 tensor->data_sync();
308 }
309 }
310
311 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
312 TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
313 if (output_type_id == kTypeUnknown) {
314 output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
315 if (output_type_id == kTypeUnknown && value_node->value() != nullptr && value_node->value()->isa<ValueTuple>() &&
316 value_node->value()->cast<ValueTuplePtr>()->size() == 0) {
317 MS_LOG(DEBUG) << "Set int64 type for empty value tuple node:" << value_node->DebugString();
318 output_type_id = TypeId::kNumberTypeInt64;
319 }
320 }
321 std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
322
323 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
324 {value_node, output_idx}, nullptr, tensor_size, output_format, output_type_id, {},
325 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
326 kernel_tensor->set_host_shape(kernel_tensor->GetShapeVector());
327 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
328 device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
329 MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address
330 << " size:" << tensor_size << " format:" << output_format << " type:" << output_type_id
331 << " shape:" << kernel_tensor->GetShapeVector();
332 MS_EXCEPTION_IF_NULL(address);
333 address->set_from_persistent_mem(true);
334 AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
335 (void)address_list.emplace_back(address);
336 return address_list;
337 }
338
FetchValueNodesNeedDevicePtr(const KernelGraphPtr & graph)339 mindspore::HashSet<mindspore::AnfNodePtr> FetchValueNodesNeedDevicePtr(const KernelGraphPtr &graph) {
340 mindspore::HashSet<mindspore::AnfNodePtr> nodes;
341 auto topo_nodes = TopoSort(graph->get_return());
342 for (auto const &n : topo_nodes) {
343 if (!n->isa<CNode>()) {
344 continue;
345 }
346 auto node = n->cast<CNodePtr>();
347 auto op_name = common::AnfAlgo::GetCNodeName(node);
348 auto input_num = common::AnfAlgo::GetInputTensorNum(node);
349 mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
350 if (op_def == nullptr) {
351 MS_LOG(DEBUG) << op_name << " is not found in OpDef.";
352 for (size_t i = 0; i < input_num; i++) {
353 auto input = common::AnfAlgo::GetInputNode(node, i);
354 (void)nodes.insert(input);
355 }
356 continue;
357 }
358 auto args = op_def->args_;
359 if (input_num != args.size()) {
360 int input_with_init_args = std::count_if(args.begin(), args.end(), [](auto arg) { return arg.as_init_arg_; });
361 size_t total = input_num - IntToSize(input_with_init_args);
362 for (size_t i = 0; i < total; i++) {
363 (void)nodes.insert(common::AnfAlgo::GetInputNode(node, i));
364 }
365 MS_LOG(DEBUG) << "Node " << op_name << ", has " << input_num << " inputs, but has " << args.size()
366 << " inputs in op_def, it means allsame input, input with init args number: "
367 << input_with_init_args;
368 continue;
369 }
370 for (size_t i = 0; i < input_num; i++) {
371 if (args[i].as_init_arg_ == 0) {
372 auto input = common::AnfAlgo::GetInputNode(node, i);
373 (void)nodes.insert(input);
374 }
375 }
376 }
377 return nodes;
378 }
379
CreateDeviceAddressForTypeValue(const DeviceContext * device_context,const ValueNodePtr & value_node)380 device::DeviceAddressPtr CreateDeviceAddressForTypeValue(const DeviceContext *device_context,
381 const ValueNodePtr &value_node) {
382 MS_EXCEPTION_IF_NULL(device_context);
383 MS_EXCEPTION_IF_NULL(value_node);
384 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
385 {value_node, 0}, nullptr, 0, kOpFormat_DEFAULT, kMetaTypeTypeType, {},
386 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
387 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
388 device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
389 MS_LOG(DEBUG) << "Create addr for node:" << value_node->DebugString() << " addr:" << address;
390 MS_EXCEPTION_IF_NULL(address);
391 address->set_from_persistent_mem(true);
392 AnfAlgo::SetOutputAddr(address, 0, value_node.get());
393 return address;
394 }
395
CreateValueNodeDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)396 void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *device_context,
397 const KernelGraphPtr &graph) {
398 MS_EXCEPTION_IF_NULL(device_context);
399 MS_EXCEPTION_IF_NULL(graph);
400 #ifdef ENABLE_DEBUGGER
401 auto debugger = Debugger::GetInstance();
402 auto &dump_json_parser = DumpJsonParser::GetInstance();
403 bool enable_debug = debugger->debugger_enabled() || dump_json_parser.InputNeedDump();
404 #endif
405 // store node without init args, means need device addr
406 auto value_nodes_without_init_args = FetchValueNodesNeedDevicePtr(graph);
407 for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
408 MS_EXCEPTION_IF_NULL(value_node);
409 if (NodeDeviceAddressExist(device_context, value_node, 0)) {
410 continue;
411 }
412
413 const auto &abstract = value_node->abstract();
414 if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
415 CreateDeviceAddressByMapTensorNode(device_context, value_node, 0);
416 continue;
417 }
418 const auto &node_value = value_node->value();
419 MS_EXCEPTION_IF_NULL(node_value);
420 if (node_value->isa<tensor::BaseTensor>() || node_value->isa<ValueSequence>()) {
421 auto address_list = CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
422 // Deal with tensor and tuple
423 if (value_nodes_without_init_args.find(value_node) == value_nodes_without_init_args.end()) {
424 for (const auto &address : address_list) {
425 #ifdef ENABLE_DEBUGGER
426 if (enable_debug) {
427 continue;
428 }
429 #endif
430 address->UpdateFlag(device::kDeviceAddressFlagIgnoreDevicePtr);
431 MS_LOG(DEBUG) << "Find node " << value_node->DebugString() << " has init args";
432 }
433 }
434 continue;
435 } else if (node_value->isa<Type>()) {
436 CreateDeviceAddressForTypeValue(device_context, value_node);
437 continue;
438 }
439
440 device::DeviceAddressPtr address = CreateDeviceAddressForScalarAndString(device_context, value_node);
441 // Deal with string and scalar; Address will be nullptr if the input is a type.
442 if (address && (value_nodes_without_init_args.find(value_node) == value_nodes_without_init_args.end())) {
443 address->UpdateFlag(device::kDeviceAddressFlagIgnoreDevicePtr);
444 MS_LOG(DEBUG) << "Find node " << value_node->DebugString() << " has init args";
445 #ifdef ENABLE_DEBUGGER
446 if (enable_debug) {
447 address->ClearFlag(device::kDeviceAddressFlagIgnoreDevicePtr);
448 }
449 #endif
450 }
451 if (address != nullptr) {
452 MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node)
453 << " addr:" << address;
454 address->set_from_persistent_mem(true);
455 AnfAlgo::SetOutputAddr(address, 0, value_node.get());
456 } else {
457 MS_LOG(INFO) << "No device address for value node:" << value_node->fullname_with_scope()
458 << ", debug name:" << common::AnfAlgo::GetNodeDebugString(value_node);
459 }
460 }
461 }
462
CreateKernelOutputDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph,bool is_gradient_out)463 void DeviceAddressUtils::CreateKernelOutputDeviceAddress(const DeviceContext *device_context,
464 const KernelGraphPtr &graph, bool is_gradient_out) {
465 MS_EXCEPTION_IF_NULL(device_context);
466 MS_EXCEPTION_IF_NULL(graph);
467
468 if (graph->memory_managed_by_ge()) {
469 return;
470 }
471 MS_LOG(DEBUG) << "Start create kernel output device address for graph:" << graph->ToString();
472 bool is_pynative_bprop_graph = graph->has_flag(kFlagIsPynativeBpropGraph);
473 auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
474
475 const std::vector<CNodePtr> &kernels = graph->execution_order();
476 for (const auto &kernel : kernels) {
477 MS_EXCEPTION_IF_NULL(kernel);
478 if (common::AnfAlgo::IsBpropCutOpExecInBackend(kernel)) {
479 continue;
480 }
481
482 bool is_from_persistent_mem =
483 (is_gradient_out || (is_pynative_bprop_graph && (find(outputs.begin(), outputs.end(), kernel) != outputs.end())));
484
485 auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
486 for (size_t i = 0; i < output_size; ++i) {
487 if (AnfAlgo::OutputAddrExist(kernel, i)) {
488 continue;
489 }
490
491 const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_context);
492 MS_EXCEPTION_IF_NULL(real_device_context);
493 const auto &abstract = AnfAlgo::GetNodeAbstractByIndex(kernel, i);
494 if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
495 CreateDeviceAddressByMapTensorNode(real_device_context, kernel, i);
496 continue;
497 }
498 auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
499 auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
500 auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
501 UserDataPtr user_data = nullptr;
502 auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel->kernel_info());
503 MS_EXCEPTION_IF_NULL(kernel_info);
504 if (kernel_info->kernel_mod() != nullptr && kernel_info->kernel_mod()->need_user_data()) {
505 user_data = std::make_shared<UserData>();
506 user_data->set(kSyncUserDataHandler,
507 std::make_shared<device::DeviceAddress::SyncUserDataHandler>(pyexecute::UserDataToRawMemory));
508 graph->set_has_kernel_need_user_data(true);
509 }
510 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
511 {kernel, i}, nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(kernel, i),
512 real_device_context->device_context_key().device_name_, real_device_context->device_context_key().device_id_,
513 user_data);
514 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(kernel));
515 MS_LOG(DEBUG) << "Kernel tensor created without set stream id, but set after device address created.";
516 auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
517 device_address->SetNodeIndex(kernel, i);
518 if (is_from_persistent_mem) {
519 device_address->set_from_persistent_mem(true);
520 }
521 MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
522 << " addr:" << device_address << " type:" << device_address->type_id()
523 << ", kernel tensor addr:" << kernel_tensor.get()
524 << ", kernel tensor: " << kernel_tensor->ToString() << " addr size:" << address_size
525 << " real size:" << device_address->GetSize()
526 << " origin ref count:" << device_address->original_ref_count();
527 device_address->set_stream_id(AnfAlgo::GetStreamId(kernel));
528 AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
529 }
530 }
531 MS_LOG(DEBUG) << "End create kernel output device address for graph:" << graph->ToString();
532 }
533
CreateGraphOutputDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)534 void DeviceAddressUtils::CreateGraphOutputDeviceAddress(const DeviceContext *device_context,
535 const KernelGraphPtr &graph) {
536 MS_EXCEPTION_IF_NULL(device_context);
537 MS_EXCEPTION_IF_NULL(graph);
538 auto output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
539 for (const auto &output_with_index : output_with_indexs) {
540 const auto &output = output_with_index.first;
541 MS_EXCEPTION_IF_NULL(output);
542 if (common::AnfAlgo::IsBpropCutOpExecInBackend(output) || HasAbstractMonad(output)) {
543 continue;
544 }
545 auto output_size = AnfAlgo::GetOutputAddressNum(output);
546 for (size_t i = 0; i < output_size; ++i) {
547 if (AnfAlgo::OutputAddrExist(output, i)) {
548 continue;
549 }
550
551 const auto &real_device_context = device::FetchRealDeviceContext(output, device_context);
552 MS_EXCEPTION_IF_NULL(real_device_context);
553 MS_EXCEPTION_IF_NULL(real_device_context->device_res_manager_);
554 auto output_format = AnfAlgo::GetOutputFormat(output, i);
555 auto output_type = AnfAlgo::GetOutputDeviceDataType(output, i);
556 auto address_size = AnfAlgo::GetOutputTensorMemSize(output, i);
557 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
558 {output, i}, nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(output, i),
559 real_device_context->device_context_key().device_name_, real_device_context->device_context_key().device_id_);
560 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(output));
561 auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
562 MS_LOG(DEBUG) << "Create addr for node:" << output->DebugString() << " addr:" << device_address
563 << " type:" << device_address->type_id();
564 AnfAlgo::SetOutputAddr(device_address, i, output.get());
565 }
566 }
567 }
568
GetTensorDeviceSize(const DeviceContext * device_context,const AnfNodePtr & node,const ShapeVector & shape,const string & format,TypeId dtype,size_t output_index)569 size_t DeviceAddressUtils::GetTensorDeviceSize(const DeviceContext *device_context, const AnfNodePtr &node,
570 const ShapeVector &shape, const string &format, TypeId dtype,
571 size_t output_index) {
572 MS_EXCEPTION_IF_NULL(device_context);
573 auto device_shape = shape;
574 if (device_context->GetDeviceType() == device::DeviceType::kAscend) {
575 if (device_shape.empty() && format != kOpFormat_DEFAULT) {
576 device_shape = trans::PaddingShape(device_shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
577 device_shape = trans::TransShapeToDevice(device_shape, format, node, output_index, dtype);
578 } else {
579 if (trans::IsNeedPadding(format, device_shape)) {
580 device_shape =
581 trans::PaddingShape(device_shape, format, AnfAlgo::GetOutputReshapeType(node, output_index), node);
582 }
583 device_shape = trans::TransShapeToDevice(device_shape, format, node, output_index, dtype);
584 }
585 }
586 size_t type_size = GetTypeByte(TypeIdToType(dtype));
587 size_t tensor_size = type_size * SizeOf(device_shape);
588 return tensor_size;
589 }
590
CreateGraphOutputDeviceAddress(const OpCompilerInfoPtr & op_compiler_info,const abstract::AbstractBasePtr & out_abstract,size_t stream_id)591 vector<device::DeviceAddressPtr> DeviceAddressUtils::CreateGraphOutputDeviceAddress(
592 const OpCompilerInfoPtr &op_compiler_info, const abstract::AbstractBasePtr &out_abstract, size_t stream_id) {
593 auto device_context = op_compiler_info->device_context_;
594 const auto &output_edges = op_compiler_info->simple_graph_->outputs_;
595 size_t output_num = output_edges.size();
596
597 std::vector<device::DeviceAddressPtr> output_address_list;
598 output_address_list.reserve(output_num);
599
600 for (size_t i = 0; i < output_num; ++i) {
601 const auto &edge = output_edges[i];
602 const auto &address = edge->address_;
603 if (address != nullptr) {
604 MS_LOG(DEBUG) << "Already have output device address for ref output";
605 output_address_list.push_back(address);
606 continue;
607 }
608
609 const auto &[output_node, index] = edge->node_with_index_;
610 const auto &cache_output_address = edge->origin_address_;
611
612 auto real_abstract = out_abstract;
613 if (out_abstract->isa<abstract::AbstractTuple>()) {
614 auto abstract_tuple = out_abstract->cast<abstract::AbstractTuplePtr>();
615 if (i >= abstract_tuple->elements().size()) {
616 MS_LOG(EXCEPTION) << "abstract_tuple size is " << abstract_tuple->elements().size() << " ,but get index is"
617 << i;
618 }
619 real_abstract = abstract_tuple->elements()[i];
620 }
621 auto output_shape_ptr = real_abstract->BuildShape();
622 MS_EXCEPTION_IF_NULL(output_shape_ptr);
623 auto shape_vector = output_shape_ptr->cast<abstract::ShapePtr>();
624 MS_EXCEPTION_IF_NULL(shape_vector);
625 const auto &shape = shape_vector->shape();
626 auto output_type = cache_output_address->type_id();
627 const auto &output_format = cache_output_address->format();
628 auto address_size = GetTensorDeviceSize(device_context, output_node, shape, output_format, output_type, index);
629 const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
630 real_abstract->GetShape()->Clone(), real_abstract->GetType()->Clone(), real_abstract->GetValue(), nullptr,
631 address_size, output_format, output_type, shape, device_context->device_context_key().device_name_,
632 device_context->device_context_key().device_id_, cache_output_address->user_data());
633 kernel_tensor->set_stream_id(stream_id);
634 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
635 MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(output_node)
636 << " addr:" << device_address;
637 output_address_list.push_back(device_address);
638 edge->address_ = device_address;
639 }
640 return output_address_list;
641 }
642
CreateKernelWorkspaceDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)643 void DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context,
644 const KernelGraphPtr &graph) {
645 MS_EXCEPTION_IF_NULL(device_context);
646 MS_EXCEPTION_IF_NULL(graph);
647
648 if (graph->memory_managed_by_ge()) {
649 return;
650 }
651
652 const std::vector<CNodePtr> &kernels = graph->execution_order();
653 for (const auto &kernel : kernels) {
654 MS_EXCEPTION_IF_NULL(kernel);
655 if (common::AnfAlgo::IsBpropCutOpExecInBackend(kernel)) {
656 continue;
657 }
658 const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_context);
659 MS_EXCEPTION_IF_NULL(real_device_context);
660 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
661 MS_EXCEPTION_IF_NULL(kernel_mod);
662 auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
663 for (size_t i = 0; i < workspace_sizes.size(); ++i) {
664 if (AnfAlgo::WorkspaceAddrExist(kernel, i)) {
665 break;
666 }
667 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
668 nullptr, workspace_sizes[i], Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
669 real_device_context->device_context_key().device_name_, real_device_context->device_context_key().device_id_);
670 kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(kernel));
671 auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
672 MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
673 << " addr:" << device_address;
674 AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
675 }
676 }
677 }
678
UpdateDeviceAddressForInplaceNode(const KernelGraphPtr & graph)679 void DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
680 MS_EXCEPTION_IF_NULL(graph);
681
682 if (graph->memory_managed_by_ge()) {
683 return;
684 }
685
686 // Collect the inplace groups.
687 std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
688 const std::vector<CNodePtr> &kernels = graph->execution_order();
689 for (const auto &kernel : kernels) {
690 if (!common::AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
691 continue;
692 }
693 auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
694 MS_EXCEPTION_IF_NULL(primitive);
695 auto inplace_group_attr = primitive->GetAttr("inplace_group");
696 MS_EXCEPTION_IF_NULL(inplace_group_attr);
697 auto group_id = GetValue<uint32_t>(inplace_group_attr);
698 (void)inplace_groups[group_id].emplace_back(kernel);
699 }
700
701 constexpr size_t kMinInplaceGroupSize = 2;
702 for (const auto &inplace_group : inplace_groups) {
703 auto &group_nodes = inplace_group.second;
704 if (group_nodes.size() < kMinInplaceGroupSize) {
705 continue;
706 }
707 // Get the device address of the first node in the inplace group.
708 auto node_primitive = common::AnfAlgo::GetCNodePrimitive(group_nodes[0]);
709 MS_EXCEPTION_IF_NULL(node_primitive);
710 auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
711 auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
712 MS_EXCEPTION_IF_NULL(device_address);
713
714 // Update the device address of other nodes using device address of the first node in the inplace group.
715 for (size_t i = 1; i < group_nodes.size(); ++i) {
716 auto &group_node = group_nodes[i];
717 auto prim = common::AnfAlgo::GetCNodePrimitive(group_node);
718 MS_EXCEPTION_IF_NULL(prim);
719 auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
720 auto group_node_device_address = AnfAlgo::GetMutableOutputAddr(group_node, index, false);
721 MS_EXCEPTION_IF_NULL(group_node_device_address);
722 // Update the reference count of device address.
723 device_address->IncreaseOriginalRefCount();
724 MS_LOG(DEBUG) << "After increase ref count for device address:" << device_address
725 << " ref count:" << device_address->original_ref_count();
726 device_address->ResetRefCount();
727 group_node_device_address->set_pointer_ref_count(device_address->pointer_ref_count());
728 }
729 }
730 }
731
UpdateDeviceAddress(const session::AnfWithOutIndex & cur_pair,const session::AnfWithOutIndex & origin_pair)732 void DeviceAddressUtils::UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair,
733 const session::AnfWithOutIndex &origin_pair) {
734 MS_EXCEPTION_IF_NULL(cur_pair.first);
735 MS_EXCEPTION_IF_NULL(origin_pair.first);
736 MS_LOG(INFO) << "Ref node pair: origin kernel is " << origin_pair.first->fullname_with_scope() << ", index is "
737 << origin_pair.second << "; cur kernel is " << cur_pair.first->fullname_with_scope() << ", index is "
738 << cur_pair.second;
739 // If the output of ref node is parameter, need add the monad attr(for example Transdata/Cast node to ref
740 // parameter).
741 if (!common::AnfAlgo::HasMonadInput(cur_pair.first) && origin_pair.first->isa<Parameter>()) {
742 MS_LOG(INFO) << cur_pair.first->fullname_with_scope() << "with index " << cur_pair.second
743 << " ref node to parameter " << origin_pair.first->fullname_with_scope() << " and add the monad attr.";
744 common::AnfAlgo::SetNodeAttr(kAttrRefNodeMonadOutputIdx, MakeValue(cur_pair.second), cur_pair.first);
745 }
746
747 auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second, false);
748 MS_EXCEPTION_IF_NULL(origin_node_output_addr);
749 auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(cur_pair.first, cur_pair.second, false);
750 MS_EXCEPTION_IF_NULL(cur_node_output_addr);
751 auto origin_stream_id = origin_node_output_addr->stream_id();
752 auto cur_stream_id = cur_node_output_addr->stream_id();
753 if (origin_stream_id != cur_stream_id) {
754 MS_LOG(DEBUG) << "Origin node output addr : " << origin_node_output_addr << " stream id : " << origin_stream_id
755 << " is not equal to cur node output addr stream id : " << cur_stream_id << ".";
756 }
757
758 // Update the device address flag.
759 origin_node_output_addr->UpdateFlag(device::kDeviceAddressFlagRefNode);
760
761 if (origin_node_output_addr->pointer_ref_count() != cur_node_output_addr->pointer_ref_count()) {
762 // Check the device target whether consistent.
763 if (origin_node_output_addr->GetDeviceType() != cur_node_output_addr->GetDeviceType()) {
764 std::string error_info =
765 "Device target is not consistent: ref origin kernel is " + origin_pair.first->fullname_with_scope() +
766 ", index is " + std::to_string(origin_pair.second) + ", device target is " +
767 device::GetDeviceNameByType(origin_node_output_addr->GetDeviceType()) + "; cur kernel is " +
768 cur_pair.first->fullname_with_scope() + ", index is " + std::to_string(cur_pair.second) +
769 ", device target is " + device::GetDeviceNameByType(cur_node_output_addr->GetDeviceType());
770
771 MS_LOG(ERROR) << error_info;
772 if (AnfAlgo::IsKernelSelectBackoffOp(origin_pair.first)) {
773 const auto &backoff_info = AnfAlgo::GetKernelSelectBackoffInfo(origin_pair.first);
774 MS_EXCEPTION(backoff_info.second) << "#umsg#Kernel select failed:#umsg#" << backoff_info.second;
775 } else if (AnfAlgo::IsKernelSelectBackoffOp(cur_pair.first)) {
776 const auto &backoff_info = AnfAlgo::GetKernelSelectBackoffInfo(cur_pair.first);
777 MS_EXCEPTION(backoff_info.second) << "#umsg#Kernel select failed:#umsg#" << backoff_info.second;
778 } else {
779 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << error_info;
780 }
781 }
782 MS_LOG(INFO) << "Update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
783 << ", index is " << origin_pair.second << "; cur kernel is " << cur_pair.first->fullname_with_scope()
784 << ", index is " << cur_pair.second;
785 // Update the reference count of device address.
786 cur_node_output_addr->DecreaseOriginalRefCount();
787 cur_node_output_addr->ResetRefCount();
788 origin_node_output_addr->IncreaseOriginalRefCount();
789 MS_LOG(DEBUG) << "After increase ref count for device address:" << origin_node_output_addr
790 << " ref count:" << origin_node_output_addr->original_ref_count();
791 origin_node_output_addr->ResetRefCount();
792 cur_node_output_addr->set_pointer_ref_count(origin_node_output_addr->pointer_ref_count());
793 cur_node_output_addr->UpdateFlag(device::kDeviceAddressFlagRefNode);
794 } else {
795 MS_LOG(DEBUG) << "No need update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
796 << ", index is " << origin_pair.second << "; cur kernel is " << cur_pair.first->fullname_with_scope()
797 << ", index is " << cur_pair.second;
798 }
799 }
800
UpdateDeviceAddressForRefNode(const KernelGraphPtr & graph)801 void DeviceAddressUtils::UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph) {
802 MS_EXCEPTION_IF_NULL(graph);
803
804 if (graph->memory_managed_by_ge()) {
805 return;
806 }
807
808 AnfAlgo::UpdateGraphValidRefPair(graph);
809 for (const auto &ref_pair : graph->GetRefMap()) {
810 const auto &out_pair = ref_pair.first;
811 const auto &origin_pair = ref_pair.second;
812 const auto &recursive_origin_pair = graph->GetRefNodeRecursive(out_pair);
813 UpdateDeviceAddress(out_pair, recursive_origin_pair);
814 // Update ref map in kernel info which will be used in kernel actor on swap scenario.
815 for (size_t input_index = 0; input_index < common::AnfAlgo::GetInputTensorNum(out_pair.first); ++input_index) {
816 const auto &prev_node_output = common::AnfAlgo::GetPrevNodeOutput(out_pair.first, input_index, false);
817 if (prev_node_output == origin_pair) {
818 auto kernel_info = dynamic_cast<device::KernelInfo *>(out_pair.first->kernel_info());
819 MS_EXCEPTION_IF_NULL(kernel_info);
820 kernel_info->AddRefMap(out_pair.second, input_index);
821 break;
822 }
823 }
824 }
825 }
826
CloneEmptyDeviceAddress(const device::DeviceAddressPtr & old_device_address,const DeviceContext * device_context)827 device::DeviceAddressPtr DeviceAddressUtils::CloneEmptyDeviceAddress(const device::DeviceAddressPtr &old_device_address,
828 const DeviceContext *device_context) {
829 MS_EXCEPTION_IF_NULL(old_device_address);
830 MS_EXCEPTION_IF_NULL(device_context);
831 const auto &kernel_tensor = old_device_address->kernel_tensor();
832 MS_EXCEPTION_IF_NULL(kernel_tensor);
833 auto new_kernel_tensor = kernel_tensor->CloneKernelTensor();
834 MS_EXCEPTION_IF_NULL(new_kernel_tensor);
835
836 new_kernel_tensor->set_device_name(device_context->device_context_key().device_name_);
837 new_kernel_tensor->set_device_id(device_context->device_context_key().device_id_);
838 new_kernel_tensor->set_device_ptr(nullptr);
839 auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
840 MS_EXCEPTION_IF_NULL(new_device_address);
841 MS_LOG(DEBUG) << "Create device tensor:" << new_device_address << " type:" << new_device_address->type_id();
842
843 new_device_address->set_original_ref_count(old_device_address->original_ref_count());
844 new_device_address->ResetRefCount();
845 auto node = old_device_address->GetNodeIndex();
846 new_device_address->SetNodeIndex(node.first, node.second);
847 new_device_address->set_padding_type(old_device_address->padding_type());
848 return new_device_address;
849 }
850
CreateInputTensorAddress(const DeviceContext * device_context,size_t stream_id,size_t index,const tensor::BaseTensorPtr & tensor)851 void DeviceAddressUtils::CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index,
852 const tensor::BaseTensorPtr &tensor) {
853 MS_EXCEPTION_IF_NULL(device_context);
854 MS_EXCEPTION_IF_NULL(tensor);
855
856 auto addr = tensor->device_address();
857 if (addr != nullptr) {
858 auto device_address = std::static_pointer_cast<device::DeviceAddress>(addr);
859 if (device_address->GetDeviceType() != device::DeviceType::kAscend) {
860 // CPU or GPU View CreateDeviceAddress without KernelTensor
861 CreateKernelTensor(device_address, tensor);
862 }
863 if (device_address->GetDeviceType() == device_context->GetDeviceType()) {
864 MS_LOG(DEBUG) << "Already have device address of tensor " << tensor->id();
865 return;
866 }
867 MS_LOG(DEBUG) << "Input tensor device type is " << device_address->GetDeviceType()
868 << " but current device context is " << device_context->GetDeviceType();
869 tensor->data_sync();
870 tensor->set_device_address(nullptr);
871 }
872 auto tensor_size = LongToSize(tensor->data().nbytes());
873 const auto &format = GetFormatByTensorShape(device_context, tensor->shape());
874 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
875 nullptr, tensor_size, tensor->shape(), format, tensor->data_type(),
876 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
877 if (device_address->GetDeviceType() != device::DeviceType::kAscend) {
878 // CPU or GPU need KernelTensor to LaunchKernel
879 CreateKernelTensor(device_address, tensor);
880 }
881
882 MS_EXCEPTION_IF_NULL(device_address);
883 device_address->set_from_persistent_mem(tensor->is_parameter());
884 tensor->set_device_address(device_address);
885 MS_LOG(DEBUG) << "Create input tensor device address " << device_address << " for " << index
886 << "th input, Shape: " << tensor->shape() << ", Type: " << TypeIdToType(tensor->data_type())->ToString()
887 << ", Size:" << tensor_size;
888 }
889
MallocForInput(const DeviceContext * device_context,const tensor::BaseTensorPtr & tensor,bool is_view)890 void DeviceAddressUtils::MallocForInput(const DeviceContext *device_context, const tensor::BaseTensorPtr &tensor,
891 bool is_view) {
892 MS_EXCEPTION_IF_NULL(tensor);
893 const auto &device_sync = tensor->device_address();
894 auto device_address = std::static_pointer_cast<device::DeviceAddress>(device_sync);
895 MS_EXCEPTION_IF_NULL(device_address);
896 device_address->set_is_view(is_view);
897
898 if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
899 auto mem_type =
900 tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
901 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
902 device_address.get());
903 }
904 if (device_address->GetMutablePtr() != nullptr) {
905 if (!is_view || device_address->GetDeviceType() != device::DeviceType::kCPU || device_address->from_mem_pool()) {
906 return;
907 }
908 // If not from the pool, the lifetime of the device ptr is guaranteed elsewhere.
909 // Before applying for a new address, clear the address. Otherwise a warnging is generated.
910 device_address->set_ptr(nullptr);
911 const auto new_device_context = device_context->GetDeviceType() == device_address->GetDeviceType()
912 ? device_context
913 : runtime::OpRunner::GetDeviceContext(kCPUDevice);
914
915 MS_EXCEPTION_IF_NULL(new_device_context);
916 if (!new_device_context->device_res_manager_->AllocateMemory(device_address.get())) {
917 MS_LOG(EXCEPTION) << "Allocate memory failed";
918 }
919 } else {
920 if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
921 MS_LOG(EXCEPTION) << "Allocate memory failed";
922 }
923 }
924
925 auto tensor_size = LongToSize(tensor->data().nbytes());
926 if (device_address->GetDeviceType() == device::DeviceType::kAscend) {
927 OpExecutor::DispatchLaunchTask([=]() {
928 if (!device_address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), device_address->format(),
929 tensor->data_ptr())) {
930 MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
931 }
932 });
933 } else {
934 if (!device_address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), device_address->format(),
935 tensor->data_ptr())) {
936 MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
937 }
938 }
939 }
940
MallocForInput(const DeviceContext * device_context,const std::vector<tensor::BaseTensorPtr> & tensors,bool is_view)941 void DeviceAddressUtils::MallocForInput(const DeviceContext *device_context,
942 const std::vector<tensor::BaseTensorPtr> &tensors, bool is_view) {
943 for (const auto &tensor : tensors) {
944 MallocForInput(device_context, tensor, is_view);
945 }
946 }
947
MallocForInput(const DeviceContext * device_context,const std::optional<tensor::BaseTensorPtr> & val,bool is_view)948 void DeviceAddressUtils::MallocForInput(const DeviceContext *device_context,
949 const std::optional<tensor::BaseTensorPtr> &val, bool is_view) {
950 if (!val.has_value()) {
951 return;
952 }
953 MallocForInput(device_context, val.value(), is_view);
954 }
955
CreateInputTensorAddress(const DeviceContext * device_context,size_t stream_id,size_t index,const std::optional<tensor::BaseTensorPtr> & val)956 void DeviceAddressUtils::CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index,
957 const std::optional<tensor::BaseTensorPtr> &val) {
958 if (!val.has_value()) {
959 return;
960 }
961 CreateInputTensorAddress(device_context, stream_id, index, val.value());
962 }
963
CreateKernelTensor(const device::DeviceAddressPtr & device_address,const tensor::BaseTensorPtr & tensor)964 void DeviceAddressUtils::CreateKernelTensor(const device::DeviceAddressPtr &device_address,
965 const tensor::BaseTensorPtr &tensor) {
966 MS_EXCEPTION_IF_NULL(device_address);
967 MS_EXCEPTION_IF_NULL(tensor);
968 if (device_address->kernel_tensor() != nullptr) {
969 return;
970 }
971 const auto &address_common = device_address->address_common();
972 MS_EXCEPTION_IF_NULL(address_common);
973 auto real_kernel_tensor = std::make_shared<kernel::KernelTensor>(
974 address_common, std::make_shared<abstract::TensorShape>(tensor->shape()),
975 std::make_shared<TensorType>(TypeIdToType(tensor->data_type())), nullptr, tensor->shape());
976 device_address->set_kernel_tensor(real_kernel_tensor);
977 device_address->DeviceSynchronizerInit();
978 }
979
CreateKernelTensor(const ValuePtr & input_value)980 void DeviceAddressUtils::CreateKernelTensor(const ValuePtr &input_value) {
981 MS_EXCEPTION_IF_NULL(input_value);
982 if (input_value->isa<tensor::BaseTensor>()) {
983 auto tensor = input_value->cast<tensor::BaseTensorPtr>();
984 if (tensor->device_address() != nullptr) {
985 auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
986 MS_EXCEPTION_IF_NULL(device_address);
987 CreateKernelTensor(device_address, tensor);
988 }
989 }
990 }
991
CreateKernelTensor(const tensor::TensorPtr & input_tensor)992 void DeviceAddressUtils::CreateKernelTensor(const tensor::TensorPtr &input_tensor) {
993 MS_EXCEPTION_IF_NULL(input_tensor);
994 if (input_tensor->device_address() != nullptr) {
995 auto device_address = std::static_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
996 MS_EXCEPTION_IF_NULL(device_address);
997 CreateKernelTensor(device_address, input_tensor);
998 }
999 }
1000
CreateKernelTensor(const device::DeviceAddressPtr & device_address,const AbstractBasePtr & abs)1001 void DeviceAddressUtils::CreateKernelTensor(const device::DeviceAddressPtr &device_address,
1002 const AbstractBasePtr &abs) {
1003 MS_EXCEPTION_IF_NULL(device_address);
1004 if (device_address->kernel_tensor() != nullptr) {
1005 return;
1006 }
1007 const auto address_common = device_address->address_common();
1008 MS_EXCEPTION_IF_NULL(address_common);
1009 MS_EXCEPTION_IF_NULL(abs);
1010 const auto &shape = abs->GetShape();
1011 const auto &type = abs->GetType();
1012 auto real_kernel_tensor =
1013 std::make_shared<kernel::KernelTensor>(address_common, shape, type, nullptr, shape->GetShapeVector());
1014 device_address->set_kernel_tensor(real_kernel_tensor);
1015 device_address->DeviceSynchronizerInit();
1016 }
1017
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const tensor::BaseTensorPtr & tensor)1018 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1019 const abstract::AbstractBasePtr &abs, size_t index,
1020 const tensor::BaseTensorPtr &tensor) {
1021 MS_EXCEPTION_IF_NULL(device_context);
1022 MS_EXCEPTION_IF_NULL(tensor);
1023 auto addr = tensor->device_address();
1024 if (addr != nullptr) {
1025 auto device_address = std::static_pointer_cast<device::DeviceAddress>(addr);
1026 MS_EXCEPTION_IF_NULL(device_address);
1027 if (device_address->GetPtr() != nullptr) {
1028 MS_LOG(DEBUG) << "Input tensor already have address " << device_address.get() << " and device Ptr "
1029 << device_address->GetPtr();
1030 return device_address;
1031 }
1032 }
1033 BaseShapePtr shape;
1034 TypePtr type;
1035 if (abs != nullptr) {
1036 shape = abs->GetShape();
1037 type = abs->GetType();
1038 } else {
1039 shape = std::make_shared<abstract::Shape>(tensor->shape());
1040 type = tensor->Dtype();
1041 }
1042
1043 const auto &tensor_size = LongToSize(tensor->data().nbytes());
1044 const auto &format = GetFormatByTensorShape(device_context, tensor->shape());
1045 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1046 shape, type, nullptr, nullptr, tensor_size, kernel::GetFormatFromEnumToStr(format), tensor->data_type(),
1047 tensor->shape(), device_context->device_context_key().device_name_,
1048 device_context->device_context_key().device_id_);
1049 kernel_tensor->set_stream_id(stream_id);
1050 device::DeviceAddressPtr device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1051 MS_EXCEPTION_IF_NULL(device_address);
1052 device_address->set_from_persistent_mem(tensor->is_parameter());
1053 tensor->set_device_address(device_address);
1054
1055 auto mem_type = tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kConstantValue;
1056 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
1057 device_address.get());
1058 if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1059 MS_LOG(EXCEPTION) << "Allocate memory failed";
1060 }
1061 if (!device_address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(),
1062 kernel::GetFormatFromEnumToStr(format), tensor->data_ptr())) {
1063 MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
1064 }
1065 MS_LOG(DEBUG) << "Create input tensor device address " << device_address << " for " << index
1066 << "th input, Shape: " << shape->ToString()
1067 << ", Type: " << TypeIdToType(tensor->data_type())->ToString() << ", host shape: " << tensor->shape()
1068 << ", dev ptr " << device_address->GetPtr();
1069 return device_address;
1070 }
1071
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const ScalarPtr & scalar_value)1072 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1073 const abstract::AbstractBasePtr &abs, size_t index,
1074 const ScalarPtr &scalar_value) {
1075 MS_EXCEPTION_IF_NULL(device_context);
1076 MS_EXCEPTION_IF_NULL(scalar_value);
1077 const auto type = scalar_value->type();
1078 MS_EXCEPTION_IF_NULL(type);
1079 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1080 abstract::kNoShape, type, scalar_value, nullptr, GetTypeByte(TypeIdToType(type->type_id())), kOpFormat_DEFAULT,
1081 type->type_id(), ShapeVector(), device_context->device_context_key().device_name_,
1082 device_context->device_context_key().device_id_);
1083 kernel_tensor->set_stream_id(stream_id);
1084 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1085 device_address->set_from_persistent_mem(true);
1086
1087 if (device_address->GetPtr() == nullptr) {
1088 CopyNoneTensorDataToDevice(device_context, device_address);
1089 }
1090 MS_LOG(DEBUG) << "Create input scalar device address " << device_address << " for " << index
1091 << "th input, Shape: " << abstract::kNoShape->ToString() << ", Type: " << type->ToString()
1092 << ", Value: " << (scalar_value ? scalar_value->ToString() : "nullptr") << ", dev ptr "
1093 << device_address->GetPtr();
1094 return device_address;
1095 }
1096
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const std::optional<tensor::BaseTensorPtr> & val)1097 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1098 const abstract::AbstractBasePtr &abs, size_t index,
1099 const std::optional<tensor::BaseTensorPtr> &val) {
1100 if (!val.has_value()) {
1101 return nullptr;
1102 }
1103 return CreateInputAddress(device_context, stream_id, abs, index, val.value());
1104 }
1105
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const StringImmPtr & string_imm)1106 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1107 const abstract::AbstractBasePtr &abs, size_t index,
1108 const StringImmPtr &string_imm) {
1109 MS_EXCEPTION_IF_NULL(device_context);
1110 MS_EXCEPTION_IF_NULL(string_imm);
1111 const auto &type = string_imm->type();
1112 MS_EXCEPTION_IF_NULL(type);
1113 const auto &tensor_value = GetValue<std::string>(string_imm);
1114 // Allocate one more byte to '/0'
1115 size_t size = tensor_value.size() + 1;
1116 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1117 abstract::kNoShape, type, string_imm, nullptr, size, kOpFormat_DEFAULT, kObjectTypeString, ShapeVector(),
1118 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1119 kernel_tensor->set_stream_id(stream_id);
1120 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1121 device_address->set_from_persistent_mem(true);
1122
1123 if (device_address->GetPtr() == nullptr) {
1124 CopyNoneTensorDataToDevice(device_context, device_address);
1125 }
1126 MS_LOG(DEBUG) << "Create input string device address " << device_address << " for " << index
1127 << "th input, Shape: " << abstract::kNoShape->ToString() << ", Type: " << type->ToString()
1128 << ", Value: " << (string_imm ? string_imm->ToString() : "nullptr") << ", dev ptr "
1129 << device_address->GetPtr();
1130 return device_address;
1131 }
1132
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const TypePtr & type_ptr)1133 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1134 const abstract::AbstractBasePtr &abs, size_t index,
1135 const TypePtr &type_ptr) {
1136 MS_EXCEPTION_IF_NULL(device_context);
1137 const auto &type = type_ptr->type();
1138 MS_EXCEPTION_IF_NULL(type);
1139 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1140 abstract::kNoShape, type, nullptr, nullptr, GetTypeByte(TypeIdToType(type->type_id())), kOpFormat_DEFAULT,
1141 type_ptr->type_id(), ShapeVector(), device_context->device_context_key().device_name_,
1142 device_context->device_context_key().device_id_);
1143 kernel_tensor->set_stream_id(stream_id);
1144 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1145 device_address->set_from_persistent_mem(true);
1146
1147 if (device_address->GetPtr() == nullptr) {
1148 CopyNoneTensorDataToDevice(device_context, device_address);
1149 }
1150 MS_LOG(DEBUG) << "Create input " << type_ptr->ToString() << " device address for " << index
1151 << "th input, Shape: " << abstract::kNoShape->ToString() << ", Type: " << type->ToString()
1152 << ", Value: nullptr, device address:" << device_address;
1153 return device_address;
1154 }
1155
CreateOutputTensorAddress(const DeviceContext * device_context,size_t stream_id,const std::vector<tensor::BaseTensorPtr> & outputs)1156 void DeviceAddressUtils::CreateOutputTensorAddress(const DeviceContext *device_context, size_t stream_id,
1157 const std::vector<tensor::BaseTensorPtr> &outputs) {
1158 MS_EXCEPTION_IF_NULL(device_context);
1159 for (size_t i = 0; i < outputs.size(); ++i) {
1160 const auto &tensor = outputs[i];
1161 MS_EXCEPTION_IF_NULL(tensor);
1162 auto tensor_size = LongToSize(tensor->data().nbytes());
1163 const auto &format = GetFormatByTensorShape(device_context, tensor->shape());
1164 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
1165 nullptr, tensor_size, tensor->shape(), format, tensor->data_type(),
1166 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
1167 if (device_address->GetDeviceType() != device::DeviceType::kAscend) {
1168 // CPU or GPU need KernelTensor to LaunchKernel
1169 CreateKernelTensor(device_address, tensor);
1170 }
1171 MS_EXCEPTION_IF_NULL(device_address);
1172 tensor->set_device_address(device_address);
1173 MS_LOG(DEBUG) << "Create output tensor device address " << device_address << " for " << i
1174 << "th output, Shape: " << tensor->shape()
1175 << ", Type: " << TypeIdToType(tensor->data_type())->ToString() << ", Size:" << tensor_size;
1176 }
1177 }
1178
CreateOutputTensorAddress(const DeviceContext * device_context,size_t stream_id,const tensor::BaseTensorPtr & output_tensor,size_t size)1179 void DeviceAddressUtils::CreateOutputTensorAddress(const DeviceContext *device_context, size_t stream_id,
1180 const tensor::BaseTensorPtr &output_tensor, size_t size) {
1181 MS_EXCEPTION_IF_NULL(device_context);
1182 MS_EXCEPTION_IF_NULL(output_tensor);
1183 const auto &format = GetFormatByTensorShape(device_context, output_tensor->shape());
1184 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
1185 nullptr, size, output_tensor->shape(), format, output_tensor->data_type(),
1186 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
1187 if (device_address->GetDeviceType() != device::DeviceType::kAscend) {
1188 // CPU or GPU need KernelTensor to LaunchKernel
1189 CreateKernelTensor(device_address, output_tensor);
1190 }
1191 MS_EXCEPTION_IF_NULL(device_address);
1192 output_tensor->set_device_address(device_address);
1193 MS_LOG(DEBUG) << "Create output tensor device address " << device_address << "the output, Shape: "
1194 << static_cast<int64_t>(size / GetTypeByte(TypeIdToType(output_tensor->data_type())))
1195 << ", Type: " << TypeIdToType(output_tensor->data_type())->ToString() << ", Size:" << size;
1196 }
1197
CreateDeviceAddress(const DeviceContext * device_context,const tensor::BaseTensorPtr & tensor,const ShapeVector & real_shape,const size_t & stream_id)1198 device::DeviceAddressPtr DeviceAddressUtils::CreateDeviceAddress(const DeviceContext *device_context,
1199 const tensor::BaseTensorPtr &tensor,
1200 const ShapeVector &real_shape,
1201 const size_t &stream_id) {
1202 MS_EXCEPTION_IF_NULL(device_context);
1203 MS_EXCEPTION_IF_NULL(tensor);
1204 auto tensor_size = GetTypeByte(TypeIdToType(tensor->data_type())) * SizeOf(real_shape);
1205 const auto &device_format = GetFormatByTensorShape(device_context, tensor->shape());
1206 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1207 nullptr, tensor_size, device_format, tensor->data_type(), real_shape,
1208 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1209 kernel_tensor->set_stream_id(stream_id);
1210 device::DeviceAddressPtr device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1211 MS_LOG(DEBUG) << "Create tensor device address " << device_address << "Shape: " << tensor->shape()
1212 << ", Type: " << TypeIdToType(tensor->data_type())->ToString();
1213 return device_address;
1214 }
1215
MallocForOutputs(const DeviceContext * device_context,const std::vector<tensor::BaseTensorPtr> & outputs)1216 void DeviceAddressUtils::MallocForOutputs(const DeviceContext *device_context,
1217 const std::vector<tensor::BaseTensorPtr> &outputs) {
1218 for (const auto &output : outputs) {
1219 auto device_address = std::static_pointer_cast<device::DeviceAddress>(output->device_address());
1220 if (device_address->GetPtr() != nullptr) {
1221 // ref output
1222 continue;
1223 }
1224 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kPyNativeOutput,
1225 device_address->GetSize(), device_address.get());
1226 if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1227 MS_LOG(EXCEPTION) << "Allocate memory failed";
1228 }
1229 }
1230 }
1231
CreateWorkspaceAddressWithoutKernelTensor(const DeviceContext * device_context,size_t stream_id,const size_t & workspace_size)1232 device::DeviceAddressPtr DeviceAddressUtils::CreateWorkspaceAddressWithoutKernelTensor(
1233 const DeviceContext *device_context, size_t stream_id, const size_t &workspace_size) {
1234 MS_EXCEPTION_IF_NULL(device_context);
1235 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
1236 nullptr, workspace_size, ShapeVector(), Format::DEFAULT_FORMAT, kTypeUnknown,
1237 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
1238 MS_EXCEPTION_IF_NULL(device_address);
1239 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kWorkSpace,
1240 device_address->GetSize(), device_address.get());
1241 if (device_address->GetPtr() == nullptr &&
1242 !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1243 MS_LOG(EXCEPTION) << "Allocate dynamic workspace memory failed";
1244 }
1245 MS_LOG(DEBUG) << "Create workspace device address:" << device_address;
1246 return device_address;
1247 }
1248
CreateWorkspaceAddress(const DeviceContext * device_context,size_t stream_id,const size_t & workspace_size)1249 device::DeviceAddressPtr DeviceAddressUtils::CreateWorkspaceAddress(const DeviceContext *device_context,
1250 size_t stream_id, const size_t &workspace_size) {
1251 MS_EXCEPTION_IF_NULL(device_context);
1252
1253 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1254 nullptr, workspace_size, Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
1255 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1256 kernel_tensor->set_stream_id(stream_id);
1257
1258 auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1259 MS_EXCEPTION_IF_NULL(device_address);
1260 if (device_address->GetPtr() == nullptr &&
1261 !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1262 MS_LOG(EXCEPTION) << "Allocate dynamic workspace memory failed";
1263 }
1264 MS_LOG(DEBUG) << "Create workspace device address:" << device_address;
1265 return device_address;
1266 }
1267
ConvertContiguousTensorSync(const tensor::BaseTensorPtr & tensor)1268 void DeviceAddressUtils::ConvertContiguousTensorSync(const tensor::BaseTensorPtr &tensor) {
1269 if (tensor == nullptr || tensor->storage_info() == nullptr) {
1270 return;
1271 }
1272
1273 MS_LOG(DEBUG) << "Tensor storage_info is not nullptr, need to contiguous, id:" << tensor->id();
1274 const auto &new_device_address = ConvertContiguousDeviceAddress(
1275 nullptr, std::static_pointer_cast<device::DeviceAddress>(tensor->device_address()), true);
1276 MS_EXCEPTION_IF_NULL(new_device_address);
1277 tensor->set_device_address(new_device_address);
1278 }
1279
ConvertContiguousDeviceAddress(const DeviceContext * input_device_context,const device::DeviceAddressPtr & old_device_address,bool is_sync)1280 device::DeviceAddressPtr DeviceAddressUtils::ConvertContiguousDeviceAddress(
1281 const DeviceContext *input_device_context, const device::DeviceAddressPtr &old_device_address, bool is_sync) {
1282 MS_EXCEPTION_IF_NULL(old_device_address);
1283
1284 const DeviceContext *device_context = input_device_context == nullptr
1285 ? runtime::OpRunner::GetDeviceContext(old_device_address->device_name())
1286 : input_device_context;
1287 MS_EXCEPTION_IF_NULL(device_context);
1288 auto stream_id = device_context->device_res_manager_->GetCurrentStreamId();
1289
1290 GilReleaseWithCheck release_gil;
1291 const auto &old_storage_info = old_device_address->GetTensorStorageInfo();
1292 MS_EXCEPTION_IF_NULL(old_storage_info);
1293
1294 auto address_size = GetTypeByte(TypeIdToType(old_device_address->type_id())) * SizeOf(old_storage_info->shape);
1295 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1296 nullptr, address_size, Format::DEFAULT_FORMAT, old_device_address->type_id(), old_storage_info->shape,
1297 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1298 kernel_tensor->SetType(std::make_shared<TensorType>(TypeIdToType(old_device_address->type_id())));
1299 kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(old_storage_info->shape));
1300 kernel_tensor->set_stream_id(stream_id);
1301
1302 auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1303 new_device_address->set_device_shape(old_storage_info->shape);
1304 new_device_address->set_original_ref_count(SIZE_MAX);
1305 new_device_address->ResetRefCount();
1306
1307 if (is_sync) {
1308 // ExecuteKernelTask sync, need to wait until all tasks in queue are complete.
1309 runtime::OpExecutor::GetInstance().WaitAll();
1310 if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(
1311 runtime::KernelTaskType::kCONTIGUOUS_TASK, {old_device_address}, {new_device_address}, stream_id)) {
1312 MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << runtime::KernelTaskType::kCONTIGUOUS_TASK;
1313 }
1314 runtime::OpExecutor::GetInstance().WaitAll();
1315 } else {
1316 auto async_task = [device_context, old_device_address, new_device_address, stream_id]() {
1317 if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(
1318 runtime::KernelTaskType::kCONTIGUOUS_TASK, {old_device_address}, {new_device_address}, stream_id)) {
1319 MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << runtime::KernelTaskType::kCONTIGUOUS_TASK;
1320 }
1321 };
1322
1323 runtime::OpExecutor::GetInstance().PushSimpleOpRunTask(
1324 std::make_shared<runtime::PassthroughDeviceTask>(async_task));
1325 }
1326
1327 return new_device_address;
1328 }
1329
GetCrossStreamAddressInfoFromInput(size_t op_stream_id,std::vector<std::pair<uint32_t,void * >> * cross_stream_addresses,const tensor::BaseTensorPtr & tensor)1330 void DeviceAddressUtils::GetCrossStreamAddressInfoFromInput(
1331 size_t op_stream_id, std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
1332 const tensor::BaseTensorPtr &tensor) {
1333 MS_EXCEPTION_IF_NULL(tensor);
1334 if (tensor->device_address() == nullptr) {
1335 return;
1336 }
1337
1338 auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
1339 MS_EXCEPTION_IF_NULL(device_address);
1340 if (op_stream_id != device_address->stream_id()) {
1341 // Device address is cross stream.
1342 (void)cross_stream_addresses->emplace_back(device_address->stream_id(), device_address->GetMutablePtr());
1343 }
1344 }
1345
GetCrossStreamAddressInfoFromInput(size_t op_stream_id,std::vector<std::pair<uint32_t,void * >> * cross_stream_addresses,const mindspore::kernel::KernelTensor * tensor)1346 void DeviceAddressUtils::GetCrossStreamAddressInfoFromInput(
1347 size_t op_stream_id, std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
1348 const mindspore::kernel::KernelTensor *tensor) {
1349 MS_EXCEPTION_IF_NULL(tensor);
1350 if (op_stream_id != tensor->stream_id()) {
1351 (void)cross_stream_addresses->emplace_back(tensor->stream_id(), tensor->device_ptr());
1352 }
1353 }
1354
GetCrossStreamAddressInfoFromInput(size_t op_stream_id,std::vector<std::pair<uint32_t,void * >> * cross_stream_addresses,const device::DeviceAddressPtr & device_address)1355 void DeviceAddressUtils::GetCrossStreamAddressInfoFromInput(
1356 size_t op_stream_id, std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
1357 const device::DeviceAddressPtr &device_address) {
1358 MS_EXCEPTION_IF_NULL(device_address);
1359 if (op_stream_id != device_address->stream_id()) {
1360 (void)cross_stream_addresses->emplace_back(device_address->stream_id(), device_address->GetMutablePtr());
1361 }
1362 }
1363 } // namespace runtime
1364 } // namespace mindspore
1365