• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/device_address_utils.h"
18 
19 #include <algorithm>
20 #include <string>
21 #include <map>
22 #include <vector>
23 #include <memory>
24 #include "mindspore/core/ops/sequence_ops.h"
25 #include "mindspore/core/ops/op_def.h"
26 #include "ir/tensor.h"
27 #include "include/backend/device_address.h"
28 #include "include/backend/kernel_info.h"
29 #include "include/backend/py_execute_utils.h"
30 #include "runtime/device/hash_table.h"
31 #include "runtime/device/ms_device_shape_transfer.h"
32 #include "runtime/hardware/device_context_manager.h"
33 #include "runtime/pynative/op_runner.h"
34 #include "runtime/pynative/op_executor.h"
35 #include "pybind_api/gil_scoped_long_running.h"
36 #include "include/backend/mem_reuse/mem_tracker.h"
37 #ifdef ENABLE_DEBUGGER
38 #include "include/backend/debug/debugger/debugger.h"
39 #include "include/backend/debug/data_dump/dump_json_parser.h"
40 #include "include/backend/device_type.h"
41 #endif
42 
43 namespace mindspore {
44 using tensor::TensorPtr;
45 namespace runtime {
46 namespace {
CreateDeviceAddressForScalarAndString(const DeviceContext * device_context,const ValueNodePtr & value_node)47 device::DeviceAddressPtr CreateDeviceAddressForScalarAndString(const DeviceContext *device_context,
48                                                                const ValueNodePtr &value_node) {
49   device::DeviceAddressPtr address = nullptr;
50   const auto &node_value = value_node->value();
51   MS_EXCEPTION_IF_NULL(node_value);
52   if (node_value->isa<StringImm>()) {
53     auto value = GetValue<std::string>(node_value);
54     // Allocate one more byte to '/0'
55     size_t tensor_size = value.size() + 1;
56     if (device_context->device_context_key().device_name_ == kAscendDevice) {
57       // size of ge::StringHead which defined in Ascend/latest.aarch64-linux/include/types.h
58       constexpr size_t GE_STRING_HEAD_SIZE = 16;
59       // NOTE: on Ascend, string type need a head of type ge::StringHead
60       tensor_size += GE_STRING_HEAD_SIZE;
61     }
62     const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
63       {value_node, 0}, nullptr, tensor_size, kOpFormat_DEFAULT, kObjectTypeString, ShapeVector(),
64       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
65     kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
66     address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
67   } else if (node_value->isa<Scalar>()) {
68     auto scalar_value = node_value->cast<ScalarPtr>();
69     MS_EXCEPTION_IF_NULL(scalar_value);
70     TypePtr data_type = scalar_value->type();
71     MS_EXCEPTION_IF_NULL(data_type);
72     TypeId type_id = data_type->type_id();
73     const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
74       {value_node, 0}, nullptr, GetTypeByte(TypeIdToType(type_id)), kOpFormat_DEFAULT, type_id, ShapeVector(),
75       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
76     kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
77     address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
78   } else if (node_value->isa<None>()) {
79     const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
80       {value_node, 0}, nullptr, 0, kOpFormat_DEFAULT, kTypeNone->type_id(), ShapeVector(),
81       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
82     kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
83     address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
84   }
85 
86   return address;
87 }
88 
GetFormatByTensorShape(const DeviceContext * device_context,const ShapeVector & tensor_shape)89 Format GetFormatByTensorShape(const DeviceContext *device_context, const ShapeVector &tensor_shape) {
90   if (device_context->device_context_key().device_name_ != kAscendDevice) {
91     return Format::DEFAULT_FORMAT;
92   }
93 
94   switch (tensor_shape.size()) {
95     case kShape4dDims:
96       return Format::NCHW;
97     case kShape5dDims:
98       return Format::NCDHW;
99     default:
100       return Format::ND;
101   }
102 }
103 }  // namespace
104 
NodeDeviceAddressExist(const DeviceContext * device_context,const AnfNodePtr & node,size_t index)105 bool DeviceAddressUtils::NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &node,
106                                                 size_t index) {
107   MS_EXCEPTION_IF_NULL(node);
108   MS_EXCEPTION_IF_NULL(device_context);
109   if (AnfAlgo::OutputAddrExist(node, index)) {
110     const auto address = AnfAlgo::GetMutableOutputAddr(node, index, false);
111     MS_EXCEPTION_IF_NULL(address);
112     CreateKernelTensor(address, session::AnfRuntimeAlgorithm::GetNodeAbstractByIndex(node, index));
113     return address->GetDeviceType() == device_context->GetDeviceType();
114   }
115   return false;
116 }
117 
CopyNoneTensorDataToDevice(const device::DeviceContext * device_context,const device::DeviceAddressPtr & device_address,const ShapeVector & shape)118 void DeviceAddressUtils::CopyNoneTensorDataToDevice(const device::DeviceContext *device_context,
119                                                     const device::DeviceAddressPtr &device_address,
120                                                     const ShapeVector &shape) {
121   MS_EXCEPTION_IF_NULL(device_address);
122   // Break copy data to device address if has the device_address has flag ignore.
123   if (TEST_FLAG(device_address->flag(), device::kDeviceAddressFlagIgnoreDevicePtr)) {
124     MS_LOG(DEBUG) << "Address " << device_address << " has flag ignore device address, so skip copy tensor to device";
125     return;
126   }
127 
128   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kConstantValue,
129                                                  device_address->GetSize(), device_address.get());
130   MS_EXCEPTION_IF_NULL(device_context);
131   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
132   if ((device_address->GetPtr() == nullptr) &&
133       (!device_context->device_res_manager_->AllocateMemory(device_address.get()))) {
134     MS_LOG(EXCEPTION) << "Allocate memory failed";
135   }
136 
137   // Copy data from host to device.
138   const auto &kernel_tensor = device_address->kernel_tensor();
139   MS_EXCEPTION_IF_NULL(kernel_tensor);
140   auto data_size = kernel_tensor->size();
141   if (data_size == 0) {
142     MS_LOG(INFO) << "Constant size is zero.";
143     return;
144   }
145   const void *node_value = kernel_tensor->GetValuePtr();
146   MS_EXCEPTION_IF_NULL(node_value);
147   auto data_type_id = kernel_tensor->dtype_id();
148   auto format = kernel_tensor->GetStringFormat();
149   if (!device_address->SyncHostToDevice(shape, data_size, data_type_id, node_value, format)) {
150     MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
151   }
152 }
153 
CreateDeviceAddressByMapTensorNode(const DeviceContext * device_context,const AnfNodePtr & node,size_t index)154 void DeviceAddressUtils::CreateDeviceAddressByMapTensorNode(const DeviceContext *device_context, const AnfNodePtr &node,
155                                                             size_t index) {
156   MS_EXCEPTION_IF_NULL(node);
157   const auto &abstract_base = AnfAlgo::GetNodeAbstractByIndex(node, index);
158   if (!abstract_base->isa<abstract::AbstractMapTensor>()) {
159     MS_LOG(EXCEPTION) << "Parameter:" << node->DebugString() << " is not a map tensor type.";
160   }
161 
162   const auto &abstract = abstract_base->cast<abstract::AbstractMapTensorPtr>();
163   MS_EXCEPTION_IF_NULL(abstract);
164 
165   // Parse attrs for user data by abstract.
166   const auto &value_shape = abstract->value_shape();
167   MS_EXCEPTION_IF_NULL(value_shape);
168   const auto &shape_vector = value_shape->shape();
169   const auto &map_tensor_type = abstract->map_tensor_type();
170   MS_EXCEPTION_IF_NULL(map_tensor_type);
171   MS_EXCEPTION_IF_NULL(map_tensor_type->key_dtype());
172   MS_EXCEPTION_IF_NULL(map_tensor_type->value_dtype());
173 
174   auto user_data = std::make_shared<UserData>();
175   user_data->set(kUserDataType, std::make_shared<UserDataType>(UserDataType::kUserTypeHashTable));
176   user_data->set(kHashTableKeyType, std::make_shared<TypeId>(map_tensor_type->key_dtype()->type_id()));
177   user_data->set(kHashTableValueType, std::make_shared<TypeId>(map_tensor_type->value_dtype()->type_id()));
178   user_data->set(kHashTableShapeVector, std::make_shared<ShapeVector>(shape_vector));
179   user_data->set(kHashTableDefaultValue, abstract->default_value());
180   user_data->set(kHashTablePermitFilter, abstract->permit_filter_value());
181   user_data->set(kHashTableEvictFilter, abstract->evict_filter_value());
182   // Create device for map tensor node and the ptr size is 1 byte.
183   const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
184     {node, index}, nullptr, 1, kOpFormat_DEFAULT, TypeId::kObjectTypeMapTensorType, ShapeVector(),
185     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, user_data);
186   kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(node));
187   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
188   MS_LOG(DEBUG) << "Create device tensor:" << device_address << " type:" << device_address->type_id();
189   AnfAlgo::SetOutputAddr(device_address, index, node.get());
190 }
191 
CreateParameterDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)192 void DeviceAddressUtils::CreateParameterDeviceAddress(const DeviceContext *device_context,
193                                                       const KernelGraphPtr &graph) {
194   MS_EXCEPTION_IF_NULL(device_context);
195   MS_EXCEPTION_IF_NULL(graph);
196   std::vector<AnfNodePtr> graph_inputs = graph->inputs();
197   const std::vector<bool> &graph_valid_input = graph->valid_inputs();
198   (void)graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
199 
200   // Anf nodes which need create device address.
201   std::vector<AnfNodePtr> nodes_list;
202   for (size_t i = 0; i < graph_inputs.size(); ++i) {
203     AnfNodePtr item = graph_inputs[i];
204     MS_EXCEPTION_IF_NULL(item);
205     if (i < graph_valid_input.size() && !graph_valid_input[i]) {
206       continue;
207     }
208 
209     const auto &real_device_context = device::FetchRealDeviceContext(item, device_context);
210     MS_EXCEPTION_IF_NULL(real_device_context);
211     if (common::AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
212       std::vector<AnfNodePtr> outs = common::AnfAlgo::GetAllOutput(item);
213       for (const auto &out : outs) {
214         MS_EXCEPTION_IF_NULL(out);
215         if (!out->isa<Parameter>() || NodeDeviceAddressExist(real_device_context, out, 0)) {
216           continue;
217         }
218         nodes_list.push_back(out);
219       }
220     }
221     if (!item->isa<Parameter>() || NodeDeviceAddressExist(real_device_context, item, 0)) {
222       continue;
223     }
224     nodes_list.push_back(item);
225   }
226 
227   // Create device address for anf node in nodes_list
228   for (const auto &item : nodes_list) {
229     MS_EXCEPTION_IF_NULL(item);
230     const auto &real_device_context = device::FetchRealDeviceContext(item, device_context);
231     MS_EXCEPTION_IF_NULL(real_device_context);
232     auto output_size = AnfAlgo::GetOutputTensorNum(item);
233     for (size_t index = 0; index < output_size; index++) {
234       const auto &abstract = AnfAlgo::GetNodeAbstractByIndex(item, index);
235       if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
236         CreateDeviceAddressByMapTensorNode(real_device_context, item, index);
237         continue;
238       }
239 
240       TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
241       if (output_type_id == kTypeUnknown) {
242         output_type_id = common::AnfAlgo::GetOutputInferDataType(item, index);
243       }
244 
245       size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
246       const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
247         {item, index}, nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id,
248         trans::GetRuntimePaddingShape(item, index), real_device_context->device_context_key().device_name_,
249         real_device_context->device_context_key().device_id_);
250       MS_EXCEPTION_IF_NULL(kernel_tensor);
251       kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(item));
252       auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
253       MS_EXCEPTION_IF_NULL(device_address);
254       MS_LOG(DEBUG) << "Create device address:" << device_address << " for item:" << item->DebugString();
255       // Set the flag of no user parameter.
256       if (item->isa<Parameter>()) {
257         auto input_param = item->cast<ParameterPtr>();
258         MS_EXCEPTION_IF_NULL(input_param);
259         // Unused address will not alloc memory, which is easy to cause problems for weight node, so skip weight node.
260         if (!common::AnfAlgo::IsParameterWeight(input_param) &&
261             !input_param->IsUsedByRealKernelInGraph(graph->graph_id())) {
262           MS_LOG(INFO) << "Node:" << item->fullname_with_scope() << " debug name:" << item->DebugString()
263                        << " is not used in the graph " << graph->graph_id();
264           device_address->UpdateFlag(device::kDeviceAddressFlagNotUsed);
265         }
266       }
267       device_address->SetNodeIndex(item, index);
268       device_address->set_from_persistent_mem(item->isa<Parameter>());
269       MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(item)
270                     << " addr:" << device_address << " type:" << device_address->type_id();
271       AnfAlgo::SetOutputAddr(device_address, index, item.get());
272     }
273   }
274 }
275 
UpdateDeviceAddressHostInfoByNode(const device::DeviceAddressPtr & addr,const AnfNodePtr & node,size_t output_idx)276 void DeviceAddressUtils::UpdateDeviceAddressHostInfoByNode(const device::DeviceAddressPtr &addr, const AnfNodePtr &node,
277                                                            size_t output_idx) {
278   MS_EXCEPTION_IF_NULL(addr);
279   CreateKernelTensor(addr, session::AnfRuntimeAlgorithm::GetNodeAbstractByIndex(node, output_idx));
280 }
281 
CreateDeviceAddressForTensorValue(const DeviceContext * device_context,const ValuePtr & node_value,size_t output_idx,const ValueNodePtr & value_node)282 device::DeviceAddressPtrList DeviceAddressUtils::CreateDeviceAddressForTensorValue(const DeviceContext *device_context,
283                                                                                    const ValuePtr &node_value,
284                                                                                    size_t output_idx,
285                                                                                    const ValueNodePtr &value_node) {
286   MS_EXCEPTION_IF_NULL(device_context);
287   MS_EXCEPTION_IF_NULL(node_value);
288   MS_EXCEPTION_IF_NULL(value_node);
289   const auto &ms_context = MsContext::GetInstance();
290   MS_EXCEPTION_IF_NULL(ms_context);
291 
292   device::DeviceAddressPtrList address_list;
293   if (node_value->isa<tensor::BaseTensor>()) {
294     auto tensor = node_value->cast<tensor::BaseTensorPtr>();
295     MS_EXCEPTION_IF_NULL(tensor);
296     auto output_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
297     if (output_address != nullptr) {
298       if (output_address->GetDeviceType() == device_context->GetDeviceType()) {
299         // We need to set tensor->device_address to ValueNode even if the tensor is a forward_output tensor
300         // in PyNative Bprop graph. ValueNode device_address is necessary for GraphSchedule::Transform.
301         UpdateDeviceAddressHostInfoByNode(output_address, value_node, output_idx);
302         AnfAlgo::SetOutputAddr(std::static_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
303                                value_node.get());
304         (void)address_list.emplace_back(output_address);
305         return address_list;
306       }
307       tensor->data_sync();
308     }
309   }
310 
311   size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
312   TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
313   if (output_type_id == kTypeUnknown) {
314     output_type_id = common::AnfAlgo::GetOutputInferDataType(value_node, output_idx);
315     if (output_type_id == kTypeUnknown && value_node->value() != nullptr && value_node->value()->isa<ValueTuple>() &&
316         value_node->value()->cast<ValueTuplePtr>()->size() == 0) {
317       MS_LOG(DEBUG) << "Set int64 type for empty value tuple node:" << value_node->DebugString();
318       output_type_id = TypeId::kNumberTypeInt64;
319     }
320   }
321   std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
322 
323   const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
324     {value_node, output_idx}, nullptr, tensor_size, output_format, output_type_id, {},
325     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
326   kernel_tensor->set_host_shape(kernel_tensor->GetShapeVector());
327   kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
328   device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
329   MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node) << " addr:" << address
330                 << " size:" << tensor_size << " format:" << output_format << " type:" << output_type_id
331                 << " shape:" << kernel_tensor->GetShapeVector();
332   MS_EXCEPTION_IF_NULL(address);
333   address->set_from_persistent_mem(true);
334   AnfAlgo::SetOutputAddr(address, output_idx++, value_node.get());
335   (void)address_list.emplace_back(address);
336   return address_list;
337 }
338 
FetchValueNodesNeedDevicePtr(const KernelGraphPtr & graph)339 mindspore::HashSet<mindspore::AnfNodePtr> FetchValueNodesNeedDevicePtr(const KernelGraphPtr &graph) {
340   mindspore::HashSet<mindspore::AnfNodePtr> nodes;
341   auto topo_nodes = TopoSort(graph->get_return());
342   for (auto const &n : topo_nodes) {
343     if (!n->isa<CNode>()) {
344       continue;
345     }
346     auto node = n->cast<CNodePtr>();
347     auto op_name = common::AnfAlgo::GetCNodeName(node);
348     auto input_num = common::AnfAlgo::GetInputTensorNum(node);
349     mindspore::ops::OpDefPtr op_def = mindspore::ops::GetOpDef(op_name);
350     if (op_def == nullptr) {
351       MS_LOG(DEBUG) << op_name << " is not found in OpDef.";
352       for (size_t i = 0; i < input_num; i++) {
353         auto input = common::AnfAlgo::GetInputNode(node, i);
354         (void)nodes.insert(input);
355       }
356       continue;
357     }
358     auto args = op_def->args_;
359     if (input_num != args.size()) {
360       int input_with_init_args = std::count_if(args.begin(), args.end(), [](auto arg) { return arg.as_init_arg_; });
361       size_t total = input_num - IntToSize(input_with_init_args);
362       for (size_t i = 0; i < total; i++) {
363         (void)nodes.insert(common::AnfAlgo::GetInputNode(node, i));
364       }
365       MS_LOG(DEBUG) << "Node " << op_name << ", has " << input_num << " inputs, but has " << args.size()
366                     << " inputs in op_def, it means allsame input, input with init args number: "
367                     << input_with_init_args;
368       continue;
369     }
370     for (size_t i = 0; i < input_num; i++) {
371       if (args[i].as_init_arg_ == 0) {
372         auto input = common::AnfAlgo::GetInputNode(node, i);
373         (void)nodes.insert(input);
374       }
375     }
376   }
377   return nodes;
378 }
379 
CreateDeviceAddressForTypeValue(const DeviceContext * device_context,const ValueNodePtr & value_node)380 device::DeviceAddressPtr CreateDeviceAddressForTypeValue(const DeviceContext *device_context,
381                                                          const ValueNodePtr &value_node) {
382   MS_EXCEPTION_IF_NULL(device_context);
383   MS_EXCEPTION_IF_NULL(value_node);
384   const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
385     {value_node, 0}, nullptr, 0, kOpFormat_DEFAULT, kMetaTypeTypeType, {},
386     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
387   kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(value_node));
388   device::DeviceAddressPtr address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
389   MS_LOG(DEBUG) << "Create addr for node:" << value_node->DebugString() << " addr:" << address;
390   MS_EXCEPTION_IF_NULL(address);
391   address->set_from_persistent_mem(true);
392   AnfAlgo::SetOutputAddr(address, 0, value_node.get());
393   return address;
394 }
395 
CreateValueNodeDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)396 void DeviceAddressUtils::CreateValueNodeDeviceAddress(const DeviceContext *device_context,
397                                                       const KernelGraphPtr &graph) {
398   MS_EXCEPTION_IF_NULL(device_context);
399   MS_EXCEPTION_IF_NULL(graph);
400 #ifdef ENABLE_DEBUGGER
401   auto debugger = Debugger::GetInstance();
402   auto &dump_json_parser = DumpJsonParser::GetInstance();
403   bool enable_debug = debugger->debugger_enabled() || dump_json_parser.InputNeedDump();
404 #endif
405   // store node without init args, means need device addr
406   auto value_nodes_without_init_args = FetchValueNodesNeedDevicePtr(graph);
407   for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
408     MS_EXCEPTION_IF_NULL(value_node);
409     if (NodeDeviceAddressExist(device_context, value_node, 0)) {
410       continue;
411     }
412 
413     const auto &abstract = value_node->abstract();
414     if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
415       CreateDeviceAddressByMapTensorNode(device_context, value_node, 0);
416       continue;
417     }
418     const auto &node_value = value_node->value();
419     MS_EXCEPTION_IF_NULL(node_value);
420     if (node_value->isa<tensor::BaseTensor>() || node_value->isa<ValueSequence>()) {
421       auto address_list = CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
422       // Deal with tensor and tuple
423       if (value_nodes_without_init_args.find(value_node) == value_nodes_without_init_args.end()) {
424         for (const auto &address : address_list) {
425 #ifdef ENABLE_DEBUGGER
426           if (enable_debug) {
427             continue;
428           }
429 #endif
430           address->UpdateFlag(device::kDeviceAddressFlagIgnoreDevicePtr);
431           MS_LOG(DEBUG) << "Find node " << value_node->DebugString() << " has init args";
432         }
433       }
434       continue;
435     } else if (node_value->isa<Type>()) {
436       CreateDeviceAddressForTypeValue(device_context, value_node);
437       continue;
438     }
439 
440     device::DeviceAddressPtr address = CreateDeviceAddressForScalarAndString(device_context, value_node);
441     // Deal with string and scalar; Address will be nullptr if the input is a type.
442     if (address && (value_nodes_without_init_args.find(value_node) == value_nodes_without_init_args.end())) {
443       address->UpdateFlag(device::kDeviceAddressFlagIgnoreDevicePtr);
444       MS_LOG(DEBUG) << "Find node " << value_node->DebugString() << " has init args";
445 #ifdef ENABLE_DEBUGGER
446       if (enable_debug) {
447         address->ClearFlag(device::kDeviceAddressFlagIgnoreDevicePtr);
448       }
449 #endif
450     }
451     if (address != nullptr) {
452       MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(value_node)
453                     << " addr:" << address;
454       address->set_from_persistent_mem(true);
455       AnfAlgo::SetOutputAddr(address, 0, value_node.get());
456     } else {
457       MS_LOG(INFO) << "No device address for value node:" << value_node->fullname_with_scope()
458                    << ", debug name:" << common::AnfAlgo::GetNodeDebugString(value_node);
459     }
460   }
461 }
462 
CreateKernelOutputDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph,bool is_gradient_out)463 void DeviceAddressUtils::CreateKernelOutputDeviceAddress(const DeviceContext *device_context,
464                                                          const KernelGraphPtr &graph, bool is_gradient_out) {
465   MS_EXCEPTION_IF_NULL(device_context);
466   MS_EXCEPTION_IF_NULL(graph);
467 
468   if (graph->memory_managed_by_ge()) {
469     return;
470   }
471   MS_LOG(DEBUG) << "Start create kernel output device address for graph:" << graph->ToString();
472   bool is_pynative_bprop_graph = graph->has_flag(kFlagIsPynativeBpropGraph);
473   auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
474 
475   const std::vector<CNodePtr> &kernels = graph->execution_order();
476   for (const auto &kernel : kernels) {
477     MS_EXCEPTION_IF_NULL(kernel);
478     if (common::AnfAlgo::IsBpropCutOpExecInBackend(kernel)) {
479       continue;
480     }
481 
482     bool is_from_persistent_mem =
483       (is_gradient_out || (is_pynative_bprop_graph && (find(outputs.begin(), outputs.end(), kernel) != outputs.end())));
484 
485     auto output_size = AnfAlgo::GetOutputAddressNum(kernel);
486     for (size_t i = 0; i < output_size; ++i) {
487       if (AnfAlgo::OutputAddrExist(kernel, i)) {
488         continue;
489       }
490 
491       const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_context);
492       MS_EXCEPTION_IF_NULL(real_device_context);
493       const auto &abstract = AnfAlgo::GetNodeAbstractByIndex(kernel, i);
494       if (abstract != nullptr && abstract->isa<abstract::AbstractMapTensor>()) {
495         CreateDeviceAddressByMapTensorNode(real_device_context, kernel, i);
496         continue;
497       }
498       auto output_format = AnfAlgo::GetOutputFormat(kernel, i);
499       auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
500       auto address_size = AnfAlgo::GetOutputTensorMemSize(kernel, i);
501       UserDataPtr user_data = nullptr;
502       auto kernel_info = dynamic_cast<device::KernelInfo *>(kernel->kernel_info());
503       MS_EXCEPTION_IF_NULL(kernel_info);
504       if (kernel_info->kernel_mod() != nullptr && kernel_info->kernel_mod()->need_user_data()) {
505         user_data = std::make_shared<UserData>();
506         user_data->set(kSyncUserDataHandler,
507                        std::make_shared<device::DeviceAddress::SyncUserDataHandler>(pyexecute::UserDataToRawMemory));
508         graph->set_has_kernel_need_user_data(true);
509       }
510       const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
511         {kernel, i}, nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(kernel, i),
512         real_device_context->device_context_key().device_name_, real_device_context->device_context_key().device_id_,
513         user_data);
514       kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(kernel));
515       MS_LOG(DEBUG) << "Kernel tensor created without set stream id, but set after device address created.";
516       auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
517       device_address->SetNodeIndex(kernel, i);
518       if (is_from_persistent_mem) {
519         device_address->set_from_persistent_mem(true);
520       }
521       MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
522                     << " addr:" << device_address << " type:" << device_address->type_id()
523                     << ", kernel tensor addr:" << kernel_tensor.get()
524                     << ", kernel tensor: " << kernel_tensor->ToString() << " addr size:" << address_size
525                     << " real size:" << device_address->GetSize()
526                     << " origin ref count:" << device_address->original_ref_count();
527       device_address->set_stream_id(AnfAlgo::GetStreamId(kernel));
528       AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
529     }
530   }
531   MS_LOG(DEBUG) << "End create kernel output device address for graph:" << graph->ToString();
532 }
533 
CreateGraphOutputDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)534 void DeviceAddressUtils::CreateGraphOutputDeviceAddress(const DeviceContext *device_context,
535                                                         const KernelGraphPtr &graph) {
536   MS_EXCEPTION_IF_NULL(device_context);
537   MS_EXCEPTION_IF_NULL(graph);
538   auto output_with_indexs = common::AnfAlgo::GetAllOutputWithIndex(graph->output());
539   for (const auto &output_with_index : output_with_indexs) {
540     const auto &output = output_with_index.first;
541     MS_EXCEPTION_IF_NULL(output);
542     if (common::AnfAlgo::IsBpropCutOpExecInBackend(output) || HasAbstractMonad(output)) {
543       continue;
544     }
545     auto output_size = AnfAlgo::GetOutputAddressNum(output);
546     for (size_t i = 0; i < output_size; ++i) {
547       if (AnfAlgo::OutputAddrExist(output, i)) {
548         continue;
549       }
550 
551       const auto &real_device_context = device::FetchRealDeviceContext(output, device_context);
552       MS_EXCEPTION_IF_NULL(real_device_context);
553       MS_EXCEPTION_IF_NULL(real_device_context->device_res_manager_);
554       auto output_format = AnfAlgo::GetOutputFormat(output, i);
555       auto output_type = AnfAlgo::GetOutputDeviceDataType(output, i);
556       auto address_size = AnfAlgo::GetOutputTensorMemSize(output, i);
557       const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
558         {output, i}, nullptr, address_size, output_format, output_type, trans::GetRuntimePaddingShape(output, i),
559         real_device_context->device_context_key().device_name_, real_device_context->device_context_key().device_id_);
560       kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(output));
561       auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
562       MS_LOG(DEBUG) << "Create addr for node:" << output->DebugString() << " addr:" << device_address
563                     << " type:" << device_address->type_id();
564       AnfAlgo::SetOutputAddr(device_address, i, output.get());
565     }
566   }
567 }
568 
GetTensorDeviceSize(const DeviceContext * device_context,const AnfNodePtr & node,const ShapeVector & shape,const string & format,TypeId dtype,size_t output_index)569 size_t DeviceAddressUtils::GetTensorDeviceSize(const DeviceContext *device_context, const AnfNodePtr &node,
570                                                const ShapeVector &shape, const string &format, TypeId dtype,
571                                                size_t output_index) {
572   MS_EXCEPTION_IF_NULL(device_context);
573   auto device_shape = shape;
574   if (device_context->GetDeviceType() == device::DeviceType::kAscend) {
575     if (device_shape.empty() && format != kOpFormat_DEFAULT) {
576       device_shape = trans::PaddingShape(device_shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
577       device_shape = trans::TransShapeToDevice(device_shape, format, node, output_index, dtype);
578     } else {
579       if (trans::IsNeedPadding(format, device_shape)) {
580         device_shape =
581           trans::PaddingShape(device_shape, format, AnfAlgo::GetOutputReshapeType(node, output_index), node);
582       }
583       device_shape = trans::TransShapeToDevice(device_shape, format, node, output_index, dtype);
584     }
585   }
586   size_t type_size = GetTypeByte(TypeIdToType(dtype));
587   size_t tensor_size = type_size * SizeOf(device_shape);
588   return tensor_size;
589 }
590 
CreateGraphOutputDeviceAddress(const OpCompilerInfoPtr & op_compiler_info,const abstract::AbstractBasePtr & out_abstract,size_t stream_id)591 vector<device::DeviceAddressPtr> DeviceAddressUtils::CreateGraphOutputDeviceAddress(
592   const OpCompilerInfoPtr &op_compiler_info, const abstract::AbstractBasePtr &out_abstract, size_t stream_id) {
593   auto device_context = op_compiler_info->device_context_;
594   const auto &output_edges = op_compiler_info->simple_graph_->outputs_;
595   size_t output_num = output_edges.size();
596 
597   std::vector<device::DeviceAddressPtr> output_address_list;
598   output_address_list.reserve(output_num);
599 
600   for (size_t i = 0; i < output_num; ++i) {
601     const auto &edge = output_edges[i];
602     const auto &address = edge->address_;
603     if (address != nullptr) {
604       MS_LOG(DEBUG) << "Already have output device address for ref output";
605       output_address_list.push_back(address);
606       continue;
607     }
608 
609     const auto &[output_node, index] = edge->node_with_index_;
610     const auto &cache_output_address = edge->origin_address_;
611 
612     auto real_abstract = out_abstract;
613     if (out_abstract->isa<abstract::AbstractTuple>()) {
614       auto abstract_tuple = out_abstract->cast<abstract::AbstractTuplePtr>();
615       if (i >= abstract_tuple->elements().size()) {
616         MS_LOG(EXCEPTION) << "abstract_tuple size is " << abstract_tuple->elements().size() << " ,but get index is"
617                           << i;
618       }
619       real_abstract = abstract_tuple->elements()[i];
620     }
621     auto output_shape_ptr = real_abstract->BuildShape();
622     MS_EXCEPTION_IF_NULL(output_shape_ptr);
623     auto shape_vector = output_shape_ptr->cast<abstract::ShapePtr>();
624     MS_EXCEPTION_IF_NULL(shape_vector);
625     const auto &shape = shape_vector->shape();
626     auto output_type = cache_output_address->type_id();
627     const auto &output_format = cache_output_address->format();
628     auto address_size = GetTensorDeviceSize(device_context, output_node, shape, output_format, output_type, index);
629     const auto &kernel_tensor = std::make_shared<kernel::KernelTensor>(
630       real_abstract->GetShape()->Clone(), real_abstract->GetType()->Clone(), real_abstract->GetValue(), nullptr,
631       address_size, output_format, output_type, shape, device_context->device_context_key().device_name_,
632       device_context->device_context_key().device_id_, cache_output_address->user_data());
633     kernel_tensor->set_stream_id(stream_id);
634     auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
635     MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(output_node)
636                   << " addr:" << device_address;
637     output_address_list.push_back(device_address);
638     edge->address_ = device_address;
639   }
640   return output_address_list;
641 }
642 
CreateKernelWorkspaceDeviceAddress(const DeviceContext * device_context,const KernelGraphPtr & graph)643 void DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context,
644                                                             const KernelGraphPtr &graph) {
645   MS_EXCEPTION_IF_NULL(device_context);
646   MS_EXCEPTION_IF_NULL(graph);
647 
648   if (graph->memory_managed_by_ge()) {
649     return;
650   }
651 
652   const std::vector<CNodePtr> &kernels = graph->execution_order();
653   for (const auto &kernel : kernels) {
654     MS_EXCEPTION_IF_NULL(kernel);
655     if (common::AnfAlgo::IsBpropCutOpExecInBackend(kernel)) {
656       continue;
657     }
658     const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_context);
659     MS_EXCEPTION_IF_NULL(real_device_context);
660     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
661     MS_EXCEPTION_IF_NULL(kernel_mod);
662     auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
663     for (size_t i = 0; i < workspace_sizes.size(); ++i) {
664       if (AnfAlgo::WorkspaceAddrExist(kernel, i)) {
665         break;
666       }
667       auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
668         nullptr, workspace_sizes[i], Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
669         real_device_context->device_context_key().device_name_, real_device_context->device_context_key().device_id_);
670       kernel_tensor->set_stream_id(AnfAlgo::GetStreamId(kernel));
671       auto device_address = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
672       MS_LOG(DEBUG) << "Create addr for node:" << common::AnfAlgo::GetNodeDebugString(kernel)
673                     << " addr:" << device_address;
674       AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
675     }
676   }
677 }
678 
UpdateDeviceAddressForInplaceNode(const KernelGraphPtr & graph)679 void DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(const KernelGraphPtr &graph) {
680   MS_EXCEPTION_IF_NULL(graph);
681 
682   if (graph->memory_managed_by_ge()) {
683     return;
684   }
685 
686   // Collect the inplace groups.
687   std::map<uint32_t, std::vector<CNodePtr>> inplace_groups;
688   const std::vector<CNodePtr> &kernels = graph->execution_order();
689   for (const auto &kernel : kernels) {
690     if (!common::AnfAlgo::IsInplaceNode(kernel, "inplace_algo")) {
691       continue;
692     }
693     auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel);
694     MS_EXCEPTION_IF_NULL(primitive);
695     auto inplace_group_attr = primitive->GetAttr("inplace_group");
696     MS_EXCEPTION_IF_NULL(inplace_group_attr);
697     auto group_id = GetValue<uint32_t>(inplace_group_attr);
698     (void)inplace_groups[group_id].emplace_back(kernel);
699   }
700 
701   constexpr size_t kMinInplaceGroupSize = 2;
702   for (const auto &inplace_group : inplace_groups) {
703     auto &group_nodes = inplace_group.second;
704     if (group_nodes.size() < kMinInplaceGroupSize) {
705       continue;
706     }
707     // Get the device address of the first node in the inplace group.
708     auto node_primitive = common::AnfAlgo::GetCNodePrimitive(group_nodes[0]);
709     MS_EXCEPTION_IF_NULL(node_primitive);
710     auto output_index = GetValue<uint32_t>(node_primitive->GetAttr("inplace_output_index"));
711     auto device_address = AnfAlgo::GetMutableOutputAddr(group_nodes[0], output_index, false);
712     MS_EXCEPTION_IF_NULL(device_address);
713 
714     // Update the device address of other nodes using device address of the first node in the inplace group.
715     for (size_t i = 1; i < group_nodes.size(); ++i) {
716       auto &group_node = group_nodes[i];
717       auto prim = common::AnfAlgo::GetCNodePrimitive(group_node);
718       MS_EXCEPTION_IF_NULL(prim);
719       auto index = GetValue<uint32_t>(prim->GetAttr("inplace_output_index"));
720       auto group_node_device_address = AnfAlgo::GetMutableOutputAddr(group_node, index, false);
721       MS_EXCEPTION_IF_NULL(group_node_device_address);
722       // Update the reference count of device address.
723       device_address->IncreaseOriginalRefCount();
724       MS_LOG(DEBUG) << "After increase ref count for device address:" << device_address
725                     << " ref count:" << device_address->original_ref_count();
726       device_address->ResetRefCount();
727       group_node_device_address->set_pointer_ref_count(device_address->pointer_ref_count());
728     }
729   }
730 }
731 
UpdateDeviceAddress(const session::AnfWithOutIndex & cur_pair,const session::AnfWithOutIndex & origin_pair)732 void DeviceAddressUtils::UpdateDeviceAddress(const session::AnfWithOutIndex &cur_pair,
733                                              const session::AnfWithOutIndex &origin_pair) {
734   MS_EXCEPTION_IF_NULL(cur_pair.first);
735   MS_EXCEPTION_IF_NULL(origin_pair.first);
736   MS_LOG(INFO) << "Ref node pair: origin kernel is " << origin_pair.first->fullname_with_scope() << ", index is "
737                << origin_pair.second << "; cur kernel is " << cur_pair.first->fullname_with_scope() << ", index is "
738                << cur_pair.second;
739   // If the output of ref node is parameter, need add the monad attr(for example Transdata/Cast node to ref
740   // parameter).
741   if (!common::AnfAlgo::HasMonadInput(cur_pair.first) && origin_pair.first->isa<Parameter>()) {
742     MS_LOG(INFO) << cur_pair.first->fullname_with_scope() << "with index " << cur_pair.second
743                  << " ref node to parameter " << origin_pair.first->fullname_with_scope() << " and add the monad attr.";
744     common::AnfAlgo::SetNodeAttr(kAttrRefNodeMonadOutputIdx, MakeValue(cur_pair.second), cur_pair.first);
745   }
746 
747   auto origin_node_output_addr = AnfAlgo::GetMutableOutputAddr(origin_pair.first, origin_pair.second, false);
748   MS_EXCEPTION_IF_NULL(origin_node_output_addr);
749   auto cur_node_output_addr = AnfAlgo::GetMutableOutputAddr(cur_pair.first, cur_pair.second, false);
750   MS_EXCEPTION_IF_NULL(cur_node_output_addr);
751   auto origin_stream_id = origin_node_output_addr->stream_id();
752   auto cur_stream_id = cur_node_output_addr->stream_id();
753   if (origin_stream_id != cur_stream_id) {
754     MS_LOG(DEBUG) << "Origin node output addr : " << origin_node_output_addr << " stream id : " << origin_stream_id
755                   << " is not equal to cur node output addr stream id : " << cur_stream_id << ".";
756   }
757 
758   // Update the device address flag.
759   origin_node_output_addr->UpdateFlag(device::kDeviceAddressFlagRefNode);
760 
761   if (origin_node_output_addr->pointer_ref_count() != cur_node_output_addr->pointer_ref_count()) {
762     // Check the device target whether consistent.
763     if (origin_node_output_addr->GetDeviceType() != cur_node_output_addr->GetDeviceType()) {
764       std::string error_info =
765         "Device target is not consistent: ref origin kernel is " + origin_pair.first->fullname_with_scope() +
766         ", index is " + std::to_string(origin_pair.second) + ", device target is " +
767         device::GetDeviceNameByType(origin_node_output_addr->GetDeviceType()) + "; cur kernel is " +
768         cur_pair.first->fullname_with_scope() + ", index is " + std::to_string(cur_pair.second) +
769         ", device target is " + device::GetDeviceNameByType(cur_node_output_addr->GetDeviceType());
770 
771       MS_LOG(ERROR) << error_info;
772       if (AnfAlgo::IsKernelSelectBackoffOp(origin_pair.first)) {
773         const auto &backoff_info = AnfAlgo::GetKernelSelectBackoffInfo(origin_pair.first);
774         MS_EXCEPTION(backoff_info.second) << "#umsg#Kernel select failed:#umsg#" << backoff_info.second;
775       } else if (AnfAlgo::IsKernelSelectBackoffOp(cur_pair.first)) {
776         const auto &backoff_info = AnfAlgo::GetKernelSelectBackoffInfo(cur_pair.first);
777         MS_EXCEPTION(backoff_info.second) << "#umsg#Kernel select failed:#umsg#" << backoff_info.second;
778       } else {
779         MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#" << error_info;
780       }
781     }
782     MS_LOG(INFO) << "Update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
783                  << ", index is " << origin_pair.second << "; cur kernel is " << cur_pair.first->fullname_with_scope()
784                  << ", index is " << cur_pair.second;
785     // Update the reference count of device address.
786     cur_node_output_addr->DecreaseOriginalRefCount();
787     cur_node_output_addr->ResetRefCount();
788     origin_node_output_addr->IncreaseOriginalRefCount();
789     MS_LOG(DEBUG) << "After increase ref count for device address:" << origin_node_output_addr
790                   << " ref count:" << origin_node_output_addr->original_ref_count();
791     origin_node_output_addr->ResetRefCount();
792     cur_node_output_addr->set_pointer_ref_count(origin_node_output_addr->pointer_ref_count());
793     cur_node_output_addr->UpdateFlag(device::kDeviceAddressFlagRefNode);
794   } else {
795     MS_LOG(DEBUG) << "No need update device address: ref origin kernel is " << origin_pair.first->fullname_with_scope()
796                   << ", index is " << origin_pair.second << "; cur kernel is " << cur_pair.first->fullname_with_scope()
797                   << ", index is " << cur_pair.second;
798   }
799 }
800 
UpdateDeviceAddressForRefNode(const KernelGraphPtr & graph)801 void DeviceAddressUtils::UpdateDeviceAddressForRefNode(const KernelGraphPtr &graph) {
802   MS_EXCEPTION_IF_NULL(graph);
803 
804   if (graph->memory_managed_by_ge()) {
805     return;
806   }
807 
808   AnfAlgo::UpdateGraphValidRefPair(graph);
809   for (const auto &ref_pair : graph->GetRefMap()) {
810     const auto &out_pair = ref_pair.first;
811     const auto &origin_pair = ref_pair.second;
812     const auto &recursive_origin_pair = graph->GetRefNodeRecursive(out_pair);
813     UpdateDeviceAddress(out_pair, recursive_origin_pair);
814     // Update ref map in kernel info which will be used in kernel actor on swap scenario.
815     for (size_t input_index = 0; input_index < common::AnfAlgo::GetInputTensorNum(out_pair.first); ++input_index) {
816       const auto &prev_node_output = common::AnfAlgo::GetPrevNodeOutput(out_pair.first, input_index, false);
817       if (prev_node_output == origin_pair) {
818         auto kernel_info = dynamic_cast<device::KernelInfo *>(out_pair.first->kernel_info());
819         MS_EXCEPTION_IF_NULL(kernel_info);
820         kernel_info->AddRefMap(out_pair.second, input_index);
821         break;
822       }
823     }
824   }
825 }
826 
CloneEmptyDeviceAddress(const device::DeviceAddressPtr & old_device_address,const DeviceContext * device_context)827 device::DeviceAddressPtr DeviceAddressUtils::CloneEmptyDeviceAddress(const device::DeviceAddressPtr &old_device_address,
828                                                                      const DeviceContext *device_context) {
829   MS_EXCEPTION_IF_NULL(old_device_address);
830   MS_EXCEPTION_IF_NULL(device_context);
831   const auto &kernel_tensor = old_device_address->kernel_tensor();
832   MS_EXCEPTION_IF_NULL(kernel_tensor);
833   auto new_kernel_tensor = kernel_tensor->CloneKernelTensor();
834   MS_EXCEPTION_IF_NULL(new_kernel_tensor);
835 
836   new_kernel_tensor->set_device_name(device_context->device_context_key().device_name_);
837   new_kernel_tensor->set_device_id(device_context->device_context_key().device_id_);
838   new_kernel_tensor->set_device_ptr(nullptr);
839   auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(new_kernel_tensor);
840   MS_EXCEPTION_IF_NULL(new_device_address);
841   MS_LOG(DEBUG) << "Create device tensor:" << new_device_address << " type:" << new_device_address->type_id();
842 
843   new_device_address->set_original_ref_count(old_device_address->original_ref_count());
844   new_device_address->ResetRefCount();
845   auto node = old_device_address->GetNodeIndex();
846   new_device_address->SetNodeIndex(node.first, node.second);
847   new_device_address->set_padding_type(old_device_address->padding_type());
848   return new_device_address;
849 }
850 
CreateInputTensorAddress(const DeviceContext * device_context,size_t stream_id,size_t index,const tensor::BaseTensorPtr & tensor)851 void DeviceAddressUtils::CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index,
852                                                   const tensor::BaseTensorPtr &tensor) {
853   MS_EXCEPTION_IF_NULL(device_context);
854   MS_EXCEPTION_IF_NULL(tensor);
855 
856   auto addr = tensor->device_address();
857   if (addr != nullptr) {
858     auto device_address = std::static_pointer_cast<device::DeviceAddress>(addr);
859     if (device_address->GetDeviceType() != device::DeviceType::kAscend) {
860       // CPU or GPU View CreateDeviceAddress without KernelTensor
861       CreateKernelTensor(device_address, tensor);
862     }
863     if (device_address->GetDeviceType() == device_context->GetDeviceType()) {
864       MS_LOG(DEBUG) << "Already have device address of tensor " << tensor->id();
865       return;
866     }
867     MS_LOG(DEBUG) << "Input tensor device type is " << device_address->GetDeviceType()
868                   << " but current device context is " << device_context->GetDeviceType();
869     tensor->data_sync();
870     tensor->set_device_address(nullptr);
871   }
872   auto tensor_size = LongToSize(tensor->data().nbytes());
873   const auto &format = GetFormatByTensorShape(device_context, tensor->shape());
874   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
875     nullptr, tensor_size, tensor->shape(), format, tensor->data_type(),
876     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
877   if (device_address->GetDeviceType() != device::DeviceType::kAscend) {
878     // CPU or GPU need KernelTensor to LaunchKernel
879     CreateKernelTensor(device_address, tensor);
880   }
881 
882   MS_EXCEPTION_IF_NULL(device_address);
883   device_address->set_from_persistent_mem(tensor->is_parameter());
884   tensor->set_device_address(device_address);
885   MS_LOG(DEBUG) << "Create input tensor device address " << device_address << " for " << index
886                 << "th input, Shape: " << tensor->shape() << ", Type: " << TypeIdToType(tensor->data_type())->ToString()
887                 << ", Size:" << tensor_size;
888 }
889 
MallocForInput(const DeviceContext * device_context,const tensor::BaseTensorPtr & tensor,bool is_view)890 void DeviceAddressUtils::MallocForInput(const DeviceContext *device_context, const tensor::BaseTensorPtr &tensor,
891                                         bool is_view) {
892   MS_EXCEPTION_IF_NULL(tensor);
893   const auto &device_sync = tensor->device_address();
894   auto device_address = std::static_pointer_cast<device::DeviceAddress>(device_sync);
895   MS_EXCEPTION_IF_NULL(device_address);
896   device_address->set_is_view(is_view);
897 
898   if (device::tracker::MemTrackerManager::GetInstance().IsEnabled()) {
899     auto mem_type =
900       tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
901     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
902                                                    device_address.get());
903   }
904   if (device_address->GetMutablePtr() != nullptr) {
905     if (!is_view || device_address->GetDeviceType() != device::DeviceType::kCPU || device_address->from_mem_pool()) {
906       return;
907     }
908     // If not from the pool, the lifetime of the device ptr is guaranteed elsewhere.
909     // Before applying for a new address, clear the address. Otherwise a warnging is generated.
910     device_address->set_ptr(nullptr);
911     const auto new_device_context = device_context->GetDeviceType() == device_address->GetDeviceType()
912                                       ? device_context
913                                       : runtime::OpRunner::GetDeviceContext(kCPUDevice);
914 
915     MS_EXCEPTION_IF_NULL(new_device_context);
916     if (!new_device_context->device_res_manager_->AllocateMemory(device_address.get())) {
917       MS_LOG(EXCEPTION) << "Allocate memory failed";
918     }
919   } else {
920     if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
921       MS_LOG(EXCEPTION) << "Allocate memory failed";
922     }
923   }
924 
925   auto tensor_size = LongToSize(tensor->data().nbytes());
926   if (device_address->GetDeviceType() == device::DeviceType::kAscend) {
927     OpExecutor::DispatchLaunchTask([=]() {
928       if (!device_address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), device_address->format(),
929                                             tensor->data_ptr())) {
930         MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
931       }
932     });
933   } else {
934     if (!device_address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), device_address->format(),
935                                           tensor->data_ptr())) {
936       MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
937     }
938   }
939 }
940 
MallocForInput(const DeviceContext * device_context,const std::vector<tensor::BaseTensorPtr> & tensors,bool is_view)941 void DeviceAddressUtils::MallocForInput(const DeviceContext *device_context,
942                                         const std::vector<tensor::BaseTensorPtr> &tensors, bool is_view) {
943   for (const auto &tensor : tensors) {
944     MallocForInput(device_context, tensor, is_view);
945   }
946 }
947 
MallocForInput(const DeviceContext * device_context,const std::optional<tensor::BaseTensorPtr> & val,bool is_view)948 void DeviceAddressUtils::MallocForInput(const DeviceContext *device_context,
949                                         const std::optional<tensor::BaseTensorPtr> &val, bool is_view) {
950   if (!val.has_value()) {
951     return;
952   }
953   MallocForInput(device_context, val.value(), is_view);
954 }
955 
CreateInputTensorAddress(const DeviceContext * device_context,size_t stream_id,size_t index,const std::optional<tensor::BaseTensorPtr> & val)956 void DeviceAddressUtils::CreateInputTensorAddress(const DeviceContext *device_context, size_t stream_id, size_t index,
957                                                   const std::optional<tensor::BaseTensorPtr> &val) {
958   if (!val.has_value()) {
959     return;
960   }
961   CreateInputTensorAddress(device_context, stream_id, index, val.value());
962 }
963 
CreateKernelTensor(const device::DeviceAddressPtr & device_address,const tensor::BaseTensorPtr & tensor)964 void DeviceAddressUtils::CreateKernelTensor(const device::DeviceAddressPtr &device_address,
965                                             const tensor::BaseTensorPtr &tensor) {
966   MS_EXCEPTION_IF_NULL(device_address);
967   MS_EXCEPTION_IF_NULL(tensor);
968   if (device_address->kernel_tensor() != nullptr) {
969     return;
970   }
971   const auto &address_common = device_address->address_common();
972   MS_EXCEPTION_IF_NULL(address_common);
973   auto real_kernel_tensor = std::make_shared<kernel::KernelTensor>(
974     address_common, std::make_shared<abstract::TensorShape>(tensor->shape()),
975     std::make_shared<TensorType>(TypeIdToType(tensor->data_type())), nullptr, tensor->shape());
976   device_address->set_kernel_tensor(real_kernel_tensor);
977   device_address->DeviceSynchronizerInit();
978 }
979 
CreateKernelTensor(const ValuePtr & input_value)980 void DeviceAddressUtils::CreateKernelTensor(const ValuePtr &input_value) {
981   MS_EXCEPTION_IF_NULL(input_value);
982   if (input_value->isa<tensor::BaseTensor>()) {
983     auto tensor = input_value->cast<tensor::BaseTensorPtr>();
984     if (tensor->device_address() != nullptr) {
985       auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
986       MS_EXCEPTION_IF_NULL(device_address);
987       CreateKernelTensor(device_address, tensor);
988     }
989   }
990 }
991 
CreateKernelTensor(const tensor::TensorPtr & input_tensor)992 void DeviceAddressUtils::CreateKernelTensor(const tensor::TensorPtr &input_tensor) {
993   MS_EXCEPTION_IF_NULL(input_tensor);
994   if (input_tensor->device_address() != nullptr) {
995     auto device_address = std::static_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
996     MS_EXCEPTION_IF_NULL(device_address);
997     CreateKernelTensor(device_address, input_tensor);
998   }
999 }
1000 
CreateKernelTensor(const device::DeviceAddressPtr & device_address,const AbstractBasePtr & abs)1001 void DeviceAddressUtils::CreateKernelTensor(const device::DeviceAddressPtr &device_address,
1002                                             const AbstractBasePtr &abs) {
1003   MS_EXCEPTION_IF_NULL(device_address);
1004   if (device_address->kernel_tensor() != nullptr) {
1005     return;
1006   }
1007   const auto address_common = device_address->address_common();
1008   MS_EXCEPTION_IF_NULL(address_common);
1009   MS_EXCEPTION_IF_NULL(abs);
1010   const auto &shape = abs->GetShape();
1011   const auto &type = abs->GetType();
1012   auto real_kernel_tensor =
1013     std::make_shared<kernel::KernelTensor>(address_common, shape, type, nullptr, shape->GetShapeVector());
1014   device_address->set_kernel_tensor(real_kernel_tensor);
1015   device_address->DeviceSynchronizerInit();
1016 }
1017 
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const tensor::BaseTensorPtr & tensor)1018 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1019                                                                 const abstract::AbstractBasePtr &abs, size_t index,
1020                                                                 const tensor::BaseTensorPtr &tensor) {
1021   MS_EXCEPTION_IF_NULL(device_context);
1022   MS_EXCEPTION_IF_NULL(tensor);
1023   auto addr = tensor->device_address();
1024   if (addr != nullptr) {
1025     auto device_address = std::static_pointer_cast<device::DeviceAddress>(addr);
1026     MS_EXCEPTION_IF_NULL(device_address);
1027     if (device_address->GetPtr() != nullptr) {
1028       MS_LOG(DEBUG) << "Input tensor already have address " << device_address.get() << " and device Ptr "
1029                     << device_address->GetPtr();
1030       return device_address;
1031     }
1032   }
1033   BaseShapePtr shape;
1034   TypePtr type;
1035   if (abs != nullptr) {
1036     shape = abs->GetShape();
1037     type = abs->GetType();
1038   } else {
1039     shape = std::make_shared<abstract::Shape>(tensor->shape());
1040     type = tensor->Dtype();
1041   }
1042 
1043   const auto &tensor_size = LongToSize(tensor->data().nbytes());
1044   const auto &format = GetFormatByTensorShape(device_context, tensor->shape());
1045   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1046     shape, type, nullptr, nullptr, tensor_size, kernel::GetFormatFromEnumToStr(format), tensor->data_type(),
1047     tensor->shape(), device_context->device_context_key().device_name_,
1048     device_context->device_context_key().device_id_);
1049   kernel_tensor->set_stream_id(stream_id);
1050   device::DeviceAddressPtr device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1051   MS_EXCEPTION_IF_NULL(device_address);
1052   device_address->set_from_persistent_mem(tensor->is_parameter());
1053   tensor->set_device_address(device_address);
1054 
1055   auto mem_type = tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kConstantValue;
1056   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", mem_type, device_address->GetSize(),
1057                                                  device_address.get());
1058   if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1059     MS_LOG(EXCEPTION) << "Allocate memory failed";
1060   }
1061   if (!device_address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(),
1062                                         kernel::GetFormatFromEnumToStr(format), tensor->data_ptr())) {
1063     MS_LOG(EXCEPTION) << "SyncHostToDevice failed";
1064   }
1065   MS_LOG(DEBUG) << "Create input tensor device address " << device_address << " for " << index
1066                 << "th input, Shape: " << shape->ToString()
1067                 << ", Type: " << TypeIdToType(tensor->data_type())->ToString() << ", host shape: " << tensor->shape()
1068                 << ", dev ptr " << device_address->GetPtr();
1069   return device_address;
1070 }
1071 
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const ScalarPtr & scalar_value)1072 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1073                                                                 const abstract::AbstractBasePtr &abs, size_t index,
1074                                                                 const ScalarPtr &scalar_value) {
1075   MS_EXCEPTION_IF_NULL(device_context);
1076   MS_EXCEPTION_IF_NULL(scalar_value);
1077   const auto type = scalar_value->type();
1078   MS_EXCEPTION_IF_NULL(type);
1079   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1080     abstract::kNoShape, type, scalar_value, nullptr, GetTypeByte(TypeIdToType(type->type_id())), kOpFormat_DEFAULT,
1081     type->type_id(), ShapeVector(), device_context->device_context_key().device_name_,
1082     device_context->device_context_key().device_id_);
1083   kernel_tensor->set_stream_id(stream_id);
1084   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1085   device_address->set_from_persistent_mem(true);
1086 
1087   if (device_address->GetPtr() == nullptr) {
1088     CopyNoneTensorDataToDevice(device_context, device_address);
1089   }
1090   MS_LOG(DEBUG) << "Create input scalar device address " << device_address << " for " << index
1091                 << "th input, Shape: " << abstract::kNoShape->ToString() << ", Type: " << type->ToString()
1092                 << ", Value: " << (scalar_value ? scalar_value->ToString() : "nullptr") << ", dev ptr "
1093                 << device_address->GetPtr();
1094   return device_address;
1095 }
1096 
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const std::optional<tensor::BaseTensorPtr> & val)1097 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1098                                                                 const abstract::AbstractBasePtr &abs, size_t index,
1099                                                                 const std::optional<tensor::BaseTensorPtr> &val) {
1100   if (!val.has_value()) {
1101     return nullptr;
1102   }
1103   return CreateInputAddress(device_context, stream_id, abs, index, val.value());
1104 }
1105 
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const StringImmPtr & string_imm)1106 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1107                                                                 const abstract::AbstractBasePtr &abs, size_t index,
1108                                                                 const StringImmPtr &string_imm) {
1109   MS_EXCEPTION_IF_NULL(device_context);
1110   MS_EXCEPTION_IF_NULL(string_imm);
1111   const auto &type = string_imm->type();
1112   MS_EXCEPTION_IF_NULL(type);
1113   const auto &tensor_value = GetValue<std::string>(string_imm);
1114   // Allocate one more byte to '/0'
1115   size_t size = tensor_value.size() + 1;
1116   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1117     abstract::kNoShape, type, string_imm, nullptr, size, kOpFormat_DEFAULT, kObjectTypeString, ShapeVector(),
1118     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1119   kernel_tensor->set_stream_id(stream_id);
1120   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1121   device_address->set_from_persistent_mem(true);
1122 
1123   if (device_address->GetPtr() == nullptr) {
1124     CopyNoneTensorDataToDevice(device_context, device_address);
1125   }
1126   MS_LOG(DEBUG) << "Create input string device address " << device_address << " for " << index
1127                 << "th input, Shape: " << abstract::kNoShape->ToString() << ", Type: " << type->ToString()
1128                 << ", Value: " << (string_imm ? string_imm->ToString() : "nullptr") << ", dev ptr "
1129                 << device_address->GetPtr();
1130   return device_address;
1131 }
1132 
CreateInputAddress(const DeviceContext * device_context,size_t stream_id,const abstract::AbstractBasePtr & abs,size_t index,const TypePtr & type_ptr)1133 device::DeviceAddressPtr DeviceAddressUtils::CreateInputAddress(const DeviceContext *device_context, size_t stream_id,
1134                                                                 const abstract::AbstractBasePtr &abs, size_t index,
1135                                                                 const TypePtr &type_ptr) {
1136   MS_EXCEPTION_IF_NULL(device_context);
1137   const auto &type = type_ptr->type();
1138   MS_EXCEPTION_IF_NULL(type);
1139   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1140     abstract::kNoShape, type, nullptr, nullptr, GetTypeByte(TypeIdToType(type->type_id())), kOpFormat_DEFAULT,
1141     type_ptr->type_id(), ShapeVector(), device_context->device_context_key().device_name_,
1142     device_context->device_context_key().device_id_);
1143   kernel_tensor->set_stream_id(stream_id);
1144   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1145   device_address->set_from_persistent_mem(true);
1146 
1147   if (device_address->GetPtr() == nullptr) {
1148     CopyNoneTensorDataToDevice(device_context, device_address);
1149   }
1150   MS_LOG(DEBUG) << "Create input " << type_ptr->ToString() << " device address for " << index
1151                 << "th input, Shape: " << abstract::kNoShape->ToString() << ", Type: " << type->ToString()
1152                 << ", Value: nullptr, device address:" << device_address;
1153   return device_address;
1154 }
1155 
CreateOutputTensorAddress(const DeviceContext * device_context,size_t stream_id,const std::vector<tensor::BaseTensorPtr> & outputs)1156 void DeviceAddressUtils::CreateOutputTensorAddress(const DeviceContext *device_context, size_t stream_id,
1157                                                    const std::vector<tensor::BaseTensorPtr> &outputs) {
1158   MS_EXCEPTION_IF_NULL(device_context);
1159   for (size_t i = 0; i < outputs.size(); ++i) {
1160     const auto &tensor = outputs[i];
1161     MS_EXCEPTION_IF_NULL(tensor);
1162     auto tensor_size = LongToSize(tensor->data().nbytes());
1163     const auto &format = GetFormatByTensorShape(device_context, tensor->shape());
1164     auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
1165       nullptr, tensor_size, tensor->shape(), format, tensor->data_type(),
1166       device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
1167     if (device_address->GetDeviceType() != device::DeviceType::kAscend) {
1168       // CPU or GPU need KernelTensor to LaunchKernel
1169       CreateKernelTensor(device_address, tensor);
1170     }
1171     MS_EXCEPTION_IF_NULL(device_address);
1172     tensor->set_device_address(device_address);
1173     MS_LOG(DEBUG) << "Create output tensor device address " << device_address << " for " << i
1174                   << "th output, Shape: " << tensor->shape()
1175                   << ", Type: " << TypeIdToType(tensor->data_type())->ToString() << ", Size:" << tensor_size;
1176   }
1177 }
1178 
CreateOutputTensorAddress(const DeviceContext * device_context,size_t stream_id,const tensor::BaseTensorPtr & output_tensor,size_t size)1179 void DeviceAddressUtils::CreateOutputTensorAddress(const DeviceContext *device_context, size_t stream_id,
1180                                                    const tensor::BaseTensorPtr &output_tensor, size_t size) {
1181   MS_EXCEPTION_IF_NULL(device_context);
1182   MS_EXCEPTION_IF_NULL(output_tensor);
1183   const auto &format = GetFormatByTensorShape(device_context, output_tensor->shape());
1184   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
1185     nullptr, size, output_tensor->shape(), format, output_tensor->data_type(),
1186     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
1187   if (device_address->GetDeviceType() != device::DeviceType::kAscend) {
1188     // CPU or GPU need KernelTensor to LaunchKernel
1189     CreateKernelTensor(device_address, output_tensor);
1190   }
1191   MS_EXCEPTION_IF_NULL(device_address);
1192   output_tensor->set_device_address(device_address);
1193   MS_LOG(DEBUG) << "Create output tensor device address " << device_address << "the output, Shape: "
1194                 << static_cast<int64_t>(size / GetTypeByte(TypeIdToType(output_tensor->data_type())))
1195                 << ", Type: " << TypeIdToType(output_tensor->data_type())->ToString() << ", Size:" << size;
1196 }
1197 
CreateDeviceAddress(const DeviceContext * device_context,const tensor::BaseTensorPtr & tensor,const ShapeVector & real_shape,const size_t & stream_id)1198 device::DeviceAddressPtr DeviceAddressUtils::CreateDeviceAddress(const DeviceContext *device_context,
1199                                                                  const tensor::BaseTensorPtr &tensor,
1200                                                                  const ShapeVector &real_shape,
1201                                                                  const size_t &stream_id) {
1202   MS_EXCEPTION_IF_NULL(device_context);
1203   MS_EXCEPTION_IF_NULL(tensor);
1204   auto tensor_size = GetTypeByte(TypeIdToType(tensor->data_type())) * SizeOf(real_shape);
1205   const auto &device_format = GetFormatByTensorShape(device_context, tensor->shape());
1206   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1207     nullptr, tensor_size, device_format, tensor->data_type(), real_shape,
1208     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1209   kernel_tensor->set_stream_id(stream_id);
1210   device::DeviceAddressPtr device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1211   MS_LOG(DEBUG) << "Create tensor device address " << device_address << "Shape: " << tensor->shape()
1212                 << ", Type: " << TypeIdToType(tensor->data_type())->ToString();
1213   return device_address;
1214 }
1215 
MallocForOutputs(const DeviceContext * device_context,const std::vector<tensor::BaseTensorPtr> & outputs)1216 void DeviceAddressUtils::MallocForOutputs(const DeviceContext *device_context,
1217                                           const std::vector<tensor::BaseTensorPtr> &outputs) {
1218   for (const auto &output : outputs) {
1219     auto device_address = std::static_pointer_cast<device::DeviceAddress>(output->device_address());
1220     if (device_address->GetPtr() != nullptr) {
1221       // ref output
1222       continue;
1223     }
1224     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kPyNativeOutput,
1225                                                    device_address->GetSize(), device_address.get());
1226     if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1227       MS_LOG(EXCEPTION) << "Allocate memory failed";
1228     }
1229   }
1230 }
1231 
CreateWorkspaceAddressWithoutKernelTensor(const DeviceContext * device_context,size_t stream_id,const size_t & workspace_size)1232 device::DeviceAddressPtr DeviceAddressUtils::CreateWorkspaceAddressWithoutKernelTensor(
1233   const DeviceContext *device_context, size_t stream_id, const size_t &workspace_size) {
1234   MS_EXCEPTION_IF_NULL(device_context);
1235   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(
1236     nullptr, workspace_size, ShapeVector(), Format::DEFAULT_FORMAT, kTypeUnknown,
1237     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_, stream_id);
1238   MS_EXCEPTION_IF_NULL(device_address);
1239   device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "PyNative", device::tracker::MemType::kWorkSpace,
1240                                                  device_address->GetSize(), device_address.get());
1241   if (device_address->GetPtr() == nullptr &&
1242       !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1243     MS_LOG(EXCEPTION) << "Allocate dynamic workspace memory failed";
1244   }
1245   MS_LOG(DEBUG) << "Create workspace device address:" << device_address;
1246   return device_address;
1247 }
1248 
CreateWorkspaceAddress(const DeviceContext * device_context,size_t stream_id,const size_t & workspace_size)1249 device::DeviceAddressPtr DeviceAddressUtils::CreateWorkspaceAddress(const DeviceContext *device_context,
1250                                                                     size_t stream_id, const size_t &workspace_size) {
1251   MS_EXCEPTION_IF_NULL(device_context);
1252 
1253   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1254     nullptr, workspace_size, Format::DEFAULT_FORMAT, kTypeUnknown, ShapeVector(),
1255     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1256   kernel_tensor->set_stream_id(stream_id);
1257 
1258   auto device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1259   MS_EXCEPTION_IF_NULL(device_address);
1260   if (device_address->GetPtr() == nullptr &&
1261       !device_context->device_res_manager_->AllocateMemory(device_address.get())) {
1262     MS_LOG(EXCEPTION) << "Allocate dynamic workspace memory failed";
1263   }
1264   MS_LOG(DEBUG) << "Create workspace device address:" << device_address;
1265   return device_address;
1266 }
1267 
ConvertContiguousTensorSync(const tensor::BaseTensorPtr & tensor)1268 void DeviceAddressUtils::ConvertContiguousTensorSync(const tensor::BaseTensorPtr &tensor) {
1269   if (tensor == nullptr || tensor->storage_info() == nullptr) {
1270     return;
1271   }
1272 
1273   MS_LOG(DEBUG) << "Tensor storage_info is not nullptr, need to contiguous, id:" << tensor->id();
1274   const auto &new_device_address = ConvertContiguousDeviceAddress(
1275     nullptr, std::static_pointer_cast<device::DeviceAddress>(tensor->device_address()), true);
1276   MS_EXCEPTION_IF_NULL(new_device_address);
1277   tensor->set_device_address(new_device_address);
1278 }
1279 
ConvertContiguousDeviceAddress(const DeviceContext * input_device_context,const device::DeviceAddressPtr & old_device_address,bool is_sync)1280 device::DeviceAddressPtr DeviceAddressUtils::ConvertContiguousDeviceAddress(
1281   const DeviceContext *input_device_context, const device::DeviceAddressPtr &old_device_address, bool is_sync) {
1282   MS_EXCEPTION_IF_NULL(old_device_address);
1283 
1284   const DeviceContext *device_context = input_device_context == nullptr
1285                                           ? runtime::OpRunner::GetDeviceContext(old_device_address->device_name())
1286                                           : input_device_context;
1287   MS_EXCEPTION_IF_NULL(device_context);
1288   auto stream_id = device_context->device_res_manager_->GetCurrentStreamId();
1289 
1290   GilReleaseWithCheck release_gil;
1291   const auto &old_storage_info = old_device_address->GetTensorStorageInfo();
1292   MS_EXCEPTION_IF_NULL(old_storage_info);
1293 
1294   auto address_size = GetTypeByte(TypeIdToType(old_device_address->type_id())) * SizeOf(old_storage_info->shape);
1295   auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1296     nullptr, address_size, Format::DEFAULT_FORMAT, old_device_address->type_id(), old_storage_info->shape,
1297     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
1298   kernel_tensor->SetType(std::make_shared<TensorType>(TypeIdToType(old_device_address->type_id())));
1299   kernel_tensor->SetShape(std::make_shared<abstract::TensorShape>(old_storage_info->shape));
1300   kernel_tensor->set_stream_id(stream_id);
1301 
1302   auto new_device_address = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1303   new_device_address->set_device_shape(old_storage_info->shape);
1304   new_device_address->set_original_ref_count(SIZE_MAX);
1305   new_device_address->ResetRefCount();
1306 
1307   if (is_sync) {
1308     // ExecuteKernelTask sync, need to wait until all tasks in queue are complete.
1309     runtime::OpExecutor::GetInstance().WaitAll();
1310     if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(
1311           runtime::KernelTaskType::kCONTIGUOUS_TASK, {old_device_address}, {new_device_address}, stream_id)) {
1312       MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << runtime::KernelTaskType::kCONTIGUOUS_TASK;
1313     }
1314     runtime::OpExecutor::GetInstance().WaitAll();
1315   } else {
1316     auto async_task = [device_context, old_device_address, new_device_address, stream_id]() {
1317       if (!device_context->GetKernelExecutor(false)->ExecuteKernelTask(
1318             runtime::KernelTaskType::kCONTIGUOUS_TASK, {old_device_address}, {new_device_address}, stream_id)) {
1319         MS_LOG(EXCEPTION) << "ExecuteKernelTask failed, task_type:" << runtime::KernelTaskType::kCONTIGUOUS_TASK;
1320       }
1321     };
1322 
1323     runtime::OpExecutor::GetInstance().PushSimpleOpRunTask(
1324       std::make_shared<runtime::PassthroughDeviceTask>(async_task));
1325   }
1326 
1327   return new_device_address;
1328 }
1329 
GetCrossStreamAddressInfoFromInput(size_t op_stream_id,std::vector<std::pair<uint32_t,void * >> * cross_stream_addresses,const tensor::BaseTensorPtr & tensor)1330 void DeviceAddressUtils::GetCrossStreamAddressInfoFromInput(
1331   size_t op_stream_id, std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
1332   const tensor::BaseTensorPtr &tensor) {
1333   MS_EXCEPTION_IF_NULL(tensor);
1334   if (tensor->device_address() == nullptr) {
1335     return;
1336   }
1337 
1338   auto device_address = std::static_pointer_cast<device::DeviceAddress>(tensor->device_address());
1339   MS_EXCEPTION_IF_NULL(device_address);
1340   if (op_stream_id != device_address->stream_id()) {
1341     // Device address is cross stream.
1342     (void)cross_stream_addresses->emplace_back(device_address->stream_id(), device_address->GetMutablePtr());
1343   }
1344 }
1345 
GetCrossStreamAddressInfoFromInput(size_t op_stream_id,std::vector<std::pair<uint32_t,void * >> * cross_stream_addresses,const mindspore::kernel::KernelTensor * tensor)1346 void DeviceAddressUtils::GetCrossStreamAddressInfoFromInput(
1347   size_t op_stream_id, std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
1348   const mindspore::kernel::KernelTensor *tensor) {
1349   MS_EXCEPTION_IF_NULL(tensor);
1350   if (op_stream_id != tensor->stream_id()) {
1351     (void)cross_stream_addresses->emplace_back(tensor->stream_id(), tensor->device_ptr());
1352   }
1353 }
1354 
GetCrossStreamAddressInfoFromInput(size_t op_stream_id,std::vector<std::pair<uint32_t,void * >> * cross_stream_addresses,const device::DeviceAddressPtr & device_address)1355 void DeviceAddressUtils::GetCrossStreamAddressInfoFromInput(
1356   size_t op_stream_id, std::vector<std::pair<uint32_t, void *>> *cross_stream_addresses,
1357   const device::DeviceAddressPtr &device_address) {
1358   MS_EXCEPTION_IF_NULL(device_address);
1359   if (op_stream_id != device_address->stream_id()) {
1360     (void)cross_stream_addresses->emplace_back(device_address->stream_id(), device_address->GetMutablePtr());
1361   }
1362 }
1363 }  // namespace runtime
1364 }  // namespace mindspore
1365