1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/session/session_basic.h"
17
18 #include <algorithm>
19 #include <set>
20 #include <queue>
21 #include <unordered_map>
22 #include <utility>
23 #include <functional>
24
25 #include "ops/primitive_c.h"
26 #include "ir/manager.h"
27 #include "abstract/utils.h"
28 #include "backend/kernel_compiler/common_utils.h"
29 #include "base/core_ops.h"
30 #include "base/base_ref_utils.h"
31 #include "common/trans.h"
32 #include "utils/config_manager.h"
33 #include "backend/session/anf_runtime_algorithm.h"
34 #include "backend/session/executor_manager.h"
35 #include "backend/optimizer/common/common_backend_optimization.h"
36 #include "backend/optimizer/common/helper.h"
37 #include "runtime/device/kernel_runtime_manager.h"
38 #include "utils/ms_utils.h"
39 #include "ir/anf.h"
40 #include "ir/func_graph_cloner.h"
41 #include "utils/utils.h"
42 #include "debug/anf_ir_dump.h"
43 #include "debug/dump_proto.h"
44 #include "utils/file_utils.h"
45 #include "utils/trace_base.h"
46 #include "frontend/parallel/context.h"
47 #if ((defined ENABLE_CPU) && (!defined _WIN32))
48 #include "ps/ps_cache/ps_cache_manager.h"
49 #include "ps/constants.h"
50 #include "ps/util.h"
51 #include "ps/ps_context.h"
52 #include "abstract/abstract_value.h"
53 #endif
54 #include "backend/session/session_factory.h"
55 #include "backend/session/pynative_task_manager.h"
56
57 namespace mindspore {
58 namespace session {
59 MS_REG_SESSION(kSessionBasic, SessionBasic);
60
61 namespace {
62 const int kSummaryGetItem = 2;
63 const size_t max_depth = 128;
IsShapeDynamic(const abstract::ShapePtr & shape)64 bool IsShapeDynamic(const abstract::ShapePtr &shape) {
65 if (shape == nullptr) {
66 return false;
67 }
68 return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
69 }
RecursiveCheck(const FuncGraphManagerPtr & manager,const std::pair<AnfNodePtr,int64_t> & kernel,size_t * idx)70 bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodePtr, int64_t> &kernel, size_t *idx) {
71 auto node = kernel.first;
72 MS_EXCEPTION_IF_NULL(manager);
73 MS_EXCEPTION_IF_NULL(node);
74 if (kernel.second > 1 &&
75 (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
76 return false;
77 }
78 if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
79 return true;
80 }
81 (*idx) += 1;
82 // max recursion depth
83 if (*idx <= max_depth) {
84 auto users = manager->node_users()[node];
85 if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
86 return RecursiveCheck(manager, kernel, idx);
87 })) {
88 return true;
89 }
90 }
91 return false;
92 }
93
IsUsedByRealKernel(const FuncGraphManagerPtr & manager,const AnfNodePtr & node,const uint32_t graph_id)94 bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, const uint32_t graph_id) {
95 MS_EXCEPTION_IF_NULL(manager);
96 MS_EXCEPTION_IF_NULL(node);
97 auto node_users = manager->node_users()[node];
98 // filter nodes not in current graph
99 for (auto iter = node_users.begin(); iter != node_users.end();) {
100 auto func_graph = iter->first->func_graph();
101 auto kernel_graph = func_graph->cast<KernelGraphPtr>();
102 if (kernel_graph == nullptr) {
103 MS_LOG(EXCEPTION) << "func graph cast kernel graph failed, related node is: " << iter->first->DebugString();
104 }
105 if (kernel_graph->graph_id() != graph_id) {
106 iter = node_users.erase(iter);
107 } else {
108 ++iter;
109 }
110 }
111
112 size_t idx = 0;
113 if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
114 return RecursiveCheck(manager, kernel, &idx);
115 })) {
116 return true;
117 }
118 return false;
119 }
120
SetInputNodeUsage(const KernelGraphPtr & graph,const FuncGraphManagerPtr & manager)121 void SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager) {
122 MS_EXCEPTION_IF_NULL(graph);
123 MS_EXCEPTION_IF_NULL(manager);
124 auto input_nodes = graph->input_nodes();
125 for (auto &input_node : input_nodes) {
126 if (input_node->isa<Parameter>()) {
127 auto node_ptr = input_node->cast<ParameterPtr>();
128 MS_EXCEPTION_IF_NULL(node_ptr);
129 if (!IsUsedByRealKernel(manager, input_node, graph->graph_id())) {
130 node_ptr->SetNotUsedByRealKernelInGraph(graph->graph_id());
131 }
132 auto shape = node_ptr->Shape();
133 if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
134 node_ptr->set_has_dynamic_shape(true);
135 }
136 }
137 }
138 }
139
GetParamDefaultValue(const AnfNodePtr & node)140 ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
141 if (node == nullptr) {
142 return nullptr;
143 }
144 auto parameter = node->cast<ParameterPtr>();
145 if (parameter == nullptr || !parameter->has_default()) {
146 return nullptr;
147 }
148 return parameter->param_info();
149 }
150
IsPynativeMode()151 static bool IsPynativeMode() {
152 auto ms_context = MsContext::GetInstance();
153 MS_EXCEPTION_IF_NULL(ms_context);
154 return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
155 }
156
GetNodeOutputTensorFromInputs(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors)157 BaseRef GetNodeOutputTensorFromInputs(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
158 const std::vector<tensor::TensorPtr> &input_tensors) {
159 auto &node = node_output_pair.first;
160 MS_EXCEPTION_IF_NULL(node);
161 if (HasAbstractMonad(node)) {
162 return std::make_shared<tensor::Tensor>(int64_t(0), kBool);
163 }
164 // if node is a value node, no need sync addr from device to host
165 if (node->isa<ValueNode>()) {
166 auto value_node = node->cast<ValueNodePtr>();
167 MS_EXCEPTION_IF_NULL(value_node);
168 return value_node->value();
169 }
170 if (IsPynativeMode()) {
171 return nullptr;
172 }
173 if (!node->isa<Parameter>()) {
174 return nullptr;
175 }
176 MS_EXCEPTION_IF_NULL(graph);
177 auto param_node = node->cast<ParameterPtr>();
178 if (param_node != nullptr && param_node->IsUsedByRealKernelInGraph(graph->graph_id())) {
179 return nullptr;
180 }
181 for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
182 if (input_idx >= input_tensors.size()) {
183 MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
184 }
185 if (graph->inputs()[input_idx] == node) {
186 return input_tensors[input_idx];
187 }
188 }
189 return nullptr;
190 }
191
ShapeSize(const std::vector<int64_t> & shape)192 int64_t ShapeSize(const std::vector<int64_t> &shape) {
193 return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
194 }
195
CreateNodeOutputTensor(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node)196 BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
197 const std::vector<tensor::TensorPtr> &input_tensors,
198 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
199 auto &node = node_output_pair.first;
200 size_t output_index = node_output_pair.second;
201 MS_EXCEPTION_IF_NULL(node);
202 MS_EXCEPTION_IF_NULL(graph);
203 auto tensor_from_input = GetNodeOutputTensorFromInputs(node_output_pair, graph, input_tensors);
204 if (tensor_from_input != nullptr) {
205 return tensor_from_input;
206 }
207 TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
208 if (type_id == kTypeUnknown) {
209 type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
210 }
211 std::vector<int64_t> temp_shape;
212 auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
213 (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
214 if (AnfAlgo::IsDynamicShape(node)) {
215 auto max_shape = AnfAlgo::GetOutputMaxShape(node, output_index);
216 temp_shape = ShapeSize(max_shape) > ShapeSize(temp_shape) ? max_shape : temp_shape;
217 }
218 tensor::TensorPtr tensor;
219 bool is_internal_output = graph->IsInternalOutput(node, output_index);
220 if (is_internal_output) {
221 tensor = graph->GetInternalOutputTensor(node, output_index);
222 if (tensor == nullptr) {
223 tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
224 graph->AddInternalOutputTensor(node, output_index, tensor);
225 }
226 } else {
227 tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
228 }
229 MS_EXCEPTION_IF_NULL(tensor);
230 tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
231 if (is_internal_output) {
232 tensor->set_sync_status(kNoNeedSync);
233 } else {
234 // if in pynative mode,data only copied to host when user want to print data
235 auto ms_context = MsContext::GetInstance();
236 MS_EXCEPTION_IF_NULL(ms_context);
237 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
238 ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
239 tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
240 } else {
241 tensor->set_sync_status(kNeedSyncDeviceToHost);
242 }
243 }
244 tensor->SetIsGraphOutput();
245 (*tensor_to_node)[tensor] = node_output_pair;
246 return tensor;
247 }
248
CreateNodeOutputTensors(const AnfNodePtr & anf,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node,KernelMapTensor * node_to_tensor)249 BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph,
250 const std::vector<tensor::TensorPtr> &input_tensors,
251 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
252 KernelMapTensor *node_to_tensor) {
253 MS_EXCEPTION_IF_NULL(anf);
254 MS_EXCEPTION_IF_NULL(tensor_to_node);
255 MS_EXCEPTION_IF_NULL(node_to_tensor);
256 MS_LOG(DEBUG) << "Create tensor for output[" << anf->DebugString() << "]";
257 auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
258 MS_EXCEPTION_IF_NULL(item_with_index.first);
259 MS_LOG(DEBUG) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
260 // special handle for maketuple
261 if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
262 auto cnode = item_with_index.first->cast<CNodePtr>();
263 MS_EXCEPTION_IF_NULL(cnode);
264 VectorRef ret;
265 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
266 auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node, node_to_tensor);
267 ret.push_back(out);
268 }
269 return ret;
270 }
271 // if is graph return nothing ,the function should return a null anylist
272 size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
273 if (size == 0) {
274 return VectorRef();
275 }
276
277 // The outputs of graph may have the same kernel node, no need to create new tensor.
278 const auto &iter = node_to_tensor->find(item_with_index);
279 if (iter != node_to_tensor->end()) {
280 return iter->second;
281 }
282
283 const auto &tensor = CreateNodeOutputTensor(item_with_index, graph, input_tensors, tensor_to_node);
284 (*node_to_tensor)[item_with_index] = tensor;
285 return tensor;
286 }
287
CreateNewValueNode(const AnfNodePtr & anf,KernelGraph * graph)288 ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
289 MS_EXCEPTION_IF_NULL(anf);
290 MS_EXCEPTION_IF_NULL(graph);
291 auto value_node = anf->cast<ValueNodePtr>();
292 MS_EXCEPTION_IF_NULL(value_node);
293 auto value = value_node->value();
294 MS_EXCEPTION_IF_NULL(value);
295 if (value->isa<None>()) {
296 return nullptr;
297 }
298 auto new_value_node = graph->NewValueNode(value_node);
299 graph->FrontBackendlMapAdd(anf, new_value_node);
300 graph->AddValueNodeToGraph(new_value_node);
301 return new_value_node;
302 }
303
ConstructRunOpParameter(const std::shared_ptr<KernelGraph> & graph,const tensor::TensorPtr & input_tensor,int64_t tensor_mask)304 ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
305 int64_t tensor_mask) {
306 MS_EXCEPTION_IF_NULL(graph);
307 auto param = graph->NewParameter();
308 MS_EXCEPTION_IF_NULL(param);
309 if (tensor_mask == kParameterWeightTensorMask) {
310 param->set_default_param(input_tensor);
311 }
312 // set the kernel info of parameter
313 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
314 MS_EXCEPTION_IF_NULL(input_tensor);
315 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
316 if (device_address == nullptr) {
317 kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
318 TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
319 kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
320 } else {
321 kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
322 kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
323 kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
324 AnfAlgo::SetOutputAddr(device_address, 0, param.get());
325 }
326 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
327 // construct abstract of parameter
328 auto type_of_tensor = input_tensor->Dtype();
329 auto shape_of_tensor = input_tensor->shape();
330 auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
331 param->set_abstract(abstract);
332 return param;
333 }
334
DumpGraphOutput(const Any & any,size_t recurse_level=0)335 void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
336 MS_LOG(INFO) << "Graph outputs:";
337 const size_t max_deep = 10;
338 if (recurse_level > max_deep) {
339 MS_LOG(INFO) << "Recurse too deep";
340 return;
341 }
342 std::string tab_str;
343 for (size_t i = 0; i < recurse_level; i++) {
344 tab_str = tab_str.append(" ");
345 }
346 if (any.is<AnyList>()) {
347 (void)tab_str.append("{");
348 MS_LOG(INFO) << tab_str;
349 auto any_list = any.cast<AnyList>();
350 for (auto &it : any_list) {
351 DumpGraphOutput(it, recurse_level + 1);
352 }
353 (void)tab_str.append("}");
354 MS_LOG(INFO) << tab_str;
355 }
356 (void)tab_str.append(any.ToString());
357 MS_LOG(INFO) << tab_str;
358 }
359
360 #ifndef ENABLE_SECURITY
ExistSummaryNode(const KernelGraph * graph)361 bool ExistSummaryNode(const KernelGraph *graph) {
362 MS_EXCEPTION_IF_NULL(graph);
363 auto ret = graph->get_return();
364 MS_EXCEPTION_IF_NULL(ret);
365 auto all_nodes = DeepLinkedGraphSearch(ret);
366 for (auto &n : all_nodes) {
367 if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
368 IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
369 return true;
370 }
371 }
372 return false;
373 }
374 #endif
375
CreateNodeOutputPlaceholder(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<size_t> & indexes,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)376 BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
377 const std::vector<tensor::TensorPtr> &input_tensors,
378 const std::vector<size_t> &indexes,
379 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
380 auto &node = node_output_pair.first;
381 MS_EXCEPTION_IF_NULL(node);
382 MS_EXCEPTION_IF_NULL(graph);
383 MS_EXCEPTION_IF_NULL(output_indexes);
384 MS_LOG(DEBUG) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second
385 << "]";
386 // if node is a value node, no need sync addr from device to host
387 if (node->isa<ValueNode>()) {
388 auto value_node = node->cast<ValueNodePtr>();
389 MS_EXCEPTION_IF_NULL(value_node);
390 return value_node->value();
391 }
392 if (node->isa<Parameter>()) {
393 for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
394 if (input_idx >= input_tensors.size()) {
395 MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
396 }
397 if (graph->inputs()[input_idx] == node) {
398 return input_tensors[input_idx];
399 }
400 }
401 MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
402 }
403 (*output_indexes)[node_output_pair].emplace_back(indexes);
404 BaseRef output_placeholder = std::make_shared<BaseRef>();
405 return output_placeholder;
406 }
407
CreateNodeOutputPlaceholder(const AnfNodePtr & anf,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<size_t> & indexes,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)408 BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
409 const std::vector<tensor::TensorPtr> &input_tensors,
410 const std::vector<size_t> &indexes,
411 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
412 MS_EXCEPTION_IF_NULL(anf);
413 MS_EXCEPTION_IF_NULL(output_indexes);
414 MS_LOG(DEBUG) << "Create placeholder for output[" << anf->DebugString() << "]";
415 auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
416 MS_EXCEPTION_IF_NULL(item_with_index.first);
417 MS_LOG(DEBUG) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
418 // special handle for maketuple
419 if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
420 auto cnode = item_with_index.first->cast<CNodePtr>();
421 MS_EXCEPTION_IF_NULL(cnode);
422 VectorRef ret;
423 for (size_t i = 1; i < cnode->inputs().size(); ++i) {
424 std::vector<size_t> cur_index = indexes;
425 cur_index.emplace_back(i - 1);
426 auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes);
427 ret.push_back(out);
428 }
429 return ret;
430 }
431 // if is graph return nothing ,the function should return a null anylist
432 size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
433 if (size == 0) {
434 return VectorRef();
435 }
436 return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
437 }
438
CheckInputTensorShape(const TensorPtr & tensor,const CNodePtr & kernel,size_t input_index)439 void CheckInputTensorShape(const TensorPtr &tensor, const CNodePtr &kernel, size_t input_index) {
440 MS_EXCEPTION_IF_NULL(tensor);
441 const auto &tensor_shape = tensor->shape();
442 const auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
443 if (tensor_shape.size() != input_shape.size()) {
444 MS_LOG(EXCEPTION) << "The input tensor's shape size: " << tensor_shape.size()
445 << " is not equal to expected size: " << input_shape.size() << " for input[" << input_index
446 << "] of kernel: " << AnfAlgo::GetCNodeName(kernel);
447 }
448 for (size_t i = 0; i < tensor_shape.size(); i++) {
449 if (tensor_shape[i] < 0 || static_cast<size_t>(tensor_shape[i]) != input_shape[i]) {
450 MS_LOG(EXCEPTION) << "The input tensor's shape: " << tensor_shape
451 << " is not equal to expected shape: " << input_shape << " for input[" << input_index
452 << "] of kernel: " << AnfAlgo::GetCNodeName(kernel);
453 }
454 }
455 }
456
UpdateGraphAquireGilAttr(const NotNull<KernelGraphPtr> & root_graph)457 void UpdateGraphAquireGilAttr(const NotNull<KernelGraphPtr> &root_graph) {
458 for (const auto &cnode : root_graph->execution_order()) {
459 if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPyFunc)) {
460 MS_LOG(INFO) << "The Graph require GIL. Graph id: " << root_graph->graph_id();
461 root_graph->set_is_need_gil(true);
462 return;
463 }
464 }
465 return;
466 }
467
ExistGraphCaller(const AnfNodePtr & partial_node)468 bool ExistGraphCaller(const AnfNodePtr &partial_node) {
469 MS_EXCEPTION_IF_NULL(partial_node);
470 auto partial_cnode = partial_node->cast<CNodePtr>();
471 MS_EXCEPTION_IF_NULL(partial_cnode);
472 auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
473 MS_EXCEPTION_IF_NULL(partial_graph);
474 auto graph_nodes = TopoSort(partial_graph->get_return());
475 return std::any_of(graph_nodes.begin(), graph_nodes.end(), IsValueNode<FuncGraph>);
476 }
477
478 // 1. Convert the node to make_tuple if the node is a ValueNode<ValueTuple> and it's the input of 'return' node.
479 // 2. Set the return of graph if node is "Return" node.
SetReturnNode(const AnfNodePtr & node,KernelGraph * graph)480 void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
481 MS_EXCEPTION_IF_NULL(graph);
482 MS_EXCEPTION_IF_NULL(node);
483
484 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
485 constexpr auto kReturnInputIdx = 1;
486 auto return_node = node->cast<CNodePtr>();
487 graph->set_return(return_node);
488 auto graph_output = return_node->input(kReturnInputIdx);
489 MS_EXCEPTION_IF_NULL(graph_output);
490
491 // If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
492 // match this pattern because that pass begin with output node but return node. So we add transform value tuple
493 // to make_tuple here.
494 if (AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
495 return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
496 }
497 }
498 }
499 } // namespace
500
501 GraphId SessionBasic::graph_sum_ = 0;
502
InitExecutor(const std::string & device_name,uint32_t device_id)503 void SessionBasic::InitExecutor(const std::string &device_name, uint32_t device_id) {
504 device_id_ = device_id;
505 context_ = std::make_shared<Context>(device_name, device_id);
506 executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);
507 }
508
GetGraphIdByNode(const AnfNodePtr & front_anf) const509 GraphId SessionBasic::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
510 for (const auto &graph_item : graphs_) {
511 auto graph = graph_item.second;
512 MS_EXCEPTION_IF_NULL(graph);
513 // if front_anf is a parameter,the backend parameter may have two
514 if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
515 return graph_item.first;
516 }
517 }
518 MS_EXCEPTION_IF_NULL(front_anf);
519 MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
520 return kInvalidGraphId;
521 }
522
GetGraph(mindspore::GraphId graph_id) const523 KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
524 auto it = graphs_.find(graph_id);
525 if (it == graphs_.end()) {
526 MS_LOG(INFO) << "Can't find graph " << graph_id;
527 return nullptr;
528 }
529 return it->second;
530 }
531
ClearGraph()532 void SessionBasic::ClearGraph() {
533 auto graph_iter = graphs_.begin();
534 while (graph_iter != graphs_.end()) {
535 graph_iter->second.reset();
536 graphs_.erase(graph_iter++);
537 }
538 graph_sum_ = 0;
539 }
540
InitInternalOutputParameter(const AnfNodePtr & out_node,const AnfNodePtr & parameter)541 void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter) {
542 auto graph_id = GetGraphIdByNode(out_node);
543 if (graph_id == kInvalidGraphId) {
544 return;
545 }
546 auto node_graph = GetGraph(graph_id);
547 if (node_graph == nullptr) {
548 return;
549 }
550 MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
551 auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
552 if (ref_node == nullptr) {
553 MS_LOG(INFO) << "No corresponding internal output for output node";
554 return;
555 }
556 size_t output_idx = 0;
557 if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
558 output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
559 }
560 auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
561 auto ref_real_node = real_kernel.first;
562 auto ref_real_node_index = real_kernel.second;
563 if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
564 auto kernel_info = ref_real_node->kernel_info();
565 if (kernel_info == nullptr || !kernel_info->has_build_info()) {
566 MS_LOG(INFO) << "No kernel info";
567 return;
568 }
569 if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
570 MS_LOG(INFO) << "No kernel address";
571 return;
572 }
573 auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
574 auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
575 auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
576 auto d_kernel_info = std::make_shared<device::KernelInfo>();
577 MS_EXCEPTION_IF_NULL(d_kernel_info);
578 parameter->set_kernel_info(d_kernel_info);
579 kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
580 builder.SetOutputsDeviceType({type});
581 builder.SetOutputsFormat({format});
582 d_kernel_info->set_select_kernel_build_info(builder.Build());
583 AnfAlgo::SetOutputAddr(address, 0, parameter.get());
584 auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
585 parameter->Shape()->cast<abstract::BaseShapePtr>());
586 parameter->set_abstract(abstract);
587 }
588 }
589
CreateParameterFromTuple(const AnfNodePtr & node,KernelGraph * graph)590 AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
591 MS_EXCEPTION_IF_NULL(node);
592 MS_EXCEPTION_IF_NULL(graph);
593 auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
594 auto parameters = AnfAlgo::GetAllOutput(new_parameter);
595 std::vector<AnfNodePtr> pre_graph_out = {node};
596 // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
597 if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
598 pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
599 }
600
601 for (size_t i = 0; i < parameters.size(); ++i) {
602 const auto ¶meter = parameters[i];
603 // In control flow, if the input of the cnode is a call node, it will be processed as a make_tuple input,
604 // which needs to be linked when processing the internal node.
605 graph->CacheInternalParameterToFrontNode(parameter, {node, i});
606 auto valid_inputs = graph->MutableValidInputs();
607 MS_EXCEPTION_IF_NULL(valid_inputs);
608 auto graph_inputs = graph->MutableInputs();
609 MS_EXCEPTION_IF_NULL(graph_inputs);
610 valid_inputs->push_back(true);
611 graph_inputs->push_back(parameter);
612 }
613 size_t param_index = 0;
614 for (const auto &out_node : pre_graph_out) {
615 size_t output_size = AnfAlgo::GetOutputTensorNum(out_node);
616 for (size_t i = 0; i < output_size; i++) {
617 if (param_index >= parameters.size()) {
618 MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString()
619 << ",out_node:" << out_node->DebugString();
620 }
621 InitInternalOutputParameter(out_node, parameters[param_index++]);
622 }
623 }
624 return new_parameter;
625 }
626
CreateNewParameterFromParameter(const AnfNodePtr & anf,KernelGraph * graph)627 ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
628 MS_EXCEPTION_IF_NULL(anf);
629 if (!anf->isa<Parameter>()) {
630 MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
631 }
632 MS_EXCEPTION_IF_NULL(graph);
633 auto param_value = GetParamDefaultValue(anf);
634 auto valid_inputs = graph->MutableValidInputs();
635 MS_EXCEPTION_IF_NULL(valid_inputs);
636 auto graph_inputs = graph->MutableInputs();
637 MS_EXCEPTION_IF_NULL(graph_inputs);
638 ParameterPtr new_parameter = nullptr;
639 // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
640 if (param_value != nullptr) {
641 new_parameter = param_value->parameter();
642 }
643 if (new_parameter == nullptr) {
644 TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
645 new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
646
647 auto input_node_iter = partial_parameters_map_.find(anf);
648 if (input_node_iter != partial_parameters_map_.end()) {
649 InitInternalOutputParameter(input_node_iter->second, new_parameter);
650 }
651
652 if (param_value != nullptr) {
653 param_value->set_parameter(new_parameter);
654 }
655 }
656 new_parameter->IncreaseUsedGraphCount();
657 graph_inputs->push_back(new_parameter);
658 valid_inputs->push_back(true);
659 return new_parameter;
660 }
661
CreateNewParameterFromCNode(const AnfNodePtr & anf,KernelGraph * graph)662 AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
663 MS_EXCEPTION_IF_NULL(anf);
664 MS_EXCEPTION_IF_NULL(graph);
665 MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
666 return CreateParameterFromTuple(anf, graph);
667 }
668
GetCNodeInfo(const CNodePtr & cnode,std::vector<AnfNodePtr> * cnode_inputs) const669 void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const {
670 MS_EXCEPTION_IF_NULL(cnode);
671 MS_EXCEPTION_IF_NULL(cnode_inputs);
672 auto prim = AnfAlgo::GetCNodePrimitive(cnode);
673 if (prim != nullptr) {
674 // push attr to inputs[0] of new cnode
675 cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
676 } else {
677 auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
678 MS_EXCEPTION_IF_NULL(fg);
679 auto new_fg = BasicClone(fg);
680 cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
681 }
682 }
683
GetNewCNodeInputs(const CNodePtr & cnode,KernelGraph * graph,std::vector<AnfNodePtr> * cnode_inputs,std::unordered_map<AnfNodePtr,AnfNodePtr> * other_graph_cnode)684 void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
685 std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
686 MS_EXCEPTION_IF_NULL(cnode);
687 MS_EXCEPTION_IF_NULL(graph);
688 MS_EXCEPTION_IF_NULL(other_graph_cnode);
689 MS_EXCEPTION_IF_NULL(cnode_inputs);
690 auto origin_inputs = cnode->inputs();
691 const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
692 // if has multiple depends,only select first depend as parameter
693 for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
694 auto anf = origin_inputs[input_idx];
695 MS_EXCEPTION_IF_NULL(anf);
696 // anf has been created before
697 if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
698 (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
699 continue;
700 } else if ((is_depend && input_idx > kRealInputIndexInDepend)) {
701 cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
702 continue;
703 } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
704 cnode_inputs->push_back((*other_graph_cnode)[anf]);
705 continue;
706 } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
707 // if input is a value node,
708 auto new_value_node = CreateNewValueNode(anf, graph);
709 if (new_value_node != nullptr) {
710 (void)cnode_inputs->emplace_back(new_value_node);
711 }
712 continue;
713 } else if (anf->isa<Parameter>()) {
714 auto new_parameter = CreateNewParameterFromParameter(anf, graph);
715 cnode_inputs->push_back(new_parameter);
716 graph->FrontBackendlMapAdd(anf, new_parameter);
717 continue;
718 } else {
719 // the input node is a cnode from other graph
720 auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
721 if (parameter_from_cnode == nullptr) {
722 parameter_from_cnode = NewValueNode(MakeValue(SizeToLong(input_idx)));
723 }
724 if (parameter_from_cnode->isa<Parameter>() && IsPrimitiveCNode(anf, prim::kPrimLoad)) {
725 auto para = parameter_from_cnode->cast<ParameterPtr>();
726 auto load_cnode = anf->cast<CNodePtr>();
727 para->set_name(load_cnode->input(kFirstDataInputIndex)->fullname_with_scope());
728 }
729 cnode_inputs->push_back(parameter_from_cnode);
730 (*other_graph_cnode)[anf] = parameter_from_cnode;
731 }
732 }
733 }
734
CreateNewCNode(const CNodePtr & cnode,KernelGraph * graph,std::unordered_map<AnfNodePtr,AnfNodePtr> * other_graph_cnode)735 CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
736 std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
737 MS_EXCEPTION_IF_NULL(cnode);
738 MS_EXCEPTION_IF_NULL(graph);
739 MS_EXCEPTION_IF_NULL(other_graph_cnode);
740 // get primitive of old node
741 std::vector<AnfNodePtr> cnode_inputs;
742 GetCNodeInfo(cnode, &cnode_inputs);
743 GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
744 TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
745 auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
746 return new_cnode;
747 }
748
CreateSwitchInput(const CNodePtr & cnode,const AnfNodePtr & node_input,KernelGraph * graph)749 CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) {
750 MS_EXCEPTION_IF_NULL(node_input);
751 MS_EXCEPTION_IF_NULL(graph);
752 // switch input generalizes partial
753 std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
754 if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
755 auto backend_node = graph->GetBackendAnfByFrontAnf(node_input);
756 return backend_node->cast<CNodePtr>();
757 } else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
758 partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
759 } else {
760 KernelGraphPtr kernel_graph = NewKernelGraph();
761 MS_EXCEPTION_IF_NULL(kernel_graph);
762 auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
763 MS_EXCEPTION_IF_NULL(parameter);
764 parameter->set_abstract(cnode->abstract());
765 auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
766 auto return_node = kernel_graph->NewCNode({primitive, parameter});
767 return_node->set_abstract(cnode->abstract());
768 kernel_graph->set_return(return_node);
769 partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
770 partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
771 }
772 auto partial_node = graph->NewCNode(partial_inputs);
773 return partial_node;
774 }
775
CreateCallSwitchInputs(const CNodePtr & cnode,KernelGraph * graph)776 std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) {
777 MS_EXCEPTION_IF_NULL(cnode);
778 MS_EXCEPTION_IF_NULL(graph);
779 std::vector<AnfNodePtr> cnode_inputs = {
780 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
781 auto attr_input = cnode->input(kAnfPrimitiveIndex);
782 MS_EXCEPTION_IF_NULL(attr_input);
783 auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
784 auto switch_cnode = cnode_input->cast<CNodePtr>();
785 MS_EXCEPTION_IF_NULL(switch_cnode);
786 if (cnode->inputs().size() <= 1) {
787 cnode_inputs = switch_cnode->inputs();
788 return cnode_inputs;
789 }
790 std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
791 switch_cnode->input(kFirstDataInputIndex)};
792 for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
793 auto node = switch_cnode->input(index);
794 // there is real input in call, should put it to true and false branch in switch
795 if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
796 auto partial_node = node->cast<CNodePtr>();
797 MS_EXCEPTION_IF_NULL(partial_node);
798 std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
799 // Put all call args at the end of partial inputs.
800 for (size_t i = kFirstDataInputIndex; i < cnode->size(); ++i) {
801 (void)partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(i)));
802 }
803 auto new_partial = graph->NewCNode(partial_inputs);
804 (void)switch_inputs.emplace_back(new_partial);
805 }
806 }
807 if (switch_inputs.size() < kSwitchInputSize) {
808 MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
809 }
810 auto switch_node = graph->NewCNode(switch_inputs);
811 (void)cnode_inputs.emplace_back(switch_node);
812 return cnode_inputs;
813 }
814
ProcessNodeRetFunc(const CNodePtr & cnode,KernelGraph * graph,const std::vector<AnfNodePtr> & real_inputs)815 void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
816 const std::vector<AnfNodePtr> &real_inputs) {
817 MS_EXCEPTION_IF_NULL(cnode);
818 // func1 =switch(branch1, branch2)
819 // func2 = func1(param1)
820 // out = func2(param2)
821 // process the last cnode(func2), not func1 which abstract is AbstractFunction
822 if (cnode->abstract()->isa<abstract::AbstractFunction>()) {
823 return;
824 }
825 MS_EXCEPTION_IF_NULL(graph);
826 auto ret = graph->get_return();
827 MS_EXCEPTION_IF_NULL(ret);
828 auto return_input = ret->input(kFirstDataInputIndex);
829 // return node is a function
830 std::vector<AnfNodePtr> call_inputs = {
831 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
832 if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
833 auto return_input_cnode = return_input->cast<CNodePtr>();
834 auto partial_inputs = return_input_cnode->inputs();
835 call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
836 } else if (IsValueNode<KernelGraph>(return_input)) { // return node is kernel graph
837 call_inputs.emplace_back(return_input);
838 } else { // return node is value node
839 KernelGraphPtr kernel_graph = NewKernelGraph();
840 auto valid_inputs = kernel_graph->MutableValidInputs();
841 MS_EXCEPTION_IF_NULL(valid_inputs);
842 auto graph_inputs = kernel_graph->MutableInputs();
843 MS_EXCEPTION_IF_NULL(graph_inputs);
844 std::vector<AnfNodePtr> cnode_inputs = {return_input};
845 for (auto &real_input : real_inputs) {
846 auto new_parameter = kernel_graph->NewParameter(real_input->abstract());
847 valid_inputs->push_back(true);
848 graph_inputs->push_back(new_parameter);
849 cnode_inputs.push_back(new_parameter);
850 }
851 auto new_cnode = kernel_graph->NewCNode(cnode_inputs);
852 new_cnode->set_abstract(cnode->abstract());
853 std::vector<AnfNodePtr> return_inputs = {
854 kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))), new_cnode};
855 auto return_node = kernel_graph->NewCNode(return_inputs);
856 return_node->set_abstract(cnode->abstract());
857 kernel_graph->set_return(return_node);
858 call_inputs.push_back(std::make_shared<ValueNode>(kernel_graph));
859 }
860
861 // new call node inputs
862 for (auto &input_node : real_inputs) {
863 auto parameter_for_input = CreateNewParameterFromCNode(input_node, graph);
864 call_inputs.emplace_back(parameter_for_input);
865 }
866
867 auto call_node = graph->NewCNode(call_inputs);
868 call_node->set_abstract(cnode->abstract());
869 // update return input
870 ret->set_input(kFirstDataInputIndex, call_node);
871 }
872
CreateCallSwitchLayerInputs(const CNodePtr & cnode,KernelGraph * graph)873 std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
874 MS_EXCEPTION_IF_NULL(cnode);
875 MS_EXCEPTION_IF_NULL(graph);
876 std::vector<AnfNodePtr> cnode_inputs = {
877 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
878 auto attr_input = cnode->input(kAnfPrimitiveIndex);
879 MS_EXCEPTION_IF_NULL(attr_input);
880 auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
881 auto switch_layer_cnode = cnode_input->cast<CNodePtr>();
882 MS_EXCEPTION_IF_NULL(switch_layer_cnode);
883 std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
884 switch_layer_cnode->input(kFirstDataInputIndex)};
885 auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex);
886 MS_EXCEPTION_IF_NULL(make_tuple_node);
887 auto node = make_tuple_node->cast<CNodePtr>();
888 MS_EXCEPTION_IF_NULL(node);
889 auto make_tuple_inputs = node->inputs();
890 // there are real inputs in call, should put it to make_tuple in switch_layer
891 std::vector<AnfNodePtr> real_inputs;
892 for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) {
893 real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx)));
894 }
895 std::vector<AnfNodePtr> new_make_tuple_inputs = {
896 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
897 for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
898 auto partial_idx = make_tuple_inputs[idx];
899 MS_EXCEPTION_IF_NULL(cnode->abstract());
900 std::vector<AnfNodePtr> new_partial_inputs;
901 KernelGraphPtr partial_kernel_graph;
902 // switch_layer node input is partial cnode
903 if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
904 auto partial_node = partial_idx->cast<CNodePtr>();
905 MS_EXCEPTION_IF_NULL(partial_node);
906 auto partial_input = partial_node->input(kFirstDataInputIndex);
907 partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
908 new_partial_inputs = partial_node->inputs();
909 } else if (IsValueNode<KernelGraph>(partial_idx)) { // switch_layer node input is kernel graph value node
910 new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
911 new_partial_inputs.emplace_back(partial_idx);
912 partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_idx);
913 }
914 // when branch in swich_layer return function
915 MS_EXCEPTION_IF_NULL(partial_kernel_graph);
916 auto ret = partial_kernel_graph->get_return();
917 MS_EXCEPTION_IF_NULL(ret);
918 auto return_input = ret->input(kFirstDataInputIndex);
919 if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
920 ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
921 }
922 // partial node add input args
923 new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
924 // create new partial node
925 auto new_partial = graph->NewCNode(new_partial_inputs);
926 new_make_tuple_inputs.emplace_back(new_partial);
927 }
928 auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
929 auto abstract = make_tuple_node->abstract();
930 if (abstract == nullptr) {
931 abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
932 }
933 new_make_tuple->set_abstract(abstract);
934 switch_layer_inputs.emplace_back(new_make_tuple);
935 auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
936 cnode_inputs.emplace_back(new_switch_layer);
937 return cnode_inputs;
938 }
939
CreateSwitchOrPartialNode(const CNodePtr & cnode,KernelGraph * graph)940 std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
941 MS_EXCEPTION_IF_NULL(cnode);
942 MS_EXCEPTION_IF_NULL(graph);
943 // create primitive of cnode:call(partial or switch or switch_layer)
944 std::vector<AnfNodePtr> cnode_inputs = {
945 graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
946 auto attr_input = cnode->input(kAnfPrimitiveIndex);
947 MS_EXCEPTION_IF_NULL(attr_input);
948 auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
949 if (cnode_input == nullptr) {
950 MS_LOG(ERROR) << "CNode input[0] is CNode:" << attr_input->DebugString() << ", but input[0] has not been created.";
951 return {};
952 }
953 // if the node is partial, insert the inputs of partial to the call
954 if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
955 auto partial_node = attr_input->cast<CNodePtr>();
956 MS_EXCEPTION_IF_NULL(partial_node);
957 auto partial_inputs = partial_node->inputs();
958 (void)std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
959 std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
960 MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
961 return graph->GetBackendAnfByFrontAnf(node);
962 });
963 return cnode_inputs;
964 } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
965 return CreateCallSwitchInputs(cnode, graph);
966 } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
967 return CreateCallSwitchLayerInputs(cnode, graph);
968 }
969 MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
970 << "must be partial or switch or switch_layer.";
971 return {};
972 }
973
CreateValueNode(const CNodePtr & cnode,KernelGraph * graph)974 std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
975 MS_EXCEPTION_IF_NULL(cnode);
976 MS_EXCEPTION_IF_NULL(graph);
977 std::vector<AnfNodePtr> cnode_inputs;
978 auto attr_input = cnode->input(kAnfPrimitiveIndex);
979 MS_EXCEPTION_IF_NULL(attr_input);
980 if (AnfAlgo::IsGraphKernel(cnode)) {
981 auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
982 MS_EXCEPTION_IF_NULL(fg);
983 auto new_fg = BasicClone(fg);
984 cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
985 } else {
986 // create primitive of cnode:call
987 cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
988 // create a ValueNode<KernelGraph> as input of cnode:call
989 if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
990 cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
991 } else {
992 auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
993 if (new_value_node != nullptr) {
994 cnode_inputs.emplace_back(new_value_node);
995 }
996 }
997 }
998 return cnode_inputs;
999 }
1000
CreateCNodeInputs(const CNodePtr & cnode,KernelGraph * graph,std::vector<AnfNodePtr> * cnode_inputs)1001 void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs) {
1002 MS_EXCEPTION_IF_NULL(cnode);
1003 MS_EXCEPTION_IF_NULL(graph);
1004 if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1005 (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
1006 for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
1007 auto node_input = cnode->input(index);
1008 auto switch_input = CreateSwitchInput(cnode, node_input, graph);
1009 (void)cnode_inputs->emplace_back(switch_input);
1010 }
1011 } else {
1012 for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) {
1013 auto anf = cnode->input(input_idx);
1014 MS_EXCEPTION_IF_NULL(anf);
1015 // anf has been created before
1016 if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
1017 (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
1018 continue;
1019 } else if (IsValueNode<None>(anf)) {
1020 continue;
1021 }
1022 MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
1023 }
1024 }
1025 }
1026
CreateNewCNode(const CNodePtr & cnode,KernelGraph * graph)1027 CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
1028 MS_EXCEPTION_IF_NULL(cnode);
1029 MS_EXCEPTION_IF_NULL(graph);
1030 std::vector<AnfNodePtr> cnode_inputs;
1031 auto attr_input = cnode->input(kAnfPrimitiveIndex);
1032 MS_EXCEPTION_IF_NULL(attr_input);
1033 if (IsValueNode<FuncGraph>(attr_input)) {
1034 // cnode is a graph or a call
1035 cnode_inputs = CreateValueNode(cnode, graph);
1036 } else if (attr_input->isa<CNode>()) {
1037 // cnode ia a call (partial/switch/switch_layer)
1038 // 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
1039 // 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
1040 cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
1041 if (cnode_inputs.empty()) {
1042 MS_LOG_ERROR << "Create switch or partial failed, cnode:" << cnode->DebugString();
1043 return nullptr;
1044 }
1045 } else {
1046 // get primitive of old node
1047 auto prim = AnfAlgo::GetCNodePrimitive(cnode);
1048 MS_EXCEPTION_IF_NULL(prim);
1049 // push attr to inputs[0] of new cnode
1050 cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
1051 }
1052 // handle inputs of cnode except primitive
1053 CreateCNodeInputs(cnode, graph, &cnode_inputs);
1054 TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
1055 auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
1056 // if the cnode is call switch, remove call
1057 if (new_cnode->inputs().size() > 1) {
1058 auto first_input = new_cnode->input(kFirstDataInputIndex);
1059 MS_EXCEPTION_IF_NULL(first_input);
1060 if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
1061 AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
1062 new_cnode = first_input->cast<CNodePtr>();
1063 }
1064 if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
1065 AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
1066 auto abstract = cnode->abstract();
1067 new_cnode = first_input->cast<CNodePtr>();
1068 new_cnode->set_abstract(abstract);
1069 }
1070 }
1071 return new_cnode;
1072 }
1073
CreateValueNodeKernelGraph(const AnfNodePtr & anf,KernelGraph * graph)1074 ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
1075 MS_EXCEPTION_IF_NULL(anf);
1076 MS_EXCEPTION_IF_NULL(graph);
1077 auto value_node = anf->cast<ValueNodePtr>();
1078 MS_EXCEPTION_IF_NULL(value_node);
1079 auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
1080 MS_EXCEPTION_IF_NULL(sub_func_graph);
1081 if (front_backend_graph_map_.find(sub_func_graph.get()) == front_backend_graph_map_.end()) {
1082 MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
1083 }
1084 auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph.get()];
1085
1086 ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
1087 new_value_node->set_abstract(value_node->abstract());
1088 // create new kernel_info of new value_node
1089 auto kernel_info = std::make_shared<device::KernelInfo>();
1090 new_value_node->set_kernel_info(kernel_info);
1091 // create kernel_build_info for new value node
1092 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
1093 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
1094 AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
1095
1096 graph->FrontBackendlMapAdd(anf, new_value_node);
1097
1098 return new_value_node;
1099 }
1100
CreateNewParameter(const AnfNodePtr & anf,KernelGraph * graph)1101 ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
1102 MS_EXCEPTION_IF_NULL(anf);
1103 MS_EXCEPTION_IF_NULL(graph);
1104 if (!anf->isa<Parameter>()) {
1105 MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
1106 }
1107
1108 auto param_value = GetParamDefaultValue(anf);
1109 ParameterPtr new_parameter = nullptr;
1110 // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
1111 if (param_value != nullptr) {
1112 new_parameter = param_value->parameter();
1113 if (new_parameter == nullptr) {
1114 TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1115 new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1116 param_value->set_parameter(new_parameter);
1117 }
1118 } else {
1119 TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1120 new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1121 }
1122
1123 new_parameter->IncreaseUsedGraphCount();
1124
1125 return new_parameter;
1126 }
1127
ConstructKernelGraph(const AnfNodePtrList & lst,const AnfNodePtrList & outputs,bool common_opt)1128 KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
1129 bool common_opt) {
1130 std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
1131 auto graph = NewKernelGraph();
1132 MS_EXCEPTION_IF_NULL(graph);
1133 MS_LOG(INFO) << "Create graph: " << graph->graph_id();
1134 for (const auto &node : lst) {
1135 MS_EXCEPTION_IF_NULL(node);
1136 MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
1137 if (!node->isa<CNode>()) {
1138 MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
1139 }
1140 auto cnode = node->cast<CNodePtr>();
1141 MS_EXCEPTION_IF_NULL(cnode);
1142 // create a new cnode object
1143 auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
1144 MS_EXCEPTION_IF_NULL(new_cnode);
1145 new_cnode->set_abstract(cnode->abstract());
1146 new_cnode->set_scope(cnode->scope());
1147 if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
1148 new_cnode->set_fullname_with_scope(cnode->input(kFirstDataInputIndex)->fullname_with_scope());
1149 }
1150 // record map relations between anf from ME and new anf node used in backend
1151 graph->FrontBackendlMapAdd(node, new_cnode);
1152 }
1153 // add a make_tuple at the end of graph as output
1154 graph->set_output(ConstructOutput(outputs, graph));
1155 FuncGraphManagerPtr manager = MakeManager({graph});
1156 if (manager) {
1157 manager->AddFuncGraph(graph);
1158 graph->set_manager(manager);
1159 }
1160 graph->SetExecOrderByDefault();
1161
1162 #ifndef ENABLE_SECURITY
1163 if (ExistSummaryNode(graph.get())) {
1164 graph->set_summary_node_exist(true);
1165 }
1166 #endif
1167
1168 UnifyMindIR(graph);
1169 // Update Graph Dynamic Shape Attr
1170 UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
1171 UpdateGraphAquireGilAttr(NOT_NULL(graph));
1172 if (common_opt) {
1173 opt::BackendCommonOptimization(graph);
1174 }
1175 graph->SetInputNodes();
1176 SetInputNodeUsage(graph, manager);
1177 graph->SetOptimizerFlag();
1178 return graph;
1179 }
1180
GetSingleOpGraphInfo(const CNodePtr & kernel,const std::vector<tensor::TensorPtr> & input_tensors)1181 GraphInfo SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel,
1182 const std::vector<tensor::TensorPtr> &input_tensors) {
1183 MS_EXCEPTION_IF_NULL(kernel);
1184 auto prim = AnfAlgo::GetCNodePrimitive(kernel);
1185 MS_EXCEPTION_IF_NULL(prim);
1186 const AbstractBasePtr &abstract = kernel->abstract();
1187 MS_EXCEPTION_IF_NULL(abstract);
1188 size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
1189 GraphInfo graph_info;
1190 // get input tensor info
1191 for (const auto &tensor : input_tensors) {
1192 MS_EXCEPTION_IF_NULL(tensor);
1193 auto tensor_shape = tensor->shape();
1194 (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
1195 [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
1196 (void)graph_info.append(std::to_string(tensor->data_type()) + "_");
1197 if (tensor->device_address() != nullptr) {
1198 const auto type_id = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id();
1199 (void)graph_info.append(std::to_string(type_id) + "_");
1200 const auto format = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format();
1201 (void)graph_info.append(format + "_");
1202 }
1203 for (const auto &padding_type : tensor->padding_type()) {
1204 (void)graph_info.append(std::to_string(padding_type) + "_");
1205 }
1206 }
1207 // get attr info
1208 const auto &attr_map = prim->attrs();
1209 (void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) {
1210 if (element.second->ToString().empty()) {
1211 return;
1212 }
1213 (void)graph_info.append(element.second->ToString() + "_");
1214 });
1215 auto build_shape = abstract->BuildShape();
1216 MS_EXCEPTION_IF_NULL(build_shape);
1217 (void)graph_info.append(build_shape->ToString() + "_");
1218 for (size_t output_index = 0; output_index < output_num; output_index += 1) {
1219 const auto output_type = AnfAlgo::GetOutputInferDataType(kernel, output_index);
1220 (void)graph_info.append(std::to_string(output_type) + "_");
1221 }
1222 graph_info.append(std::to_string(prim->id()));
1223 return graph_info;
1224 }
1225
GetSingleOpRunInfo(const CNodePtr cnode,OpRunInfo * run_info)1226 void SessionBasic::GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) {
1227 MS_EXCEPTION_IF_NULL(cnode);
1228 MS_EXCEPTION_IF_NULL(run_info);
1229 auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
1230 run_info->primitive = primitive;
1231 run_info->op_name = primitive->name();
1232 const auto &abstract = cnode->abstract();
1233 if (abstract == nullptr) {
1234 MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
1235 }
1236 run_info->abstract = abstract;
1237 const auto &shape = abstract->BuildShape();
1238 MS_EXCEPTION_IF_NULL(shape);
1239 run_info->is_dynamic_shape = shape->IsDynamic();
1240 }
1241
GetParameterIndex(const KernelGraph * graph,const std::vector<tensor::TensorPtr> & inputs,std::map<AnfNodePtr,size_t> * parameter_index)1242 void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
1243 std::map<AnfNodePtr, size_t> *parameter_index) {
1244 size_t index = 0;
1245 for (const auto &input_node : graph->inputs()) {
1246 auto params = AnfAlgo::GetAllOutput(input_node);
1247 for (const auto ¶m : params) {
1248 if (index >= inputs.size()) {
1249 MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
1250 << ", input size: " << inputs.size();
1251 }
1252 const auto &input = inputs[index];
1253 MS_EXCEPTION_IF_NULL(input);
1254 // Check shape of input and parameter
1255 const auto &input_shape = input->shape();
1256 const auto ¶m_shape = AnfAlgo::GetOutputInferShape(param, 0);
1257 if (input_shape.size() != param_shape.size()) {
1258 MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
1259 << ", parameter: " << param->fullname_with_scope();
1260 }
1261 bool is_dynamic = param->Shape()->IsDynamic();
1262 for (size_t i = 0; i < input_shape.size(); i += 1) {
1263 if (input_shape[i] < 0 || (static_cast<size_t>(input_shape[i]) != param_shape[i] && !is_dynamic)) {
1264 MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
1265 << ", parameter: " << param->fullname_with_scope();
1266 }
1267 }
1268 parameter_index->emplace(param, index++);
1269 }
1270 }
1271 }
1272
CreateOutputPlaceholder(const KernelGraphPtr & kernel_graph,const std::vector<tensor::TensorPtr> & input_tensors,VectorRef * const outputs,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)1273 void SessionBasic::CreateOutputPlaceholder(
1274 const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *const outputs,
1275 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
1276 MS_EXCEPTION_IF_NULL(kernel_graph);
1277 MS_EXCEPTION_IF_NULL(outputs);
1278 MS_EXCEPTION_IF_NULL(output_indexes);
1279 auto anf_outputs = kernel_graph->outputs();
1280 size_t index = 0;
1281 for (auto &item : anf_outputs) {
1282 MS_EXCEPTION_IF_NULL(item);
1283 std::vector<size_t> indexes{index++};
1284 outputs->emplace_back(CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes));
1285 }
1286 }
1287
GetRefCount(const KernelGraph * graph,std::map<KernelWithIndex,size_t> * ref_count)1288 void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
1289 MS_EXCEPTION_IF_NULL(graph);
1290 for (const auto &kernel : graph->execution_order()) {
1291 for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
1292 const auto &input = kernel->input(i);
1293 auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
1294 const auto &node = kernel_with_index.first;
1295 if (node->isa<CNode>()) {
1296 (*ref_count)[kernel_with_index] += 1;
1297 }
1298 }
1299 }
1300 }
1301
HandleOpInputs(const std::set<KernelWithIndex> & input_kernel,std::map<KernelWithIndex,size_t> * ref_count,std::map<KernelWithIndex,tensor::TensorPtr> * op_output_map)1302 void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
1303 std::map<KernelWithIndex, size_t> *ref_count,
1304 std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
1305 MS_EXCEPTION_IF_NULL(ref_count);
1306 MS_EXCEPTION_IF_NULL(op_output_map);
1307 for (auto &kernel_with_index : input_kernel) {
1308 MS_EXCEPTION_IF_NULL(kernel_with_index.first);
1309 if (!kernel_with_index.first->isa<CNode>()) {
1310 continue;
1311 }
1312 auto ref_iter = ref_count->find(kernel_with_index);
1313 if (ref_iter == ref_count->end()) {
1314 MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
1315 << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
1316 }
1317 // Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
1318 ref_iter->second -= 1;
1319 if (ref_iter->second != 0) {
1320 continue;
1321 }
1322 ref_count->erase(ref_iter);
1323 auto output_iter = op_output_map->find(kernel_with_index);
1324 if (output_iter == op_output_map->end()) {
1325 MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
1326 << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
1327 }
1328 op_output_map->erase(output_iter);
1329 }
1330 }
1331
HandleOpOutputs(const AnfNodePtr & kernel,const VectorRef & op_outputs,const std::map<KernelWithIndex,size_t> & ref_count,std::map<KernelWithIndex,tensor::TensorPtr> * op_output_map,GraphOutputInfo * const graph_output_info)1332 void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
1333 const std::map<KernelWithIndex, size_t> &ref_count,
1334 std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map,
1335 GraphOutputInfo *const graph_output_info) {
1336 MS_EXCEPTION_IF_NULL(kernel);
1337 MS_EXCEPTION_IF_NULL(op_output_map);
1338 MS_EXCEPTION_IF_NULL(graph_output_info);
1339 MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
1340 auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
1341 if (output_tensors.size() > op_outputs.size()) {
1342 MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
1343 }
1344 size_t out_index = 0;
1345 for (const auto &output_tensor : output_tensors) {
1346 auto kernel_with_index = make_pair(kernel, out_index++);
1347 if (ref_count.find(kernel_with_index) != ref_count.end()) {
1348 (*op_output_map)[kernel_with_index] = output_tensor;
1349 }
1350 const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
1351 if (iter == graph_output_info->output_indexes.end()) {
1352 continue;
1353 }
1354 const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
1355 for (const auto &ref_indexes : multiple_ref_indexes) {
1356 size_t n = 0;
1357 const VectorRef *cur_vector_ref = graph_output_info->graph_outputs;
1358 for (; n < ref_indexes.size() - 1; n += 1) {
1359 size_t index = ref_indexes.at(n);
1360 if (index >= cur_vector_ref->size()) {
1361 MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
1362 << cur_vector_ref->size();
1363 }
1364 const BaseRef &base_ref = (*cur_vector_ref)[index];
1365 if (!utils::isa<VectorRef>(base_ref)) {
1366 MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
1367 }
1368 cur_vector_ref = &utils::cast<VectorRef>(base_ref);
1369 }
1370 BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
1371 tensor_ref = output_tensor;
1372 graph_output_info->graph_output_tensors.emplace_back(output_tensor);
1373 }
1374 }
1375 }
GetValueNodeOutputTensor(const AnfNodePtr & node,size_t output_index)1376 TensorPtr SessionBasic::GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index) {
1377 MS_EXCEPTION_IF_NULL(node);
1378 if (!node->isa<ValueNode>()) {
1379 return nullptr;
1380 }
1381 auto value_node = node->cast<ValueNodePtr>();
1382 MS_EXCEPTION_IF_NULL(value_node);
1383 auto value = GetValueNode(value_node);
1384 MS_EXCEPTION_IF_NULL(value);
1385 if (value->isa<ValueTuple>()) {
1386 auto value_tuple = value->cast<ValueTuplePtr>();
1387 MS_EXCEPTION_IF_NULL(value_tuple);
1388 if (output_index >= value_tuple->size()) {
1389 MS_LOG(EXCEPTION) << "Index " << output_index << "is out of value tuple range";
1390 }
1391 auto tensor_value = value_tuple->value()[output_index];
1392 if (tensor_value->isa<tensor::Tensor>()) {
1393 return tensor_value->cast<tensor::TensorPtr>();
1394 }
1395 } else if (value->isa<tensor::Tensor>()) {
1396 if (output_index != 0) {
1397 MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << output_index;
1398 }
1399 return value->cast<TensorPtr>();
1400 }
1401 return nullptr;
1402 }
1403
GetParameterOutputTensor(const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs)1404 TensorPtr SessionBasic::GetParameterOutputTensor(const AnfNodePtr &node,
1405 const std::map<AnfNodePtr, size_t> ¶meter_index,
1406 const std::vector<tensor::TensorPtr> &graph_inputs) {
1407 MS_EXCEPTION_IF_NULL(node);
1408 if (!node->isa<Parameter>()) {
1409 return nullptr;
1410 }
1411 const auto &iter = parameter_index.find(node);
1412 if (iter == parameter_index.end()) {
1413 MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, parameter = " << node->DebugString();
1414 }
1415 const size_t index = iter->second;
1416 if (index >= graph_inputs.size()) {
1417 MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " << index
1418 << ", input tensor size = " << graph_inputs.size();
1419 }
1420 return graph_inputs[index];
1421 }
1422
GetCNodeOutputTensor(const KernelWithIndex & kernel_with_index,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output)1423 TensorPtr SessionBasic::GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
1424 const std::map<KernelWithIndex, tensor::TensorPtr> &op_output) {
1425 const auto &iter = op_output.find(kernel_with_index);
1426 if (iter == op_output.end()) {
1427 MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << kernel_with_index.first->DebugString();
1428 }
1429 return iter->second;
1430 }
1431
GetOpInputTensors(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputTensorInfo * input_tensor_info)1432 void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
1433 const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
1434 const std::map<AnfNodePtr, size_t> ¶meter_index,
1435 const std::vector<tensor::TensorPtr> &graph_inputs,
1436 InputTensorInfo *input_tensor_info) {
1437 MS_EXCEPTION_IF_NULL(cnode);
1438 MS_EXCEPTION_IF_NULL(input_tensor_info);
1439 const auto input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
1440 for (size_t i = 1; i <= input_tensor_num; i += 1) {
1441 const auto &input = cnode->input(i);
1442 auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
1443 auto real_input = kernel_with_index.first;
1444 MS_EXCEPTION_IF_NULL(real_input);
1445 tensor::TensorPtr tensor = nullptr;
1446 if (real_input->isa<ValueNode>()) {
1447 tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
1448 } else if (real_input->isa<Parameter>()) {
1449 tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
1450 } else if (real_input->isa<CNode>()) {
1451 tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
1452 if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
1453 CheckInputTensorShape(tensor, cnode, i - 1);
1454 }
1455 input_tensor_info->input_kernel.insert(kernel_with_index);
1456 } else {
1457 MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
1458 }
1459 MS_EXCEPTION_IF_NULL(tensor);
1460 MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
1461 << real_input->fullname_with_scope() << "-" << kernel_with_index.second;
1462 input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask
1463 : kParameterDataTensorMask);
1464 input_tensor_info->input_tensors.emplace_back(tensor);
1465 }
1466 }
1467
GetOpInputTensorByIndex(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputTensorInfo * const input_tensor_info,size_t input_index)1468 tensor::TensorPtr SessionBasic::GetOpInputTensorByIndex(const CNodePtr &cnode,
1469 const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
1470 const std::map<AnfNodePtr, size_t> ¶meter_index,
1471 const std::vector<tensor::TensorPtr> &graph_inputs,
1472 InputTensorInfo *const input_tensor_info, size_t input_index) {
1473 MS_EXCEPTION_IF_NULL(cnode);
1474 MS_EXCEPTION_IF_NULL(input_tensor_info);
1475 if (input_index >= cnode->inputs().size() - 1) {
1476 MS_LOG(EXCEPTION) << "Input index is out of range:" << cnode->inputs().size() << ",cnode:" << cnode->DebugString();
1477 }
1478
1479 const auto &input = cnode->input(input_index + 1);
1480 auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
1481 auto real_input = kernel_with_index.first;
1482 MS_EXCEPTION_IF_NULL(real_input);
1483
1484 if (real_input->isa<Parameter>()) {
1485 return GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
1486 } else if (real_input->isa<CNode>()) {
1487 tensor::TensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
1488 if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
1489 CheckInputTensorShape(tensor, cnode, input_index);
1490 }
1491 input_tensor_info->input_kernel.insert(kernel_with_index);
1492 return tensor;
1493 } else {
1494 MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
1495 }
1496 }
1497
CreateCNodeOfKernelGraph(const AnfNodePtr & node,KernelGraph * graph)1498 bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
1499 MS_EXCEPTION_IF_NULL(node);
1500 MS_EXCEPTION_IF_NULL(graph);
1501 auto cnode = node->cast<CNodePtr>();
1502 MS_EXCEPTION_IF_NULL(cnode);
1503 // create a new cnode object
1504 auto new_cnode = CreateNewCNode(cnode, graph);
1505 if (new_cnode == nullptr) {
1506 return false;
1507 }
1508 new_cnode->set_abstract(cnode->abstract());
1509 std::string fullname;
1510 if (cnode->input(kAnfPrimitiveIndex)->isa<CNode>()) {
1511 fullname = cnode->input(kAnfPrimitiveIndex)->fullname_with_scope();
1512 } else if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
1513 fullname = cnode->input(kFirstDataInputIndex)->fullname_with_scope();
1514 } else {
1515 fullname = cnode->fullname_with_scope();
1516 }
1517 new_cnode->set_fullname_with_scope(fullname);
1518 new_cnode->set_scope(cnode->scope());
1519 graph->FrontBackendlMapAdd(node, new_cnode);
1520 SetReturnNode(new_cnode, graph);
1521 return true;
1522 }
1523
ConstructKernelGraph(const FuncGraphPtr & func_graph,std::vector<KernelGraphPtr> * all_out_graph)1524 std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
1525 std::vector<KernelGraphPtr> *all_out_graph) {
1526 MS_EXCEPTION_IF_NULL(func_graph);
1527 MS_EXCEPTION_IF_NULL(all_out_graph);
1528 auto node_list = TopoSort(func_graph->get_return());
1529 auto graph = NewKernelGraph();
1530 MS_EXCEPTION_IF_NULL(graph);
1531 front_backend_graph_map_[func_graph.get()] = graph;
1532 MS_LOG(INFO) << "Create graph: " << graph->graph_id();
1533 for (const auto &node : node_list) {
1534 MS_EXCEPTION_IF_NULL(node);
1535 MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
1536 // Create parameter
1537 if (node->isa<Parameter>()) {
1538 auto graph_inputs = graph->MutableInputs();
1539 MS_EXCEPTION_IF_NULL(graph_inputs);
1540 auto new_parameter = CreateNewParameter(node, graph.get());
1541 graph_inputs->push_back(new_parameter);
1542 graph->FrontBackendlMapAdd(node, new_parameter);
1543 continue;
1544 }
1545 // Create value node
1546 if (node->isa<ValueNode>()) {
1547 // Create common value node
1548 if (!IsValueNode<FuncGraph>(node)) {
1549 (void)CreateNewValueNode(node, graph.get());
1550 continue;
1551 }
1552 // Create child kernel graph according ValueNode<FuncGraph>
1553 FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
1554 if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
1555 (void)ConstructKernelGraph(child_graph, all_out_graph);
1556 }
1557 (void)CreateValueNodeKernelGraph(node, graph.get());
1558 continue;
1559 }
1560 // Create cnode
1561 if (!CreateCNodeOfKernelGraph(node, graph.get())) {
1562 #ifdef ENABLE_DUMP_IR
1563 DumpIR("construct_kernel_graph_fail.ir", func_graph);
1564 #endif
1565 MS_LOG(EXCEPTION) << "Construct func graph " << func_graph->ToString() << " failed."
1566 << trace::DumpSourceLines(node);
1567 }
1568 }
1569
1570 AddParameterToGraphInputs(func_graph->parameters(), graph.get());
1571 FuncGraphManagerPtr manager = MakeManager({graph});
1572 graph->SetInputNodes();
1573 SetInputNodeUsage(graph, manager);
1574 graph->SetExecOrderByDefault();
1575
1576 #ifndef ENABLE_SECURITY
1577 if (ExistSummaryNode(graph.get())) {
1578 graph->set_summary_node_exist(true);
1579 }
1580 #endif
1581
1582 all_out_graph->push_back(graph);
1583 return graph;
1584 }
1585
AddParameterToGraphInputs(const std::vector<AnfNodePtr> & parameters,KernelGraph * graph)1586 void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph) {
1587 MS_EXCEPTION_IF_NULL(graph);
1588 auto graph_inputs = graph->MutableInputs();
1589 MS_EXCEPTION_IF_NULL(graph_inputs);
1590 graph_inputs->clear();
1591 for (auto ¶meter : parameters) {
1592 MS_EXCEPTION_IF_NULL(parameter);
1593 auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
1594 if (backend_parameter == nullptr) {
1595 // for example "def f(x,y,z) {return x + y}", parameter z in unused
1596 auto new_parameter = CreateNewParameter(parameter, graph);
1597 graph_inputs->push_back(new_parameter);
1598 MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
1599 continue;
1600 }
1601 graph_inputs->push_back(backend_parameter);
1602 }
1603 }
1604
UpdateOutputs(const std::shared_ptr<KernelGraph> & kernel_graph,VectorRef * const outputs,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node) const1605 void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
1606 const std::vector<tensor::TensorPtr> &input_tensors,
1607 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const {
1608 MS_EXCEPTION_IF_NULL(kernel_graph);
1609 MS_EXCEPTION_IF_NULL(outputs);
1610 MS_EXCEPTION_IF_NULL(tensor_to_node);
1611 KernelMapTensor node_to_tensor;
1612 auto anf_outputs = kernel_graph->outputs();
1613 for (auto &item : anf_outputs) {
1614 MS_EXCEPTION_IF_NULL(item);
1615 MS_LOG(DEBUG) << "Update output[" << item->DebugString() << "]";
1616 outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
1617 }
1618
1619 auto ms_context = MsContext::GetInstance();
1620 MS_EXCEPTION_IF_NULL(ms_context);
1621 for (auto &item : *tensor_to_node) {
1622 auto &tensor = item.first;
1623 auto &node = item.second.first;
1624 auto &output_index = item.second.second;
1625 DeviceAddressPtr address = nullptr;
1626 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
1627 ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1628 address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
1629 } else {
1630 address = AnfAlgo::GetMutableOutputAddr(node, output_index);
1631 }
1632 MS_EXCEPTION_IF_NULL(tensor);
1633 tensor->set_device_address(address);
1634 tensor->SetNeedWait(false);
1635 MS_LOG(DEBUG) << "Debug address: Output tensor obj " << tensor.get() << ", tensor id " << tensor->id()
1636 << ", device address " << tensor->device_address().get();
1637 if (AnfAlgo::IsDynamicShape(node)) {
1638 const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
1639 ShapeVector int_shape;
1640 (void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
1641 (void)tensor->set_shape(int_shape);
1642 }
1643 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
1644 tensor->data_sync(false);
1645 tensor->set_sync_status(kNeedSyncHostToDevice);
1646 }
1647 }
1648 }
1649
UpdateOutputAbstract(const std::shared_ptr<KernelGraph> & kernel_graph,OpRunInfo * op_run_info) const1650 void SessionBasic::UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph,
1651 OpRunInfo *op_run_info) const {
1652 MS_EXCEPTION_IF_NULL(kernel_graph);
1653 MS_EXCEPTION_IF_NULL(op_run_info);
1654 const auto &kernels = kernel_graph->execution_order();
1655 for (const auto &kernel : kernels) {
1656 MS_EXCEPTION_IF_NULL(kernel);
1657 if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
1658 op_run_info->abstract = kernel->abstract();
1659 }
1660 }
1661 }
1662
GetInputNeedLockTensors(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs)1663 std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id,
1664 const std::vector<tensor::TensorPtr> &inputs) {
1665 auto graph = GetGraph(graph_id);
1666 MS_EXCEPTION_IF_NULL(graph);
1667 if (!graph->has_optimizer()) {
1668 return {};
1669 }
1670 auto input_nodes = graph->inputs();
1671 bool check_monad = false;
1672 if (input_nodes.size() == inputs.size()) {
1673 check_monad = true;
1674 }
1675 std::vector<tensor::TensorPtr> result;
1676 for (size_t i = 0; i < inputs.size(); ++i) {
1677 if (check_monad && HasAbstractMonad(input_nodes[i])) {
1678 continue;
1679 }
1680 auto &tensor = inputs[i];
1681 MS_EXCEPTION_IF_NULL(tensor);
1682 if (!tensor->IsGraphOutput()) {
1683 result.emplace_back(tensor);
1684 }
1685 }
1686 return result;
1687 }
1688
CreateOutputTensors(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & input_tensors,VectorRef * outputs,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node)1689 void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
1690 VectorRef *outputs,
1691 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
1692 auto kernel_graph = GetGraph(graph_id);
1693 MS_EXCEPTION_IF_NULL(kernel_graph);
1694 MS_EXCEPTION_IF_NULL(outputs);
1695 MS_EXCEPTION_IF_NULL(tensor_to_node);
1696 auto anf_outputs = kernel_graph->outputs();
1697 KernelMapTensor node_to_tensor;
1698 for (auto &item : anf_outputs) {
1699 MS_EXCEPTION_IF_NULL(item);
1700 MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
1701 outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
1702 }
1703 auto ms_context = MsContext::GetInstance();
1704 MS_EXCEPTION_IF_NULL(ms_context);
1705 auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
1706 if (enable_mem_scheduler) {
1707 kernel_graph->SetOutputNodeToTensor(node_to_tensor);
1708 }
1709 }
1710
UpdateOutputTensors(const VectorRef * outputs,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node,std::map<DeviceAddressPtr,DeviceAddressPtr> *)1711 void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
1712 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
1713 std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
1714 auto context_ptr = MsContext::GetInstance();
1715 MS_EXCEPTION_IF_NULL(context_ptr);
1716 auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
1717 if (enable_mem_scheduler) {
1718 return;
1719 }
1720 MS_EXCEPTION_IF_NULL(outputs);
1721 for (const auto &item : *outputs) {
1722 if (utils::isa<VectorRefPtr>(item)) {
1723 const auto &vector_ref = utils::cast<VectorRef>(item);
1724 std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
1725 UpdateOutputTensors(&vector_ref, tensor_to_node, &new_to_old_device_address);
1726 } else if (utils::isa<tensor::TensorPtr>(item)) {
1727 const auto &tensor = utils::cast<tensor::TensorPtr>(item);
1728 MS_EXCEPTION_IF_NULL(tensor);
1729 const auto &iter = tensor_to_node.find(tensor);
1730 if (iter != tensor_to_node.end()) {
1731 const auto &node = iter->second.first;
1732 const auto &output_index = iter->second.second;
1733 if (!AnfAlgo::OutputAddrExist(node, output_index, true)) {
1734 continue;
1735 }
1736 const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);
1737 tensor->set_device_address(address);
1738
1739 if (AnfAlgo::IsDynamicShape(node)) {
1740 const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
1741 ShapeVector int_shape;
1742 (void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
1743 (void)tensor->set_shape(int_shape);
1744 }
1745 }
1746 if (tensor->NeedSyncDeviceToHostImmediately()) {
1747 tensor->data_sync(false);
1748 tensor->set_device_address(nullptr);
1749 tensor->set_sync_status(kNeedSyncHostToDevice);
1750 }
1751 }
1752 }
1753 }
1754
GetModelInputsInfo(uint32_t graph_id,std::vector<tensor::TensorPtr> * inputs,std::vector<std::string> * inputs_name) const1755 void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
1756 std::vector<std::string> *inputs_name) const {
1757 MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
1758 auto kernel_graph = GetGraph(graph_id);
1759 MS_EXCEPTION_IF_NULL(kernel_graph);
1760 MS_EXCEPTION_IF_NULL(inputs);
1761 MS_EXCEPTION_IF_NULL(inputs_name);
1762 auto kernel_graph_inputs = kernel_graph->inputs();
1763 // find parameters of graph inputs
1764 for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
1765 if (!kernel_graph_inputs[i]->isa<Parameter>()) {
1766 MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
1767 continue;
1768 }
1769 auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
1770 if (!AnfAlgo::IsParameterWeight(parameter)) {
1771 vector<int64_t> input_shape;
1772 auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
1773 (void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape),
1774 [](const size_t dim) { return SizeToLong(dim); });
1775 auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
1776 auto data_type = kernel_build_info->GetOutputDeviceType(0);
1777 auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
1778 inputs->push_back(ms_tensor);
1779 inputs_name->push_back(parameter->name());
1780 }
1781 }
1782 }
1783
GetModelOutputsInfo(uint32_t graph_id,std::vector<tensor::TensorPtr> * outputs,std::vector<std::string> * output_names) const1784 void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
1785 std::vector<std::string> *output_names) const {
1786 std::vector<tensor::TensorPtr> inputs;
1787 std::vector<std::string> input_names;
1788 GetModelInputsInfo(graph_id, &inputs, &input_names);
1789
1790 auto kernel_graph = GetGraph(graph_id);
1791 MS_EXCEPTION_IF_NULL(kernel_graph);
1792 MS_EXCEPTION_IF_NULL(outputs);
1793 MS_EXCEPTION_IF_NULL(output_names);
1794
1795 VectorRef vector_outputs;
1796 std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
1797 KernelMapTensor node_to_tensor;
1798 auto anf_outputs = kernel_graph->outputs();
1799 for (auto &item : anf_outputs) {
1800 MS_EXCEPTION_IF_NULL(item);
1801 MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
1802 vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node, &node_to_tensor));
1803 }
1804 *outputs = TransformVectorRefToMultiTensor(vector_outputs);
1805 for (size_t i = 0; i < outputs->size(); i++) {
1806 output_names->push_back("output" + std::to_string(i));
1807 }
1808 }
1809
1810 #ifndef ENABLE_SECURITY
RegisterSummaryCallBackFunc(const CallBackFunc & callback)1811 void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
1812 MS_EXCEPTION_IF_NULL(callback);
1813 summary_callback_ = callback;
1814 }
1815
SetSummaryNodes(KernelGraph * graph)1816 void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
1817 MS_LOG(DEBUG) << "Update summary Start";
1818 MS_EXCEPTION_IF_NULL(graph);
1819 if (!graph->summary_node_exist()) {
1820 return;
1821 }
1822 auto summary = graph->summary_nodes();
1823 auto apply_list = TopoSort(graph->get_return());
1824 for (auto &n : apply_list) {
1825 MS_EXCEPTION_IF_NULL(n);
1826 if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
1827 IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
1828 auto cnode = n->cast<CNodePtr>();
1829 MS_EXCEPTION_IF_NULL(cnode);
1830 if (cnode->inputs().size() <= kSummaryGetItem) {
1831 MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
1832 }
1833 auto node = cnode->input(kSummaryGetItem);
1834 MS_EXCEPTION_IF_NULL(node);
1835 auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
1836 MS_EXCEPTION_IF_NULL(item_with_index.first);
1837 if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
1838 MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
1839 }
1840 summary[n->fullname_with_scope()] = item_with_index;
1841 }
1842 }
1843 graph->set_summary_nodes(summary);
1844 MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
1845 }
1846
Summary(KernelGraph * graph)1847 void SessionBasic::Summary(KernelGraph *graph) {
1848 if (summary_callback_ == nullptr) {
1849 return;
1850 }
1851 MS_EXCEPTION_IF_NULL(graph);
1852 bool exist_summary = graph->summary_node_exist();
1853 if (!exist_summary) {
1854 return;
1855 }
1856
1857 static bool is_first = true;
1858 if (is_first && !IsSupportSummary()) {
1859 is_first = false;
1860 MS_LOG(ERROR) << "The Summary operator can not collect data correctly. Detail: the data sink mode is used and the"
1861 " sink size(in model.train() python api) is not equal to 1.";
1862 }
1863 SetSummaryNodes(graph);
1864 auto summary_outputs = graph->summary_nodes();
1865 std::map<std::string, tensor::TensorPtr> params_list;
1866 // fetch outputs apply kernel in session & run callback functions
1867 for (auto &output_item : summary_outputs) {
1868 auto node = output_item.second.first;
1869 size_t index = IntToSize(output_item.second.second);
1870 auto address = AnfAlgo::GetOutputAddr(node, index);
1871 auto shape = AnfAlgo::GetOutputInferShape(node, index);
1872 TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
1873 std::vector<int64_t> temp_shape;
1874 (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
1875 tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
1876 MS_EXCEPTION_IF_NULL(address);
1877 if (!address->GetPtr()) {
1878 continue;
1879 }
1880 if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
1881 tensor->data_type(), tensor->data_c())) {
1882 MS_LOG(ERROR) << "Failed to sync output from device to host.";
1883 }
1884 tensor->set_sync_status(kNoNeedSync);
1885 params_list[output_item.first] = tensor;
1886 }
1887 // call callback function here
1888 summary_callback_(0, params_list);
1889 }
1890 #endif
1891
1892 namespace {
CNodeFirstInputIsPrimitive(const AnfNodePtr & node)1893 bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
1894 if (node == nullptr) {
1895 return false;
1896 }
1897 auto cnode = node->cast<CNodePtr>();
1898 if (cnode == nullptr) {
1899 return false;
1900 }
1901 auto prim = cnode->input(kAnfPrimitiveIndex);
1902 if (prim == nullptr || !IsValueNode<Primitive>(prim)) {
1903 return false;
1904 }
1905 return true;
1906 }
1907
ExtendNodeUsers(const FuncGraphManagerPtr & front_func_graph_manager,const AnfNodePtr & front_node)1908 std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
1909 const AnfNodePtr &front_node) {
1910 MS_EXCEPTION_IF_NULL(front_func_graph_manager);
1911 auto &users = front_func_graph_manager->node_users()[front_node];
1912 std::vector<AnfNodePtr> result;
1913 for (auto &user : users) {
1914 if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend) ||
1915 AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimLoad)) {
1916 auto depend_cnode = user.first->cast<CNodePtr>();
1917 if (depend_cnode == nullptr) {
1918 continue;
1919 }
1920 if (front_node != depend_cnode->input(1)) {
1921 continue;
1922 }
1923 auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
1924 result.insert(result.end(), res.begin(), res.end());
1925 } else if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
1926 auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
1927 (void)result.insert(result.end(), res.begin(), res.end());
1928 } else {
1929 (void)result.emplace_back(user.first);
1930 }
1931 }
1932 return result;
1933 }
1934
GetSupportedInternalNode(const AnfNodePtr & front_node)1935 AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
1936 MS_EXCEPTION_IF_NULL(front_node);
1937 if (!front_node->isa<CNode>()) {
1938 return nullptr;
1939 }
1940 if (AnfAlgo::IsRealKernel(front_node)) {
1941 return front_node;
1942 }
1943 if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
1944 return front_node;
1945 }
1946 if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimMakeTuple)) {
1947 auto cnode = front_node->cast<CNodePtr>();
1948 MS_EXCEPTION_IF_NULL(cnode);
1949 auto &inputs = cnode->inputs();
1950 if (inputs.size() > 1) {
1951 return GetSupportedInternalNode(inputs[1]);
1952 }
1953 }
1954 if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) {
1955 auto cnode = front_node->cast<CNodePtr>();
1956 MS_EXCEPTION_IF_NULL(cnode);
1957 auto &inputs = cnode->inputs();
1958 if (inputs.size() >= kDependInputSize) {
1959 return GetSupportedInternalNode(inputs[kRealInputIndexInDepend]);
1960 }
1961 }
1962 return nullptr;
1963 }
1964 } // namespace
1965
1966 constexpr auto kMixTarget = "MixTarget";
1967 constexpr auto kNoTarget = "NoTarget";
AddPartialParametersMap(const AnfNodePtr & partial_node)1968 std::string SessionBasic::AddPartialParametersMap(const AnfNodePtr &partial_node) {
1969 MS_EXCEPTION_IF_NULL(partial_node);
1970 auto iter = partial_target_map_.find(partial_node);
1971 if (iter != partial_target_map_.end()) {
1972 return iter->second;
1973 }
1974 auto partial_cnode = partial_node->cast<CNodePtr>();
1975 MS_EXCEPTION_IF_NULL(partial_cnode);
1976 auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
1977 MS_EXCEPTION_IF_NULL(partial_graph);
1978 auto parameters = partial_graph->parameters();
1979 auto partial_inputs = partial_cnode->inputs();
1980 const size_t kNonParameterNum = 2;
1981 if (parameters.size() + kNonParameterNum != partial_inputs.size()) {
1982 return kMixTarget;
1983 }
1984 for (size_t i = 0; i < parameters.size(); ++i) {
1985 partial_parameters_map_[parameters[i]] = partial_inputs[kNonParameterNum + i];
1986 }
1987 auto graph_nodes = TopoSort(partial_graph->get_return());
1988 std::string graph_target = kNoTarget;
1989 for (auto &node : graph_nodes) {
1990 if (!node->isa<CNode>()) {
1991 continue;
1992 }
1993 if (!AnfAlgo::IsRealKernel(node)) {
1994 continue;
1995 }
1996 std::string cur_target = GetCNodeTarget(node);
1997 if (graph_target == kNoTarget) {
1998 graph_target = cur_target;
1999 }
2000 if (graph_target != cur_target) {
2001 graph_target = kMixTarget;
2002 break;
2003 }
2004 }
2005 (void)partial_target_map_.emplace(std::pair<AnfNodePtr, std::string>(partial_node, graph_target));
2006 return graph_target;
2007 }
2008
HandleInternalOutput(const AnfNodePtr & input_front_node,const AnfNodePtr & backend_node,const FuncGraphManagerPtr & front_func_graph_manager,const std::shared_ptr<KernelGraph> & backend_graph)2009 void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
2010 const FuncGraphManagerPtr &front_func_graph_manager,
2011 const std::shared_ptr<KernelGraph> &backend_graph) {
2012 auto front_node = GetSupportedInternalNode(input_front_node);
2013 if (front_node == nullptr) {
2014 return;
2015 }
2016 auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0);
2017 auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0);
2018 auto backend_real_kernel = backend_real_kernel_pair.first;
2019 if (backend_real_kernel == nullptr || !backend_real_kernel->isa<CNode>()) {
2020 return;
2021 }
2022 auto front_real_kernel = front_real_kernel_pair.first;
2023 std::string kernel_target = GetCNodeTarget(front_real_kernel);
2024 bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel);
2025 bool unique_target = true;
2026 if (internal_output && opt::IsNopNode(front_real_kernel)) {
2027 auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
2028 auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
2029 if (pre_node_target != kernel_target) {
2030 unique_target = false;
2031 }
2032 }
2033 if (internal_output) {
2034 auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
2035 for (auto &user : users) {
2036 if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
2037 !ExistGraphCaller(user)) {
2038 auto partial_target = AddPartialParametersMap(user);
2039 if (partial_target != kNoTarget && partial_target != kernel_target) {
2040 unique_target = false;
2041 }
2042 continue;
2043 }
2044 if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
2045 continue;
2046 }
2047 if (!CNodeFirstInputIsPrimitive(user)) {
2048 internal_output = false;
2049 break;
2050 }
2051 if (!AnfAlgo::IsRealKernel(user)) {
2052 internal_output = false;
2053 break;
2054 }
2055 if (kernel_target != GetCNodeTarget(user)) {
2056 unique_target = false;
2057 }
2058 }
2059 }
2060 if (internal_output) {
2061 MS_LOG(INFO) << "AddInternalOutput: " << front_node->DebugString() << " To " << backend_real_kernel->DebugString()
2062 << ", unique_target: " << unique_target;
2063 backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target);
2064 }
2065 }
2066
ConstructOutput(const AnfNodePtrList & outputs,const std::shared_ptr<KernelGraph> & graph)2067 CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
2068 MS_EXCEPTION_IF_NULL(graph);
2069 std::vector<AnfNodePtr> output_args;
2070 for (const auto &output : outputs) {
2071 MS_EXCEPTION_IF_NULL(output);
2072 MS_LOG(INFO) << "Output:" << output->DebugString();
2073 }
2074 auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr {
2075 auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
2076 if (backend_anf != nullptr) {
2077 auto context_ptr = MsContext::GetInstance();
2078 MS_EXCEPTION_IF_NULL(context_ptr);
2079 if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
2080 return backend_anf;
2081 }
2082
2083 MS_EXCEPTION_IF_NULL(out);
2084 auto out_func_graph = out->func_graph();
2085 MS_EXCEPTION_IF_NULL(out_func_graph);
2086 auto out_func_graph_manager = out_func_graph->manager();
2087 if (out_func_graph_manager == nullptr) {
2088 return backend_anf;
2089 }
2090 HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph);
2091 return backend_anf;
2092 }
2093 MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
2094 };
2095 output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
2096 (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
2097 [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
2098 return graph->NewCNode(output_args);
2099 }
2100
CreateOutputNode(const CNodePtr & cnode,const std::shared_ptr<KernelGraph> & graph)2101 void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
2102 std::vector<AnfNodePtr> make_tuple_inputs;
2103 make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
2104 MS_EXCEPTION_IF_NULL(graph);
2105 if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
2106 for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
2107 auto idx = NewValueNode(SizeToLong(output_index));
2108 MS_EXCEPTION_IF_NULL(idx);
2109 auto imm = std::make_shared<Int64Imm>(output_index);
2110 idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
2111 auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
2112 std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
2113 std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
2114 AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
2115 make_tuple_inputs.push_back(getitem);
2116 }
2117 } else {
2118 make_tuple_inputs.push_back(cnode);
2119 }
2120 // create output
2121 auto g_output = graph->NewCNode(make_tuple_inputs);
2122 graph->set_output(g_output);
2123 }
2124
ConstructSingleOpGraph(const OpRunInfo & op_run_info,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<int64_t> & tensors_mask,bool is_ascend)2125 std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
2126 const std::vector<tensor::TensorPtr> &input_tensors,
2127 const std::vector<int64_t> &tensors_mask,
2128 bool is_ascend) {
2129 auto graph = std::make_shared<KernelGraph>();
2130 graph->set_graph_id(graph_sum_);
2131 graph_sum_++;
2132 std::vector<AnfNodePtr> inputs;
2133 // set input[0]
2134 PrimitivePtr op_prim = op_run_info.primitive;
2135 MS_EXCEPTION_IF_NULL(op_prim);
2136 // Decoupling of frontend PrimitivePy and backend Primitive
2137 inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*op_prim)));
2138 // set input parameter
2139 if (input_tensors.size() != tensors_mask.size()) {
2140 MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
2141 << tensors_mask.size();
2142 }
2143 for (size_t i = 0; i < input_tensors.size(); ++i) {
2144 if (tensors_mask[i] == kValueNodeTensorMask) {
2145 auto value_node = graph->NewValueNode(input_tensors[i]);
2146 inputs.push_back(value_node);
2147 continue;
2148 }
2149 auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
2150 inputs.push_back(parameter);
2151 auto mutable_inputs = graph->MutableInputs();
2152 MS_EXCEPTION_IF_NULL(mutable_inputs);
2153 mutable_inputs->push_back(parameter);
2154 }
2155 // set execution order
2156 auto cnode = graph->NewCNode(inputs);
2157 MS_EXCEPTION_IF_NULL(cnode);
2158 // set abstract,which include inferred shapes and types
2159 cnode->set_abstract(op_run_info.abstract);
2160 // get output dynamic shape info
2161 AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info.is_dynamic_shape), cnode);
2162 if (op_run_info.is_auto_mixed_precision) {
2163 AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info.next_op_name), cnode);
2164 AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info.next_input_index), cnode);
2165 }
2166 // set execution order
2167 std::vector<CNodePtr> exe_order = {cnode};
2168 graph->set_execution_order(exe_order);
2169 // set output
2170 if (is_ascend) {
2171 graph->set_output(cnode);
2172 } else {
2173 CreateOutputNode(cnode, graph);
2174 }
2175 graph->SetInputNodes();
2176 auto manager = MakeManager({graph});
2177 if (manager != nullptr) {
2178 manager->AddFuncGraph(graph);
2179 graph->set_manager(manager);
2180 }
2181 auto ms_context = MsContext::GetInstance();
2182 MS_EXCEPTION_IF_NULL(ms_context);
2183 if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
2184 UnifyMindIR(graph);
2185 }
2186 graph->UpdateGraphDynamicAttr();
2187 return graph;
2188 }
2189
NewKernelGraph()2190 KernelGraphPtr SessionBasic::NewKernelGraph() {
2191 auto graph = std::make_shared<KernelGraph>();
2192 graph->set_graph_id(graph_sum_);
2193 graphs_[graph_sum_++] = graph;
2194 return graph;
2195 }
2196
FindPullNode(const AnfNodePtr & push_node,const std::vector<AnfNodePtr> & node_list)2197 AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) {
2198 MS_EXCEPTION_IF_NULL(push_node);
2199 for (auto &node : node_list) {
2200 if (node != nullptr && node->isa<CNode>()) {
2201 for (auto input : node->cast<CNodePtr>()->inputs()) {
2202 if (push_node == AnfAlgo::VisitKernel(input, 0).first) {
2203 if (AnfAlgo::GetCNodeName(node) != kPullOpName) {
2204 MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
2205 }
2206 return node;
2207 }
2208 }
2209 }
2210 }
2211 return nullptr;
2212 }
2213
CompileGraph(const GraphSegmentPtr & segment,const AnfNodePtrList & outputs)2214 GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) {
2215 MS_EXCEPTION_IF_NULL(executor_);
2216 return executor_->CompileGraph(shared_from_this(), segment, outputs);
2217 }
2218
CompileGraph(NotNull<FuncGraphPtr> func_graph)2219 GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
2220 MS_EXCEPTION_IF_NULL(executor_);
2221 return executor_->CompileGraph(shared_from_this(), func_graph);
2222 }
2223
BuildGraph(GraphId graph_id)2224 void SessionBasic::BuildGraph(GraphId graph_id) {
2225 MS_EXCEPTION_IF_NULL(executor_);
2226 executor_->BuildGraph(shared_from_this(), graph_id);
2227 }
2228
RunOp(OpRunInfo * op_run_info,const GraphInfo & graph_info,std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs,const std::vector<int64_t> & tensors_mask)2229 void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
2230 std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
2231 const std::vector<int64_t> &tensors_mask) {
2232 MS_EXCEPTION_IF_NULL(executor_);
2233 executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs, tensors_mask);
2234 }
2235
RunOpsInGraph(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)2236 void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
2237 VectorRef *outputs) {
2238 MS_EXCEPTION_IF_NULL(executor_);
2239 executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs);
2240 }
2241
RunGraph(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)2242 void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
2243 MS_EXCEPTION_IF_NULL(executor_);
2244 executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);
2245 }
2246
RunGraphAsync(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)2247 void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
2248 VectorRef *outputs) {
2249 MS_EXCEPTION_IF_NULL(executor_);
2250 executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
2251 }
2252
RunGraphImpl(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)2253 void SessionBasic::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
2254 VectorRef *const outputs) {
2255 MS_LOG(INFO) << "Run graph start, graph id: " << graph_id;
2256 auto kernel_graph = GetGraph(graph_id);
2257 MS_EXCEPTION_IF_NULL(kernel_graph);
2258 // if none of child graph and no anf output exists
2259 if (!kernel_graph->executable()) {
2260 MS_LOG(INFO) << "No child graph has anf output";
2261 return;
2262 }
2263 PreExecuteGraph(kernel_graph, inputs, outputs);
2264 ExecuteGraph(kernel_graph);
2265 PostExecuteGraph(kernel_graph, inputs, outputs);
2266 MS_LOG(INFO) << "Run graph end, graph id: " << graph_id;
2267 }
2268
RunOpsInGraphImpl(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)2269 void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
2270 VectorRef *outputs) {
2271 MS_LOG(INFO) << "Clean task in Queue";
2272 session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
2273 MS_LOG(INFO) << "Start!";
2274 auto kernel_graph = GetGraph(graph_id);
2275 MS_EXCEPTION_IF_NULL(kernel_graph);
2276 std::map<AnfNodePtr, size_t> parameter_index;
2277 GetParameterIndex(kernel_graph.get(), inputs, ¶meter_index);
2278 GraphOutputInfo graph_output_info;
2279 graph_output_info.graph_outputs = outputs;
2280 CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes);
2281 std::map<KernelWithIndex, size_t> cnode_refcount;
2282 GetRefCount(kernel_graph.get(), &cnode_refcount);
2283 BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
2284
2285 // Clear bucket resources every step
2286 if (kernel_graph->is_bprop()) {
2287 ClearAllBucket(graph_id);
2288 }
2289
2290 std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
2291 for (const auto &kernel : kernel_graph->execution_order()) {
2292 // Generate input tensors, tensor masks and input kernel with index
2293 InputTensorInfo input_tensor_info;
2294 GetOpInputTensors(kernel, op_output_map, parameter_index, inputs, &input_tensor_info);
2295
2296 // Get OpRunInfo and GraphInfo
2297 OpRunInfo run_info;
2298 GetSingleOpRunInfo(kernel, &run_info);
2299 GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
2300
2301 // Build and run current single op
2302 VectorRef op_outputs;
2303 RunOpImplOrigin(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
2304 input_tensor_info.input_tensors_mask);
2305 graph_output_info.graph_output_tensors.clear();
2306 // Handle inputs and outputs of current op
2307 HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
2308 HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
2309 // Save grad node to Bucket
2310 if (kernel_graph->is_bprop()) {
2311 AddGradAddrToBucket(graph_id, graph_output_info.graph_output_tensors);
2312 }
2313 }
2314 MS_LOG(INFO) << "Finish!";
2315 }
2316
EraseValueNodeTensor(const std::vector<int64_t> & tensors_mask,std::vector<tensor::TensorPtr> * input_tensors) const2317 void SessionBasic::EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask,
2318 std::vector<tensor::TensorPtr> *input_tensors) const {
2319 MS_EXCEPTION_IF_NULL(input_tensors);
2320 if (input_tensors->size() != tensors_mask.size()) {
2321 MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
2322 << tensors_mask.size();
2323 }
2324 std::vector<tensor::TensorPtr> new_input_tensors;
2325 for (size_t index = 0; index < tensors_mask.size(); ++index) {
2326 if (tensors_mask[index] != kValueNodeTensorMask) {
2327 new_input_tensors.emplace_back(input_tensors->at(index));
2328 }
2329 }
2330 *input_tensors = new_input_tensors;
2331 }
2332
UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> & all_graphs)2333 void SessionBasic::UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs) {
2334 bool is_dynamic = false;
2335 for (const auto &graph : all_graphs) {
2336 UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
2337 is_dynamic = graph->is_dynamic_shape() || is_dynamic;
2338 }
2339 if (is_dynamic && all_graphs.size() > 1) {
2340 MS_LOG(EXCEPTION)
2341 << "Dynamic shape is not supported with control flow(loop control statements and condition control statements).";
2342 }
2343 }
2344
UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> & root_graph)2345 void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph) {
2346 for (const auto &cnode : root_graph->execution_order()) {
2347 if (AnfAlgo::IsNodeDynamicShape(cnode)) {
2348 AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
2349 MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
2350 }
2351 }
2352 root_graph->UpdateGraphDynamicAttr();
2353 }
2354
IsGetNextGraph(const std::shared_ptr<KernelGraph> & kernel_graph,std::string * channel_name)2355 bool SessionBasic::IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name) {
2356 MS_EXCEPTION_IF_NULL(kernel_graph);
2357 for (const auto &kernel_node : kernel_graph->execution_order()) {
2358 auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
2359 if (kernel_name == kGetNextOpName) {
2360 auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
2361 MS_EXCEPTION_IF_NULL(prim);
2362 *channel_name = GetValue<std::string>(prim->GetAttr("shared_name"));
2363 return true;
2364 }
2365 }
2366 return false;
2367 }
2368
RunOpRemoveNopNode(const KernelGraphPtr & kernel_graph) const2369 void SessionBasic::RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const {
2370 auto ms_context = MsContext::GetInstance();
2371 MS_EXCEPTION_IF_NULL(ms_context);
2372 if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
2373 opt::RemoveNopNode(kernel_graph.get());
2374 }
2375 }
2376
RunOpHideNopNode(const KernelGraphPtr & kernel_graph)2377 void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) {
2378 auto ms_context = MsContext::GetInstance();
2379 MS_EXCEPTION_IF_NULL(ms_context);
2380 if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
2381 opt::HideNopNode(kernel_graph.get());
2382 }
2383 }
2384
GetAllReduceSplitIndex()2385 std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
2386 auto ms_context = MsContext::GetInstance();
2387 MS_EXCEPTION_IF_NULL(ms_context);
2388 std::string group = GetCommWorldGroup();
2389 auto parallel_context = parallel::ParallelContext::GetInstance();
2390 MS_EXCEPTION_IF_NULL(parallel_context);
2391 // PyNative not support multi group allreduce
2392 group += "sum1";
2393 return parallel_context->GetAllReduceFusionSplitIndices(group);
2394 }
2395
GetBpropGraphGradsCount(const KernelGraphPtr & graph)2396 uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
2397 return AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size();
2398 }
2399
SetGraphBpropAttr(const KernelGraphPtr & graph)2400 void SetGraphBpropAttr(const KernelGraphPtr &graph) {
2401 auto &execution_orders = graph->execution_order();
2402 if (std::any_of(execution_orders.begin(), execution_orders.end(),
2403 [](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
2404 graph->set_is_bprop(true);
2405 MS_LOG(INFO) << "Match bprop graph";
2406 } else {
2407 graph->set_is_bprop(false);
2408 }
2409 }
2410
GenerateBucketSizeList(const KernelGraphPtr & graph,const std::vector<uint32_t> & split_index)2411 std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const std::vector<uint32_t> &split_index) {
2412 if (split_index.empty()) {
2413 auto grads_count = GetBpropGraphGradsCount(graph);
2414 if (grads_count == 0) {
2415 MS_LOG(EXCEPTION) << "Bprop graph has no grad";
2416 }
2417 return {grads_count};
2418 }
2419
2420 std::vector<uint32_t> bucket_size_list;
2421 uint32_t old_index = 0;
2422 for (const auto &index : split_index) {
2423 if (old_index == 0) {
2424 bucket_size_list.emplace_back(index - old_index + 1);
2425 } else {
2426 bucket_size_list.emplace_back(index - old_index);
2427 }
2428 old_index = index;
2429 }
2430 return bucket_size_list;
2431 }
2432
CheckSplitIndexValid(const vector<uint32_t> & split_index)2433 void CheckSplitIndexValid(const vector<uint32_t> &split_index) {
2434 uint32_t last = 0;
2435 for (size_t i = 0; i < split_index.size(); ++i) {
2436 if (split_index[i] <= last && i != 0) {
2437 MS_LOG(EXCEPTION) << "Invalid split index:" << split_index;
2438 }
2439 last = split_index[i];
2440 }
2441 }
2442
PreProcessOnSplitIndex(const KernelGraphPtr & graph,vector<uint32_t> * split_index)2443 void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
2444 MS_EXCEPTION_IF_NULL(split_index);
2445 if (split_index->empty()) {
2446 return;
2447 }
2448
2449 CheckSplitIndexValid(*split_index);
2450 // calculate split index num
2451 auto split_index_num = split_index->back();
2452 // obtain graph output tensor num
2453 auto grads_count = GetBpropGraphGradsCount(graph);
2454 if (split_index_num >= grads_count) {
2455 MS_LOG(WARNING) << "Invalid all_reduce_fusion_config:" << *split_index << " total grads count:" << grads_count
2456 << ". All AllReduce operators will be fused into one.";
2457 split_index->clear();
2458 split_index->push_back(grads_count - 1);
2459 } else if (split_index_num < grads_count - 1) {
2460 split_index->push_back(grads_count - 1);
2461 }
2462 }
2463
InitAllBucket(const KernelGraphPtr & graph,const device::DeviceContext * device_context)2464 void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
2465 MS_EXCEPTION_IF_NULL(graph);
2466 MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
2467 auto ms_context = MsContext::GetInstance();
2468 MS_EXCEPTION_IF_NULL(ms_context);
2469 const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
2470 auto parallel_context = parallel::ParallelContext::GetInstance();
2471 MS_EXCEPTION_IF_NULL(parallel_context);
2472 auto parallel_mode = parallel_context->parallel_mode();
2473 if (!pynative_mode || parallel_mode != parallel::DATA_PARALLEL) {
2474 return;
2475 }
2476 SetGraphBpropAttr(graph);
2477
2478 if (!graph->is_bprop()) {
2479 return;
2480 }
2481
2482 std::vector<std::shared_ptr<device::Bucket>> bucket_list;
2483 // Create bucket for every split allreduce ops
2484 auto split_index = GetAllReduceSplitIndex();
2485 PreProcessOnSplitIndex(graph, &split_index);
2486 auto bucket_size_list = GenerateBucketSizeList(graph, split_index);
2487 uint32_t bucket_id = 0;
2488 for (const auto &bucket_size : bucket_size_list) {
2489 MS_LOG(INFO) << "Create new bucket:" << bucket_id << " size:" << bucket_size;
2490 std::shared_ptr<device::Bucket> bucket = nullptr;
2491 if (device_context != nullptr) {
2492 bucket = device_context->CreateBucket(bucket_id++, bucket_size);
2493 } else {
2494 bucket = CreateBucket(bucket_id++, bucket_size);
2495 }
2496 bucket_list.emplace_back(bucket);
2497 }
2498
2499 auto bucket_ret = bucket_map_.try_emplace(graph->graph_id(), bucket_list);
2500 if (!bucket_ret.second) {
2501 MS_LOG(EXCEPTION) << "Duplicate bucket_map_ graph key:" << graph->graph_id();
2502 }
2503 // set all free bucket index to 0
2504 auto free_bucket_ret = free_bucket_id_map_.try_emplace(graph->graph_id(), 0);
2505 if (!free_bucket_ret.second) {
2506 MS_LOG(EXCEPTION) << "Duplicate free_bucket_id_map_ graph key:" << graph->graph_id();
2507 }
2508 MS_LOG(INFO) << "Init Bucket finish";
2509 }
2510
AddGradAddrToBucket(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & grad_tensor)2511 void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
2512 auto parallel_context = parallel::ParallelContext::GetInstance();
2513 MS_EXCEPTION_IF_NULL(parallel_context);
2514 auto parallel_mode = parallel_context->parallel_mode();
2515 if (parallel_mode != parallel::DATA_PARALLEL) {
2516 return;
2517 }
2518
2519 auto iter = bucket_map_.find(graph_id);
2520 if (iter == bucket_map_.end()) {
2521 MS_LOG(EXCEPTION) << "unknown graph id:" << graph_id;
2522 }
2523 auto &bucket_list = iter->second;
2524 auto free_bucket_iter = free_bucket_id_map_.find(graph_id);
2525 if (free_bucket_iter == free_bucket_id_map_.end()) {
2526 MS_LOG(EXCEPTION) << "unknown free graph id:" << graph_id;
2527 }
2528
2529 auto free_bucket_index = free_bucket_iter->second;
2530 for (auto &tensor : grad_tensor) {
2531 if (free_bucket_index >= bucket_list.size()) {
2532 MS_LOG(EXCEPTION) << "Invalid free bucket id:" << free_bucket_iter->second
2533 << " total bucket num:" << bucket_list.size();
2534 }
2535 auto &free_bucket = bucket_list[free_bucket_index];
2536 free_bucket->AddGradTensor(tensor);
2537 if (free_bucket->full()) {
2538 MS_LOG(INFO) << "bucket is full";
2539 free_bucket->Launch();
2540 free_bucket_index = ++free_bucket_iter->second;
2541 MS_LOG(INFO) << "new free bucket:" << free_bucket_index;
2542 }
2543 }
2544 }
2545
ClearAllBucket(const GraphId & graph_id)2546 void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
2547 auto iter = bucket_map_.find(graph_id);
2548 if (iter != bucket_map_.end()) {
2549 auto bucket_list = iter->second;
2550 for (auto &bucket : bucket_list) {
2551 MS_LOG(INFO) << "Clear bucket:" << bucket->id();
2552 bucket->Release();
2553 }
2554 }
2555 auto free_iter = free_bucket_id_map_.find(graph_id);
2556 if (free_iter != free_bucket_id_map_.end()) {
2557 free_iter->second = 0;
2558 }
2559 }
2560
FinalOptimize(const KernelGraphPtr & graph) const2561 void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const {
2562 MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id();
2563 opt::CommonFinalOptimization(graph);
2564 MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id();
2565 }
2566
DumpGraph(const std::shared_ptr<KernelGraph> & kernel_graph)2567 void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
2568 #ifdef ENABLE_DUMP_IR
2569 auto context_ptr = MsContext::GetInstance();
2570 MS_EXCEPTION_IF_NULL(context_ptr);
2571 bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
2572 if (save_graphs) {
2573 DumpIR("graph_build_" + std::to_string(kernel_graph->graph_id()) + ".ir", kernel_graph, true, kWholeStack);
2574 DumpIRProto(kernel_graph, "vm_build_" + std::to_string(kernel_graph->graph_id()));
2575 DumpIR("trace_code_graph", kernel_graph, true, kWholeStack);
2576 }
2577 #endif
2578 }
2579
UnifyMindIR(const KernelGraphPtr & graph)2580 void SessionBasic::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIROptimization(graph); }
2581
2582 #if ((defined ENABLE_CPU) && (!defined _WIN32))
InitPsWorker(const KernelGraphPtr & kernel_graph)2583 void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
2584 if (!ps::PSContext::instance()->is_worker()) {
2585 return;
2586 }
2587 CheckPSModeConsistence(kernel_graph);
2588 if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
2589 if (!ps::ps_cache_instance.initialized_ps_cache()) {
2590 auto context_ptr = MsContext::GetInstance();
2591 MS_EXCEPTION_IF_NULL(context_ptr);
2592 auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2593 auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(devcie_target, device_id_);
2594 MS_EXCEPTION_IF_NULL(runtime_instance);
2595 auto context = runtime_instance->context();
2596 const auto &kernels = kernel_graph->execution_order();
2597 if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") {
2598 GetBatchElements(kernels[0]);
2599 ps::ps_cache_instance.Initialize();
2600 }
2601 ps::ps_cache_instance.DoProcessData(device_id_, context);
2602 }
2603 } else {
2604 // Assign parameter keys.
2605 AssignParamKey(kernel_graph);
2606 }
2607 }
2608
GetBatchElements(const AnfNodePtr & kernel_node) const2609 void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const {
2610 auto shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
2611 auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types");
2612 if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) {
2613 MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size "
2614 << types;
2615 }
2616 size_t batch_elements = 1;
2617 const auto &shape = shapes[0];
2618 for (size_t i = 0; i < shape.size(); ++i) {
2619 batch_elements *= LongToSize(shape[i]);
2620 }
2621 ps::ps_cache_instance.set_batch_elements(batch_elements);
2622 }
2623
CheckPSModeConsistence(const KernelGraphPtr & kernel_graph) const2624 void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const {
2625 auto input_nodes = kernel_graph->inputs();
2626 for (const auto &input_node : input_nodes) {
2627 if (!input_node->isa<Parameter>()) {
2628 continue;
2629 }
2630 auto pk_node = input_node->cast<ParameterPtr>();
2631 MS_EXCEPTION_IF_NULL(pk_node);
2632 auto param_info_ptr = pk_node->param_info();
2633 const std::string ¶m_name = pk_node->fullname_with_scope();
2634 if (param_info_ptr != nullptr && param_info_ptr->init_in_server() &&
2635 !ps::ps_cache_instance.IsHashTable(param_name)) {
2636 MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name
2637 << "] in server, this parameter is used by kernel which executes in device";
2638 }
2639 }
2640 }
2641
AssignParamKey(const KernelGraphPtr & kernel_graph)2642 void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
2643 MS_EXCEPTION_IF_NULL(kernel_graph);
2644 // PS embeddingLookup cache check.
2645 if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
2646 MS_LOG(EXCEPTION) << "The other parameter can't set ps mode when the embeddingLookup cache is enabled in "
2647 "parameter server training mode.";
2648 }
2649 std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
2650 for (auto &node : node_list) {
2651 if (node != nullptr && node->isa<CNode>()) {
2652 // Assign key for forward kernel EmbeddingLookup.
2653 // The key will be assigned to embedding table ande Push kernel as well.
2654 if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
2655 size_t embedding_table_idx = 0;
2656 auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
2657 size_t key = ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope());
2658 AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
2659 } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
2660 auto pull_node = FindPullNode(node, node_list);
2661 if (!pull_node) {
2662 MS_LOG(EXCEPTION) << "Assigning parameter key failed: can't find Pull node of the Push node.";
2663 }
2664
2665 // Second input of Pull node is the trainable parameter.
2666 size_t parameter_index = 1;
2667 auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
2668 size_t key = ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope());
2669 AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
2670 AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
2671
2672 std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
2673 ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name);
2674 }
2675 }
2676 }
2677 }
2678
InitPSParamAndOptim(const KernelGraphPtr & kernel_graph,const std::vector<tensor::TensorPtr> & inputs_const)2679 void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
2680 const std::vector<tensor::TensorPtr> &inputs_const) {
2681 if (!ps::PSContext::instance()->is_worker()) {
2682 return;
2683 }
2684 std::vector<tensor::TensorPtr> inputs(inputs_const);
2685 MS_EXCEPTION_IF_NULL(kernel_graph);
2686 auto input_nodes = kernel_graph->inputs();
2687 auto ms_context = MsContext::GetInstance();
2688 MS_EXCEPTION_IF_NULL(ms_context);
2689 for (size_t i = 0; i < inputs.size(); ++i) {
2690 auto tensor = inputs[i];
2691 MS_EXCEPTION_IF_NULL(tensor);
2692 auto input_node = input_nodes[i];
2693 MS_EXCEPTION_IF_NULL(input_node);
2694 if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
2695 ps::Worker::GetInstance().InitPSParamAndOptim(input_node, tensor);
2696 }
2697 }
2698 }
2699 #endif
2700 } // namespace session
DumpGraphExeOrder(const std::string & file_name,const std::string & target_dir,const std::vector<CNodePtr> & execution_order)2701 void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir,
2702 const std::vector<CNodePtr> &execution_order) {
2703 std::string file_path = target_dir + "/execution_order/" + file_name;
2704 auto realpath = Common::CreatePrefixPath(file_path);
2705 if (!realpath.has_value()) {
2706 MS_LOG(ERROR) << "Failed to get real path: [" << file_path << "] in dump graph execution order.";
2707 return;
2708 }
2709 file_path = realpath.value();
2710
2711 ChangeFileMode(file_path, S_IWUSR);
2712 // write to csv file
2713 std::ofstream ofs(file_path);
2714 if (!ofs.is_open()) {
2715 MS_LOG(ERROR) << "Failed to open file [" << file_path
2716 << "] in dump graph execution order, please check the file access permission and whether disk space "
2717 "is available.";
2718 return;
2719 }
2720 ofs << "NodeExecutionOrder-FullNameWithScope\n";
2721 for (const CNodePtr &node : execution_order) {
2722 ofs << node->fullname_with_scope() << "\n";
2723 }
2724 ofs.close();
2725 // set file mode to read only by user
2726 ChangeFileMode(file_path, S_IRUSR);
2727 }
2728
GetRankId()2729 uint32_t GetRankId() {
2730 uint32_t rank_id = 0;
2731 auto ms_context = MsContext::GetInstance();
2732 MS_EXCEPTION_IF_NULL(ms_context);
2733
2734 std::string world_group;
2735 std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2736 if (backend == kAscendDevice) {
2737 world_group = kHcclWorldGroup;
2738 } else if (backend == kGPUDevice) {
2739 world_group = kNcclWorldGroup;
2740 } else {
2741 MS_LOG(ERROR) << "Invalid backend: " << backend;
2742 return rank_id;
2743 }
2744 if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
2745 MS_LOG(INFO) << "Failed to get rank id.";
2746 }
2747 return rank_id;
2748 }
2749 } // namespace mindspore
2750