1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/graph_compiler/backend_base.h"
17
18 #include <algorithm>
19 #include <map>
20 #include <vector>
21 #include <queue>
22 #if defined(_WIN32) || defined(_WIN64)
23 #include <windows.h>
24 #endif
25
26 #include "pipeline/jit/ps/parse/data_converter.h"
27 #include "backend/graph_compiler/transform.h"
28 #include "backend/common/pass/erase_invalid_micro_depend.h"
29 #include "backend/common/pass/erase_not_cut_attr.h"
30 #include "backend/common/pass/switch_not_cut.h"
31 #include "include/backend/distributed/recovery/recovery_context.h"
32 #include "include/common/utils/callbacks.h"
33 #include "include/common/utils/scoped_long_running.h"
34 #include "include/common/debug/anf_ir_dump.h"
35 #include "include/backend/mem_reuse/mem_tracker.h"
36 #include "ir/anf.h"
37 #include "ops/framework_ops.h"
38 #include "ops/sequence_ops.h"
39 #include "ops/sparse_tensor_ops.h"
40 #include "ops/nn_ops.h"
41 #include "runtime/device/device_address_utils.h"
42 #include "runtime/device/multi_stream_controller.h"
43 #include "runtime/graph_scheduler/graph_compiler.h"
44 #include "runtime/pynative/graph_adapter.h"
45 #include "pybind_api/gil_scoped_long_running.h"
46 #include "utils/log_adapter.h"
47 #ifdef ENABLE_DEBUGGER
48 #include "include/backend/debug/debugger/debugger.h"
49 #endif
50 #include "include/backend/debug/profiler/profiling.h"
51 #if defined(__linux__) && defined(WITH_BACKEND)
52 #include "include/backend/distributed/ps/ps_context.h"
53 #endif
54 #include "backend/common/graph_kernel/graph_kernel_flags.h"
55 #include "include/common/symbol_engine/symbol_engine_impl.h"
56
57 namespace mindspore {
58 namespace compile {
GetCond(const BaseRef & c,bool * value)59 bool Backend::GetCond(const BaseRef &c, bool *value) {
60 mindspore::ScopedLongRunning long_running;
61 return BaseRefToBool(c, value);
62 }
GetIndex(const BaseRef & c,int64_t * value)63 bool Backend::GetIndex(const BaseRef &c, int64_t *value) { return BaseRefToInt(utils::cast<ValuePtr>(c), value); }
64
Backend(const std::string & name)65 Backend::Backend(const std::string &name) : name_(name), is_multi_graph_sink_(false) {
66 MS_LOG(DEBUG) << "Select backend:" << name;
67 convert_fn_ = MsVmConvert;
68 }
69
set_pydata_converter(const pyexecute::PyDataConverter & pydata_converter)70 void set_pydata_converter(const pyexecute::PyDataConverter &pydata_converter) {
71 pyexecute::set_pydata_converter(pydata_converter);
72 }
73
74 namespace {
75 // Insert the front_node related tensor in the input_tensor.
PushTensor(const VectorRef & args,const std::vector<AnfNodePtr> & parameters,const AnfNodePtr & front_node,std::vector<tensor::TensorPtr> * input_tensors)76 void PushTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
77 std::vector<tensor::TensorPtr> *input_tensors) {
78 MS_EXCEPTION_IF_NULL(input_tensors);
79 const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
80 if (iter == parameters.end()) {
81 (void)((*input_tensors).emplace_back(nullptr));
82 return;
83 }
84 auto position = iter - parameters.begin();
85
86 std::vector<tensor::TensorPtr> flatten_values;
87 AnfAlgo::FlattenInputArg(args[position], front_node, &flatten_values);
88 (void)std::copy(flatten_values.begin(), flatten_values.end(), std::back_inserter(*input_tensors));
89 }
90
PushTupleTensor(const VectorRef & args,const std::vector<AnfNodePtr> & parameters,const AnfNodePtr & front_node,size_t index,std::map<size_t,std::vector<tensor::TensorPtr>> * flatten_values,std::vector<tensor::TensorPtr> * input_tensors)91 void PushTupleTensor(const VectorRef &args, const std::vector<AnfNodePtr> ¶meters, const AnfNodePtr &front_node,
92 size_t index, std::map<size_t, std::vector<tensor::TensorPtr>> *flatten_values,
93 std::vector<tensor::TensorPtr> *input_tensors) {
94 MS_EXCEPTION_IF_NULL(input_tensors);
95 MS_EXCEPTION_IF_NULL(flatten_values);
96
97 const auto &iter = std::find(parameters.begin(), parameters.end(), front_node);
98 const size_t position = iter - parameters.begin();
99 // If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph,
100 // and there is no need to input a tensor.
101 if (position >= args.size()) {
102 MS_LOG(DEBUG) << "Position out of args range, position value is " << position << " and args size is " << args.size()
103 << ".";
104 (void)input_tensors->emplace_back(nullptr);
105 return;
106 }
107
108 // Avoid repeating flatten tuple for each args position.
109 auto &flatten_value = (*flatten_values)[position];
110 if (flatten_value.empty()) {
111 AnfAlgo::FlattenInputArg(args[position], front_node, &flatten_value);
112 }
113
114 if (index >= flatten_value.size()) {
115 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Index out of flatten_value range, index value is "
116 << index << " and flatten_value size is " << flatten_value.size() << ".";
117 }
118 auto tensor_input = flatten_value[index];
119 MS_EXCEPTION_IF_NULL(tensor_input);
120 input_tensors->push_back(tensor_input);
121 }
122 } // namespace
123
GetTensorFromForwardOutputParameter(const AnfNodePtr & input_node,std::vector<tensor::TensorPtr> * input_tensors)124 bool GetTensorFromForwardOutputParameter(const AnfNodePtr &input_node, std::vector<tensor::TensorPtr> *input_tensors) {
125 MS_EXCEPTION_IF_NULL(input_node);
126 // if input_node if from ValueNode,
127 // push Tensor of ValueNode to input_tensors.
128 if (input_node->isa<Parameter>()) {
129 auto parameter = input_node->cast<ParameterPtr>();
130 MS_EXCEPTION_IF_NULL(parameter);
131 if (parameter->has_user_data(kForwardOutput)) {
132 auto value = parameter->user_data<Value>(kForwardOutput);
133 auto tensor = value->cast<tensor::TensorPtr>();
134 MS_EXCEPTION_IF_NULL(tensor);
135 (void)input_tensors->emplace_back(tensor);
136 MS_LOG(DEBUG) << "Get forward output tensor " << tensor->ToString()
137 << " for graph input, address:" << tensor->device_address().get();
138 return true;
139 }
140 }
141 return false;
142 }
143
GetRunGraphInputs(const GraphCompilerInfo & graph_compiler_info,const VectorRef & args)144 std::vector<std::vector<tensor::TensorPtr>> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info,
145 const VectorRef &args) {
146 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kInputProcess,
147 graph_compiler_info.name_);
148 const auto &origin_parameters = graph_compiler_info.origin_parameters_order_;
149 std::vector<std::vector<tensor::TensorPtr>> input_tensor_lists;
150 std::map<size_t, std::vector<tensor::TensorPtr>> flatten_values;
151
152 for (const auto &kernel_graph : graph_compiler_info.graphs_) {
153 std::vector<tensor::TensorPtr> input_tensors;
154 MS_EXCEPTION_IF_NULL(kernel_graph);
155 bool is_pynative_bprop_kernel_graph = kernel_graph->has_flag(kFlagIsPyNativeBpropKernelGraph);
156 for (const auto &input_node : kernel_graph->input_nodes()) {
157 if (is_pynative_bprop_kernel_graph && GetTensorFromForwardOutputParameter(input_node, &input_tensors)) {
158 continue;
159 }
160
161 auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node);
162 if (element_pair.first) {
163 PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &flatten_values,
164 &input_tensors);
165 } else {
166 const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node);
167 // Use kernel graph in compile
168 if (front_node == nullptr && is_pynative_bprop_kernel_graph) {
169 PushTensor(args, origin_parameters, input_node, &input_tensors);
170 continue;
171 }
172 PushTensor(args, origin_parameters, front_node, &input_tensors);
173 }
174 }
175 (void)input_tensor_lists.emplace_back(input_tensors);
176 }
177
178 // Input tensors of the control node.
179 std::vector<tensor::TensorPtr> input_tensors;
180 MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
181 // Get inputs of control node which come from the host actor.
182 const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
183 for (const auto ¶meter_with_index : control_node_parameters) {
184 const auto ¶meter = parameter_with_index.first;
185 MS_EXCEPTION_IF_NULL(parameter);
186 const auto &abs = parameter->abstract();
187 MS_EXCEPTION_IF_NULL(abs);
188 if (abs->isa<abstract::AbstractSequence>() && (!common::AnfAlgo::IsDynamicSequence(parameter))) {
189 MS_LOG(DEBUG) << "Fetch input tensor for tuple parameter:" << parameter->DebugString() << " in control flow.";
190 PushTupleTensor(args, origin_parameters, parameter, parameter_with_index.second, &flatten_values, &input_tensors);
191 } else {
192 PushTensor(args, origin_parameters, parameter, &input_tensors);
193 }
194 }
195 (void)input_tensor_lists.emplace_back(input_tensors);
196
197 return input_tensor_lists;
198 }
199
FetchOriginOutputOrder(const AnfNodePtr & node)200 runtime::KernelMapPosition FetchOriginOutputOrder(const AnfNodePtr &node) {
201 MS_EXCEPTION_IF_NULL(node);
202 runtime::KernelMapPosition outputs_order;
203 const auto &root_output = common::AnfAlgo::VisitKernelWithReturnType(node, 0, false, {prim::kPrimTupleGetItem}).first;
204 size_t position = 0;
205 auto outputs = common::AnfAlgo::GetAllOutputWithIndex(root_output);
206 for (const auto &output : outputs) {
207 if (outputs_order.count(output) == 0) {
208 outputs_order[output] = {position++};
209 } else {
210 (void)outputs_order[output].emplace_back(position++);
211 }
212 }
213 return outputs_order;
214 }
215
MindRTBackendBase(const std::string & backend_name,const std::string & device_name,uint32_t device_id)216 MindRTBackendBase::MindRTBackendBase(const std::string &backend_name, const std::string &device_name,
217 uint32_t device_id)
218 : Backend(backend_name), device_name_(device_name), device_id_(device_id) {
219 root_graph_ = nullptr;
220 auto ms_context = MsContext::GetInstance();
221 const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
222 auto &cut_list = pynative_mode ? GetControlOps() : GetMsNonlinearOps();
223
224 graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
225 graph_compiler_ = std::make_shared<GraphCompiler>();
226
227 const auto &device_context =
228 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
229 MS_EXCEPTION_IF_NULL(device_context);
230 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventDeviceInit, kStageDeviceInit, 1, 0, 0);
231 device_context->Initialize();
232 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventDeviceInit, kStageDeviceInit, 1, 0, 1);
233 device_id_ = device_context->device_context_key().device_id_;
234 #ifdef ENABLE_DEBUGGER
235 SetDebuggerInit();
236 #endif
237 runtime::GraphScheduler::GetInstance().Initialize();
238 }
239
ProcessNotSupportCnode(const FuncGraphPtr & func_graph,const mindspore::device::DeviceType & old_target,const mindspore::device::DeviceType & new_target) const240 void MindRTBackendBase::ProcessNotSupportCnode(const FuncGraphPtr &func_graph,
241 const mindspore::device::DeviceType &old_target,
242 const mindspore::device::DeviceType &new_target) const {
243 const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
244 for (const auto &node : all_nodes) {
245 MS_EXCEPTION_IF_NULL(node);
246 if (!node->isa<CNode>()) {
247 continue;
248 }
249
250 auto cnode = node->cast<CNodePtr>();
251 if (!common::AnfAlgo::HasNodeAttr(mindspore::kAttrNotSupportOpForDevice, cnode)) {
252 continue;
253 }
254
255 auto not_support_device = common::AnfAlgo::GetNodeAttr<std::string>(node, mindspore::kAttrNotSupportOpForDevice);
256 if (device::GetDeviceTypeByName(not_support_device) != old_target) {
257 continue;
258 }
259
260 common::AnfAlgo::SetNodeAttr(kAttrPrimitiveTarget, MakeValue(device::GetDeviceNameByType(new_target)), node);
261 }
262 }
263
264 namespace {
GetTupleGetItemOutIndex(const CNodePtr & tuple_get_item)265 int64_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item) {
266 MS_EXCEPTION_IF_NULL(tuple_get_item);
267 if (tuple_get_item->size() != kTupleGetItemInputSize) {
268 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The node tuple_get_item must have 2 inputs!";
269 }
270 auto output_index_value_node = tuple_get_item->input(kInputNodeOutputIndexInTupleGetItem);
271 MS_EXCEPTION_IF_NULL(output_index_value_node);
272 auto value_node = output_index_value_node->cast<ValueNodePtr>();
273 MS_EXCEPTION_IF_NULL(value_node);
274 auto value = value_node->value();
275 MS_EXCEPTION_IF_NULL(value);
276 auto idx = value->isa<Int64Imm>() ? GetValue<int64_t>(value) : GetValue<int>(value);
277 return idx;
278 }
279
VisitRealNodeWithNestLevel(const AnfNodePtr & anf_node,size_t index,size_t * nest_level)280 KernelWithIndex VisitRealNodeWithNestLevel(const AnfNodePtr &anf_node, size_t index, size_t *nest_level) {
281 MS_EXCEPTION_IF_NULL(anf_node);
282 if (!anf_node->isa<CNode>()) {
283 return {anf_node, index};
284 }
285 auto cnode = anf_node->cast<CNodePtr>();
286 if (common::AnfAlgo::GetCNodeName(cnode) == mindspore::kTupleGetItemOpName) {
287 (*nest_level)++;
288 auto real_node_with_index = VisitRealNodeWithNestLevel(common::AnfAlgo::GetTupleGetItemRealInput(cnode),
289 common::AnfAlgo::GetTupleGetItemOutIndex(cnode), nest_level);
290 auto real_node = real_node_with_index.first;
291 auto real_index = real_node_with_index.second;
292 MS_EXCEPTION_IF_NULL(real_node);
293 if (real_node->isa<CNode>() && common::AnfAlgo::GetCNodeName(real_node) == mindspore::kMakeTupleOpName) {
294 (*nest_level)--;
295 auto make_tuple = real_node->cast<CNodePtr>();
296 return VisitRealNodeWithNestLevel(make_tuple->input(real_index + 1), index, nest_level);
297 }
298 return real_node_with_index;
299 }
300 return common::AnfAlgo::VisitKernelWithReturnType(anf_node, index, false,
301 {prim::kPrimMakeTuple, prim::kPrimTupleGetItem});
302 }
303
NeedConvertToRealTupleGetItem(const CNodePtr & cnode)304 bool NeedConvertToRealTupleGetItem(const CNodePtr &cnode) {
305 if (cnode->size() != kTupleGetItemInputSize) {
306 return false;
307 }
308 if (!cnode->input(kInputNodeOutputIndexInTupleGetItem)->isa<ValueNode>() || GetTupleGetItemOutIndex(cnode) < 0) {
309 return true;
310 }
311 size_t nest_level = 0;
312 const size_t nest_limit = 1;
313 auto real_node = VisitRealNodeWithNestLevel(cnode, 0, &nest_level);
314 if (!common::AnfAlgo::IsCallNode(real_node.first) && AnfUtils::IsRealCNodeKernel(real_node.first) &&
315 nest_level > nest_limit) {
316 return true;
317 }
318 return false;
319 }
320
321 // If it is windows OS, create a child thread with 8M stack space to call `common::AnfAlgo::GetRealPrevNodesOutput`.
322 #if defined(_WIN32) || defined(_WIN64)
323 typedef struct {
324 const AnfNodePtr *anf_node_;
325 size_t input_idx_;
326 std::vector<KernelWithIndex> *nodes_ptr_;
327 } WinThreadParam;
328
WinThreadFunction(PVOID para)329 DWORD WINAPI WinThreadFunction(PVOID para) {
330 auto p = static_cast<WinThreadParam *>(para);
331 MS_EXCEPTION_IF_NULL(p->anf_node_);
332 MS_EXCEPTION_IF_NULL(p->nodes_ptr_);
333 const AnfNodePtr &anf_node = *(p->anf_node_);
334 std::vector<KernelWithIndex> *nodes_ptr = p->nodes_ptr_;
335 auto inputs = common::AnfAlgo::GetRealPrevNodesOutput(anf_node, p->input_idx_);
336 nodes_ptr->insert(nodes_ptr->end(), inputs.begin(), inputs.end());
337 return 0;
338 }
339 #endif
340
CheckNodeValid(const AnfNodePtr & node)341 void CheckNodeValid(const AnfNodePtr &node) {
342 MS_EXCEPTION_IF_NULL(node);
343 // Check the joined any abstract.
344 const auto &node_abs = node->abstract();
345 if (node_abs != nullptr && node_abs->isa<abstract::AbstractJoinedAny>()) {
346 auto abs_joined_any = node_abs->cast<abstract::AbstractJoinedAnyPtr>();
347 if (abs_joined_any != nullptr) {
348 abs_joined_any->ThrowException();
349 }
350 }
351 }
352
AddKernelGraphCompileInfo(const KernelGraphPtr & kernel_graph,const session::SessionPtr & session_ptr)353 bool AddKernelGraphCompileInfo(const KernelGraphPtr &kernel_graph, const session::SessionPtr &session_ptr) {
354 const auto ¶meters = kernel_graph->parameters();
355 // Just have a return node or empty graph
356 if ((kernel_graph->nodes().size() - parameters.size()) < kIndex2) {
357 return false;
358 }
359 // Update parameters info
360 const auto &manager = kernel_graph->manager();
361 MS_EXCEPTION_IF_NULL(manager);
362 const auto &users = manager->node_users();
363 for (const auto &p : parameters) {
364 // Exclude parameter not used in graph, such as constant input
365 if (users.find(p) != users.end()) {
366 (void)session_ptr->CreateNewParameterFromParameter(p, kernel_graph.get());
367 kernel_graph->SetKernelInfoForNode(p);
368 }
369 }
370
371 // Run by single op will create kernel info in single op graph, so no need do this here;
372 // But, run by Actor need kernel info, so do this here
373 bool run_by_single_op = kernel_graph->has_flag(kFlagEnableRunGraphBySingleOp);
374 if (!run_by_single_op) {
375 const auto &nodes = TopoSort(kernel_graph->get_return());
376 for (const auto &node : nodes) {
377 if (node->isa<CNode>()) {
378 const auto &cnode = node->cast<CNodePtr>();
379 // Bprop cut use prim_py, no need change
380 if (auto prim = GetValueNode<PrimitivePtr>(cnode->input(kIndex0));
381 !IsPrimitiveEquals(prim, prim::kPrimBpropCut)) {
382 auto new_prim = std::make_shared<Primitive>(*prim);
383 cnode->set_input(kIndex0, NewValueNode(new_prim));
384 }
385 kernel_graph->PostNewCNode(cnode);
386 } else {
387 if (node->isa<ValueNode>()) {
388 session_ptr->CreateNewValueNode(node, kernel_graph.get());
389 }
390 // Kernel graph new value node will create kernel info
391 if (node->kernel_info() == nullptr) {
392 kernel_graph->SetKernelInfoForNode(node);
393 }
394 }
395 }
396 }
397 auto output_node = kernel_graph->NewCNode({NewValueNode(prim::kPrimMakeTuple), kernel_graph->output()});
398 AbstractBasePtrList output_abs_list{kernel_graph->output()->abstract()};
399 auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(output_abs_list);
400 output_node->set_abstract(abstract_tuple);
401 kernel_graph->set_output(output_node);
402 MS_LOG(INFO) << "Insert make tuple for output";
403 return true;
404 }
405
NeedCheckMultiTarget(const FuncGraphPtr & func_graph,int ms_execution_mode)406 bool NeedCheckMultiTarget(const FuncGraphPtr &func_graph, int ms_execution_mode) {
407 if (ms_execution_mode == kGraphMode) {
408 return true;
409 }
410 bool run_in_dynamic = ms_execution_mode == kPynativeMode && func_graph->has_flag(kFlagEnableRunGraphBySingleOp);
411 bool is_call_graph = func_graph->has_flag(kFlagJitCallGraph);
412 bool is_control_flow = !func_graph->func_graphs_used_total().empty();
413 return (run_in_dynamic && is_call_graph) || is_control_flow;
414 }
415
UnifyIR(const CNodePtr & cnode,bool enable_run_graph_by_single_op)416 void UnifyIR(const CNodePtr &cnode, bool enable_run_graph_by_single_op) {
417 MS_EXCEPTION_IF_NULL(cnode);
418 static const std::map<std::string, std::string> kOpListToTupleNames = {
419 {mindspore::kMakeListNewOpName, mindspore::kMakeTupleOpName},
420 {mindspore::kListGetItemOpName, mindspore::kTupleGetItemOpName},
421 {mindspore::kListSetItemOpName, mindspore::kTupleSetItemOpName}};
422 // List name --> tuple name.
423 auto &&op_name = common::AnfAlgo::GetCNodeName(cnode);
424 auto iter = kOpListToTupleNames.find(op_name);
425 if (iter != kOpListToTupleNames.end()) {
426 common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
427 cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(iter->second)));
428 // Reset full scope name.
429 cnode->set_fullname_with_scope("");
430 MS_LOG(INFO) << "Rename op from " << iter->first << " to " << iter->second << " for op "
431 << cnode->fullname_with_scope() << ", debug name:" << cnode->DebugString();
432 op_name = iter->second;
433 }
434
435 // TupleGetItem --> RealTupleGetItem.
436 if (!enable_run_graph_by_single_op && op_name == mindspore::kTupleGetItemOpName &&
437 NeedConvertToRealTupleGetItem(cnode)) {
438 common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
439 cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(mindspore::kRealTupleGetItemOpName)));
440 // Reset full scope name.
441 cnode->set_fullname_with_scope("");
442 MS_LOG(INFO) << "Rename op from TupleGetItem to RealTupleGetItem for op " << cnode->fullname_with_scope()
443 << ", debug name:" << cnode->DebugString();
444 }
445
446 // MakeTuple --> RealMakeTuple
447 if (op_name == mindspore::kMakeTupleOpName && common::AnfAlgo::IsDynamicSequence(cnode)) {
448 common::AnfAlgo::SetNodeAttr(kAttrOpAdaptationProcessed, MakeValue(true), cnode);
449 cnode->set_input(0, mindspore::NewValueNode(std::make_shared<Primitive>(mindspore::kRealMakeTupleOpName)));
450 // Reset full scope name.
451 cnode->set_fullname_with_scope("");
452 MS_LOG(INFO) << "Rename op from MakeTuple to RealMakeTuple for op " << cnode->fullname_with_scope()
453 << ", debug name:" << cnode->DebugString();
454 }
455 }
456
EnableSymbolEngine(const FuncGraphPtr & func_graph,device::RunMode run_mode)457 bool EnableSymbolEngine(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
458 // Currently, only Graph Kernel Fusion dynamic shape case need build symbol engine
459 if (run_mode != device::RunMode::kKernelMode) {
460 return false;
461 }
462 if (common::GetEnv("MS_SYMBOL_ENGINE_OPTIMIZE") == "off") {
463 return false;
464 }
465 if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
466 return false;
467 }
468 return common::AnfAlgo::IsDynamicGraph(func_graph);
469 }
470
BuildSymbolEngine(const FuncGraphPtr & func_graph,device::RunMode run_mode)471 void BuildSymbolEngine(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
472 if (func_graph == nullptr) {
473 return;
474 }
475 MS_LOG(INFO) << "Status record: start build symbol engine for function graph: " << func_graph->ToString();
476 if (!EnableSymbolEngine(func_graph, run_mode)) {
477 MS_LOG(INFO) << "Status record: skip build symbol engine for function graph: " << func_graph->ToString();
478 return;
479 }
480 try {
481 MS_LOG_TRY_CATCH_SCOPE;
482 symshape::SymbolEngineImpl::Build(func_graph);
483 } catch (std::exception &e) {
484 MS_LOG(WARNING) << "A problem occurs when build symbol engine for function graph[" << func_graph->ToString()
485 << "]: " << e.what();
486 }
487 MS_LOG(INFO) << "Status record: end build symbol engine for function graph: " << func_graph->ToString();
488 }
489 } // namespace
490
CompileGraphs(const FuncGraphPtr & func_graph)491 const ActorInfo &MindRTBackendBase::CompileGraphs(const FuncGraphPtr &func_graph) {
492 WaitTaskFinish();
493 MS_EXCEPTION_IF_NULL(graph_compiler_);
494 MS_EXCEPTION_IF_NULL(func_graph);
495 MS_LOG(INFO) << "Status record: start compile function graph: " << func_graph->ToString();
496 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCompileGraphs, 1, 0, 0);
497 PROF_START(compile_backend_graph);
498
499 auto root_graph = WrapPrimitives(func_graph);
500 MS_EXCEPTION_IF_NULL(root_graph);
501 bool pynative_with_jit_call_graph = func_graph->has_flag(kFlagPyNativeWithJitCallGraph);
502 if (!pynative_with_jit_call_graph) {
503 UnifyMindIR(root_graph);
504 }
505 root_graph_ = root_graph;
506 // Use kernel graph, which output maybe change by backed pass, so backup output
507 if (root_graph_->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
508 output_node_ = root_graph_->output();
509 }
510
511 // Register a summary callback function, which is called in the final stages of summary.
512 graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback);
513
514 auto context_ptr = MsContext::GetInstance();
515 MS_EXCEPTION_IF_NULL(context_ptr);
516 ms_execution_mode_ = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
517 func_graph->set_flag(kFlagPyNativeRunInGraph, ms_execution_mode_ == kPynativeMode);
518
519 // Compile root graph.
520 graph_id_to_device_context_.clear();
521 func_graph_to_kernel_graph_ids_.clear();
522 control_nodes_.clear();
523
524 const auto &device_context =
525 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
526 MS_EXCEPTION_IF_NULL(device_context);
527 device_context->Initialize();
528 device_context->device_res_manager_->BindDeviceToCurrentThread(false);
529
530 // Current only ascend do need do checkout in PartitionGraph
531 bool all_support = device_context->PartitionGraph(func_graph);
532 PROF_START(CompileSubGraph);
533 if (all_support) {
534 auto run_mode = device_context->GetRunMode(func_graph);
535 if (run_mode == device::RunMode::kGraphMode && pynative::GraphAdapter::PyNativeEnableTaskSink(func_graph)) {
536 auto graph_id = graph_compiler_->CompileWholeGraphForGraphRunMode(func_graph, device_context);
537 graph_id_to_device_context_[graph_id] = device_context;
538 } else {
539 // Build symbol engine for root graph before partition graph
540 BuildSymbolEngine(func_graph, device::RunMode::kKernelMode);
541 CompileSubGraph(func_graph, device::RunMode::kKernelMode);
542 }
543 } else {
544 if (NeedCheckMultiTarget(func_graph, ms_execution_mode_)) {
545 ProcessNotSupportCnode(func_graph, device_context->GetDeviceType(), mindspore::device::DeviceType::kCPU);
546 }
547 // Build symbol engine for root graph before partition graph
548 BuildSymbolEngine(func_graph, device_context->GetRunMode(func_graph));
549 CompileSubGraph(func_graph);
550 }
551 PROF_END(CompileSubGraph);
552
553 // Construct the graph compiler info.
554 auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph);
555 MS_EXCEPTION_IF_NULL(graph_compiler_info);
556 if ((ms_execution_mode_ == kGraphMode ||
557 (ms_execution_mode_ == kPynativeMode && pynative::GraphAdapter::IsPynativeGeGraphSink(root_graph_))) &&
558 ((!graph_compiler_info->graphs_.empty()) || graph_compiler_info->control_nodes_.size() > 1)) {
559 MS_LOG(DEBUG) << "Start transform";
560 PROF_START(GraphScheduler);
561 // Transform graph to actor DAG, and schedule the actor DAG.
562 ParseControlNodes(*graph_compiler_info);
563 const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info);
564 runtime::GraphScheduler::GetInstance().Schedule(actor_set);
565 PROF_END(GraphScheduler);
566 }
567 const ActorInfo &actor_info = graph_compiler_info->name_;
568 (void)actor_to_graph_compiler_info_.emplace(graph_compiler_info->name_, std::move(graph_compiler_info));
569 PROF_END(compile_backend_graph);
570
571 for (const auto &graph_id_to_context : graph_id_to_device_context_) {
572 auto context = graph_id_to_context.second;
573 device::MultiStreamController::GetInstance()->Refresh(context);
574 }
575
576 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageCompileGraphs, 1, 0, 1);
577 MS_LOG(INFO) << "Status record: end compile function graph: " << func_graph->ToString()
578 << ", produce actor: " << actor_info;
579 return actor_info;
580 }
581
582 namespace {
DoUnifyMindIRPass(const FuncGraphPtr & graph,const std::shared_ptr<opt::GraphOptimizer> & optimizer)583 void DoUnifyMindIRPass(const FuncGraphPtr &graph, const std::shared_ptr<opt::GraphOptimizer> &optimizer) {
584 MS_EXCEPTION_IF_NULL(graph);
585 MS_EXCEPTION_IF_NULL(optimizer);
586 auto context_ptr = MsContext::GetInstance();
587 MS_EXCEPTION_IF_NULL(context_ptr);
588 MS_LOG(INFO) << "Do unify mindir pass for graph " << graph->ToString();
589 #ifdef ENABLE_DUMP_IR
590 if (context_ptr->CanDump(kIntroductory)) {
591 std::string file_name = "hwopt_before_mindrt_unify_mindir_graph_" + graph->ToString() + ".ir";
592 DumpIR(file_name, graph, true, kWholeStack);
593 }
594 #endif
595 (void)optimizer->Optimize(graph);
596 #ifdef ENABLE_DUMP_IR
597 if (context_ptr->CanDump(kIntroductory)) {
598 std::string file_name = "hwopt_end_mindrt_unify_mindir_graph_" + graph->ToString() + ".ir";
599 DumpIR(file_name, graph, true, kWholeStack);
600 }
601 #endif
602 }
603
HasSwitchNode(const FuncGraphPtr & func_graph)604 bool HasSwitchNode(const FuncGraphPtr &func_graph) {
605 if (func_graph == nullptr) {
606 return false;
607 }
608 const auto &nodes = TopoSort(func_graph->get_return());
609 return std::any_of(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) {
610 return node != nullptr && node->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch);
611 });
612 }
613
IsNodeValid(const AnfNodePtr & node)614 bool IsNodeValid(const AnfNodePtr &node) {
615 if (node != nullptr && common::AnfAlgo::IsNodeOutputDynamicShape(node)) {
616 MS_LOG(INFO) << "Disable switch inline for dynamic shape node:" << node->DebugString();
617 return false;
618 } else if (common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
619 const auto &cnode = node->cast<CNodePtr>();
620 MS_EXCEPTION_IF_NULL(cnode);
621 if (cnode->size() <= 1 || cnode->input(1) == nullptr || !(IsValueNode<FuncGraph>(cnode->input(1)))) {
622 return true;
623 }
624 const auto &func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1));
625 MS_EXCEPTION_IF_NULL(func_graph);
626 if (std::any_of(func_graph->parameters().begin(), func_graph->parameters().end(), [](const AnfNodePtr ¶) {
627 return para != nullptr && para->abstract() != nullptr &&
628 para->abstract()->isa<abstract::AbstractSequence>() &&
629 (para->abstract()->cast<abstract::AbstractSequencePtr>()->dynamic_len() ||
630 para->abstract()->cast<abstract::AbstractSequencePtr>()->size() > 1);
631 })) {
632 MS_LOG(INFO) << "Disable switch inline for tuple input in graph:" << func_graph->ToString()
633 << " for partial node:" << node->DebugString();
634 return false;
635 }
636 }
637 return true;
638 }
639
IsEnableControlFlowInline(const FuncGraphPtr & graph)640 bool IsEnableControlFlowInline(const FuncGraphPtr &graph) {
641 auto context = MsContext::GetInstance();
642 MS_EXCEPTION_IF_NULL(context);
643 if (std::any_of(
644 graph->func_graphs_used_total().cbegin(), graph->func_graphs_used_total().cend(), [](const auto &sub_graph) {
645 return sub_graph != nullptr && sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && HasSwitchNode(sub_graph);
646 })) {
647 MS_LOG(INFO) << "Set reuse level from:" << context->CellReuseLevel() << " to:" << CellReuseLevel::kNoInline;
648 context->SetCellReuseLevel(CellReuseLevel::kNoInline);
649 }
650
651 static const auto is_disable_switch_inline = common::IsDisableRuntimeConfig(common::kRuntimeSwitchInline);
652 if (is_disable_switch_inline) {
653 MS_LOG(INFO) << "Disable switch inline by runtime config.";
654 return false;
655 }
656
657 // Only support ge backend, kernel by kernel mode and multi-funcgraph.
658 static const bool is_enable_ge = (context->backend_policy() == "ge");
659 if (!is_enable_ge || !context->IsKByKExecutorMode() || graph->func_graphs_used_total().empty()) {
660 MS_LOG(INFO) << "Disable switch inline, executor mode:" << context->IsKByKExecutorMode();
661 return false;
662 }
663
664 MS_EXCEPTION_IF_NULL(graph);
665 // Not support recursive.
666 if (std::any_of(graph->func_graphs_used_total().cbegin(), graph->func_graphs_used_total().cend(),
667 [](const auto &sub_graph) { return sub_graph->recursive(); })) {
668 MS_LOG(INFO) << "Disable switch inline for recursive.";
669 return false;
670 }
671
672 if (context->CellReuseLevel() != CellReuseLevel::kLazyInline) {
673 auto is_include_no_switch_call = [](const FuncGraphPtr &graph) {
674 MS_EXCEPTION_IF_NULL(graph);
675 const auto &nodes = TopoSort(graph->get_return());
676 for (const auto &node : nodes) {
677 MS_EXCEPTION_IF_NULL(node);
678 if (common::AnfAlgo::IsCallNode(node)) {
679 const auto &cnode = node->cast<CNodePtr>();
680 if (!common::AnfAlgo::CheckPrimitiveType(cnode->input(0), prim::kPrimSwitch)) {
681 return true;
682 }
683 }
684 }
685 return false;
686 };
687 if (is_include_no_switch_call(graph)) {
688 MS_LOG(INFO) << "Disable switch inline for unsupported call node.";
689 return false;
690 }
691 if (std::any_of(graph->func_graphs_used_total().begin(), graph->func_graphs_used_total().end(),
692 is_include_no_switch_call)) {
693 MS_LOG(INFO) << "Disable switch inline for unsupported call node.";
694 return false;
695 }
696 }
697 const auto &mng = graph->manager();
698 if (mng != nullptr && std::any_of(mng->all_nodes().begin(), mng->all_nodes().end(),
699 [](const AnfNodePtr &node) { return !IsNodeValid(node); })) {
700 return false;
701 }
702 MS_LOG(INFO) << "Enable switch inline.";
703 return true;
704 }
705
AddGraphDynamicShapeAttr(const KernelGraphPtr & kernel_graph)706 void AddGraphDynamicShapeAttr(const KernelGraphPtr &kernel_graph) {
707 MS_EXCEPTION_IF_NULL(kernel_graph);
708 if (kernel_graph->is_dynamic_shape()) {
709 return;
710 }
711
712 const auto &nodes = TopoSort(kernel_graph->output());
713 for (const auto &node : nodes) {
714 MS_EXCEPTION_IF_NULL(node);
715 if (node->isa<CNode>() && common::AnfAlgo::IsDynamicShape(node)) {
716 kernel_graph->SetGraphDynamicAttr(true);
717 break;
718 }
719 }
720 }
721 } // namespace
722
UnifyMindIR(const FuncGraphPtr & root_graph) const723 void MindRTBackendBase::UnifyMindIR(const FuncGraphPtr &root_graph) const {
724 MS_EXCEPTION_IF_NULL(root_graph);
725 MS_EXCEPTION_IF_NULL(root_graph->manager());
726 // When the input is an empty sequence, the number of inputs will be recorded as 0, and the tensor cannot be
727 // expressed, so the empty sequence is set to dynamic len.
728 for (const auto ¶meter : root_graph->parameters()) {
729 MS_EXCEPTION_IF_NULL(parameter);
730 const auto &abs = parameter->abstract();
731 if (abs != nullptr && abs->isa<abstract::AbstractSequence>()) {
732 const auto &sequence_abs = abs->cast<abstract::AbstractSequencePtr>();
733 MS_EXCEPTION_IF_NULL(sequence_abs);
734 if ((!sequence_abs->dynamic_len()) && sequence_abs->empty()) {
735 MS_LOG(INFO) << "Set dynamic len flag for empty sequence input:" << parameter->DebugString();
736 sequence_abs->set_dynamic_len(true);
737 }
738 }
739 }
740 bool enable_run_graph_by_single_op = root_graph->has_flag(kFlagEnableRunGraphBySingleOp);
741 const auto &graphs = root_graph->manager()->func_graphs();
742 for (const auto &graph : graphs) {
743 MS_EXCEPTION_IF_NULL(graph);
744 auto output = graph->get_return();
745 if (!output->isa<CNode>()) {
746 continue;
747 }
748 auto seen = NewSeenGeneration();
749 std::queue<AnfNodePtr> to_visit;
750 to_visit.emplace(output);
751 while (!to_visit.empty()) {
752 auto node = to_visit.front();
753 to_visit.pop();
754 MS_EXCEPTION_IF_NULL(node);
755 CheckNodeValid(node);
756
757 const auto &cnode = node->cast<CNodePtr>();
758 MS_EXCEPTION_IF_NULL(cnode);
759 UnifyIR(cnode, enable_run_graph_by_single_op);
760 for (auto &input : cnode->inputs()) {
761 MS_EXCEPTION_IF_NULL(input);
762 if (input->seen_ == seen || !input->isa<CNode>()) {
763 continue;
764 }
765 to_visit.emplace(input);
766 input->seen_ = seen;
767 }
768 }
769 }
770
771 auto optimizer = std::make_shared<opt::GraphOptimizer>();
772 auto unify_mindir_pm = std::make_shared<opt::PassManager>("unify_mindir_pm");
773 unify_mindir_pm->AddPass(std::make_shared<opt::EraseInvalidMicroDepend>());
774 if (common::AnfAlgo::IsDynamicGraph(root_graph)) {
775 unify_mindir_pm->AddPass(std::make_shared<opt::EraseNotCutAttr>());
776 }
777 if (IsEnableControlFlowInline(root_graph)) {
778 unify_mindir_pm->AddPass(std::make_shared<opt::SwitchNotCut>());
779 }
780 optimizer->AddPassManager(unify_mindir_pm);
781
782 DoUnifyMindIRPass(root_graph, optimizer);
783 const auto &sub_graphs = root_graph->manager()->func_graphs_used_total(root_graph);
784 for (const auto &sub_graph : sub_graphs) {
785 MS_EXCEPTION_IF_NULL(sub_graph);
786 DoUnifyMindIRPass(sub_graph, optimizer);
787 }
788 }
789
CompileSubGraph(const FuncGraphPtr & func_graph,device::RunMode run_mode)790 void MindRTBackendBase::CompileSubGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
791 auto root_graph = func_graph;
792 if (!func_graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
793 root_graph = WrapPrimitives(func_graph);
794 }
795 MS_EXCEPTION_IF_NULL(root_graph);
796 auto manager = root_graph->manager();
797 CompileGraph(root_graph, run_mode);
798 auto context = MsContext::GetInstance();
799 MS_EXCEPTION_IF_NULL(context);
800 MS_EXCEPTION_IF_NULL(manager);
801 const auto &sub_graphs = manager->func_graphs_used_total(root_graph);
802 std::vector<FuncGraphPtr> cand_graph(sub_graphs.begin(), sub_graphs.end());
803 std::sort(cand_graph.begin(), cand_graph.end(),
804 [](const FuncGraphPtr &a, const FuncGraphPtr &b) { return a->ToString() < b->ToString(); });
805 for (const auto &sub_graph : cand_graph) {
806 MS_EXCEPTION_IF_NULL(sub_graph);
807 bool skip_inline_graph =
808 (sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE) && context->CellReuseLevel() == CellReuseLevel::kLazyInline) ||
809 sub_graph->has_flag(kFlagSwitchInline);
810 if (sub_graph != func_graph && sub_graph != nullptr && !sub_graph->has_flag(kFlagJitCallGraph) &&
811 !skip_inline_graph) {
812 MS_LOG(INFO) << "Compile sub graph " << sub_graph->ToString();
813 CompileGraph(sub_graph, run_mode);
814 }
815 }
816 }
817
CompileGraph(const FuncGraphPtr & func_graph,device::RunMode run_mode)818 void MindRTBackendBase::CompileGraph(const FuncGraphPtr &func_graph, device::RunMode run_mode) {
819 MS_EXCEPTION_IF_NULL(func_graph);
820 if (!func_graph->has_flag(kFlagIsPyNativeBpropKernelGraph)) {
821 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageGraphPartition, 1, 0, 0);
822 // Split graph to segments.
823 MS_EXCEPTION_IF_NULL(graph_partition_);
824 const auto &segments = graph_partition_->Partition(func_graph);
825 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventCompileGraph, kStageGraphPartition, 1, 0, 1);
826 MS_LOG(INFO) << "Compile graph: " << func_graph->ToString() << ", Split segments size: " << segments.size();
827
828 // Foreach the segments to compile graph.
829 for (const auto &segment : segments) {
830 CompileGraphFromSegment(segment, run_mode);
831 }
832 } else {
833 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
834 AddGraphDynamicShapeAttr(kernel_graph);
835 MS_EXCEPTION_IF_NULL(kernel_graph);
836 const auto &session = graph_compiler_->session_ptr();
837 MS_EXCEPTION_IF_NULL(session);
838 session->SetKernelGraphId(kernel_graph);
839 MS_LOG(INFO) << "Compile graph: " << kernel_graph->ToString() << ", kernel graph";
840 if (AddKernelGraphCompileInfo(kernel_graph, session)) {
841 kernel_graph->SetExecOrderByDefault();
842 auto context_ptr = MsContext::GetInstance();
843 MS_EXCEPTION_IF_NULL(context_ptr);
844 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
845 {context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET), device_id_});
846 MS_EXCEPTION_IF_NULL(device_context);
847 device_context->Initialize();
848 CompileKernelGraph(kernel_graph, std::make_pair(kernel_graph->inputs(), kernel_graph->outputs()), device_context,
849 run_mode);
850 }
851 }
852 }
853
CompileGraphFromSegment(const GraphSegmentPtr & segment,device::RunMode run_mode)854 void MindRTBackendBase::CompileGraphFromSegment(const GraphSegmentPtr &segment, device::RunMode run_mode) {
855 MS_EXCEPTION_IF_NULL(segment);
856 // Compile the normal nodes, which doesn't contain the cut node.
857 if (segment->nodes_.empty()) {
858 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The segments size is 0.";
859 }
860 if (!segment->is_cut_) {
861 MS_EXCEPTION_IF_NULL(segment->nodes_[0]);
862 MS_LOG(INFO) << "Compile normal segment, the first node: " << segment->nodes_[0]->DebugString();
863
864 // Get the device context.
865 const auto &cur_device_name = GetCNodeTarget(segment->nodes_[0]);
866 auto device_context =
867 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({cur_device_name, device_id_});
868 MS_EXCEPTION_IF_NULL(device_context);
869 device_context->Initialize();
870
871 // Transform nodes to inputs and outputs.
872 FuncGraphPtr fg;
873 AnfNodePtrList inputs;
874 AnfNodePtrList outputs;
875 std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_);
876
877 // Get segment run mode.
878 auto seg_run_mode = run_mode;
879 for (auto &node : outputs) {
880 if (node->isa<CNode>()) {
881 if (common::AnfAlgo::GetGraphSplitGroup(node) == kKernelGroup) {
882 seg_run_mode = device::RunMode::kKernelMode;
883 break;
884 }
885 }
886 }
887
888 GraphId graph_id;
889 if (root_graph_->has_flag(kFlagEnableRunGraphBySingleOp)) {
890 graph_id = graph_compiler_->CompileDynamicGraph(segment, outputs, device_context);
891 } else {
892 graph_id = graph_compiler_->CompileGraph(segment, std::make_pair(inputs, outputs), device_context, seg_run_mode,
893 ms_execution_mode_ == kPynativeMode);
894 if (graph_compiler_->Fetch(graph_id)->has_flag(kFlagEnableRunGraphBySingleOp)) {
895 MS_LOG(INFO)
896 << "Set kFlagEnableRunGraphBySingleOp: require the root_graph and subgraph to have the same markings ";
897 root_graph_->set_flag(kFlagEnableRunGraphBySingleOp, true);
898 }
899 }
900 CacheFuncGraphWithKernelGraphId(segment->nodes_[0]->func_graph(), graph_id, device_context);
901 } else {
902 // Compile the cut node.
903 auto cut_node = segment->nodes_[0];
904 MS_EXCEPTION_IF_NULL(cut_node);
905 MS_LOG(INFO) << "Compile cut segment, the cut node: " << cut_node->DebugString();
906 control_nodes_.push_back(cut_node);
907 if (common::AnfAlgo::IsCallNode(cut_node) || common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitch) ||
908 common::AnfAlgo::CheckPrimitiveType(cut_node, prim::kPrimSwitchLayer)) {
909 const auto &func_graph = cut_node->func_graph();
910 MS_EXCEPTION_IF_NULL(func_graph);
911 (void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>());
912 }
913 }
914 }
915
CompileKernelGraph(const KernelGraphPtr & kernel_graph,const std::pair<AnfNodePtrList,AnfNodePtrList> & io_nodes,DeviceContext * device_context,device::RunMode run_mode)916 void MindRTBackendBase::CompileKernelGraph(const KernelGraphPtr &kernel_graph,
917 const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes,
918 DeviceContext *device_context, device::RunMode run_mode) {
919 GraphId graph_id;
920 if (root_graph_->has_flag(kFlagEnableRunGraphBySingleOp)) {
921 graph_id = graph_compiler_->CompileDynamicGraph(kernel_graph, device_context);
922 } else {
923 graph_id = graph_compiler_->CompileGraph(kernel_graph, io_nodes, device_context, run_mode,
924 ms_execution_mode_ == kPynativeMode);
925 if (graph_compiler_->Fetch(graph_id)->has_flag(kFlagEnableRunGraphBySingleOp)) {
926 MS_LOG(INFO)
927 << "Set kFlagEnableRunGraphBySingleOp: require the root_graph and subgraph to have the same markings ";
928 root_graph_->set_flag(kFlagEnableRunGraphBySingleOp, true);
929 }
930 }
931 CacheFuncGraphWithKernelGraphId(kernel_graph, graph_id, device_context);
932 }
933
CacheFuncGraphWithKernelGraphId(const FuncGraphPtr & func_graph,const GraphId & graph_id,DeviceContext * device_context)934 void MindRTBackendBase::CacheFuncGraphWithKernelGraphId(const FuncGraphPtr &func_graph, const GraphId &graph_id,
935 DeviceContext *device_context) {
936 graph_id_to_device_context_[graph_id] = device_context;
937 if (func_graph_to_kernel_graph_ids_.find(func_graph) == func_graph_to_kernel_graph_ids_.end()) {
938 (void)func_graph_to_kernel_graph_ids_[func_graph].emplace_back(std::vector<GraphId>{graph_id});
939 } else {
940 (void)func_graph_to_kernel_graph_ids_[func_graph].back().emplace_back(graph_id);
941 }
942 }
943
944 namespace {
TensorValueToVector(const ValuePtr & value,VectorRef * outputs)945 void TensorValueToVector(const ValuePtr &value, VectorRef *outputs) {
946 MS_EXCEPTION_IF_NULL(value);
947 MS_EXCEPTION_IF_NULL(outputs);
948 if (value->isa<ValueSequence>()) {
949 auto value_tuple = value->cast<ValueSequencePtr>();
950 MS_EXCEPTION_IF_NULL(value_tuple);
951 for (size_t i = 0; i < value_tuple->size(); ++i) {
952 ValuePtr element = value_tuple->value()[i];
953 MS_EXCEPTION_IF_NULL(element);
954 if (element->isa<tensor::Tensor>()) {
955 auto tensor = element->cast<tensor::TensorPtr>();
956 MS_EXCEPTION_IF_NULL(tensor);
957 outputs->emplace_back(tensor);
958 } else if (element->isa<Scalar>()) {
959 auto scalar = element->cast<ScalarPtr>();
960 MS_EXCEPTION_IF_NULL(scalar);
961 outputs->emplace_back(ScalarToTensor(scalar));
962 } else if (element->isa<ValueSequence>()) {
963 VectorRef tuple;
964 TensorValueToVector(element, &tuple);
965 outputs->emplace_back(tuple);
966 }
967 }
968 } else if (value->isa<tensor::Tensor>()) {
969 auto tensor = value->cast<tensor::TensorPtr>();
970 MS_EXCEPTION_IF_NULL(tensor);
971 outputs->emplace_back(tensor);
972 } else if (value->isa<Scalar>()) {
973 auto scalar = value->cast<ScalarPtr>();
974 MS_EXCEPTION_IF_NULL(scalar);
975 outputs->emplace_back(ScalarToTensor(scalar));
976 }
977 }
978
IsGraphOutputValueNodeOrParameter(const AnfNodePtr & graph_output,const VectorRef & args,VectorRef * outputs)979 bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const VectorRef &args, VectorRef *outputs) {
980 MS_EXCEPTION_IF_NULL(graph_output);
981 MS_EXCEPTION_IF_NULL(outputs);
982 if (graph_output->isa<ValueNode>()) {
983 MS_LOG(INFO) << "Graph's output is a constant. No need to execute.";
984 VectorRef output_tmp;
985 ValuePtr value = GetValueNode(graph_output);
986 TensorValueToVector(value, &output_tmp);
987 MS_EXCEPTION_IF_NULL(value);
988 if (value->isa<ValueSequence>()) {
989 outputs->emplace_back(output_tmp);
990 } else if (value->isa<tensor::Tensor>() || value->isa<Scalar>()) {
991 *outputs = output_tmp;
992 } else {
993 MS_LOG(INFO) << "Graph output is empty!";
994 }
995 return true;
996 }
997
998 if (graph_output->isa<Parameter>()) {
999 MS_LOG(INFO) << "Graph's output is a parameter. If all params are inputs, no need to execute.";
1000 // Find the right parameter as ret_val.
1001 auto func_graph = graph_output->func_graph();
1002 MS_EXCEPTION_IF_NULL(func_graph);
1003 auto params = func_graph->parameters();
1004 if (args.size() != params.size()) {
1005 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Input size " << args.size()
1006 << " is not equal to graph input size " << params.size();
1007 }
1008
1009 auto it = std::find(params.begin(), params.end(), graph_output);
1010 if (it == params.end()) {
1011 MS_EXCEPTION(UnknownError) << "When graph output is Parameter, it should be found in graph parameters";
1012 }
1013 size_t index = it - params.cbegin();
1014 if (index >= args.size()) {
1015 MS_EXCEPTION(UnknownError) << "Index " << index << " equal or larger than args size " << args.size();
1016 }
1017
1018 outputs->emplace_back(args[index]);
1019 return true;
1020 }
1021 return false;
1022 }
1023 } // namespace
1024
ConstructOutputs(runtime::ActorSet * actor_set,VectorRef * outputs,const FuncGraphPtr & root_graph)1025 void MindRTBackendBase::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs,
1026 const FuncGraphPtr &root_graph) {
1027 MS_EXCEPTION_IF_NULL(actor_set);
1028 MS_EXCEPTION_IF_NULL(outputs);
1029 MS_EXCEPTION_IF_NULL(root_graph);
1030 bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
1031 distributed::recovery::RecoveryContext::GetInstance()->need_reset());
1032 bool is_embedding_cache_server = false;
1033 #if defined(__linux__) && defined(WITH_BACKEND)
1034 is_embedding_cache_server = ps::PSContext::instance()->cache_enable() && ps::PSContext::instance()->is_server();
1035 #endif
1036 if (need_contruct_output) {
1037 MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
1038 // Update device address for output node of graph.
1039 // Summary processing will use the output device address, so must be after the summary processing.
1040 if (!is_embedding_cache_server) {
1041 actor_set->output_actor_->UpdateOutputDeviceAddress();
1042 }
1043
1044 // Fetch outputs.
1045 auto &output_tensors = actor_set->output_actor_->outputs();
1046 if (!output_tensors.empty()) {
1047 size_t output_position = 0;
1048 std::vector<tensor::TensorPtr> tuple_tensors;
1049 ConstructOutputs(root_graph->output(), output_tensors, &output_position, outputs, &tuple_tensors);
1050
1051 // The tensor may be repeated, so it needs to be set null last.
1052 for (auto &tuple_tensor : tuple_tensors) {
1053 MS_EXCEPTION_IF_NULL(tuple_tensor);
1054 tuple_tensor->set_device_address(nullptr);
1055 }
1056 }
1057 }
1058 }
1059
ContiguousArgs(const VectorRef & args,const GraphCompilerInfo & graph_compiler_info)1060 void MindRTBackendBase::ContiguousArgs(const VectorRef &args, const GraphCompilerInfo &graph_compiler_info) {
1061 for (const auto &arg : args) {
1062 if (utils::isa<tensor::BaseTensorPtr>(arg)) {
1063 auto value = utils::cast<tensor::BaseTensorPtr>(arg);
1064 runtime::DeviceAddressUtils::ConvertContiguousTensorSync(value);
1065 } else if (utils::isa<ValuePtr>(arg)) {
1066 auto value = utils::cast<ValuePtr>(arg);
1067 MS_EXCEPTION_IF_NULL(value);
1068 if (!value->isa<ValueSequence>()) {
1069 return;
1070 }
1071 auto value_tuple = value->cast<ValueSequencePtr>();
1072 MS_EXCEPTION_IF_NULL(value_tuple);
1073 auto tuple_value = value_tuple->value();
1074 for (const auto &v : tuple_value) {
1075 if (!v->isa<tensor::BaseTensor>()) {
1076 continue;
1077 }
1078 auto t = v->cast<tensor::BaseTensorPtr>();
1079 runtime::DeviceAddressUtils::ConvertContiguousTensorSync(t);
1080 }
1081 }
1082 }
1083 }
1084
WaitMultiStream(const GraphCompilerInfo & graph_compiler_info)1085 void MindRTBackendBase::WaitMultiStream(const GraphCompilerInfo &graph_compiler_info) {
1086 for (auto device_context : graph_compiler_info.device_contexts_) {
1087 MS_EXCEPTION_IF_NULL(device_context);
1088 if (device_context->device_res_manager_->single_op_multi_stream_enable()) {
1089 device_context->device_res_manager_->SyncNotDefaultStreams();
1090 }
1091 }
1092 }
1093
RunGraph(const ActorInfo & actor_info,const VectorRef & args,VectorRef * outputs)1094 void MindRTBackendBase::RunGraph(const ActorInfo &actor_info, const VectorRef &args, VectorRef *outputs) {
1095 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kRuntime, runtime::ProfilerEvent::kBackendGraphRunInner,
1096 actor_info, true);
1097 MS_EXCEPTION_IF_NULL(root_graph_);
1098 if (IsGraphOutputValueNodeOrParameter(root_graph_->output(), args, outputs)) {
1099 return;
1100 }
1101
1102 const auto &context_ptr = MsContext::GetInstance();
1103 MS_EXCEPTION_IF_NULL(context_ptr);
1104 if (context_ptr->get_param<bool>(MS_CTX_PRECOMPILE_ONLY)) {
1105 MS_LOG(INFO) << "PrecompileOnly, stop run graph";
1106 return;
1107 }
1108
1109 // Open abstract_lock for dynamic_shape
1110 AnfUtils::OpenAbstractLock();
1111
1112 // Fetch the graph compiler info.
1113 const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
1114 if (graph_iter == actor_to_graph_compiler_info_.end()) {
1115 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Can't find the graph compiler info.";
1116 }
1117 MS_EXCEPTION_IF_NULL(graph_iter->second);
1118 const auto &graph_compiler_info = *(graph_iter->second);
1119 // For pynative and graph mix execution.
1120 WaitTaskFinish();
1121 WaitMultiStream(graph_compiler_info);
1122
1123 // Run in the pynative mode.
1124 MS_EXCEPTION_IF_NULL(outputs);
1125 // There will be more than one kernel graph in heterogeneous scenario in a jit of PyNative Mode.
1126 if (ms_execution_mode_ == kPynativeMode && !pynative::GraphAdapter::IsPynativeGeGraphSink(root_graph_)) {
1127 // The tensor needs to be converted to contiguous before being given to the actors.
1128 // After the view feature is supported in the graph mode, the following code will be deleted.
1129 ContiguousArgs(args, graph_compiler_info);
1130 RunGraphByCondition(actor_info, graph_compiler_info, args, outputs);
1131 return;
1132 }
1133
1134 MS_LOG(INFO) << "Status record: start run actor: " << actor_info;
1135 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventRunGraph, kStageRunGraph, 1, 0, 0);
1136 std::vector<std::vector<tensor::TensorPtr>> input_tensors;
1137 if (graph_compiler_info.exist_flatten_concat_) {
1138 input_tensors = GetRunGraphInputs(graph_compiler_info, args);
1139 // The tensor needs to be converted to contiguous before being given to the actors.
1140 // After the view feature is supported in the graph mode, the following code will be deleted.
1141 // Single ops(run in pynative mode) output to net(context is graph mode) input.
1142 (void)std::for_each(input_tensors.begin(), input_tensors.end(), [this](const auto &tensor_vec) {
1143 (void)std::for_each(tensor_vec.begin(), tensor_vec.end(), [](const tensor::TensorPtr &t) {
1144 runtime::DeviceAddressUtils::ConvertContiguousTensorSync(t);
1145 runtime::DeviceAddressUtils::CreateKernelTensor(t);
1146 });
1147 });
1148 }
1149 // Release python gil.
1150 mindspore::ScopedLongRunning long_running;
1151 // Run actor DAG.
1152 const auto &actor_set = runtime::GraphScheduler::GetInstance().Fetch(actor_info);
1153 MS_EXCEPTION_IF_NULL(actor_set);
1154 runtime::GraphScheduler::GetInstance().Run(actor_set, input_tensors, args);
1155
1156 {
1157 uint64_t start_time = 0;
1158 PROFILER_START(start_time);
1159 MS_EXCEPTION_IF_NULL(graph_compiler_);
1160 graph_compiler_->Summary(graph_compiler_info.graphs_);
1161 ConstructOutputs(actor_set, outputs, root_graph_);
1162 actor_set->output_actor_->FreeSummaryNodeMem();
1163 runtime::GraphScheduler::GetInstance().ClearActorData(actor_set);
1164 PROFILER_END(start_time, runtime::ProfilerModule::kKernel, runtime::ProfilerEvent::kOutputProcess, actor_set->name_,
1165 false);
1166 }
1167 // Close abstract_lock for dynamic_shape
1168 AnfUtils::CloseAbstractLock();
1169 (void)profiler::CollectHostInfo(kModelNameRuntime, kEventRunGraph, kStageRunGraph, 1, 0, 1);
1170 MS_LOG(INFO) << "Status record: end run actor: " << actor_info;
1171 }
1172
GetRandomStatus(const ActorInfo & actor_info)1173 std::string MindRTBackendBase::GetRandomStatus(const ActorInfo &actor_info) {
1174 auto iter = actor_to_graph_compiler_info_.find(actor_info);
1175 if (iter == actor_to_graph_compiler_info_.end()) {
1176 MS_LOG(EXCEPTION) << "Cannot find actor info " << actor_info;
1177 }
1178 MS_EXCEPTION_IF_NULL(iter->second);
1179
1180 auto device_context =
1181 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name_, device_id_});
1182 MS_EXCEPTION_IF_NULL(device_context);
1183 if (device_context->graph_executor_ == nullptr) {
1184 return "";
1185 }
1186 std::vector<FuncGraphPtr> graphs;
1187 std::transform(iter->second->graphs_.begin(), iter->second->graphs_.end(), std::back_inserter(graphs),
1188 [](const auto &g) -> FuncGraphPtr { return g; });
1189 return device_context->graph_executor_->GetRandomStatus(graphs);
1190 }
1191
1192 namespace {
IsTupleOutputOfAnyType(const abstract::AbstractBasePtr & abstract,const tensor::TensorPtr & tensor)1193 bool IsTupleOutputOfAnyType(const abstract::AbstractBasePtr &abstract, const tensor::TensorPtr &tensor) {
1194 if (abstract == nullptr || !abstract->isa<abstract::AbstractAny>() || tensor == nullptr) {
1195 return false;
1196 }
1197 auto device_tensor = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1198 return device_tensor != nullptr && device_tensor->user_data() == nullptr &&
1199 device_tensor->kernel_tensor() != nullptr && device_tensor->kernel_tensor()->GetShape() != nullptr &&
1200 device_tensor->kernel_tensor()->GetShape()->isa<abstract::SequenceShape>();
1201 }
1202 } // namespace
1203
ConstructOutputByAbstract(const abstract::AbstractBasePtr & abstract,const std::vector<tensor::TensorPtr> & output_tensors,size_t * output_position,std::vector<tensor::TensorPtr> * tuple_tensors)1204 BaseRef MindRTBackendBase::ConstructOutputByAbstract(const abstract::AbstractBasePtr &abstract,
1205 const std::vector<tensor::TensorPtr> &output_tensors,
1206 size_t *output_position,
1207 std::vector<tensor::TensorPtr> *tuple_tensors) {
1208 MS_EXCEPTION_IF_NULL(abstract);
1209 MS_EXCEPTION_IF_NULL(output_position);
1210 MS_EXCEPTION_IF_NULL(tuple_tensors);
1211
1212 size_t outputs_num = common::AnfAlgo::GetOutputNumByAbstract(abstract);
1213 if (*output_position + outputs_num > output_tensors.size()) {
1214 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The output position is out of range: "
1215 << *output_position << " need:" << outputs_num << " total:" << output_tensors.size();
1216 }
1217
1218 if (!abstract->isa<abstract::AbstractSequence>()) {
1219 if (IsTupleOutputOfAnyType(abstract, output_tensors[*output_position])) {
1220 MS_LOG(DEBUG) << "Any output for position:" << *output_position;
1221 VectorRef outputs;
1222 auto device_tensor =
1223 std::dynamic_pointer_cast<device::DeviceAddress>(output_tensors[*output_position]->device_address());
1224 ConstructOutputByTupleTensor(output_tensors[*output_position],
1225 device_tensor->kernel_tensor()->GetShape()->cast<abstract::SequenceShapePtr>(),
1226 &outputs, tuple_tensors);
1227 (*output_position)++;
1228 std::vector<ValuePtr> values;
1229
1230 (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(values),
1231 [](const auto &output) { return utils::cast<ValuePtr>(output); });
1232 return std::make_shared<ValueList>(values);
1233 }
1234
1235 (*output_position)++;
1236 return output_tensors[(*output_position) - 1];
1237 }
1238
1239 VectorRef outputs;
1240 const auto &tuple_abstract = abstract->cast<abstract::AbstractSequencePtr>();
1241 MS_EXCEPTION_IF_NULL(tuple_abstract);
1242 // Dynamic len tuple.
1243 if (tuple_abstract->dynamic_len()) {
1244 auto &output_tensor = output_tensors[*output_position];
1245 MS_EXCEPTION_IF_NULL(output_tensor);
1246 auto &tensor_shape = output_tensor->base_shape_ptr();
1247 // Restore the tuple output by the tensor of tuple.
1248 if ((tensor_shape != nullptr) && tensor_shape->isa<abstract::SequenceShape>()) {
1249 ConstructOutputByTupleTensor(output_tensor, tensor_shape->cast<abstract::SequenceShapePtr>(), &outputs,
1250 tuple_tensors);
1251 (*output_position)++;
1252 return outputs;
1253 }
1254 }
1255
1256 const auto &sub_abstracts = tuple_abstract->elements();
1257 for (const auto &sub_abstract : sub_abstracts) {
1258 MS_EXCEPTION_IF_NULL(sub_abstract);
1259 outputs.emplace_back(ConstructOutputByAbstract(sub_abstract, output_tensors, output_position, tuple_tensors));
1260 }
1261 return outputs;
1262 }
1263
ConstructOutputByTupleTensor(tensor::TensorPtr output_tensor,const abstract::SequenceShapePtr & tensor_shape,VectorRef * outputs,std::vector<tensor::TensorPtr> * tuple_tensors) const1264 void MindRTBackendBase::ConstructOutputByTupleTensor(tensor::TensorPtr output_tensor,
1265 const abstract::SequenceShapePtr &tensor_shape, VectorRef *outputs,
1266 std::vector<tensor::TensorPtr> *tuple_tensors) const {
1267 MS_EXCEPTION_IF_NULL(output_tensor);
1268 MS_EXCEPTION_IF_NULL(tensor_shape);
1269 MS_EXCEPTION_IF_NULL(outputs);
1270 MS_EXCEPTION_IF_NULL(tuple_tensors);
1271 MS_LOG(DEBUG) << "Tensor shape:" << tensor_shape->ToString();
1272 // If outputs an empty sequence return an empty sequence value.
1273 if (tensor_shape->size() == 0) {
1274 if (tensor_shape->isa<abstract::TupleShape>()) {
1275 outputs->emplace_back(std::make_shared<ValueTuple>(std::vector<ValuePtr>()));
1276 } else {
1277 outputs->emplace_back(std::make_shared<ValueList>(std::vector<ValuePtr>()));
1278 }
1279 return;
1280 }
1281 // No need split multi tensors when the tuple size is not greater than 1.
1282 if (tensor_shape->size() <= 1) {
1283 outputs->emplace_back(output_tensor);
1284 return;
1285 }
1286
1287 auto tensor_type_id = output_tensor->data_type();
1288 auto device_tensor = std::dynamic_pointer_cast<device::DeviceAddress>(output_tensor->device_address());
1289 MS_EXCEPTION_IF_NULL(device_tensor);
1290 auto tensor_device_ptr = device_tensor->GetMutablePtr();
1291 auto tensor_device_size = device_tensor->GetSize();
1292 MS_EXCEPTION_IF_NULL(tensor_device_ptr);
1293 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1294 {device_tensor->device_name(), device_tensor->device_id()});
1295 MS_EXCEPTION_IF_NULL(device_context);
1296 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
1297
1298 const auto &output_kernel_tensor = device_tensor->kernel_tensor();
1299 MS_EXCEPTION_IF_NULL(output_kernel_tensor);
1300 TypePtr output_type = output_kernel_tensor->GetType();
1301 MS_EXCEPTION_IF_NULL(output_type);
1302 TuplePtr output_tuple_type = output_type->cast<TuplePtr>();
1303 MS_EXCEPTION_IF_NULL(output_tuple_type);
1304 const auto &element_types = output_tuple_type->elements();
1305 if (tensor_shape->size() != element_types.size()) {
1306 MS_LOG(EXCEPTION) << "The tensor shape size[" << tensor_shape->size() << "] is not equal to output element size["
1307 << element_types.size() << "].";
1308 }
1309
1310 // Split the tensor of tuple to tensors.
1311 (void)tuple_tensors->emplace_back(output_tensor);
1312 size_t copy_offset_size = 0;
1313 for (size_t i = 0; i < tensor_shape->size(); ++i) {
1314 // Create split tensor.
1315 auto split_tensor_shape = BaseShapeToShape((*tensor_shape)[i]);
1316 auto split_tensor_size = SizeOf(split_tensor_shape) * GetTypeByte(TypeIdToType(tensor_type_id));
1317 auto split_tensor = std::make_shared<tensor::Tensor>(tensor_type_id, split_tensor_shape);
1318
1319 auto kernel_tensor = std::make_shared<kernel::KernelTensor>(
1320 nullptr, split_tensor_size, kernel::GetFormatFromStrToEnum(device_tensor->format()), device_tensor->type_id(),
1321 split_tensor_shape, device_context->device_context_key().device_name_,
1322 device_context->device_context_key().device_id_);
1323 kernel_tensor->SetType(element_types[i]);
1324 kernel_tensor->SetShape((*tensor_shape)[i]);
1325 kernel_tensor->set_stream_id(device_tensor->stream_id());
1326 auto split_device_tensor = device_context->device_res_manager_->CreateDeviceAddress(kernel_tensor);
1327 MS_LOG(DEBUG) << "Create device tensor:" << split_device_tensor << " type:" << device_tensor->type_id();
1328 // Copy data from origin tensor to the split tensor.
1329 device::DynamicMemAllocatorDebugInfo::SetDebugInfo("Split tuple outputs", device::AllocatorType::kOther);
1330 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddTask, "ConstructOutputByTupleTensor",
1331 "ConstructOutputByTupleTensor", "");
1332 device::tracker::CALL_MEMORY_TRACKER_WITH_FILE(AddMemInfo, "ConstructOutputByTupleTensor",
1333 device::tracker::MemType::kOther, split_device_tensor->GetSize(),
1334 split_device_tensor.get());
1335 if (!device_context->device_res_manager_->AllocateMemory(split_device_tensor.get())) {
1336 MS_LOG(EXCEPTION) << "#umsg#Memory not enough:#umsg#Device(id:" << device_context->device_context_key().device_id_
1337 << ") memory isn't enough and alloc failed, kernel name: Split tuple outputs, alloc size: "
1338 << split_device_tensor->GetSize() << "B.";
1339 }
1340 if (copy_offset_size + split_tensor_size > tensor_device_size) {
1341 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The copy size is out of range, copy size:"
1342 << split_tensor_size << ", copy offset size:" << copy_offset_size
1343 << ", total size:" << tensor_device_size;
1344 }
1345 if (!split_device_tensor->SyncDeviceToDevice(split_tensor_shape, split_tensor_size, device_tensor->type_id(),
1346 AddressOffset(tensor_device_ptr, copy_offset_size),
1347 device_tensor->format())) {
1348 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#Sync device to device failed, device type:"
1349 << split_device_tensor->GetDeviceType() << ", copy size:" << split_tensor_size
1350 << ", output node: Split tuple outputs.";
1351 }
1352 copy_offset_size += split_tensor_size;
1353
1354 // Fill the outputs.
1355 split_tensor->set_device_address(split_device_tensor);
1356 outputs->emplace_back(split_tensor);
1357 }
1358 }
1359
1360 namespace {
IsEmptySequence(const AnfNodePtr & output_node,const std::vector<tensor::TensorPtr> & output_tensors,const size_t * const output_position)1361 bool IsEmptySequence(const AnfNodePtr &output_node, const std::vector<tensor::TensorPtr> &output_tensors,
1362 const size_t *const output_position) {
1363 MS_EXCEPTION_IF_NULL(output_node);
1364 MS_EXCEPTION_IF_NULL(output_position);
1365 // When the output node is a valuenode, the position may out of range.
1366 if (*output_position >= output_tensors.size()) {
1367 return false;
1368 }
1369
1370 if (output_node->abstract() == nullptr || (!output_node->abstract()->isa<abstract::AbstractSequence>())) {
1371 return false;
1372 }
1373 const auto &tuple_abs = output_node->abstract()->cast<abstract::AbstractSequencePtr>();
1374 MS_EXCEPTION_IF_NULL(tuple_abs);
1375 if ((!tuple_abs->dynamic_len()) && tuple_abs->dynamic_len_element_abs() == nullptr) {
1376 return false;
1377 }
1378 const auto &tensor = output_tensors[*output_position];
1379 MS_EXCEPTION_IF_NULL(tensor);
1380 if (tensor->base_shape_ptr() == nullptr || (!tensor->base_shape_ptr()->isa<abstract::SequenceShape>())) {
1381 return false;
1382 }
1383 const auto &sequence_shape = tensor->base_shape_ptr()->cast<abstract::SequenceShapePtr>();
1384 MS_EXCEPTION_IF_NULL(sequence_shape);
1385 return sequence_shape->size() == 0;
1386 }
1387 } // namespace
1388
ConstructOutputs(const AnfNodePtr & output_node,const std::vector<tensor::TensorPtr> & output_tensors,size_t * output_position,VectorRef * outputs,std::vector<tensor::TensorPtr> * tuple_tensors)1389 void MindRTBackendBase::ConstructOutputs(const AnfNodePtr &output_node,
1390 const std::vector<tensor::TensorPtr> &output_tensors, size_t *output_position,
1391 VectorRef *outputs, std::vector<tensor::TensorPtr> *tuple_tensors) {
1392 MS_EXCEPTION_IF_NULL(output_node);
1393 MS_EXCEPTION_IF_NULL(outputs);
1394 MS_EXCEPTION_IF_NULL(output_position);
1395 MS_EXCEPTION_IF_NULL(tuple_tensors);
1396 static const PrimitiveSet expand_prims{
1397 prim::kPrimMakeTuple,
1398 prim::kPrimMakeCSRTensor,
1399 prim::kPrimMakeCOOTensor,
1400 prim::kPrimMakeRowTensor,
1401 };
1402 MS_LOG(DEBUG) << "output node:" << output_node->DebugString();
1403 // If outputs an empty sequence return an empty sequence value.
1404 if (IsEmptySequence(output_node, output_tensors, output_position)) {
1405 if (output_node->abstract()->isa<abstract::AbstractTuple>()) {
1406 outputs->emplace_back(std::make_shared<ValueTuple>(std::vector<ValuePtr>()));
1407 } else {
1408 outputs->emplace_back(std::make_shared<ValueList>(std::vector<ValuePtr>()));
1409 }
1410 ++(*output_position);
1411 return;
1412 }
1413
1414 // The MakeTuple/MakeSaprse node need expand and recurse.
1415 if (IsOneOfPrimitiveCNode(output_node, expand_prims)) {
1416 auto make_tuple = output_node->cast<CNodePtr>();
1417 MS_EXCEPTION_IF_NULL(make_tuple);
1418 VectorRef make_tuple_output;
1419 for (size_t i = 1; i < make_tuple->size(); i++) {
1420 ConstructOutputs(make_tuple->input(i), output_tensors, output_position, &make_tuple_output, tuple_tensors);
1421 }
1422 outputs->emplace_back(std::move(make_tuple_output));
1423 return;
1424 }
1425
1426 // The depend node need get the real node.
1427 if (common::AnfAlgo::CheckPrimitiveType(output_node, prim::kPrimDepend)) {
1428 auto depend_node = output_node->cast<CNodePtr>();
1429 MS_EXCEPTION_IF_NULL(depend_node);
1430 ConstructOutputs(depend_node->input(kRealInputIndexInDepend), output_tensors, output_position, outputs,
1431 tuple_tensors);
1432 return;
1433 }
1434
1435 auto outputs_num = AnfAlgo::GetOutputElementNum(output_node);
1436 // The value node uses the value to be output, to avoid the host memory of value free due to value node destruction.
1437 if (output_node->isa<ValueNode>()) {
1438 auto value = output_node->cast<ValueNodePtr>()->value();
1439 MS_EXCEPTION_IF_NULL(value);
1440 if (value->isa<ValueSequence>()) {
1441 outputs->emplace_back(value);
1442 (*output_position) += CountValueNum(value->cast<ValueSequencePtr>());
1443 } else if (outputs_num != 0) {
1444 outputs->emplace_back(value);
1445 (*output_position) += outputs_num;
1446 }
1447 // The empty value node return the empty VectorRef.
1448 return;
1449 }
1450
1451 if (common::AnfAlgo::IsCallNode(output_node)) {
1452 auto abstract = output_node->abstract();
1453 MS_EXCEPTION_IF_NULL(abstract);
1454 outputs->emplace_back(ConstructOutputByAbstract(abstract, output_tensors, output_position, tuple_tensors));
1455 return;
1456 }
1457
1458 auto &output_abstract = output_node->abstract();
1459 MS_EXCEPTION_IF_NULL(output_abstract);
1460 // Wrap output to VectorRef if the output is tuple.
1461 MS_LOG(DEBUG) << "output abstract:" << output_abstract->ToString();
1462 if (output_abstract->isa<abstract::AbstractSequence>()) {
1463 VectorRef output_tuple;
1464 for (size_t i = 0; i < outputs_num; ++i) {
1465 MS_LOG(DEBUG) << "output index:" << i;
1466 if (*output_position >= output_tensors.size()) {
1467 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The output position is out of range: "
1468 << *output_position;
1469 }
1470 auto &output_tensor = output_tensors[*output_position];
1471 MS_EXCEPTION_IF_NULL(output_tensor);
1472 auto &tensor_shape = output_tensor->base_shape_ptr();
1473 // Restore the tuple output by the tensor of tuple.
1474 if ((tensor_shape != nullptr) && tensor_shape->isa<abstract::SequenceShape>()) {
1475 ConstructOutputByTupleTensor(output_tensor, tensor_shape->cast<abstract::SequenceShapePtr>(), &output_tuple,
1476 tuple_tensors);
1477 } else {
1478 output_tuple.emplace_back(output_tensor);
1479 }
1480 ++(*output_position);
1481 }
1482 outputs->emplace_back(std::move(output_tuple));
1483 } else {
1484 for (size_t i = 0; i < outputs_num; ++i) {
1485 if (*output_position >= output_tensors.size()) {
1486 MS_LOG(INTERNAL_EXCEPTION) << "#dmsg#Runtime error info:#dmsg#The output position is out of range: "
1487 << *output_position;
1488 }
1489 outputs->emplace_back(output_tensors[*output_position]);
1490 ++(*output_position);
1491 }
1492 }
1493 }
1494
1495 #ifdef ENABLE_DEBUGGER
SetDebuggerInit() const1496 void MindRTBackendBase::SetDebuggerInit() const {
1497 auto debugger_ = Debugger::GetInstance();
1498 auto ms_context = MsContext::GetInstance();
1499 MS_EXCEPTION_IF_NULL(ms_context);
1500 MS_EXCEPTION_IF_NULL(debugger_);
1501 debugger_->Init(device_id_, ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET));
1502 }
1503 #endif
1504
ConstructGraphCompilerInfo(const FuncGraphPtr & root_graph)1505 std::shared_ptr<GraphCompilerInfo> MindRTBackendBase::ConstructGraphCompilerInfo(const FuncGraphPtr &root_graph) {
1506 MS_EXCEPTION_IF_NULL(root_graph);
1507 MS_EXCEPTION_IF_NULL(graph_compiler_);
1508
1509 std::vector<KernelGraphPtr> graphs;
1510 std::vector<DeviceContext *> device_contexts;
1511 std::string name = "kernel_graph";
1512 size_t graph_index = 0;
1513 for (const auto &graph_id_to_context : graph_id_to_device_context_) {
1514 (void)graphs.emplace_back(graph_compiler_->Fetch(graph_id_to_context.first));
1515 (void)device_contexts.emplace_back(graph_id_to_context.second);
1516 if (graph_index == 0) {
1517 (void)name.append("_").append(std::to_string(graph_id_to_context.first));
1518 } else if (graph_index == graph_id_to_device_context_.size() - 1) {
1519 (void)name.append("-").append(std::to_string(graph_id_to_context.first));
1520 }
1521 ++graph_index;
1522 }
1523
1524 auto parser = std::make_shared<ControlNodeParser>();
1525 const auto &root_output =
1526 common::AnfAlgo::VisitKernelWithReturnType(root_graph->output(), 0, false, {prim::kPrimTupleGetItem}).first;
1527 auto outputs_num = common::AnfAlgo::GetAllOutputWithIndex(root_output).size();
1528 runtime::KernelMapPosition outputs_order = FetchOriginOutputOrder(root_graph->output());
1529
1530 std::vector<std::vector<int64_t> *> tensors_mask;
1531 std::vector<std::vector<tensor::TensorPtr> *> input_tensors;
1532 auto strategy = runtime::GraphExecutionStrategy::kPipeline;
1533 auto context_ptr = MsContext::GetInstance();
1534 MS_EXCEPTION_IF_NULL(context_ptr);
1535 if (context_ptr->get_param<int>(MS_CTX_MEMORY_OPTIMIZE_LEVEL) != kOptimizeO0) {
1536 strategy = runtime::GraphExecutionStrategy::kPipelineWithExecutionOrder;
1537 }
1538 auto compile_func = [graph_compiler = this->graph_compiler_](
1539 const GraphSegmentPtr &segment, const std::pair<AnfNodePtrList, AnfNodePtrList> &io_nodes,
1540 const DeviceContext *device_context, device::RunMode run_mode) -> KernelGraphPtr {
1541 auto graph_id = graph_compiler->CompileGraph(segment, io_nodes, device_context, run_mode, false);
1542 return graph_compiler->Fetch(graph_id);
1543 };
1544
1545 return std::make_shared<GraphCompilerInfo>(graphs, device_contexts, tensors_mask, input_tensors, control_nodes_,
1546 root_graph->parameters(), parser, outputs_order, outputs_num,
1547 root_graph->GetPositionalArgsCount(), name, false, strategy, compile_func);
1548 }
1549
ParseControlNodes(const GraphCompilerInfo & graph_compile_info)1550 void MindRTBackendBase::ParseControlNodes(const GraphCompilerInfo &graph_compile_info) {
1551 MS_EXCEPTION_IF_NULL(graph_compiler_);
1552 MS_EXCEPTION_IF_NULL(graph_compile_info.control_node_parser_);
1553
1554 FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
1555 for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) {
1556 const auto &func_graph = func_graph_to_kernel_graph_ids.first;
1557 for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) {
1558 std::vector<KernelGraphPtr> kernel_graphs;
1559 for (const auto &graph_id : sub_kernel_graphs_ids) {
1560 const auto &kernel_graph = graph_compiler_->Fetch(graph_id);
1561 MS_EXCEPTION_IF_NULL(kernel_graph);
1562 (void)kernel_graphs.emplace_back(kernel_graph);
1563 }
1564 (void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs);
1565 }
1566 }
1567
1568 graph_compile_info.control_node_parser_->Parse(control_nodes_, graph_compile_info.graphs_,
1569 graph_compile_info.device_contexts_, root_graph_,
1570 func_graph_to_kernel_graphs);
1571 }
1572
UpdateGraphCompilerInfo(const ActorInfo & actor_info)1573 void MindRTBackendBase::UpdateGraphCompilerInfo(const ActorInfo &actor_info) {
1574 const auto &graph_iter = actor_to_graph_compiler_info_.find(actor_info);
1575 if (graph_iter == actor_to_graph_compiler_info_.end()) {
1576 return;
1577 }
1578 MS_EXCEPTION_IF_NULL(graph_iter->second);
1579 MS_EXCEPTION_IF_NULL(root_graph_);
1580 graph_iter->second->origin_outputs_order_ = FetchOriginOutputOrder(root_graph_->output());
1581 }
1582 } // namespace compile
1583 } // namespace mindspore
1584