• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/pynative/graph_adapter.h"
18 
19 #include <string>
20 #include <memory>
21 #include <vector>
22 #include "ir/tensor.h"
23 #include "include/common/utils/convert_utils.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/parallel_context.h"
26 #include "include/backend/anf_runtime_algorithm.h"
27 #include "include/backend/mem_reuse/mem_tracker.h"
28 #include "runtime/graph_scheduler/device_tensor_store.h"
29 #include "runtime/device/ms_device_shape_transfer.h"
30 #include "runtime/graph_scheduler/actor/actor_common.h"
31 #include "runtime/graph_scheduler/scheduler_helper.h"
32 #include "runtime/device/device_address_utils.h"
33 #include "kernel/pyboost/pyboost_utils.h"
34 
35 namespace mindspore::pynative {
36 namespace {
37 constexpr auto kAttrBpropValueNodeRefCount = "bprop_value_node_ref_count";
38 constexpr auto kAttrValueNodeForwardOuputFlags = "value_node_forward_output_flags";
39 
GetTensorFromValueNode(const AnfNodePtr & node)40 tensor::BaseTensorPtr GetTensorFromValueNode(const AnfNodePtr &node) {
41   MS_EXCEPTION_IF_NULL(node);
42   if (!node->isa<ValueNode>()) {
43     return nullptr;
44   }
45   auto value_node = node->cast<ValueNodePtr>();
46   MS_EXCEPTION_IF_NULL(value_node);
47   auto value = value_node->value();
48   MS_EXCEPTION_IF_NULL(value);
49   // ValueTuple is already expanded into tensors in backend.
50   if (!value->isa<tensor::BaseTensor>()) {
51     MS_LOG(DEBUG) << "Only need to process forward output tensor. value:" << value->ToString();
52     return nullptr;
53   }
54 
55   auto tensor = value->cast<tensor::BaseTensorPtr>();
56   return tensor;
57 }
58 
GetGraphValueNodeRefCounts(const KernelGraphPtr & graph)59 HashMap<ValueNodePtr, size_t> GetGraphValueNodeRefCounts(const KernelGraphPtr &graph) {
60   MS_EXCEPTION_IF_NULL(graph);
61   HashMap<ValueNodePtr, size_t> value_node_ref_counts;
62   // For example:
63   //   %1 MakeTuple(V1, V2)
64   //   %2 TupleGetItem(0, %1)
65   //   %3 Kernel(%2)
66   // V2 is not used by kernel. Need to remove.
67   auto execution_nodes = graph->execution_order();
68   for (auto &node : execution_nodes) {
69     std::vector<session::KernelWithIndex> real_inputs;
70     common::AnfAlgo::GetRealInputs(node, &real_inputs);
71     for (auto &real_input : real_inputs) {
72       auto input = real_input.first;
73       MS_EXCEPTION_IF_NULL(input);
74       if (input->isa<ValueNode>()) {
75         auto value_node = input->cast<ValueNodePtr>();
76         value_node_ref_counts[value_node] += 1;
77       }
78     }
79   }
80 
81   // ValueNodes as graph outputs
82   auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
83   for (auto &output : outputs) {
84     MS_EXCEPTION_IF_NULL(output);
85     if (output->isa<ValueNode>()) {
86       auto value_node = output->cast<ValueNodePtr>();
87       MS_EXCEPTION_IF_NULL(value_node);
88       value_node_ref_counts[value_node] += 1;
89     }
90   }
91 
92   return value_node_ref_counts;
93 }
94 
CreateValueNodeAddress(const ValueNodePtr & value_node,const device::DeviceContext * device_context)95 device::DeviceAddressPtr CreateValueNodeAddress(const ValueNodePtr &value_node,
96                                                 const device::DeviceContext *device_context) {
97   size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, 0);
98   TypeId data_type = AnfAlgo::GetOutputDeviceDataType(value_node, 0);
99   if (data_type == kTypeUnknown) {
100     data_type = common::AnfAlgo::GetOutputInferDataType(value_node, 0);
101   }
102   auto output_format = AnfAlgo::GetOutputFormat(value_node, 0);
103   MS_EXCEPTION_IF_NULL(device_context);
104   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
105   const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
106     {value_node, 0}, nullptr, tensor_size, output_format, data_type, trans::GetRuntimePaddingShape(value_node, 0),
107     device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
108   return device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
109 }
110 
CopyTensorData(const tensor::BaseTensorPtr & tensor,const device::DeviceAddressPtr & device_address,const AnfNodePtr & node,const device::DeviceContext * device_context)111 bool CopyTensorData(const tensor::BaseTensorPtr &tensor, const device::DeviceAddressPtr &device_address,
112                     const AnfNodePtr &node, const device::DeviceContext *device_context) {
113   MS_EXCEPTION_IF_NULL(tensor);
114   MS_EXCEPTION_IF_NULL(device_address);
115   MS_EXCEPTION_IF_NULL(node);
116   MS_EXCEPTION_IF_NULL(device_context);
117   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
118   device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), device::AllocatorType::kConstantValue,
119                                                      0);
120   if (device_address->GetPtr() == nullptr) {
121     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "CopyTensorData", "CopyTensorData", "");
122     auto mem_type =
123       tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
124     device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "CopyTensorData", mem_type, device_address->GetSize(),
125                                                    device_address.get());
126     if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
127       MS_LOG(ERROR) << "Allocate memory failed, allocate size " << device_address->GetSize();
128       return false;
129     }
130   }
131 
132   // Copy data from host tensor to device.
133   auto host_tensor_size = LongToSize(tensor->data().nbytes());
134   auto host_tensor_type = tensor->data_type();
135   if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), host_tensor_size, host_tensor_type,
136                                         kOpFormat_DEFAULT, tensor->data_ptr())) {
137     std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope() +
138                              ", tensor size: " + std::to_string(host_tensor_size) +
139                              ", tensor type: " + std::to_string(static_cast<int>(host_tensor_type)) +
140                              ", device address size: " + std::to_string(device_address->GetSize());
141     MS_LOG(ERROR) << error_info;
142     return false;
143   }
144   return true;
145 }
146 
HandleAddressForHeterogeneous(const tensor::BaseTensorPtr & tensor,const ValueNodePtr & value_node,const device::DeviceContext * device_context)147 device::DeviceAddressPtr HandleAddressForHeterogeneous(const tensor::BaseTensorPtr &tensor,
148                                                        const ValueNodePtr &value_node,
149                                                        const device::DeviceContext *device_context) {
150   MS_EXCEPTION_IF_NULL(tensor);
151   MS_EXCEPTION_IF_NULL(value_node);
152   MS_EXCEPTION_IF_NULL(device_context);
153   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
154   if (device_address == nullptr) {
155     MS_LOG(INFO) << "Forward output " << tensor->ToString() << " device address is null";
156     device_address = CreateValueNodeAddress(value_node, device_context);
157     if (!CopyTensorData(tensor, device_address, value_node, device_context)) {
158       MS_LOG(EXCEPTION) << "CopyTensorData failed, value_node " << value_node->DebugString();
159     }
160   }
161   MS_EXCEPTION_IF_NULL(device_address);
162   if (device_address->GetDeviceType() != device_context->GetDeviceType()) {
163     tensor->data_sync();
164     auto new_device_address = CreateValueNodeAddress(value_node, device_context);
165     MS_EXCEPTION_IF_NULL(new_device_address);
166     if (!CopyTensorData(tensor, new_device_address, value_node, device_context)) {
167       MS_LOG(EXCEPTION) << "CopyTensorData failed, value_node " << value_node->DebugString();
168     }
169     return new_device_address;
170   }
171   return device_address;
172 }
173 }  // namespace
174 
RemoveUnusedValueNodes(const KernelGraphPtr & graph)175 void GraphAdapter::RemoveUnusedValueNodes(const KernelGraphPtr &graph) {
176   MS_EXCEPTION_IF_NULL(graph);
177   auto value_node_ref_counts = GetGraphValueNodeRefCounts(graph);
178   for (const auto &value_node : graph->graph_value_nodes()) {
179     MS_EXCEPTION_IF_NULL(value_node);
180     auto iter = value_node_ref_counts.find(value_node);
181     if (iter == value_node_ref_counts.end()) {
182       MS_LOG(DEBUG) << "Remove unused ValueNode " << value_node->DebugString();
183       graph->RemoveNodeFromGraph(value_node);
184     }
185   }
186 }
187 
ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr & graph,const device::DeviceContext * device_context)188 void GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph,
189                                                             const device::DeviceContext *device_context) {
190   MS_EXCEPTION_IF_NULL(graph);
191   for (auto &value_node : graph->graph_value_nodes()) {
192     MS_EXCEPTION_IF_NULL(value_node);
193     auto value = value_node->value();
194     MS_EXCEPTION_IF_NULL(value);
195     if (value->isa<tensor::BaseTensor>()) {
196       auto tensor = value->cast<tensor::BaseTensorPtr>();
197       MS_EXCEPTION_IF_NULL(tensor);
198       if (!tensor->is_forward_output()) {
199         continue;
200       }
201 
202       if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
203         MS_LOG(DEBUG) << "Output addr is not exist for ValueNode " << value_node->ToString();
204         continue;
205       }
206       const auto &device_address = AnfAlgo::GetMutableOutputAddr(value_node, 0);
207       auto new_device_address = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(device_address, device_context);
208       AnfAlgo::SetOutputAddr(new_device_address, 0, value_node.get());
209     }
210   }
211 }
212 
213 // The device address of graph value node need to release
214 // if the value node is output of forward_graph in PyNative mode.
GenerateRefCountForBpropValueNode(const KernelGraphPtr & graph)215 void GraphAdapter::GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph) {
216   MS_EXCEPTION_IF_NULL(graph);
217   HashMap<std::string, size_t> tensor_counts;
218   HashMap<ValueNodePtr, size_t> value_node_ref_counts = GetGraphValueNodeRefCounts(graph);
219 
220   std::vector<size_t> value_node_ref_count_list;
221   std::vector<bool> value_node_forward_output_flags;
222   for (auto &value_node : graph->graph_value_nodes()) {
223     MS_EXCEPTION_IF_NULL(value_node);
224     auto tensor = GetTensorFromValueNode(value_node);
225     if (tensor == nullptr || !tensor->is_forward_output()) {
226       (void)value_node_ref_count_list.emplace_back(SIZE_MAX);
227       (void)value_node_forward_output_flags.emplace_back(false);
228       continue;
229     }
230 
231     auto iter = value_node_ref_counts.find(value_node);
232     if (iter == value_node_ref_counts.end()) {
233       // The value_node is in bp graph but not used.
234       // e.g. %1-MakeTuple(T1, T2) -> TupleGetItem(%1, 0). T2 is not used.
235       MS_LOG(DEBUG) << "ValueNode " << value_node->ToString() << " is not used in graph";
236       (void)value_node_ref_count_list.emplace_back(SIZE_MAX);
237       (void)value_node_forward_output_flags.emplace_back(false);
238       continue;
239     }
240 
241     (void)value_node_ref_count_list.emplace_back(iter->second);
242     (void)value_node_forward_output_flags.emplace_back(true);
243     MS_LOG(DEBUG) << "ValueNode " << value_node->DebugString() << " ref_count " << iter->second;
244   }
245   graph->set_attr(kAttrBpropValueNodeRefCount, MakeValue(value_node_ref_count_list));
246   graph->set_attr(kAttrValueNodeForwardOuputFlags, MakeValue(value_node_forward_output_flags));
247 }
248 
GenerateBackoffValueNodeOwners(const KernelGraphPtr & graph)249 void GraphAdapter::GenerateBackoffValueNodeOwners(const KernelGraphPtr &graph) {
250   for (auto &kernel : graph->execution_order()) {
251     if (!AnfAlgo::IsKernelSelectBackoffOp(kernel)) {
252       continue;
253     }
254     for (size_t j = 0; j < common::AnfAlgo::GetInputTensorNum(kernel); ++j) {
255       const auto &input_node = common::AnfAlgo::GetInputNode(kernel, j);
256       const auto &real_input_node = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false).first;
257       MS_EXCEPTION_IF_NULL(real_input_node);
258       if (real_input_node->isa<ValueNode>()) {
259         (void)node_to_backoff_kernels_[real_input_node.get()].insert(kernel);
260         MS_LOG(DEBUG) << "Generate backoff ValueNode " << real_input_node->DebugString() << " with kernel "
261                       << kernel->DebugString();
262       }
263     }
264   }
265 }
266 
HandleBackoffValueNode(const ValueNodePtr & value_node,const AnfNodePtr & front_node,const DeviceContext * device_context) const267 void GraphAdapter::HandleBackoffValueNode(const ValueNodePtr &value_node, const AnfNodePtr &front_node,
268                                           const DeviceContext *device_context) const {
269   auto iter = node_to_backoff_kernels_.find(value_node.get());
270   if (iter == node_to_backoff_kernels_.end()) {
271     return;
272   }
273 
274   MS_LOG(DEBUG) << "Backoff ValueNode " << value_node->ToString();
275   const auto &kernels = iter->second;
276   for (const auto &kernel : kernels) {
277     const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_context);
278     MS_EXCEPTION_IF_NULL(real_device_context);
279 
280     if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
281       MS_LOG(EXCEPTION) << "The device address is not exist: " << value_node->ToString();
282     }
283     auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
284     MS_EXCEPTION_IF_NULL(device_tensor);
285 
286     auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
287       nullptr, device_tensor->GetSize(), device_tensor->kernel_tensor()->format(), device_tensor->type_id(),
288       device_tensor->host_shape(), device_context->device_context_key().device_name_,
289       device_context->device_context_key().device_id_);
290 
291     kernel_tensor->SetHostInfo(
292       std::make_shared<abstract::TensorShape>(device_tensor->kernel_tensor()->GetShapeVector()),
293       std::make_shared<TensorType>(TypeIdToType(device_tensor->kernel_tensor()->dtype_id())), nullptr);
294 
295     kernel_tensor->set_stream_id(device_tensor->stream_id());
296     auto new_device_tensor = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
297     MS_EXCEPTION_IF_NULL(new_device_tensor);
298     new_device_tensor->SetNodeIndex(value_node, 0);
299     new_device_tensor->set_from_persistent_mem(true);
300     MS_LOG(DEBUG) << "Create backoff device tensor:" << new_device_tensor << " type:" << new_device_tensor->type_id()
301                   << " for ValueNode " << value_node->ToString();
302     runtime::SchedulerHelper::AddDeviceTensorStore(front_node.get(), new_device_tensor);
303   }
304 }
305 
UpdateForwardOutputInBpropGraph(const KernelGraphPtr & graph,const device::DeviceContext * device_context,bool no_control_flow)306 void GraphAdapter::UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph,
307                                                    const device::DeviceContext *device_context, bool no_control_flow) {
308   MS_EXCEPTION_IF_NULL(graph);
309   MS_LOG(DEBUG) << "Update start";
310   auto value_node_ref_counts = GetValue<std::vector<size_t>>(graph->get_attr(kAttrBpropValueNodeRefCount));
311   auto value_node_forward_output_flags = GetValue<std::vector<bool>>(graph->get_attr(kAttrValueNodeForwardOuputFlags));
312   size_t value_node_size = graph->graph_value_nodes().size();
313   if (value_node_ref_counts.size() != value_node_size || value_node_forward_output_flags.size() != value_node_size) {
314     MS_LOG(EXCEPTION) << "value_node_ref_count.size " << value_node_ref_counts.size()
315                       << " value_node_forward_output_flags.size " << value_node_forward_output_flags.size()
316                       << " not equal to " << value_node_size;
317   }
318 
319   size_t value_node_index = 0;
320   HashMap<device::DeviceAddressPtr, size_t> address_ref_count;
321   // Update ValueNode device address
322   for (auto &value_node : graph->graph_value_nodes()) {
323     auto is_forward_output = value_node_forward_output_flags[value_node_index];
324     if (!is_forward_output) {
325       value_node_index++;
326       continue;
327     }
328     size_t value_node_ref_count = value_node_ref_counts[value_node_index++];
329     auto tensor = GetTensorFromValueNode(value_node);
330     MS_EXCEPTION_IF_NULL(tensor);
331 
332     auto device_address = HandleAddressForHeterogeneous(tensor, value_node, device_context);
333     device_address = std::dynamic_pointer_cast<device::DeviceAddress>(
334       kernel::pyboost::PyBoostUtils::ContiguousByDeviceAddress(device_address));
335     runtime::DeviceAddressUtils::CreateKernelTensor(device_address, tensor);
336     tensor->set_device_address(device_address);
337     auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(value_node, *graph);
338     MS_EXCEPTION_IF_NULL(front_node);
339     MS_EXCEPTION_IF_NULL(device_address);
340     if (device_address->GetDeviceType() != device::DeviceType::kCPU && no_control_flow) {
341       address_ref_count[device_address] += value_node_ref_count;
342       device_address->AddHeldByNode(front_node->cast<ValueNodePtr>());
343     }
344     runtime::DeviceTensorStore::GetInstance().Insert(front_node.get(), device_address);
345     HandleBackoffValueNode(value_node, front_node, device_context);
346   }
347 
348   for (auto &[address, ref_count] : address_ref_count) {
349     MS_EXCEPTION_IF_NULL(address);
350     address->set_original_ref_count(ref_count);
351     address->ResetRefCount();
352     MS_LOG(DEBUG) << "device_address " << address.get() << " ref_count " << address->ref_count();
353   }
354   MS_LOG(DEBUG) << "Update end";
355 }
356 
HandleHeterogeneousTensors(const std::vector<std::vector<tensor::TensorPtr>> & input_tensors,const std::vector<device::DeviceContext * > & device_contexts)357 void GraphAdapter::HandleHeterogeneousTensors(const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
358                                               const std::vector<device::DeviceContext *> &device_contexts) {
359   if (input_tensors.size() < device_contexts.size()) {
360     MS_LOG(EXCEPTION) << "Invalid input_tensors size " << input_tensors.size() << " device_contexts size "
361                       << device_contexts.size();
362   }
363   for (size_t i = 0; i < device_contexts.size(); ++i) {
364     auto tensors = input_tensors[i];
365     auto device_context = device_contexts[i];
366     MS_EXCEPTION_IF_NULL(device_context);
367     for (auto &tensor : tensors) {
368       if (tensor != nullptr && tensor->device_address() != nullptr) {
369         auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
370         MS_EXCEPTION_IF_NULL(device_address);
371         if (device_address->GetDeviceType() != device_context->GetDeviceType()) {
372           tensor->data_sync();
373           tensor->set_device_address(nullptr);
374         }
375       }
376     }
377   }
378 }
379 
ReplaceGraphParameterProperties(const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const device::DeviceContext * device_context)380 void GraphAdapter::ReplaceGraphParameterProperties(const KernelGraphPtr &graph,
381                                                    const std::vector<tensor::TensorPtr> &input_tensors,
382                                                    const device::DeviceContext *device_context) {
383   MS_EXCEPTION_IF_NULL(device_context);
384   MS_EXCEPTION_IF_NULL(graph);
385   size_t index = 0;
386   for (const auto &input_node : graph->input_nodes()) {
387     auto parameters = common::AnfAlgo::GetAllOutput(input_node);
388     for (const auto &parameter : parameters) {
389       MS_EXCEPTION_IF_NULL(parameter);
390       if (index >= input_tensors.size()) {
391         MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
392                           << ", input size: " << input_tensors.size();
393       }
394       const auto &input_tensor = input_tensors[index++];
395       MS_EXCEPTION_IF_NULL(input_tensor);
396       const auto &tensor_address = input_tensor->device_address();
397       auto address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address);
398       if (address == nullptr || address->GetDeviceType() != device_context->GetDeviceType()) {
399         // Need to discard input tensor properties in heterogeneous scenarios.
400         // For example, the format of device_address in input_tensor is 5D format,
401         // and it's invalid for CPU graph parameter.
402         continue;
403       }
404 
405       auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
406       MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
407       kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{address->format()});
408       kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{address->type_id()});
409       kernel_build_info_builder->SetOutputsReshapeType({address->padding_type()});
410       AnfAlgo::SetOutputAddr(address, 0, parameter.get());
411       AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), parameter.get());
412 
413       auto abstract = parameter->abstract();
414       MS_EXCEPTION_IF_NULL(abstract);
415       auto shape = abstract->BuildShape();
416       auto new_abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(address->type_id()), shape);
417       parameter->set_abstract(new_abs);
418     }
419   }
420 }
421 
IsAutoParallel()422 bool GraphAdapter::IsAutoParallel() {
423   auto parallel_context = parallel::ParallelContext::GetInstance();
424   MS_EXCEPTION_IF_NULL(parallel_context);
425   auto parallel_mode = parallel_context->parallel_mode();
426   return parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
427 }
428 
IsPynativeGeGraphSink(const GraphCompilerInfo & graph_compiler_info)429 bool GraphAdapter::IsPynativeGeGraphSink(const GraphCompilerInfo &graph_compiler_info) {
430   bool is_sink = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
431                              [](const KernelGraphPtr &graph) { return GraphAdapter::IsPynativeGeGraphSink(graph); });
432   return is_sink;
433 }
434 
IsPynativeGeGraphSink(const FuncGraphPtr & func_graph)435 bool GraphAdapter::IsPynativeGeGraphSink(const FuncGraphPtr &func_graph) {
436   auto context_ptr = MsContext::GetInstance();
437   MS_EXCEPTION_IF_NULL(context_ptr);
438   if (context_ptr->backend_policy() != "ge" || !context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
439     return false;
440   }
441 
442   MS_EXCEPTION_IF_NULL(func_graph);
443   if (func_graph->has_flag(kFlagEnableRunGraphBySingleOp)) {
444     return false;
445   }
446 
447   return true;
448 }
449 
PyNativeEnableTaskSink(const FuncGraphPtr & func_graph)450 bool GraphAdapter::PyNativeEnableTaskSink(const FuncGraphPtr &func_graph) {
451   auto ms_context = MsContext::GetInstance();
452   MS_EXCEPTION_IF_NULL(ms_context);
453   bool pynative_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
454   if (!pynative_mode) {
455     return true;
456   }
457 
458   MS_EXCEPTION_IF_NULL(func_graph);
459   if (GraphAdapter::IsPynativeGeGraphSink(func_graph)) {
460     MS_LOG(DEBUG) << "Enable graph sink for PyNative";
461     return true;
462   }
463 
464   if (!func_graph->has_attr(kAttrJitLevel)) {
465     MS_LOG(EXCEPTION) << "Not jit_level set to func_graph";
466   }
467   auto jit_level_value = func_graph->get_attr(kAttrJitLevel);
468   auto jit_level = GetValue<std::string>(jit_level_value);
469   if (jit_level != kAttrJitLevelO2) {
470     MS_LOG(INFO) << "jit_level is " << jit_level << ", task sink is disabled";
471     return false;
472   }
473 
474   std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
475   auto is_cut_graph = std::any_of(node_list.begin(), node_list.end(), [](const AnfNodePtr &node) {
476     return common::AnfAlgo::IsBpropCutOpExecInBackend(node);
477   });
478 
479   auto has_comm_op = std::any_of(node_list.begin(), node_list.end(),
480                                  [](const AnfNodePtr &node) { return common::AnfAlgo::IsCommunicationOp(node); });
481 
482   auto is_auto_parallel = IsAutoParallel();
483 
484   MS_LOG(INFO) << "JitLevel is " << jit_level << " is_auto_parallel " << is_auto_parallel << " has_comm_op "
485                << has_comm_op << " is_cut_graph " << is_cut_graph;
486 
487   return !is_auto_parallel && !has_comm_op && !is_cut_graph;
488 }
489 
UpdateValueNodeAbstractFromTensor(const ValueNodePtr & value_node,const tensor::BaseTensorPtr & tensor)490 void UpdateValueNodeAbstractFromTensor(const ValueNodePtr &value_node, const tensor::BaseTensorPtr &tensor) {
491   MS_EXCEPTION_IF_NULL(value_node);
492   MS_EXCEPTION_IF_NULL(tensor);
493   auto real_shape = tensor->shape();
494   auto old_abs = value_node->abstract();
495   auto old_abs_tensor = dyn_cast<abstract::AbstractTensor>(old_abs);
496   MS_EXCEPTION_IF_NULL(old_abs_tensor);
497   auto new_abs = std::make_shared<abstract::AbstractTensor>(old_abs_tensor->element(),
498                                                             std::make_shared<abstract::Shape>(real_shape));
499   value_node->set_abstract(new_abs);
500   MS_LOG(DEBUG) << "Change bprop ValueNode abstract from " << old_abs->ToString() << " to " << new_abs->ToString();
501 }
502 
UpdateDynamicValueNodeAbstract(const KernelGraphPtr & graph)503 void GraphAdapter::UpdateDynamicValueNodeAbstract(const KernelGraphPtr &graph) {
504   MS_EXCEPTION_IF_NULL(graph);
505   if (!graph->is_dynamic_shape()) {
506     return;
507   }
508   MS_LOG(INFO) << "Update dynamic shape value node for graph " << graph->graph_id();
509   const auto &value_nodes = graph->graph_value_nodes();
510   for (auto &value_node : value_nodes) {
511     MS_EXCEPTION_IF_NULL(value_node);
512     const auto &value = value_node->value();
513     MS_EXCEPTION_IF_NULL(value);
514     if (value->isa<tensor::BaseTensor>()) {
515       auto tensor = value->cast<tensor::BaseTensorPtr>();
516       MS_EXCEPTION_IF_NULL(tensor);
517       if (tensor->is_forward_output()) {
518         UpdateValueNodeAbstractFromTensor(value_node, tensor);
519       }
520     }
521   }
522 }
523 
SensTensorToDevice(const KernelGraphPtr & graph,const device::DeviceContext * device_context)524 void GraphAdapter::SensTensorToDevice(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
525   MS_EXCEPTION_IF_NULL(graph);
526   if (!graph->is_dynamic_shape()) {
527     return;
528   }
529   const auto &value_nodes = graph->graph_value_nodes();
530   for (const auto &value_node : value_nodes) {
531     MS_EXCEPTION_IF_NULL(value_node);
532     auto value = value_node->value();
533     MS_EXCEPTION_IF_NULL(value);
534     std::vector<tensor::BaseTensorPtr> tensors;
535     TensorValueToTensor(value, &tensors);
536     for (const auto &tensor : tensors) {
537       MS_EXCEPTION_IF_NULL(tensor);
538       if (!tensor->has_user_data(kTensorUserDataIsSensTensor)) {
539         continue;
540       }
541       const auto &device_address = tensor->device_address();
542       if (device_address == nullptr) {
543         UpdateValueNodeAbstractFromTensor(value_node, tensor);
544         auto node_address = CreateValueNodeAddress(value_node, device_context);
545         MS_EXCEPTION_IF_NULL(node_address);
546         tensor->set_device_address(node_address);
547         AnfAlgo::SetOutputAddr(node_address, 0, value_node.get());
548         MS_LOG(DEBUG) << "Start to copy sens tensor to device";
549         if (!CopyTensorData(tensor, node_address, value_node, device_context)) {
550           MS_LOG(EXCEPTION) << "ValueNode host to device copy failed";
551         }
552       }
553     }
554   }
555 }
556 }  // namespace mindspore::pynative
557