1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/pynative/graph_adapter.h"
18
19 #include <string>
20 #include <memory>
21 #include <vector>
22 #include "ir/tensor.h"
23 #include "include/common/utils/convert_utils.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/parallel_context.h"
26 #include "include/backend/anf_runtime_algorithm.h"
27 #include "include/backend/mem_reuse/mem_tracker.h"
28 #include "runtime/graph_scheduler/device_tensor_store.h"
29 #include "runtime/device/ms_device_shape_transfer.h"
30 #include "runtime/graph_scheduler/actor/actor_common.h"
31 #include "runtime/graph_scheduler/scheduler_helper.h"
32 #include "runtime/device/device_address_utils.h"
33 #include "kernel/pyboost/pyboost_utils.h"
34
35 namespace mindspore::pynative {
36 namespace {
37 constexpr auto kAttrBpropValueNodeRefCount = "bprop_value_node_ref_count";
38 constexpr auto kAttrValueNodeForwardOuputFlags = "value_node_forward_output_flags";
39
GetTensorFromValueNode(const AnfNodePtr & node)40 tensor::BaseTensorPtr GetTensorFromValueNode(const AnfNodePtr &node) {
41 MS_EXCEPTION_IF_NULL(node);
42 if (!node->isa<ValueNode>()) {
43 return nullptr;
44 }
45 auto value_node = node->cast<ValueNodePtr>();
46 MS_EXCEPTION_IF_NULL(value_node);
47 auto value = value_node->value();
48 MS_EXCEPTION_IF_NULL(value);
49 // ValueTuple is already expanded into tensors in backend.
50 if (!value->isa<tensor::BaseTensor>()) {
51 MS_LOG(DEBUG) << "Only need to process forward output tensor. value:" << value->ToString();
52 return nullptr;
53 }
54
55 auto tensor = value->cast<tensor::BaseTensorPtr>();
56 return tensor;
57 }
58
GetGraphValueNodeRefCounts(const KernelGraphPtr & graph)59 HashMap<ValueNodePtr, size_t> GetGraphValueNodeRefCounts(const KernelGraphPtr &graph) {
60 MS_EXCEPTION_IF_NULL(graph);
61 HashMap<ValueNodePtr, size_t> value_node_ref_counts;
62 // For example:
63 // %1 MakeTuple(V1, V2)
64 // %2 TupleGetItem(0, %1)
65 // %3 Kernel(%2)
66 // V2 is not used by kernel. Need to remove.
67 auto execution_nodes = graph->execution_order();
68 for (auto &node : execution_nodes) {
69 std::vector<session::KernelWithIndex> real_inputs;
70 common::AnfAlgo::GetRealInputs(node, &real_inputs);
71 for (auto &real_input : real_inputs) {
72 auto input = real_input.first;
73 MS_EXCEPTION_IF_NULL(input);
74 if (input->isa<ValueNode>()) {
75 auto value_node = input->cast<ValueNodePtr>();
76 value_node_ref_counts[value_node] += 1;
77 }
78 }
79 }
80
81 // ValueNodes as graph outputs
82 auto outputs = common::AnfAlgo::GetAllOutput(graph->output());
83 for (auto &output : outputs) {
84 MS_EXCEPTION_IF_NULL(output);
85 if (output->isa<ValueNode>()) {
86 auto value_node = output->cast<ValueNodePtr>();
87 MS_EXCEPTION_IF_NULL(value_node);
88 value_node_ref_counts[value_node] += 1;
89 }
90 }
91
92 return value_node_ref_counts;
93 }
94
CreateValueNodeAddress(const ValueNodePtr & value_node,const device::DeviceContext * device_context)95 device::DeviceAddressPtr CreateValueNodeAddress(const ValueNodePtr &value_node,
96 const device::DeviceContext *device_context) {
97 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, 0);
98 TypeId data_type = AnfAlgo::GetOutputDeviceDataType(value_node, 0);
99 if (data_type == kTypeUnknown) {
100 data_type = common::AnfAlgo::GetOutputInferDataType(value_node, 0);
101 }
102 auto output_format = AnfAlgo::GetOutputFormat(value_node, 0);
103 MS_EXCEPTION_IF_NULL(device_context);
104 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
105 const auto &kernel_tensor = AnfAlgo::CreateOutputKernelTensorWithDeviceInfo(
106 {value_node, 0}, nullptr, tensor_size, output_format, data_type, trans::GetRuntimePaddingShape(value_node, 0),
107 device_context->device_context_key().device_name_, device_context->device_context_key().device_id_);
108 return device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
109 }
110
CopyTensorData(const tensor::BaseTensorPtr & tensor,const device::DeviceAddressPtr & device_address,const AnfNodePtr & node,const device::DeviceContext * device_context)111 bool CopyTensorData(const tensor::BaseTensorPtr &tensor, const device::DeviceAddressPtr &device_address,
112 const AnfNodePtr &node, const device::DeviceContext *device_context) {
113 MS_EXCEPTION_IF_NULL(tensor);
114 MS_EXCEPTION_IF_NULL(device_address);
115 MS_EXCEPTION_IF_NULL(node);
116 MS_EXCEPTION_IF_NULL(device_context);
117 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
118 device::DynamicMemAllocatorDebugInfo::SetDebugInfo(node->fullname_with_scope(), device::AllocatorType::kConstantValue,
119 0);
120 if (device_address->GetPtr() == nullptr) {
121 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "CopyTensorData", "CopyTensorData", "");
122 auto mem_type =
123 tensor->is_parameter() ? device::tracker::MemType::kWeight : device::tracker::MemType::kPyNativeInput;
124 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "CopyTensorData", mem_type, device_address->GetSize(),
125 device_address.get());
126 if (!device_context->device_res_manager_->AllocateMemory(device_address.get())) {
127 MS_LOG(ERROR) << "Allocate memory failed, allocate size " << device_address->GetSize();
128 return false;
129 }
130 }
131
132 // Copy data from host tensor to device.
133 auto host_tensor_size = LongToSize(tensor->data().nbytes());
134 auto host_tensor_type = tensor->data_type();
135 if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), host_tensor_size, host_tensor_type,
136 kOpFormat_DEFAULT, tensor->data_ptr())) {
137 std::string error_info = "SyncHostToDevice failed, node name: " + node->fullname_with_scope() +
138 ", tensor size: " + std::to_string(host_tensor_size) +
139 ", tensor type: " + std::to_string(static_cast<int>(host_tensor_type)) +
140 ", device address size: " + std::to_string(device_address->GetSize());
141 MS_LOG(ERROR) << error_info;
142 return false;
143 }
144 return true;
145 }
146
HandleAddressForHeterogeneous(const tensor::BaseTensorPtr & tensor,const ValueNodePtr & value_node,const device::DeviceContext * device_context)147 device::DeviceAddressPtr HandleAddressForHeterogeneous(const tensor::BaseTensorPtr &tensor,
148 const ValueNodePtr &value_node,
149 const device::DeviceContext *device_context) {
150 MS_EXCEPTION_IF_NULL(tensor);
151 MS_EXCEPTION_IF_NULL(value_node);
152 MS_EXCEPTION_IF_NULL(device_context);
153 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
154 if (device_address == nullptr) {
155 MS_LOG(INFO) << "Forward output " << tensor->ToString() << " device address is null";
156 device_address = CreateValueNodeAddress(value_node, device_context);
157 if (!CopyTensorData(tensor, device_address, value_node, device_context)) {
158 MS_LOG(EXCEPTION) << "CopyTensorData failed, value_node " << value_node->DebugString();
159 }
160 }
161 MS_EXCEPTION_IF_NULL(device_address);
162 if (device_address->GetDeviceType() != device_context->GetDeviceType()) {
163 tensor->data_sync();
164 auto new_device_address = CreateValueNodeAddress(value_node, device_context);
165 MS_EXCEPTION_IF_NULL(new_device_address);
166 if (!CopyTensorData(tensor, new_device_address, value_node, device_context)) {
167 MS_LOG(EXCEPTION) << "CopyTensorData failed, value_node " << value_node->DebugString();
168 }
169 return new_device_address;
170 }
171 return device_address;
172 }
173 } // namespace
174
RemoveUnusedValueNodes(const KernelGraphPtr & graph)175 void GraphAdapter::RemoveUnusedValueNodes(const KernelGraphPtr &graph) {
176 MS_EXCEPTION_IF_NULL(graph);
177 auto value_node_ref_counts = GetGraphValueNodeRefCounts(graph);
178 for (const auto &value_node : graph->graph_value_nodes()) {
179 MS_EXCEPTION_IF_NULL(value_node);
180 auto iter = value_node_ref_counts.find(value_node);
181 if (iter == value_node_ref_counts.end()) {
182 MS_LOG(DEBUG) << "Remove unused ValueNode " << value_node->DebugString();
183 graph->RemoveNodeFromGraph(value_node);
184 }
185 }
186 }
187
ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr & graph,const device::DeviceContext * device_context)188 void GraphAdapter::ClearForwardOutputValueNodeDeviceAddress(const KernelGraphPtr &graph,
189 const device::DeviceContext *device_context) {
190 MS_EXCEPTION_IF_NULL(graph);
191 for (auto &value_node : graph->graph_value_nodes()) {
192 MS_EXCEPTION_IF_NULL(value_node);
193 auto value = value_node->value();
194 MS_EXCEPTION_IF_NULL(value);
195 if (value->isa<tensor::BaseTensor>()) {
196 auto tensor = value->cast<tensor::BaseTensorPtr>();
197 MS_EXCEPTION_IF_NULL(tensor);
198 if (!tensor->is_forward_output()) {
199 continue;
200 }
201
202 if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
203 MS_LOG(DEBUG) << "Output addr is not exist for ValueNode " << value_node->ToString();
204 continue;
205 }
206 const auto &device_address = AnfAlgo::GetMutableOutputAddr(value_node, 0);
207 auto new_device_address = runtime::DeviceAddressUtils::CloneEmptyDeviceAddress(device_address, device_context);
208 AnfAlgo::SetOutputAddr(new_device_address, 0, value_node.get());
209 }
210 }
211 }
212
213 // The device address of graph value node need to release
214 // if the value node is output of forward_graph in PyNative mode.
GenerateRefCountForBpropValueNode(const KernelGraphPtr & graph)215 void GraphAdapter::GenerateRefCountForBpropValueNode(const KernelGraphPtr &graph) {
216 MS_EXCEPTION_IF_NULL(graph);
217 HashMap<std::string, size_t> tensor_counts;
218 HashMap<ValueNodePtr, size_t> value_node_ref_counts = GetGraphValueNodeRefCounts(graph);
219
220 std::vector<size_t> value_node_ref_count_list;
221 std::vector<bool> value_node_forward_output_flags;
222 for (auto &value_node : graph->graph_value_nodes()) {
223 MS_EXCEPTION_IF_NULL(value_node);
224 auto tensor = GetTensorFromValueNode(value_node);
225 if (tensor == nullptr || !tensor->is_forward_output()) {
226 (void)value_node_ref_count_list.emplace_back(SIZE_MAX);
227 (void)value_node_forward_output_flags.emplace_back(false);
228 continue;
229 }
230
231 auto iter = value_node_ref_counts.find(value_node);
232 if (iter == value_node_ref_counts.end()) {
233 // The value_node is in bp graph but not used.
234 // e.g. %1-MakeTuple(T1, T2) -> TupleGetItem(%1, 0). T2 is not used.
235 MS_LOG(DEBUG) << "ValueNode " << value_node->ToString() << " is not used in graph";
236 (void)value_node_ref_count_list.emplace_back(SIZE_MAX);
237 (void)value_node_forward_output_flags.emplace_back(false);
238 continue;
239 }
240
241 (void)value_node_ref_count_list.emplace_back(iter->second);
242 (void)value_node_forward_output_flags.emplace_back(true);
243 MS_LOG(DEBUG) << "ValueNode " << value_node->DebugString() << " ref_count " << iter->second;
244 }
245 graph->set_attr(kAttrBpropValueNodeRefCount, MakeValue(value_node_ref_count_list));
246 graph->set_attr(kAttrValueNodeForwardOuputFlags, MakeValue(value_node_forward_output_flags));
247 }
248
GenerateBackoffValueNodeOwners(const KernelGraphPtr & graph)249 void GraphAdapter::GenerateBackoffValueNodeOwners(const KernelGraphPtr &graph) {
250 for (auto &kernel : graph->execution_order()) {
251 if (!AnfAlgo::IsKernelSelectBackoffOp(kernel)) {
252 continue;
253 }
254 for (size_t j = 0; j < common::AnfAlgo::GetInputTensorNum(kernel); ++j) {
255 const auto &input_node = common::AnfAlgo::GetInputNode(kernel, j);
256 const auto &real_input_node = common::AnfAlgo::VisitKernelWithReturnType(input_node, 0, false).first;
257 MS_EXCEPTION_IF_NULL(real_input_node);
258 if (real_input_node->isa<ValueNode>()) {
259 (void)node_to_backoff_kernels_[real_input_node.get()].insert(kernel);
260 MS_LOG(DEBUG) << "Generate backoff ValueNode " << real_input_node->DebugString() << " with kernel "
261 << kernel->DebugString();
262 }
263 }
264 }
265 }
266
HandleBackoffValueNode(const ValueNodePtr & value_node,const AnfNodePtr & front_node,const DeviceContext * device_context) const267 void GraphAdapter::HandleBackoffValueNode(const ValueNodePtr &value_node, const AnfNodePtr &front_node,
268 const DeviceContext *device_context) const {
269 auto iter = node_to_backoff_kernels_.find(value_node.get());
270 if (iter == node_to_backoff_kernels_.end()) {
271 return;
272 }
273
274 MS_LOG(DEBUG) << "Backoff ValueNode " << value_node->ToString();
275 const auto &kernels = iter->second;
276 for (const auto &kernel : kernels) {
277 const auto &real_device_context = device::FetchRealDeviceContext(kernel, device_context);
278 MS_EXCEPTION_IF_NULL(real_device_context);
279
280 if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
281 MS_LOG(EXCEPTION) << "The device address is not exist: " << value_node->ToString();
282 }
283 auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
284 MS_EXCEPTION_IF_NULL(device_tensor);
285
286 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
287 nullptr, device_tensor->GetSize(), device_tensor->kernel_tensor()->format(), device_tensor->type_id(),
288 device_tensor->host_shape(), device_context->device_context_key().device_name_,
289 device_context->device_context_key().device_id_);
290
291 kernel_tensor->SetHostInfo(
292 std::make_shared<abstract::TensorShape>(device_tensor->kernel_tensor()->GetShapeVector()),
293 std::make_shared<TensorType>(TypeIdToType(device_tensor->kernel_tensor()->dtype_id())), nullptr);
294
295 kernel_tensor->set_stream_id(device_tensor->stream_id());
296 auto new_device_tensor = real_device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
297 MS_EXCEPTION_IF_NULL(new_device_tensor);
298 new_device_tensor->SetNodeIndex(value_node, 0);
299 new_device_tensor->set_from_persistent_mem(true);
300 MS_LOG(DEBUG) << "Create backoff device tensor:" << new_device_tensor << " type:" << new_device_tensor->type_id()
301 << " for ValueNode " << value_node->ToString();
302 runtime::SchedulerHelper::AddDeviceTensorStore(front_node.get(), new_device_tensor);
303 }
304 }
305
UpdateForwardOutputInBpropGraph(const KernelGraphPtr & graph,const device::DeviceContext * device_context,bool no_control_flow)306 void GraphAdapter::UpdateForwardOutputInBpropGraph(const KernelGraphPtr &graph,
307 const device::DeviceContext *device_context, bool no_control_flow) {
308 MS_EXCEPTION_IF_NULL(graph);
309 MS_LOG(DEBUG) << "Update start";
310 auto value_node_ref_counts = GetValue<std::vector<size_t>>(graph->get_attr(kAttrBpropValueNodeRefCount));
311 auto value_node_forward_output_flags = GetValue<std::vector<bool>>(graph->get_attr(kAttrValueNodeForwardOuputFlags));
312 size_t value_node_size = graph->graph_value_nodes().size();
313 if (value_node_ref_counts.size() != value_node_size || value_node_forward_output_flags.size() != value_node_size) {
314 MS_LOG(EXCEPTION) << "value_node_ref_count.size " << value_node_ref_counts.size()
315 << " value_node_forward_output_flags.size " << value_node_forward_output_flags.size()
316 << " not equal to " << value_node_size;
317 }
318
319 size_t value_node_index = 0;
320 HashMap<device::DeviceAddressPtr, size_t> address_ref_count;
321 // Update ValueNode device address
322 for (auto &value_node : graph->graph_value_nodes()) {
323 auto is_forward_output = value_node_forward_output_flags[value_node_index];
324 if (!is_forward_output) {
325 value_node_index++;
326 continue;
327 }
328 size_t value_node_ref_count = value_node_ref_counts[value_node_index++];
329 auto tensor = GetTensorFromValueNode(value_node);
330 MS_EXCEPTION_IF_NULL(tensor);
331
332 auto device_address = HandleAddressForHeterogeneous(tensor, value_node, device_context);
333 device_address = std::dynamic_pointer_cast<device::DeviceAddress>(
334 kernel::pyboost::PyBoostUtils::ContiguousByDeviceAddress(device_address));
335 runtime::DeviceAddressUtils::CreateKernelTensor(device_address, tensor);
336 tensor->set_device_address(device_address);
337 auto front_node = AnfAlgo::FetchFrontNodeByBackendNode(value_node, *graph);
338 MS_EXCEPTION_IF_NULL(front_node);
339 MS_EXCEPTION_IF_NULL(device_address);
340 if (device_address->GetDeviceType() != device::DeviceType::kCPU && no_control_flow) {
341 address_ref_count[device_address] += value_node_ref_count;
342 device_address->AddHeldByNode(front_node->cast<ValueNodePtr>());
343 }
344 runtime::DeviceTensorStore::GetInstance().Insert(front_node.get(), device_address);
345 HandleBackoffValueNode(value_node, front_node, device_context);
346 }
347
348 for (auto &[address, ref_count] : address_ref_count) {
349 MS_EXCEPTION_IF_NULL(address);
350 address->set_original_ref_count(ref_count);
351 address->ResetRefCount();
352 MS_LOG(DEBUG) << "device_address " << address.get() << " ref_count " << address->ref_count();
353 }
354 MS_LOG(DEBUG) << "Update end";
355 }
356
HandleHeterogeneousTensors(const std::vector<std::vector<tensor::TensorPtr>> & input_tensors,const std::vector<device::DeviceContext * > & device_contexts)357 void GraphAdapter::HandleHeterogeneousTensors(const std::vector<std::vector<tensor::TensorPtr>> &input_tensors,
358 const std::vector<device::DeviceContext *> &device_contexts) {
359 if (input_tensors.size() < device_contexts.size()) {
360 MS_LOG(EXCEPTION) << "Invalid input_tensors size " << input_tensors.size() << " device_contexts size "
361 << device_contexts.size();
362 }
363 for (size_t i = 0; i < device_contexts.size(); ++i) {
364 auto tensors = input_tensors[i];
365 auto device_context = device_contexts[i];
366 MS_EXCEPTION_IF_NULL(device_context);
367 for (auto &tensor : tensors) {
368 if (tensor != nullptr && tensor->device_address() != nullptr) {
369 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
370 MS_EXCEPTION_IF_NULL(device_address);
371 if (device_address->GetDeviceType() != device_context->GetDeviceType()) {
372 tensor->data_sync();
373 tensor->set_device_address(nullptr);
374 }
375 }
376 }
377 }
378 }
379
ReplaceGraphParameterProperties(const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const device::DeviceContext * device_context)380 void GraphAdapter::ReplaceGraphParameterProperties(const KernelGraphPtr &graph,
381 const std::vector<tensor::TensorPtr> &input_tensors,
382 const device::DeviceContext *device_context) {
383 MS_EXCEPTION_IF_NULL(device_context);
384 MS_EXCEPTION_IF_NULL(graph);
385 size_t index = 0;
386 for (const auto &input_node : graph->input_nodes()) {
387 auto parameters = common::AnfAlgo::GetAllOutput(input_node);
388 for (const auto ¶meter : parameters) {
389 MS_EXCEPTION_IF_NULL(parameter);
390 if (index >= input_tensors.size()) {
391 MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
392 << ", input size: " << input_tensors.size();
393 }
394 const auto &input_tensor = input_tensors[index++];
395 MS_EXCEPTION_IF_NULL(input_tensor);
396 const auto &tensor_address = input_tensor->device_address();
397 auto address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_address);
398 if (address == nullptr || address->GetDeviceType() != device_context->GetDeviceType()) {
399 // Need to discard input tensor properties in heterogeneous scenarios.
400 // For example, the format of device_address in input_tensor is 5D format,
401 // and it's invalid for CPU graph parameter.
402 continue;
403 }
404
405 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
406 MS_EXCEPTION_IF_NULL(kernel_build_info_builder);
407 kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{address->format()});
408 kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{address->type_id()});
409 kernel_build_info_builder->SetOutputsReshapeType({address->padding_type()});
410 AnfAlgo::SetOutputAddr(address, 0, parameter.get());
411 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), parameter.get());
412
413 auto abstract = parameter->abstract();
414 MS_EXCEPTION_IF_NULL(abstract);
415 auto shape = abstract->BuildShape();
416 auto new_abs = std::make_shared<abstract::AbstractTensor>(TypeIdToType(address->type_id()), shape);
417 parameter->set_abstract(new_abs);
418 }
419 }
420 }
421
IsAutoParallel()422 bool GraphAdapter::IsAutoParallel() {
423 auto parallel_context = parallel::ParallelContext::GetInstance();
424 MS_EXCEPTION_IF_NULL(parallel_context);
425 auto parallel_mode = parallel_context->parallel_mode();
426 return parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
427 }
428
IsPynativeGeGraphSink(const GraphCompilerInfo & graph_compiler_info)429 bool GraphAdapter::IsPynativeGeGraphSink(const GraphCompilerInfo &graph_compiler_info) {
430 bool is_sink = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(),
431 [](const KernelGraphPtr &graph) { return GraphAdapter::IsPynativeGeGraphSink(graph); });
432 return is_sink;
433 }
434
IsPynativeGeGraphSink(const FuncGraphPtr & func_graph)435 bool GraphAdapter::IsPynativeGeGraphSink(const FuncGraphPtr &func_graph) {
436 auto context_ptr = MsContext::GetInstance();
437 MS_EXCEPTION_IF_NULL(context_ptr);
438 if (context_ptr->backend_policy() != "ge" || !context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
439 return false;
440 }
441
442 MS_EXCEPTION_IF_NULL(func_graph);
443 if (func_graph->has_flag(kFlagEnableRunGraphBySingleOp)) {
444 return false;
445 }
446
447 return true;
448 }
449
PyNativeEnableTaskSink(const FuncGraphPtr & func_graph)450 bool GraphAdapter::PyNativeEnableTaskSink(const FuncGraphPtr &func_graph) {
451 auto ms_context = MsContext::GetInstance();
452 MS_EXCEPTION_IF_NULL(ms_context);
453 bool pynative_mode = ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
454 if (!pynative_mode) {
455 return true;
456 }
457
458 MS_EXCEPTION_IF_NULL(func_graph);
459 if (GraphAdapter::IsPynativeGeGraphSink(func_graph)) {
460 MS_LOG(DEBUG) << "Enable graph sink for PyNative";
461 return true;
462 }
463
464 if (!func_graph->has_attr(kAttrJitLevel)) {
465 MS_LOG(EXCEPTION) << "Not jit_level set to func_graph";
466 }
467 auto jit_level_value = func_graph->get_attr(kAttrJitLevel);
468 auto jit_level = GetValue<std::string>(jit_level_value);
469 if (jit_level != kAttrJitLevelO2) {
470 MS_LOG(INFO) << "jit_level is " << jit_level << ", task sink is disabled";
471 return false;
472 }
473
474 std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return());
475 auto is_cut_graph = std::any_of(node_list.begin(), node_list.end(), [](const AnfNodePtr &node) {
476 return common::AnfAlgo::IsBpropCutOpExecInBackend(node);
477 });
478
479 auto has_comm_op = std::any_of(node_list.begin(), node_list.end(),
480 [](const AnfNodePtr &node) { return common::AnfAlgo::IsCommunicationOp(node); });
481
482 auto is_auto_parallel = IsAutoParallel();
483
484 MS_LOG(INFO) << "JitLevel is " << jit_level << " is_auto_parallel " << is_auto_parallel << " has_comm_op "
485 << has_comm_op << " is_cut_graph " << is_cut_graph;
486
487 return !is_auto_parallel && !has_comm_op && !is_cut_graph;
488 }
489
UpdateValueNodeAbstractFromTensor(const ValueNodePtr & value_node,const tensor::BaseTensorPtr & tensor)490 void UpdateValueNodeAbstractFromTensor(const ValueNodePtr &value_node, const tensor::BaseTensorPtr &tensor) {
491 MS_EXCEPTION_IF_NULL(value_node);
492 MS_EXCEPTION_IF_NULL(tensor);
493 auto real_shape = tensor->shape();
494 auto old_abs = value_node->abstract();
495 auto old_abs_tensor = dyn_cast<abstract::AbstractTensor>(old_abs);
496 MS_EXCEPTION_IF_NULL(old_abs_tensor);
497 auto new_abs = std::make_shared<abstract::AbstractTensor>(old_abs_tensor->element(),
498 std::make_shared<abstract::Shape>(real_shape));
499 value_node->set_abstract(new_abs);
500 MS_LOG(DEBUG) << "Change bprop ValueNode abstract from " << old_abs->ToString() << " to " << new_abs->ToString();
501 }
502
UpdateDynamicValueNodeAbstract(const KernelGraphPtr & graph)503 void GraphAdapter::UpdateDynamicValueNodeAbstract(const KernelGraphPtr &graph) {
504 MS_EXCEPTION_IF_NULL(graph);
505 if (!graph->is_dynamic_shape()) {
506 return;
507 }
508 MS_LOG(INFO) << "Update dynamic shape value node for graph " << graph->graph_id();
509 const auto &value_nodes = graph->graph_value_nodes();
510 for (auto &value_node : value_nodes) {
511 MS_EXCEPTION_IF_NULL(value_node);
512 const auto &value = value_node->value();
513 MS_EXCEPTION_IF_NULL(value);
514 if (value->isa<tensor::BaseTensor>()) {
515 auto tensor = value->cast<tensor::BaseTensorPtr>();
516 MS_EXCEPTION_IF_NULL(tensor);
517 if (tensor->is_forward_output()) {
518 UpdateValueNodeAbstractFromTensor(value_node, tensor);
519 }
520 }
521 }
522 }
523
SensTensorToDevice(const KernelGraphPtr & graph,const device::DeviceContext * device_context)524 void GraphAdapter::SensTensorToDevice(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
525 MS_EXCEPTION_IF_NULL(graph);
526 if (!graph->is_dynamic_shape()) {
527 return;
528 }
529 const auto &value_nodes = graph->graph_value_nodes();
530 for (const auto &value_node : value_nodes) {
531 MS_EXCEPTION_IF_NULL(value_node);
532 auto value = value_node->value();
533 MS_EXCEPTION_IF_NULL(value);
534 std::vector<tensor::BaseTensorPtr> tensors;
535 TensorValueToTensor(value, &tensors);
536 for (const auto &tensor : tensors) {
537 MS_EXCEPTION_IF_NULL(tensor);
538 if (!tensor->has_user_data(kTensorUserDataIsSensTensor)) {
539 continue;
540 }
541 const auto &device_address = tensor->device_address();
542 if (device_address == nullptr) {
543 UpdateValueNodeAbstractFromTensor(value_node, tensor);
544 auto node_address = CreateValueNodeAddress(value_node, device_context);
545 MS_EXCEPTION_IF_NULL(node_address);
546 tensor->set_device_address(node_address);
547 AnfAlgo::SetOutputAddr(node_address, 0, value_node.get());
548 MS_LOG(DEBUG) << "Start to copy sens tensor to device";
549 if (!CopyTensorData(tensor, node_address, value_node, device_context)) {
550 MS_LOG(EXCEPTION) << "ValueNode host to device copy failed";
551 }
552 }
553 }
554 }
555 }
556 } // namespace mindspore::pynative
557