• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/session/session_basic.h"
17 
18 #include <algorithm>
19 #include <set>
20 #include <queue>
21 #include <unordered_map>
22 #include <utility>
23 #include <functional>
24 
25 #include "ops/primitive_c.h"
26 #include "ir/manager.h"
27 #include "abstract/utils.h"
28 #include "backend/kernel_compiler/common_utils.h"
29 #include "base/core_ops.h"
30 #include "base/base_ref_utils.h"
31 #include "common/trans.h"
32 #include "utils/config_manager.h"
33 #include "backend/session/anf_runtime_algorithm.h"
34 #include "backend/session/executor_manager.h"
35 #include "backend/optimizer/common/common_backend_optimization.h"
36 #include "backend/optimizer/common/helper.h"
37 #include "runtime/device/kernel_runtime_manager.h"
38 #include "utils/ms_utils.h"
39 #include "ir/anf.h"
40 #include "ir/func_graph_cloner.h"
41 #include "utils/utils.h"
42 #include "debug/anf_ir_dump.h"
43 #include "debug/dump_proto.h"
44 #include "utils/file_utils.h"
45 #include "utils/trace_base.h"
46 #include "frontend/parallel/context.h"
47 #if ((defined ENABLE_CPU) && (!defined _WIN32))
48 #include "ps/ps_cache/ps_cache_manager.h"
49 #include "ps/constants.h"
50 #include "ps/util.h"
51 #include "ps/ps_context.h"
52 #include "abstract/abstract_value.h"
53 #endif
54 #include "backend/session/session_factory.h"
55 #include "backend/session/pynative_task_manager.h"
56 
57 namespace mindspore {
58 namespace session {
59 MS_REG_SESSION(kSessionBasic, SessionBasic);
60 
61 namespace {
62 const int kSummaryGetItem = 2;
63 const size_t max_depth = 128;
IsShapeDynamic(const abstract::ShapePtr & shape)64 bool IsShapeDynamic(const abstract::ShapePtr &shape) {
65   if (shape == nullptr) {
66     return false;
67   }
68   return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
69 }
RecursiveCheck(const FuncGraphManagerPtr & manager,const std::pair<AnfNodePtr,int64_t> & kernel,size_t * idx)70 bool RecursiveCheck(const FuncGraphManagerPtr &manager, const std::pair<AnfNodePtr, int64_t> &kernel, size_t *idx) {
71   auto node = kernel.first;
72   MS_EXCEPTION_IF_NULL(manager);
73   MS_EXCEPTION_IF_NULL(node);
74   if (kernel.second > 1 &&
75       (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend) || AnfAlgo::CheckPrimitiveType(node, prim::kPrimLoad))) {
76     return false;
77   }
78   if (AnfAlgo::IsRealKernel(node) && !AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
79     return true;
80   }
81   (*idx) += 1;
82   // max recursion depth
83   if (*idx <= max_depth) {
84     auto users = manager->node_users()[node];
85     if (std::any_of(users.begin(), users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
86           return RecursiveCheck(manager, kernel, idx);
87         })) {
88       return true;
89     }
90   }
91   return false;
92 }
93 
IsUsedByRealKernel(const FuncGraphManagerPtr & manager,const AnfNodePtr & node,const uint32_t graph_id)94 bool IsUsedByRealKernel(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, const uint32_t graph_id) {
95   MS_EXCEPTION_IF_NULL(manager);
96   MS_EXCEPTION_IF_NULL(node);
97   auto node_users = manager->node_users()[node];
98   // filter nodes not in current graph
99   for (auto iter = node_users.begin(); iter != node_users.end();) {
100     auto func_graph = iter->first->func_graph();
101     auto kernel_graph = func_graph->cast<KernelGraphPtr>();
102     if (kernel_graph == nullptr) {
103       MS_LOG(EXCEPTION) << "func graph cast kernel graph failed, related node is: " << iter->first->DebugString();
104     }
105     if (kernel_graph->graph_id() != graph_id) {
106       iter = node_users.erase(iter);
107     } else {
108       ++iter;
109     }
110   }
111 
112   size_t idx = 0;
113   if (std::any_of(node_users.begin(), node_users.end(), [&](const std::pair<AnfNodePtr, int64_t> &kernel) {
114         return RecursiveCheck(manager, kernel, &idx);
115       })) {
116     return true;
117   }
118   return false;
119 }
120 
SetInputNodeUsage(const KernelGraphPtr & graph,const FuncGraphManagerPtr & manager)121 void SetInputNodeUsage(const KernelGraphPtr &graph, const FuncGraphManagerPtr &manager) {
122   MS_EXCEPTION_IF_NULL(graph);
123   MS_EXCEPTION_IF_NULL(manager);
124   auto input_nodes = graph->input_nodes();
125   for (auto &input_node : input_nodes) {
126     if (input_node->isa<Parameter>()) {
127       auto node_ptr = input_node->cast<ParameterPtr>();
128       MS_EXCEPTION_IF_NULL(node_ptr);
129       if (!IsUsedByRealKernel(manager, input_node, graph->graph_id())) {
130         node_ptr->SetNotUsedByRealKernelInGraph(graph->graph_id());
131       }
132       auto shape = node_ptr->Shape();
133       if (IsShapeDynamic(shape->cast<abstract::ShapePtr>())) {
134         node_ptr->set_has_dynamic_shape(true);
135       }
136     }
137   }
138 }
139 
GetParamDefaultValue(const AnfNodePtr & node)140 ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
141   if (node == nullptr) {
142     return nullptr;
143   }
144   auto parameter = node->cast<ParameterPtr>();
145   if (parameter == nullptr || !parameter->has_default()) {
146     return nullptr;
147   }
148   return parameter->param_info();
149 }
150 
IsPynativeMode()151 static bool IsPynativeMode() {
152   auto ms_context = MsContext::GetInstance();
153   MS_EXCEPTION_IF_NULL(ms_context);
154   return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
155 }
156 
GetNodeOutputTensorFromInputs(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors)157 BaseRef GetNodeOutputTensorFromInputs(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
158                                       const std::vector<tensor::TensorPtr> &input_tensors) {
159   auto &node = node_output_pair.first;
160   MS_EXCEPTION_IF_NULL(node);
161   if (HasAbstractMonad(node)) {
162     return std::make_shared<tensor::Tensor>(int64_t(0), kBool);
163   }
164   // if node is a value node, no need sync addr from device to host
165   if (node->isa<ValueNode>()) {
166     auto value_node = node->cast<ValueNodePtr>();
167     MS_EXCEPTION_IF_NULL(value_node);
168     return value_node->value();
169   }
170   if (IsPynativeMode()) {
171     return nullptr;
172   }
173   if (!node->isa<Parameter>()) {
174     return nullptr;
175   }
176   MS_EXCEPTION_IF_NULL(graph);
177   auto param_node = node->cast<ParameterPtr>();
178   if (param_node != nullptr && param_node->IsUsedByRealKernelInGraph(graph->graph_id())) {
179     return nullptr;
180   }
181   for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
182     if (input_idx >= input_tensors.size()) {
183       MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
184     }
185     if (graph->inputs()[input_idx] == node) {
186       return input_tensors[input_idx];
187     }
188   }
189   return nullptr;
190 }
191 
ShapeSize(const std::vector<int64_t> & shape)192 int64_t ShapeSize(const std::vector<int64_t> &shape) {
193   return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
194 }
195 
CreateNodeOutputTensor(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node)196 BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
197                                const std::vector<tensor::TensorPtr> &input_tensors,
198                                std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
199   auto &node = node_output_pair.first;
200   size_t output_index = node_output_pair.second;
201   MS_EXCEPTION_IF_NULL(node);
202   MS_EXCEPTION_IF_NULL(graph);
203   auto tensor_from_input = GetNodeOutputTensorFromInputs(node_output_pair, graph, input_tensors);
204   if (tensor_from_input != nullptr) {
205     return tensor_from_input;
206   }
207   TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
208   if (type_id == kTypeUnknown) {
209     type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
210   }
211   std::vector<int64_t> temp_shape;
212   auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
213   (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
214   if (AnfAlgo::IsDynamicShape(node)) {
215     auto max_shape = AnfAlgo::GetOutputMaxShape(node, output_index);
216     temp_shape = ShapeSize(max_shape) > ShapeSize(temp_shape) ? max_shape : temp_shape;
217   }
218   tensor::TensorPtr tensor;
219   bool is_internal_output = graph->IsInternalOutput(node, output_index);
220   if (is_internal_output) {
221     tensor = graph->GetInternalOutputTensor(node, output_index);
222     if (tensor == nullptr) {
223       tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
224       graph->AddInternalOutputTensor(node, output_index, tensor);
225     }
226   } else {
227     tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
228   }
229   MS_EXCEPTION_IF_NULL(tensor);
230   tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
231   if (is_internal_output) {
232     tensor->set_sync_status(kNoNeedSync);
233   } else {
234     // if in pynative mode,data only copied to host when user want to print data
235     auto ms_context = MsContext::GetInstance();
236     MS_EXCEPTION_IF_NULL(ms_context);
237     if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
238         ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
239       tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
240     } else {
241       tensor->set_sync_status(kNeedSyncDeviceToHost);
242     }
243   }
244   tensor->SetIsGraphOutput();
245   (*tensor_to_node)[tensor] = node_output_pair;
246   return tensor;
247 }
248 
CreateNodeOutputTensors(const AnfNodePtr & anf,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node,KernelMapTensor * node_to_tensor)249 BaseRef CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph,
250                                 const std::vector<tensor::TensorPtr> &input_tensors,
251                                 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
252                                 KernelMapTensor *node_to_tensor) {
253   MS_EXCEPTION_IF_NULL(anf);
254   MS_EXCEPTION_IF_NULL(tensor_to_node);
255   MS_EXCEPTION_IF_NULL(node_to_tensor);
256   MS_LOG(DEBUG) << "Create tensor for output[" << anf->DebugString() << "]";
257   auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
258   MS_EXCEPTION_IF_NULL(item_with_index.first);
259   MS_LOG(DEBUG) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
260   // special handle for maketuple
261   if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
262     auto cnode = item_with_index.first->cast<CNodePtr>();
263     MS_EXCEPTION_IF_NULL(cnode);
264     VectorRef ret;
265     for (size_t i = 1; i < cnode->inputs().size(); ++i) {
266       auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node, node_to_tensor);
267       ret.push_back(out);
268     }
269     return ret;
270   }
271   // if is graph return nothing ,the function should return a null anylist
272   size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
273   if (size == 0) {
274     return VectorRef();
275   }
276 
277   //  The outputs of graph may have the same kernel node, no need to create new tensor.
278   const auto &iter = node_to_tensor->find(item_with_index);
279   if (iter != node_to_tensor->end()) {
280     return iter->second;
281   }
282 
283   const auto &tensor = CreateNodeOutputTensor(item_with_index, graph, input_tensors, tensor_to_node);
284   (*node_to_tensor)[item_with_index] = tensor;
285   return tensor;
286 }
287 
CreateNewValueNode(const AnfNodePtr & anf,KernelGraph * graph)288 ValueNodePtr CreateNewValueNode(const AnfNodePtr &anf, KernelGraph *graph) {
289   MS_EXCEPTION_IF_NULL(anf);
290   MS_EXCEPTION_IF_NULL(graph);
291   auto value_node = anf->cast<ValueNodePtr>();
292   MS_EXCEPTION_IF_NULL(value_node);
293   auto value = value_node->value();
294   MS_EXCEPTION_IF_NULL(value);
295   if (value->isa<None>()) {
296     return nullptr;
297   }
298   auto new_value_node = graph->NewValueNode(value_node);
299   graph->FrontBackendlMapAdd(anf, new_value_node);
300   graph->AddValueNodeToGraph(new_value_node);
301   return new_value_node;
302 }
303 
ConstructRunOpParameter(const std::shared_ptr<KernelGraph> & graph,const tensor::TensorPtr & input_tensor,int64_t tensor_mask)304 ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, const tensor::TensorPtr &input_tensor,
305                                      int64_t tensor_mask) {
306   MS_EXCEPTION_IF_NULL(graph);
307   auto param = graph->NewParameter();
308   MS_EXCEPTION_IF_NULL(param);
309   if (tensor_mask == kParameterWeightTensorMask) {
310     param->set_default_param(input_tensor);
311   }
312   // set the kernel info of parameter
313   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
314   MS_EXCEPTION_IF_NULL(input_tensor);
315   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
316   if (device_address == nullptr) {
317     kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
318     TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
319     kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
320   } else {
321     kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
322     kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
323     kernel_build_info_builder->SetOutputsReshapeType({input_tensor->padding_type()});
324     AnfAlgo::SetOutputAddr(device_address, 0, param.get());
325   }
326   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
327   // construct abstract of parameter
328   auto type_of_tensor = input_tensor->Dtype();
329   auto shape_of_tensor = input_tensor->shape();
330   auto abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, shape_of_tensor);
331   param->set_abstract(abstract);
332   return param;
333 }
334 
DumpGraphOutput(const Any & any,size_t recurse_level=0)335 void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
336   MS_LOG(INFO) << "Graph outputs:";
337   const size_t max_deep = 10;
338   if (recurse_level > max_deep) {
339     MS_LOG(INFO) << "Recurse too deep";
340     return;
341   }
342   std::string tab_str;
343   for (size_t i = 0; i < recurse_level; i++) {
344     tab_str = tab_str.append("  ");
345   }
346   if (any.is<AnyList>()) {
347     (void)tab_str.append("{");
348     MS_LOG(INFO) << tab_str;
349     auto any_list = any.cast<AnyList>();
350     for (auto &it : any_list) {
351       DumpGraphOutput(it, recurse_level + 1);
352     }
353     (void)tab_str.append("}");
354     MS_LOG(INFO) << tab_str;
355   }
356   (void)tab_str.append(any.ToString());
357   MS_LOG(INFO) << tab_str;
358 }
359 
360 #ifndef ENABLE_SECURITY
ExistSummaryNode(const KernelGraph * graph)361 bool ExistSummaryNode(const KernelGraph *graph) {
362   MS_EXCEPTION_IF_NULL(graph);
363   auto ret = graph->get_return();
364   MS_EXCEPTION_IF_NULL(ret);
365   auto all_nodes = DeepLinkedGraphSearch(ret);
366   for (auto &n : all_nodes) {
367     if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
368         IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
369       return true;
370     }
371   }
372   return false;
373 }
374 #endif
375 
CreateNodeOutputPlaceholder(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<size_t> & indexes,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)376 BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
377                                     const std::vector<tensor::TensorPtr> &input_tensors,
378                                     const std::vector<size_t> &indexes,
379                                     std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
380   auto &node = node_output_pair.first;
381   MS_EXCEPTION_IF_NULL(node);
382   MS_EXCEPTION_IF_NULL(graph);
383   MS_EXCEPTION_IF_NULL(output_indexes);
384   MS_LOG(DEBUG) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second
385                 << "]";
386   // if node is a value node, no need sync addr from device to host
387   if (node->isa<ValueNode>()) {
388     auto value_node = node->cast<ValueNodePtr>();
389     MS_EXCEPTION_IF_NULL(value_node);
390     return value_node->value();
391   }
392   if (node->isa<Parameter>()) {
393     for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
394       if (input_idx >= input_tensors.size()) {
395         MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size();
396       }
397       if (graph->inputs()[input_idx] == node) {
398         return input_tensors[input_idx];
399       }
400     }
401     MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
402   }
403   (*output_indexes)[node_output_pair].emplace_back(indexes);
404   BaseRef output_placeholder = std::make_shared<BaseRef>();
405   return output_placeholder;
406 }
407 
CreateNodeOutputPlaceholder(const AnfNodePtr & anf,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<size_t> & indexes,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)408 BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
409                                     const std::vector<tensor::TensorPtr> &input_tensors,
410                                     const std::vector<size_t> &indexes,
411                                     std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
412   MS_EXCEPTION_IF_NULL(anf);
413   MS_EXCEPTION_IF_NULL(output_indexes);
414   MS_LOG(DEBUG) << "Create placeholder for output[" << anf->DebugString() << "]";
415   auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0);
416   MS_EXCEPTION_IF_NULL(item_with_index.first);
417   MS_LOG(DEBUG) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
418   // special handle for maketuple
419   if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
420     auto cnode = item_with_index.first->cast<CNodePtr>();
421     MS_EXCEPTION_IF_NULL(cnode);
422     VectorRef ret;
423     for (size_t i = 1; i < cnode->inputs().size(); ++i) {
424       std::vector<size_t> cur_index = indexes;
425       cur_index.emplace_back(i - 1);
426       auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes);
427       ret.push_back(out);
428     }
429     return ret;
430   }
431   // if is graph return nothing ,the function should return a null anylist
432   size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
433   if (size == 0) {
434     return VectorRef();
435   }
436   return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
437 }
438 
CheckInputTensorShape(const TensorPtr & tensor,const CNodePtr & kernel,size_t input_index)439 void CheckInputTensorShape(const TensorPtr &tensor, const CNodePtr &kernel, size_t input_index) {
440   MS_EXCEPTION_IF_NULL(tensor);
441   const auto &tensor_shape = tensor->shape();
442   const auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
443   if (tensor_shape.size() != input_shape.size()) {
444     MS_LOG(EXCEPTION) << "The input tensor's shape size: " << tensor_shape.size()
445                       << " is not equal to expected size: " << input_shape.size() << " for input[" << input_index
446                       << "] of kernel: " << AnfAlgo::GetCNodeName(kernel);
447   }
448   for (size_t i = 0; i < tensor_shape.size(); i++) {
449     if (tensor_shape[i] < 0 || static_cast<size_t>(tensor_shape[i]) != input_shape[i]) {
450       MS_LOG(EXCEPTION) << "The input tensor's shape: " << tensor_shape
451                         << " is not equal to expected shape: " << input_shape << " for input[" << input_index
452                         << "] of kernel: " << AnfAlgo::GetCNodeName(kernel);
453     }
454   }
455 }
456 
UpdateGraphAquireGilAttr(const NotNull<KernelGraphPtr> & root_graph)457 void UpdateGraphAquireGilAttr(const NotNull<KernelGraphPtr> &root_graph) {
458   for (const auto &cnode : root_graph->execution_order()) {
459     if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPyFunc)) {
460       MS_LOG(INFO) << "The Graph require GIL. Graph id: " << root_graph->graph_id();
461       root_graph->set_is_need_gil(true);
462       return;
463     }
464   }
465   return;
466 }
467 
ExistGraphCaller(const AnfNodePtr & partial_node)468 bool ExistGraphCaller(const AnfNodePtr &partial_node) {
469   MS_EXCEPTION_IF_NULL(partial_node);
470   auto partial_cnode = partial_node->cast<CNodePtr>();
471   MS_EXCEPTION_IF_NULL(partial_cnode);
472   auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
473   MS_EXCEPTION_IF_NULL(partial_graph);
474   auto graph_nodes = TopoSort(partial_graph->get_return());
475   return std::any_of(graph_nodes.begin(), graph_nodes.end(), IsValueNode<FuncGraph>);
476 }
477 
478 // 1. Convert the node to make_tuple if the node is a ValueNode<ValueTuple> and it's the input of 'return' node.
479 // 2. Set the return of graph if node is "Return" node.
SetReturnNode(const AnfNodePtr & node,KernelGraph * graph)480 void SetReturnNode(const AnfNodePtr &node, KernelGraph *graph) {
481   MS_EXCEPTION_IF_NULL(graph);
482   MS_EXCEPTION_IF_NULL(node);
483 
484   if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
485     constexpr auto kReturnInputIdx = 1;
486     auto return_node = node->cast<CNodePtr>();
487     graph->set_return(return_node);
488     auto graph_output = return_node->input(kReturnInputIdx);
489     MS_EXCEPTION_IF_NULL(graph_output);
490 
491     // If return's input is value node, then the graph has no kernel, and the pass 'trans tuple to make_tuple' cannot
492     // match this pattern because that pass begin with output node but return node. So we add transform value tuple
493     // to make_tuple here.
494     if (AnfAlgo::IsTupleOutput(graph_output) && graph_output->isa<ValueNode>()) {
495       return_node->set_input(kReturnInputIdx, graph->TransTupleToMakeTuple(graph_output));
496     }
497   }
498 }
499 }  // namespace
500 
501 GraphId SessionBasic::graph_sum_ = 0;
502 
InitExecutor(const std::string & device_name,uint32_t device_id)503 void SessionBasic::InitExecutor(const std::string &device_name, uint32_t device_id) {
504   device_id_ = device_id;
505   context_ = std::make_shared<Context>(device_name, device_id);
506   executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);
507 }
508 
GetGraphIdByNode(const AnfNodePtr & front_anf) const509 GraphId SessionBasic::GetGraphIdByNode(const AnfNodePtr &front_anf) const {
510   for (const auto &graph_item : graphs_) {
511     auto graph = graph_item.second;
512     MS_EXCEPTION_IF_NULL(graph);
513     // if front_anf is a parameter,the backend parameter may have two
514     if (graph->GetBackendAnfByFrontAnf(front_anf) != nullptr) {
515       return graph_item.first;
516     }
517   }
518   MS_EXCEPTION_IF_NULL(front_anf);
519   MS_LOG(DEBUG) << "Front_anf " << front_anf->DebugString() << " is not exist in any graph";
520   return kInvalidGraphId;
521 }
522 
GetGraph(mindspore::GraphId graph_id) const523 KernelGraphPtr SessionBasic::GetGraph(mindspore::GraphId graph_id) const {
524   auto it = graphs_.find(graph_id);
525   if (it == graphs_.end()) {
526     MS_LOG(INFO) << "Can't find graph " << graph_id;
527     return nullptr;
528   }
529   return it->second;
530 }
531 
ClearGraph()532 void SessionBasic::ClearGraph() {
533   auto graph_iter = graphs_.begin();
534   while (graph_iter != graphs_.end()) {
535     graph_iter->second.reset();
536     graphs_.erase(graph_iter++);
537   }
538   graph_sum_ = 0;
539 }
540 
InitInternalOutputParameter(const AnfNodePtr & out_node,const AnfNodePtr & parameter)541 void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr &parameter) {
542   auto graph_id = GetGraphIdByNode(out_node);
543   if (graph_id == kInvalidGraphId) {
544     return;
545   }
546   auto node_graph = GetGraph(graph_id);
547   if (node_graph == nullptr) {
548     return;
549   }
550   MS_LOG(INFO) << "Init parameter with pre graph output node: " << out_node->DebugString();
551   auto ref_node = node_graph->GetInternalOutputByFrontNode(out_node);
552   if (ref_node == nullptr) {
553     MS_LOG(INFO) << "No corresponding internal output for output node";
554     return;
555   }
556   size_t output_idx = 0;
557   if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
558     output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
559   }
560   auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
561   auto ref_real_node = real_kernel.first;
562   auto ref_real_node_index = real_kernel.second;
563   if (ref_real_node->isa<CNode>() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) {
564     auto kernel_info = ref_real_node->kernel_info();
565     if (kernel_info == nullptr || !kernel_info->has_build_info()) {
566       MS_LOG(INFO) << "No kernel info";
567       return;
568     }
569     if (!opt::IsNopNode(ref_real_node) && !AnfAlgo::OutputAddrExist(ref_real_node, ref_real_node_index)) {
570       MS_LOG(INFO) << "No kernel address";
571       return;
572     }
573     auto address = AnfAlgo::GetMutableOutputAddr(ref_real_node, ref_real_node_index);
574     auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index);
575     auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index);
576     auto d_kernel_info = std::make_shared<device::KernelInfo>();
577     MS_EXCEPTION_IF_NULL(d_kernel_info);
578     parameter->set_kernel_info(d_kernel_info);
579     kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
580     builder.SetOutputsDeviceType({type});
581     builder.SetOutputsFormat({format});
582     d_kernel_info->set_select_kernel_build_info(builder.Build());
583     AnfAlgo::SetOutputAddr(address, 0, parameter.get());
584     auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type),
585                                                                parameter->Shape()->cast<abstract::BaseShapePtr>());
586     parameter->set_abstract(abstract);
587   }
588 }
589 
CreateParameterFromTuple(const AnfNodePtr & node,KernelGraph * graph)590 AnfNodePtr SessionBasic::CreateParameterFromTuple(const AnfNodePtr &node, KernelGraph *graph) {
591   MS_EXCEPTION_IF_NULL(node);
592   MS_EXCEPTION_IF_NULL(graph);
593   auto new_parameter = graph->TransTupleToMakeTuple(graph->NewParameter(node->abstract()));
594   auto parameters = AnfAlgo::GetAllOutput(new_parameter);
595   std::vector<AnfNodePtr> pre_graph_out = {node};
596   // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive
597   if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) {
598     pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState});
599   }
600 
601   for (size_t i = 0; i < parameters.size(); ++i) {
602     const auto &parameter = parameters[i];
603     // In control flow, if the input of the cnode is a call node, it will be processed as a make_tuple input,
604     // which needs to be linked when processing the internal node.
605     graph->CacheInternalParameterToFrontNode(parameter, {node, i});
606     auto valid_inputs = graph->MutableValidInputs();
607     MS_EXCEPTION_IF_NULL(valid_inputs);
608     auto graph_inputs = graph->MutableInputs();
609     MS_EXCEPTION_IF_NULL(graph_inputs);
610     valid_inputs->push_back(true);
611     graph_inputs->push_back(parameter);
612   }
613   size_t param_index = 0;
614   for (const auto &out_node : pre_graph_out) {
615     size_t output_size = AnfAlgo::GetOutputTensorNum(out_node);
616     for (size_t i = 0; i < output_size; i++) {
617       if (param_index >= parameters.size()) {
618         MS_LOG(EXCEPTION) << "Parameters size:" << parameters.size() << "out of range.Node:" << node->DebugString()
619                           << ",out_node:" << out_node->DebugString();
620       }
621       InitInternalOutputParameter(out_node, parameters[param_index++]);
622     }
623   }
624   return new_parameter;
625 }
626 
CreateNewParameterFromParameter(const AnfNodePtr & anf,KernelGraph * graph)627 ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) {
628   MS_EXCEPTION_IF_NULL(anf);
629   if (!anf->isa<Parameter>()) {
630     MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
631   }
632   MS_EXCEPTION_IF_NULL(graph);
633   auto param_value = GetParamDefaultValue(anf);
634   auto valid_inputs = graph->MutableValidInputs();
635   MS_EXCEPTION_IF_NULL(valid_inputs);
636   auto graph_inputs = graph->MutableInputs();
637   MS_EXCEPTION_IF_NULL(graph_inputs);
638   ParameterPtr new_parameter = nullptr;
639   // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
640   if (param_value != nullptr) {
641     new_parameter = param_value->parameter();
642   }
643   if (new_parameter == nullptr) {
644     TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
645     new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
646 
647     auto input_node_iter = partial_parameters_map_.find(anf);
648     if (input_node_iter != partial_parameters_map_.end()) {
649       InitInternalOutputParameter(input_node_iter->second, new_parameter);
650     }
651 
652     if (param_value != nullptr) {
653       param_value->set_parameter(new_parameter);
654     }
655   }
656   new_parameter->IncreaseUsedGraphCount();
657   graph_inputs->push_back(new_parameter);
658   valid_inputs->push_back(true);
659   return new_parameter;
660 }
661 
CreateNewParameterFromCNode(const AnfNodePtr & anf,KernelGraph * graph)662 AnfNodePtr SessionBasic::CreateNewParameterFromCNode(const AnfNodePtr &anf, KernelGraph *graph) {
663   MS_EXCEPTION_IF_NULL(anf);
664   MS_EXCEPTION_IF_NULL(graph);
665   MS_LOG(INFO) << "Create a new parameter from cnode[" << anf->DebugString() << "]";
666   return CreateParameterFromTuple(anf, graph);
667 }
668 
GetCNodeInfo(const CNodePtr & cnode,std::vector<AnfNodePtr> * cnode_inputs) const669 void SessionBasic::GetCNodeInfo(const CNodePtr &cnode, std::vector<AnfNodePtr> *cnode_inputs) const {
670   MS_EXCEPTION_IF_NULL(cnode);
671   MS_EXCEPTION_IF_NULL(cnode_inputs);
672   auto prim = AnfAlgo::GetCNodePrimitive(cnode);
673   if (prim != nullptr) {
674     // push attr to inputs[0] of new cnode
675     cnode_inputs->push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*prim)));
676   } else {
677     auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
678     MS_EXCEPTION_IF_NULL(fg);
679     auto new_fg = BasicClone(fg);
680     cnode_inputs->push_back(std::make_shared<ValueNode>(new_fg));
681   }
682 }
683 
GetNewCNodeInputs(const CNodePtr & cnode,KernelGraph * graph,std::vector<AnfNodePtr> * cnode_inputs,std::unordered_map<AnfNodePtr,AnfNodePtr> * other_graph_cnode)684 void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
685                                      std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
686   MS_EXCEPTION_IF_NULL(cnode);
687   MS_EXCEPTION_IF_NULL(graph);
688   MS_EXCEPTION_IF_NULL(other_graph_cnode);
689   MS_EXCEPTION_IF_NULL(cnode_inputs);
690   auto origin_inputs = cnode->inputs();
691   const bool is_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend);
692   // if has multiple depends,only select first depend as parameter
693   for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) {
694     auto anf = origin_inputs[input_idx];
695     MS_EXCEPTION_IF_NULL(anf);
696     // anf has been created before
697     if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
698       (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
699       continue;
700     } else if ((is_depend && input_idx > kRealInputIndexInDepend)) {
701       cnode_inputs->push_back(NewValueNode(MakeValue(SizeToInt(input_idx))));
702       continue;
703     } else if (other_graph_cnode->find(anf) != other_graph_cnode->end()) {
704       cnode_inputs->push_back((*other_graph_cnode)[anf]);
705       continue;
706     } else if (anf->isa<ValueNode>() && !IsValueNode<FuncGraph>(anf)) {
707       // if input is a value node,
708       auto new_value_node = CreateNewValueNode(anf, graph);
709       if (new_value_node != nullptr) {
710         (void)cnode_inputs->emplace_back(new_value_node);
711       }
712       continue;
713     } else if (anf->isa<Parameter>()) {
714       auto new_parameter = CreateNewParameterFromParameter(anf, graph);
715       cnode_inputs->push_back(new_parameter);
716       graph->FrontBackendlMapAdd(anf, new_parameter);
717       continue;
718     } else {
719       // the input node is a cnode from other graph
720       auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph);
721       if (parameter_from_cnode == nullptr) {
722         parameter_from_cnode = NewValueNode(MakeValue(SizeToLong(input_idx)));
723       }
724       if (parameter_from_cnode->isa<Parameter>() && IsPrimitiveCNode(anf, prim::kPrimLoad)) {
725         auto para = parameter_from_cnode->cast<ParameterPtr>();
726         auto load_cnode = anf->cast<CNodePtr>();
727         para->set_name(load_cnode->input(kFirstDataInputIndex)->fullname_with_scope());
728       }
729       cnode_inputs->push_back(parameter_from_cnode);
730       (*other_graph_cnode)[anf] = parameter_from_cnode;
731     }
732   }
733 }
734 
CreateNewCNode(const CNodePtr & cnode,KernelGraph * graph,std::unordered_map<AnfNodePtr,AnfNodePtr> * other_graph_cnode)735 CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph,
736                                       std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode) {
737   MS_EXCEPTION_IF_NULL(cnode);
738   MS_EXCEPTION_IF_NULL(graph);
739   MS_EXCEPTION_IF_NULL(other_graph_cnode);
740   // get primitive of old node
741   std::vector<AnfNodePtr> cnode_inputs;
742   GetCNodeInfo(cnode, &cnode_inputs);
743   GetNewCNodeInputs(cnode, graph, &cnode_inputs, other_graph_cnode);
744   TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
745   auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
746   return new_cnode;
747 }
748 
CreateSwitchInput(const CNodePtr & cnode,const AnfNodePtr & node_input,KernelGraph * graph)749 CNodePtr SessionBasic::CreateSwitchInput(const CNodePtr &cnode, const AnfNodePtr &node_input, KernelGraph *graph) {
750   MS_EXCEPTION_IF_NULL(node_input);
751   MS_EXCEPTION_IF_NULL(graph);
752   // switch input generalizes partial
753   std::vector<AnfNodePtr> partial_inputs = {NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name()))};
754   if (AnfAlgo::CheckPrimitiveType(node_input, prim::kPrimPartial)) {
755     auto backend_node = graph->GetBackendAnfByFrontAnf(node_input);
756     return backend_node->cast<CNodePtr>();
757   } else if (node_input->isa<ValueNode>() && IsValueNode<FuncGraph>(node_input)) {
758     partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
759   } else {
760     KernelGraphPtr kernel_graph = NewKernelGraph();
761     MS_EXCEPTION_IF_NULL(kernel_graph);
762     auto parameter = CreateNewParameterFromCNode(cnode, kernel_graph.get());
763     MS_EXCEPTION_IF_NULL(parameter);
764     parameter->set_abstract(cnode->abstract());
765     auto primitive = NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()));
766     auto return_node = kernel_graph->NewCNode({primitive, parameter});
767     return_node->set_abstract(cnode->abstract());
768     kernel_graph->set_return(return_node);
769     partial_inputs.emplace_back(std::make_shared<ValueNode>(kernel_graph));
770     partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(node_input));
771   }
772   auto partial_node = graph->NewCNode(partial_inputs);
773   return partial_node;
774 }
775 
CreateCallSwitchInputs(const CNodePtr & cnode,KernelGraph * graph)776 std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cnode, KernelGraph *graph) {
777   MS_EXCEPTION_IF_NULL(cnode);
778   MS_EXCEPTION_IF_NULL(graph);
779   std::vector<AnfNodePtr> cnode_inputs = {
780     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
781   auto attr_input = cnode->input(kAnfPrimitiveIndex);
782   MS_EXCEPTION_IF_NULL(attr_input);
783   auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
784   auto switch_cnode = cnode_input->cast<CNodePtr>();
785   MS_EXCEPTION_IF_NULL(switch_cnode);
786   if (cnode->inputs().size() <= 1) {
787     cnode_inputs = switch_cnode->inputs();
788     return cnode_inputs;
789   }
790   std::vector<AnfNodePtr> switch_inputs = {switch_cnode->input(kAnfPrimitiveIndex),
791                                            switch_cnode->input(kFirstDataInputIndex)};
792   for (size_t index = kFirstBranchInSwitch; index < switch_cnode->inputs().size(); index++) {
793     auto node = switch_cnode->input(index);
794     // there is real input in call, should put it to true and false branch in switch
795     if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimPartial)) {
796       auto partial_node = node->cast<CNodePtr>();
797       MS_EXCEPTION_IF_NULL(partial_node);
798       std::vector<AnfNodePtr> partial_inputs = partial_node->inputs();
799       // Put all call args at the end of partial inputs.
800       for (size_t i = kFirstDataInputIndex; i < cnode->size(); ++i) {
801         (void)partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(i)));
802       }
803       auto new_partial = graph->NewCNode(partial_inputs);
804       (void)switch_inputs.emplace_back(new_partial);
805     }
806   }
807   if (switch_inputs.size() < kSwitchInputSize) {
808     MS_LOG(EXCEPTION) << "Switch inputs size: " << switch_inputs.size() << "less than " << kSwitchInputSize;
809   }
810   auto switch_node = graph->NewCNode(switch_inputs);
811   (void)cnode_inputs.emplace_back(switch_node);
812   return cnode_inputs;
813 }
814 
ProcessNodeRetFunc(const CNodePtr & cnode,KernelGraph * graph,const std::vector<AnfNodePtr> & real_inputs)815 void SessionBasic::ProcessNodeRetFunc(const CNodePtr &cnode, KernelGraph *graph,
816                                       const std::vector<AnfNodePtr> &real_inputs) {
817   MS_EXCEPTION_IF_NULL(cnode);
818   // func1 =switch(branch1, branch2)
819   // func2 = func1(param1)
820   // out = func2(param2)
821   // process the last cnode(func2), not func1 which abstract is AbstractFunction
822   if (cnode->abstract()->isa<abstract::AbstractFunction>()) {
823     return;
824   }
825   MS_EXCEPTION_IF_NULL(graph);
826   auto ret = graph->get_return();
827   MS_EXCEPTION_IF_NULL(ret);
828   auto return_input = ret->input(kFirstDataInputIndex);
829   // return node is a function
830   std::vector<AnfNodePtr> call_inputs = {
831     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
832   if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
833     auto return_input_cnode = return_input->cast<CNodePtr>();
834     auto partial_inputs = return_input_cnode->inputs();
835     call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
836   } else if (IsValueNode<KernelGraph>(return_input)) {  // return node is kernel graph
837     call_inputs.emplace_back(return_input);
838   } else {  // return node is value node
839     KernelGraphPtr kernel_graph = NewKernelGraph();
840     auto valid_inputs = kernel_graph->MutableValidInputs();
841     MS_EXCEPTION_IF_NULL(valid_inputs);
842     auto graph_inputs = kernel_graph->MutableInputs();
843     MS_EXCEPTION_IF_NULL(graph_inputs);
844     std::vector<AnfNodePtr> cnode_inputs = {return_input};
845     for (auto &real_input : real_inputs) {
846       auto new_parameter = kernel_graph->NewParameter(real_input->abstract());
847       valid_inputs->push_back(true);
848       graph_inputs->push_back(new_parameter);
849       cnode_inputs.push_back(new_parameter);
850     }
851     auto new_cnode = kernel_graph->NewCNode(cnode_inputs);
852     new_cnode->set_abstract(cnode->abstract());
853     std::vector<AnfNodePtr> return_inputs = {
854       kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimReturn->name()))), new_cnode};
855     auto return_node = kernel_graph->NewCNode(return_inputs);
856     return_node->set_abstract(cnode->abstract());
857     kernel_graph->set_return(return_node);
858     call_inputs.push_back(std::make_shared<ValueNode>(kernel_graph));
859   }
860 
861   // new call node inputs
862   for (auto &input_node : real_inputs) {
863     auto parameter_for_input = CreateNewParameterFromCNode(input_node, graph);
864     call_inputs.emplace_back(parameter_for_input);
865   }
866 
867   auto call_node = graph->NewCNode(call_inputs);
868   call_node->set_abstract(cnode->abstract());
869   // update return input
870   ret->set_input(kFirstDataInputIndex, call_node);
871 }
872 
CreateCallSwitchLayerInputs(const CNodePtr & cnode,KernelGraph * graph)873 std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
874   MS_EXCEPTION_IF_NULL(cnode);
875   MS_EXCEPTION_IF_NULL(graph);
876   std::vector<AnfNodePtr> cnode_inputs = {
877     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
878   auto attr_input = cnode->input(kAnfPrimitiveIndex);
879   MS_EXCEPTION_IF_NULL(attr_input);
880   auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
881   auto switch_layer_cnode = cnode_input->cast<CNodePtr>();
882   MS_EXCEPTION_IF_NULL(switch_layer_cnode);
883   std::vector<AnfNodePtr> switch_layer_inputs = {switch_layer_cnode->input(kAnfPrimitiveIndex),
884                                                  switch_layer_cnode->input(kFirstDataInputIndex)};
885   auto make_tuple_node = switch_layer_cnode->input(kMakeTupleInSwitchLayerIndex);
886   MS_EXCEPTION_IF_NULL(make_tuple_node);
887   auto node = make_tuple_node->cast<CNodePtr>();
888   MS_EXCEPTION_IF_NULL(node);
889   auto make_tuple_inputs = node->inputs();
890   // there are real inputs in call, should put it to make_tuple in switch_layer
891   std::vector<AnfNodePtr> real_inputs;
892   for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) {
893     real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx)));
894   }
895   std::vector<AnfNodePtr> new_make_tuple_inputs = {
896     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
897   for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
898     auto partial_idx = make_tuple_inputs[idx];
899     MS_EXCEPTION_IF_NULL(cnode->abstract());
900     std::vector<AnfNodePtr> new_partial_inputs;
901     KernelGraphPtr partial_kernel_graph;
902     // switch_layer node input is partial cnode
903     if (AnfAlgo::CheckPrimitiveType(partial_idx, prim::kPrimPartial)) {
904       auto partial_node = partial_idx->cast<CNodePtr>();
905       MS_EXCEPTION_IF_NULL(partial_node);
906       auto partial_input = partial_node->input(kFirstDataInputIndex);
907       partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
908       new_partial_inputs = partial_node->inputs();
909     } else if (IsValueNode<KernelGraph>(partial_idx)) {  // switch_layer node input is kernel graph value node
910       new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
911       new_partial_inputs.emplace_back(partial_idx);
912       partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_idx);
913     }
914     // when branch in swich_layer return function
915     MS_EXCEPTION_IF_NULL(partial_kernel_graph);
916     auto ret = partial_kernel_graph->get_return();
917     MS_EXCEPTION_IF_NULL(ret);
918     auto return_input = ret->input(kFirstDataInputIndex);
919     if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || return_input->isa<ValueNode>()) {
920       ProcessNodeRetFunc(cnode, partial_kernel_graph.get(), real_inputs);
921     }
922     // partial node add input args
923     new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
924     // create new partial node
925     auto new_partial = graph->NewCNode(new_partial_inputs);
926     new_make_tuple_inputs.emplace_back(new_partial);
927   }
928   auto new_make_tuple = graph->NewCNode(new_make_tuple_inputs);
929   auto abstract = make_tuple_node->abstract();
930   if (abstract == nullptr) {
931     abstract = std::make_shared<abstract::AbstractTuple>(AbstractBasePtrList());
932   }
933   new_make_tuple->set_abstract(abstract);
934   switch_layer_inputs.emplace_back(new_make_tuple);
935   auto new_switch_layer = graph->NewCNode(switch_layer_inputs);
936   cnode_inputs.emplace_back(new_switch_layer);
937   return cnode_inputs;
938 }
939 
CreateSwitchOrPartialNode(const CNodePtr & cnode,KernelGraph * graph)940 std::vector<AnfNodePtr> SessionBasic::CreateSwitchOrPartialNode(const CNodePtr &cnode, KernelGraph *graph) {
941   MS_EXCEPTION_IF_NULL(cnode);
942   MS_EXCEPTION_IF_NULL(graph);
943   // create primitive of cnode:call(partial or switch or switch_layer)
944   std::vector<AnfNodePtr> cnode_inputs = {
945     graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
946   auto attr_input = cnode->input(kAnfPrimitiveIndex);
947   MS_EXCEPTION_IF_NULL(attr_input);
948   auto cnode_input = graph->GetBackendAnfByFrontAnf(attr_input);
949   if (cnode_input == nullptr) {
950     MS_LOG(ERROR) << "CNode input[0] is CNode:" << attr_input->DebugString() << ", but input[0] has not been created.";
951     return {};
952   }
953   // if the node is partial, insert the inputs of partial to the call
954   if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimPartial)) {
955     auto partial_node = attr_input->cast<CNodePtr>();
956     MS_EXCEPTION_IF_NULL(partial_node);
957     auto partial_inputs = partial_node->inputs();
958     (void)std::transform(partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end(),
959                          std::back_inserter(cnode_inputs), [&graph](const AnfNodePtr &node) {
960                            MS_EXCEPTION_IF_NULL(graph->GetBackendAnfByFrontAnf(node));
961                            return graph->GetBackendAnfByFrontAnf(node);
962                          });
963     return cnode_inputs;
964   } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitch)) {
965     return CreateCallSwitchInputs(cnode, graph);
966   } else if (AnfAlgo::CheckPrimitiveType(cnode_input, prim::kPrimSwitchLayer)) {
967     return CreateCallSwitchLayerInputs(cnode, graph);
968   }
969   MS_LOG(ERROR) << "CNode:" << cnode->DebugString() << " input[0]" << cnode_input->DebugString()
970                 << "must be partial or switch or switch_layer.";
971   return {};
972 }
973 
CreateValueNode(const CNodePtr & cnode,KernelGraph * graph)974 std::vector<AnfNodePtr> SessionBasic::CreateValueNode(const CNodePtr &cnode, KernelGraph *graph) {
975   MS_EXCEPTION_IF_NULL(cnode);
976   MS_EXCEPTION_IF_NULL(graph);
977   std::vector<AnfNodePtr> cnode_inputs;
978   auto attr_input = cnode->input(kAnfPrimitiveIndex);
979   MS_EXCEPTION_IF_NULL(attr_input);
980   if (AnfAlgo::IsGraphKernel(cnode)) {
981     auto fg = AnfAlgo::GetCNodeFuncGraphPtr(cnode);
982     MS_EXCEPTION_IF_NULL(fg);
983     auto new_fg = BasicClone(fg);
984     cnode_inputs.push_back(std::make_shared<ValueNode>(new_fg));
985   } else {
986     // create primitive of cnode:call
987     cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
988     // create a ValueNode<KernelGraph> as input of cnode:call
989     if (graph->GetBackendAnfByFrontAnf(attr_input) != nullptr) {
990       cnode_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(attr_input));
991     } else {
992       auto new_value_node = CreateValueNodeKernelGraph(attr_input, graph);
993       if (new_value_node != nullptr) {
994         cnode_inputs.emplace_back(new_value_node);
995       }
996     }
997   }
998   return cnode_inputs;
999 }
1000 
CreateCNodeInputs(const CNodePtr & cnode,KernelGraph * graph,std::vector<AnfNodePtr> * cnode_inputs)1001 void SessionBasic::CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs) {
1002   MS_EXCEPTION_IF_NULL(cnode);
1003   MS_EXCEPTION_IF_NULL(graph);
1004   if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
1005     (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex)));
1006     for (size_t index = kFirstBranchInSwitch; index < cnode->inputs().size(); index++) {
1007       auto node_input = cnode->input(index);
1008       auto switch_input = CreateSwitchInput(cnode, node_input, graph);
1009       (void)cnode_inputs->emplace_back(switch_input);
1010     }
1011   } else {
1012     for (size_t input_idx = kFirstDataInputIndex; input_idx < cnode->inputs().size(); input_idx++) {
1013       auto anf = cnode->input(input_idx);
1014       MS_EXCEPTION_IF_NULL(anf);
1015       // anf has been created before
1016       if (graph->GetBackendAnfByFrontAnf(anf) != nullptr) {
1017         (void)cnode_inputs->emplace_back(graph->GetBackendAnfByFrontAnf(anf));
1018         continue;
1019       } else if (IsValueNode<None>(anf)) {
1020         continue;
1021       }
1022       MS_LOG(EXCEPTION) << "Unexpected input[" << anf->DebugString() << "]";
1023     }
1024   }
1025 }
1026 
CreateNewCNode(const CNodePtr & cnode,KernelGraph * graph)1027 CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, KernelGraph *graph) {
1028   MS_EXCEPTION_IF_NULL(cnode);
1029   MS_EXCEPTION_IF_NULL(graph);
1030   std::vector<AnfNodePtr> cnode_inputs;
1031   auto attr_input = cnode->input(kAnfPrimitiveIndex);
1032   MS_EXCEPTION_IF_NULL(attr_input);
1033   if (IsValueNode<FuncGraph>(attr_input)) {
1034     // cnode is a graph or a call
1035     cnode_inputs = CreateValueNode(cnode, graph);
1036   } else if (attr_input->isa<CNode>()) {
1037     // cnode ia a call (partial/switch/switch_layer)
1038     // 1. take the args of call to the partial node, as the real_args to call switch's or switch_layer's child graph
1039     // 2. the call in frontend is map to the partial/switch/switch_layer in backend and haven't been created
1040     cnode_inputs = CreateSwitchOrPartialNode(cnode, graph);
1041     if (cnode_inputs.empty()) {
1042       MS_LOG_ERROR << "Create switch or partial failed, cnode:" << cnode->DebugString();
1043       return nullptr;
1044     }
1045   } else {
1046     // get primitive of old node
1047     auto prim = AnfAlgo::GetCNodePrimitive(cnode);
1048     MS_EXCEPTION_IF_NULL(prim);
1049     // push attr to inputs[0] of new cnode
1050     cnode_inputs = {graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(*prim)))};
1051   }
1052   // handle inputs of cnode except primitive
1053   CreateCNodeInputs(cnode, graph, &cnode_inputs);
1054   TraceGuard trace_guard(std::make_shared<TraceCopy>(cnode->debug_info()));
1055   auto new_cnode = graph->NewCNodeWithInfos(cnode_inputs, cnode);
1056   // if the cnode is call switch, remove call
1057   if (new_cnode->inputs().size() > 1) {
1058     auto first_input = new_cnode->input(kFirstDataInputIndex);
1059     MS_EXCEPTION_IF_NULL(first_input);
1060     if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
1061         AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitch)) {
1062       new_cnode = first_input->cast<CNodePtr>();
1063     }
1064     if (AnfAlgo::CheckPrimitiveType(new_cnode, prim::kPrimCall) &&
1065         AnfAlgo::CheckPrimitiveType(first_input, prim::kPrimSwitchLayer)) {
1066       auto abstract = cnode->abstract();
1067       new_cnode = first_input->cast<CNodePtr>();
1068       new_cnode->set_abstract(abstract);
1069     }
1070   }
1071   return new_cnode;
1072 }
1073 
CreateValueNodeKernelGraph(const AnfNodePtr & anf,KernelGraph * graph)1074 ValueNodePtr SessionBasic::CreateValueNodeKernelGraph(const AnfNodePtr &anf, KernelGraph *graph) {
1075   MS_EXCEPTION_IF_NULL(anf);
1076   MS_EXCEPTION_IF_NULL(graph);
1077   auto value_node = anf->cast<ValueNodePtr>();
1078   MS_EXCEPTION_IF_NULL(value_node);
1079   auto sub_func_graph = AnfAlgo::GetValueNodeFuncGraph(anf);
1080   MS_EXCEPTION_IF_NULL(sub_func_graph);
1081   if (front_backend_graph_map_.find(sub_func_graph.get()) == front_backend_graph_map_.end()) {
1082     MS_LOG(EXCEPTION) << "FuncGraph: " << sub_func_graph->ToString() << " has not been transformed to KernelGraph.";
1083   }
1084   auto sub_kernel_graph = front_backend_graph_map_[sub_func_graph.get()];
1085 
1086   ValueNodePtr new_value_node = std::make_shared<ValueNode>(sub_kernel_graph);
1087   new_value_node->set_abstract(value_node->abstract());
1088   // create new kernel_info of new value_node
1089   auto kernel_info = std::make_shared<device::KernelInfo>();
1090   new_value_node->set_kernel_info(kernel_info);
1091   // create kernel_build_info for new value node
1092   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
1093   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
1094   AnfAlgo::SetGraphId(graph->graph_id(), new_value_node.get());
1095 
1096   graph->FrontBackendlMapAdd(anf, new_value_node);
1097 
1098   return new_value_node;
1099 }
1100 
CreateNewParameter(const AnfNodePtr & anf,KernelGraph * graph)1101 ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph *graph) {
1102   MS_EXCEPTION_IF_NULL(anf);
1103   MS_EXCEPTION_IF_NULL(graph);
1104   if (!anf->isa<Parameter>()) {
1105     MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
1106   }
1107 
1108   auto param_value = GetParamDefaultValue(anf);
1109   ParameterPtr new_parameter = nullptr;
1110   // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
1111   if (param_value != nullptr) {
1112     new_parameter = param_value->parameter();
1113     if (new_parameter == nullptr) {
1114       TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1115       new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1116       param_value->set_parameter(new_parameter);
1117     }
1118   } else {
1119     TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info()));
1120     new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
1121   }
1122 
1123   new_parameter->IncreaseUsedGraphCount();
1124 
1125   return new_parameter;
1126 }
1127 
ConstructKernelGraph(const AnfNodePtrList & lst,const AnfNodePtrList & outputs,bool common_opt)1128 KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs,
1129                                                   bool common_opt) {
1130   std::unordered_map<AnfNodePtr, AnfNodePtr> other_graph_cnode;
1131   auto graph = NewKernelGraph();
1132   MS_EXCEPTION_IF_NULL(graph);
1133   MS_LOG(INFO) << "Create graph: " << graph->graph_id();
1134   for (const auto &node : lst) {
1135     MS_EXCEPTION_IF_NULL(node);
1136     MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
1137     if (!node->isa<CNode>()) {
1138       MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " is not CNode";
1139     }
1140     auto cnode = node->cast<CNodePtr>();
1141     MS_EXCEPTION_IF_NULL(cnode);
1142     // create a new cnode object
1143     auto new_cnode = CreateNewCNode(cnode, graph.get(), &other_graph_cnode);
1144     MS_EXCEPTION_IF_NULL(new_cnode);
1145     new_cnode->set_abstract(cnode->abstract());
1146     new_cnode->set_scope(cnode->scope());
1147     if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
1148       new_cnode->set_fullname_with_scope(cnode->input(kFirstDataInputIndex)->fullname_with_scope());
1149     }
1150     // record map relations between anf from ME and new anf node used in backend
1151     graph->FrontBackendlMapAdd(node, new_cnode);
1152   }
1153   // add a make_tuple at the end of graph as output
1154   graph->set_output(ConstructOutput(outputs, graph));
1155   FuncGraphManagerPtr manager = MakeManager({graph});
1156   if (manager) {
1157     manager->AddFuncGraph(graph);
1158     graph->set_manager(manager);
1159   }
1160   graph->SetExecOrderByDefault();
1161 
1162 #ifndef ENABLE_SECURITY
1163   if (ExistSummaryNode(graph.get())) {
1164     graph->set_summary_node_exist(true);
1165   }
1166 #endif
1167 
1168   UnifyMindIR(graph);
1169   // Update Graph Dynamic Shape Attr
1170   UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
1171   UpdateGraphAquireGilAttr(NOT_NULL(graph));
1172   if (common_opt) {
1173     opt::BackendCommonOptimization(graph);
1174   }
1175   graph->SetInputNodes();
1176   SetInputNodeUsage(graph, manager);
1177   graph->SetOptimizerFlag();
1178   return graph;
1179 }
1180 
GetSingleOpGraphInfo(const CNodePtr & kernel,const std::vector<tensor::TensorPtr> & input_tensors)1181 GraphInfo SessionBasic::GetSingleOpGraphInfo(const CNodePtr &kernel,
1182                                              const std::vector<tensor::TensorPtr> &input_tensors) {
1183   MS_EXCEPTION_IF_NULL(kernel);
1184   auto prim = AnfAlgo::GetCNodePrimitive(kernel);
1185   MS_EXCEPTION_IF_NULL(prim);
1186   const AbstractBasePtr &abstract = kernel->abstract();
1187   MS_EXCEPTION_IF_NULL(abstract);
1188   size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
1189   GraphInfo graph_info;
1190   // get input tensor info
1191   for (const auto &tensor : input_tensors) {
1192     MS_EXCEPTION_IF_NULL(tensor);
1193     auto tensor_shape = tensor->shape();
1194     (void)std::for_each(tensor_shape.begin(), tensor_shape.end(),
1195                         [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); });
1196     (void)graph_info.append(std::to_string(tensor->data_type()) + "_");
1197     if (tensor->device_address() != nullptr) {
1198       const auto type_id = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id();
1199       (void)graph_info.append(std::to_string(type_id) + "_");
1200       const auto format = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format();
1201       (void)graph_info.append(format + "_");
1202     }
1203     for (const auto &padding_type : tensor->padding_type()) {
1204       (void)graph_info.append(std::to_string(padding_type) + "_");
1205     }
1206   }
1207   // get attr info
1208   const auto &attr_map = prim->attrs();
1209   (void)std::for_each(attr_map.begin(), attr_map.end(), [&](const auto &element) {
1210     if (element.second->ToString().empty()) {
1211       return;
1212     }
1213     (void)graph_info.append(element.second->ToString() + "_");
1214   });
1215   auto build_shape = abstract->BuildShape();
1216   MS_EXCEPTION_IF_NULL(build_shape);
1217   (void)graph_info.append(build_shape->ToString() + "_");
1218   for (size_t output_index = 0; output_index < output_num; output_index += 1) {
1219     const auto output_type = AnfAlgo::GetOutputInferDataType(kernel, output_index);
1220     (void)graph_info.append(std::to_string(output_type) + "_");
1221   }
1222   graph_info.append(std::to_string(prim->id()));
1223   return graph_info;
1224 }
1225 
GetSingleOpRunInfo(const CNodePtr cnode,OpRunInfo * run_info)1226 void SessionBasic::GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) {
1227   MS_EXCEPTION_IF_NULL(cnode);
1228   MS_EXCEPTION_IF_NULL(run_info);
1229   auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
1230   run_info->primitive = primitive;
1231   run_info->op_name = primitive->name();
1232   const auto &abstract = cnode->abstract();
1233   if (abstract == nullptr) {
1234     MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
1235   }
1236   run_info->abstract = abstract;
1237   const auto &shape = abstract->BuildShape();
1238   MS_EXCEPTION_IF_NULL(shape);
1239   run_info->is_dynamic_shape = shape->IsDynamic();
1240 }
1241 
GetParameterIndex(const KernelGraph * graph,const std::vector<tensor::TensorPtr> & inputs,std::map<AnfNodePtr,size_t> * parameter_index)1242 void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
1243                                      std::map<AnfNodePtr, size_t> *parameter_index) {
1244   size_t index = 0;
1245   for (const auto &input_node : graph->inputs()) {
1246     auto params = AnfAlgo::GetAllOutput(input_node);
1247     for (const auto &param : params) {
1248       if (index >= inputs.size()) {
1249         MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
1250                           << ", input size: " << inputs.size();
1251       }
1252       const auto &input = inputs[index];
1253       MS_EXCEPTION_IF_NULL(input);
1254       // Check shape of input and parameter
1255       const auto &input_shape = input->shape();
1256       const auto &param_shape = AnfAlgo::GetOutputInferShape(param, 0);
1257       if (input_shape.size() != param_shape.size()) {
1258         MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
1259                           << ", parameter: " << param->fullname_with_scope();
1260       }
1261       bool is_dynamic = param->Shape()->IsDynamic();
1262       for (size_t i = 0; i < input_shape.size(); i += 1) {
1263         if (input_shape[i] < 0 || (static_cast<size_t>(input_shape[i]) != param_shape[i] && !is_dynamic)) {
1264           MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index
1265                             << ", parameter: " << param->fullname_with_scope();
1266         }
1267       }
1268       parameter_index->emplace(param, index++);
1269     }
1270   }
1271 }
1272 
CreateOutputPlaceholder(const KernelGraphPtr & kernel_graph,const std::vector<tensor::TensorPtr> & input_tensors,VectorRef * const outputs,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)1273 void SessionBasic::CreateOutputPlaceholder(
1274   const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *const outputs,
1275   std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
1276   MS_EXCEPTION_IF_NULL(kernel_graph);
1277   MS_EXCEPTION_IF_NULL(outputs);
1278   MS_EXCEPTION_IF_NULL(output_indexes);
1279   auto anf_outputs = kernel_graph->outputs();
1280   size_t index = 0;
1281   for (auto &item : anf_outputs) {
1282     MS_EXCEPTION_IF_NULL(item);
1283     std::vector<size_t> indexes{index++};
1284     outputs->emplace_back(CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes));
1285   }
1286 }
1287 
GetRefCount(const KernelGraph * graph,std::map<KernelWithIndex,size_t> * ref_count)1288 void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) {
1289   MS_EXCEPTION_IF_NULL(graph);
1290   for (const auto &kernel : graph->execution_order()) {
1291     for (size_t i = 1; i < kernel->inputs().size(); i += 1) {
1292       const auto &input = kernel->input(i);
1293       auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
1294       const auto &node = kernel_with_index.first;
1295       if (node->isa<CNode>()) {
1296         (*ref_count)[kernel_with_index] += 1;
1297       }
1298     }
1299   }
1300 }
1301 
HandleOpInputs(const std::set<KernelWithIndex> & input_kernel,std::map<KernelWithIndex,size_t> * ref_count,std::map<KernelWithIndex,tensor::TensorPtr> * op_output_map)1302 void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
1303                                   std::map<KernelWithIndex, size_t> *ref_count,
1304                                   std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) {
1305   MS_EXCEPTION_IF_NULL(ref_count);
1306   MS_EXCEPTION_IF_NULL(op_output_map);
1307   for (auto &kernel_with_index : input_kernel) {
1308     MS_EXCEPTION_IF_NULL(kernel_with_index.first);
1309     if (!kernel_with_index.first->isa<CNode>()) {
1310       continue;
1311     }
1312     auto ref_iter = ref_count->find(kernel_with_index);
1313     if (ref_iter == ref_count->end()) {
1314       MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
1315                         << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
1316     }
1317     // Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
1318     ref_iter->second -= 1;
1319     if (ref_iter->second != 0) {
1320       continue;
1321     }
1322     ref_count->erase(ref_iter);
1323     auto output_iter = op_output_map->find(kernel_with_index);
1324     if (output_iter == op_output_map->end()) {
1325       MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
1326                         << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
1327     }
1328     op_output_map->erase(output_iter);
1329   }
1330 }
1331 
HandleOpOutputs(const AnfNodePtr & kernel,const VectorRef & op_outputs,const std::map<KernelWithIndex,size_t> & ref_count,std::map<KernelWithIndex,tensor::TensorPtr> * op_output_map,GraphOutputInfo * const graph_output_info)1332 void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
1333                                    const std::map<KernelWithIndex, size_t> &ref_count,
1334                                    std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map,
1335                                    GraphOutputInfo *const graph_output_info) {
1336   MS_EXCEPTION_IF_NULL(kernel);
1337   MS_EXCEPTION_IF_NULL(op_output_map);
1338   MS_EXCEPTION_IF_NULL(graph_output_info);
1339   MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
1340   auto output_tensors = TransformVectorRefToMultiTensor(op_outputs);
1341   if (output_tensors.size() > op_outputs.size()) {
1342     MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
1343   }
1344   size_t out_index = 0;
1345   for (const auto &output_tensor : output_tensors) {
1346     auto kernel_with_index = make_pair(kernel, out_index++);
1347     if (ref_count.find(kernel_with_index) != ref_count.end()) {
1348       (*op_output_map)[kernel_with_index] = output_tensor;
1349     }
1350     const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
1351     if (iter == graph_output_info->output_indexes.end()) {
1352       continue;
1353     }
1354     const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
1355     for (const auto &ref_indexes : multiple_ref_indexes) {
1356       size_t n = 0;
1357       const VectorRef *cur_vector_ref = graph_output_info->graph_outputs;
1358       for (; n < ref_indexes.size() - 1; n += 1) {
1359         size_t index = ref_indexes.at(n);
1360         if (index >= cur_vector_ref->size()) {
1361           MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
1362                             << cur_vector_ref->size();
1363         }
1364         const BaseRef &base_ref = (*cur_vector_ref)[index];
1365         if (!utils::isa<VectorRef>(base_ref)) {
1366           MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
1367         }
1368         cur_vector_ref = &utils::cast<VectorRef>(base_ref);
1369       }
1370       BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
1371       tensor_ref = output_tensor;
1372       graph_output_info->graph_output_tensors.emplace_back(output_tensor);
1373     }
1374   }
1375 }
GetValueNodeOutputTensor(const AnfNodePtr & node,size_t output_index)1376 TensorPtr SessionBasic::GetValueNodeOutputTensor(const AnfNodePtr &node, size_t output_index) {
1377   MS_EXCEPTION_IF_NULL(node);
1378   if (!node->isa<ValueNode>()) {
1379     return nullptr;
1380   }
1381   auto value_node = node->cast<ValueNodePtr>();
1382   MS_EXCEPTION_IF_NULL(value_node);
1383   auto value = GetValueNode(value_node);
1384   MS_EXCEPTION_IF_NULL(value);
1385   if (value->isa<ValueTuple>()) {
1386     auto value_tuple = value->cast<ValueTuplePtr>();
1387     MS_EXCEPTION_IF_NULL(value_tuple);
1388     if (output_index >= value_tuple->size()) {
1389       MS_LOG(EXCEPTION) << "Index " << output_index << "is out of value tuple range";
1390     }
1391     auto tensor_value = value_tuple->value()[output_index];
1392     if (tensor_value->isa<tensor::Tensor>()) {
1393       return tensor_value->cast<tensor::TensorPtr>();
1394     }
1395   } else if (value->isa<tensor::Tensor>()) {
1396     if (output_index != 0) {
1397       MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << output_index;
1398     }
1399     return value->cast<TensorPtr>();
1400   }
1401   return nullptr;
1402 }
1403 
GetParameterOutputTensor(const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs)1404 TensorPtr SessionBasic::GetParameterOutputTensor(const AnfNodePtr &node,
1405                                                  const std::map<AnfNodePtr, size_t> &parameter_index,
1406                                                  const std::vector<tensor::TensorPtr> &graph_inputs) {
1407   MS_EXCEPTION_IF_NULL(node);
1408   if (!node->isa<Parameter>()) {
1409     return nullptr;
1410   }
1411   const auto &iter = parameter_index.find(node);
1412   if (iter == parameter_index.end()) {
1413     MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, parameter = " << node->DebugString();
1414   }
1415   const size_t index = iter->second;
1416   if (index >= graph_inputs.size()) {
1417     MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " << index
1418                       << ", input tensor size = " << graph_inputs.size();
1419   }
1420   return graph_inputs[index];
1421 }
1422 
GetCNodeOutputTensor(const KernelWithIndex & kernel_with_index,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output)1423 TensorPtr SessionBasic::GetCNodeOutputTensor(const KernelWithIndex &kernel_with_index,
1424                                              const std::map<KernelWithIndex, tensor::TensorPtr> &op_output) {
1425   const auto &iter = op_output.find(kernel_with_index);
1426   if (iter == op_output.end()) {
1427     MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << kernel_with_index.first->DebugString();
1428   }
1429   return iter->second;
1430 }
1431 
GetOpInputTensors(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputTensorInfo * input_tensor_info)1432 void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
1433                                      const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
1434                                      const std::map<AnfNodePtr, size_t> &parameter_index,
1435                                      const std::vector<tensor::TensorPtr> &graph_inputs,
1436                                      InputTensorInfo *input_tensor_info) {
1437   MS_EXCEPTION_IF_NULL(cnode);
1438   MS_EXCEPTION_IF_NULL(input_tensor_info);
1439   const auto input_tensor_num = AnfAlgo::GetInputTensorNum(cnode);
1440   for (size_t i = 1; i <= input_tensor_num; i += 1) {
1441     const auto &input = cnode->input(i);
1442     auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
1443     auto real_input = kernel_with_index.first;
1444     MS_EXCEPTION_IF_NULL(real_input);
1445     tensor::TensorPtr tensor = nullptr;
1446     if (real_input->isa<ValueNode>()) {
1447       tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second);
1448     } else if (real_input->isa<Parameter>()) {
1449       tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
1450     } else if (real_input->isa<CNode>()) {
1451       tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
1452       if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
1453         CheckInputTensorShape(tensor, cnode, i - 1);
1454       }
1455       input_tensor_info->input_kernel.insert(kernel_with_index);
1456     } else {
1457       MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
1458     }
1459     MS_EXCEPTION_IF_NULL(tensor);
1460     MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
1461                   << real_input->fullname_with_scope() << "-" << kernel_with_index.second;
1462     input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask
1463                                                                               : kParameterDataTensorMask);
1464     input_tensor_info->input_tensors.emplace_back(tensor);
1465   }
1466 }
1467 
GetOpInputTensorByIndex(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::TensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputTensorInfo * const input_tensor_info,size_t input_index)1468 tensor::TensorPtr SessionBasic::GetOpInputTensorByIndex(const CNodePtr &cnode,
1469                                                         const std::map<KernelWithIndex, tensor::TensorPtr> &op_output,
1470                                                         const std::map<AnfNodePtr, size_t> &parameter_index,
1471                                                         const std::vector<tensor::TensorPtr> &graph_inputs,
1472                                                         InputTensorInfo *const input_tensor_info, size_t input_index) {
1473   MS_EXCEPTION_IF_NULL(cnode);
1474   MS_EXCEPTION_IF_NULL(input_tensor_info);
1475   if (input_index >= cnode->inputs().size() - 1) {
1476     MS_LOG(EXCEPTION) << "Input index is out of range:" << cnode->inputs().size() << ",cnode:" << cnode->DebugString();
1477   }
1478 
1479   const auto &input = cnode->input(input_index + 1);
1480   auto kernel_with_index = AnfAlgo::VisitKernel(input, 0);
1481   auto real_input = kernel_with_index.first;
1482   MS_EXCEPTION_IF_NULL(real_input);
1483 
1484   if (real_input->isa<Parameter>()) {
1485     return GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
1486   } else if (real_input->isa<CNode>()) {
1487     tensor::TensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
1488     if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
1489       CheckInputTensorShape(tensor, cnode, input_index);
1490     }
1491     input_tensor_info->input_kernel.insert(kernel_with_index);
1492     return tensor;
1493   } else {
1494     MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
1495   }
1496 }
1497 
CreateCNodeOfKernelGraph(const AnfNodePtr & node,KernelGraph * graph)1498 bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph *graph) {
1499   MS_EXCEPTION_IF_NULL(node);
1500   MS_EXCEPTION_IF_NULL(graph);
1501   auto cnode = node->cast<CNodePtr>();
1502   MS_EXCEPTION_IF_NULL(cnode);
1503   // create a new cnode object
1504   auto new_cnode = CreateNewCNode(cnode, graph);
1505   if (new_cnode == nullptr) {
1506     return false;
1507   }
1508   new_cnode->set_abstract(cnode->abstract());
1509   std::string fullname;
1510   if (cnode->input(kAnfPrimitiveIndex)->isa<CNode>()) {
1511     fullname = cnode->input(kAnfPrimitiveIndex)->fullname_with_scope();
1512   } else if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) {
1513     fullname = cnode->input(kFirstDataInputIndex)->fullname_with_scope();
1514   } else {
1515     fullname = cnode->fullname_with_scope();
1516   }
1517   new_cnode->set_fullname_with_scope(fullname);
1518   new_cnode->set_scope(cnode->scope());
1519   graph->FrontBackendlMapAdd(node, new_cnode);
1520   SetReturnNode(new_cnode, graph);
1521   return true;
1522 }
1523 
ConstructKernelGraph(const FuncGraphPtr & func_graph,std::vector<KernelGraphPtr> * all_out_graph)1524 std::shared_ptr<KernelGraph> SessionBasic::ConstructKernelGraph(const FuncGraphPtr &func_graph,
1525                                                                 std::vector<KernelGraphPtr> *all_out_graph) {
1526   MS_EXCEPTION_IF_NULL(func_graph);
1527   MS_EXCEPTION_IF_NULL(all_out_graph);
1528   auto node_list = TopoSort(func_graph->get_return());
1529   auto graph = NewKernelGraph();
1530   MS_EXCEPTION_IF_NULL(graph);
1531   front_backend_graph_map_[func_graph.get()] = graph;
1532   MS_LOG(INFO) << "Create graph: " << graph->graph_id();
1533   for (const auto &node : node_list) {
1534     MS_EXCEPTION_IF_NULL(node);
1535     MS_LOG(DEBUG) << "Start create new cnode, node = " << node->DebugString();
1536     // Create parameter
1537     if (node->isa<Parameter>()) {
1538       auto graph_inputs = graph->MutableInputs();
1539       MS_EXCEPTION_IF_NULL(graph_inputs);
1540       auto new_parameter = CreateNewParameter(node, graph.get());
1541       graph_inputs->push_back(new_parameter);
1542       graph->FrontBackendlMapAdd(node, new_parameter);
1543       continue;
1544     }
1545     // Create value node
1546     if (node->isa<ValueNode>()) {
1547       // Create common value node
1548       if (!IsValueNode<FuncGraph>(node)) {
1549         (void)CreateNewValueNode(node, graph.get());
1550         continue;
1551       }
1552       // Create child kernel graph according ValueNode<FuncGraph>
1553       FuncGraphPtr child_graph = AnfAlgo::GetValueNodeFuncGraph(node);
1554       if (front_backend_graph_map_.find(child_graph.get()) == front_backend_graph_map_.end()) {
1555         (void)ConstructKernelGraph(child_graph, all_out_graph);
1556       }
1557       (void)CreateValueNodeKernelGraph(node, graph.get());
1558       continue;
1559     }
1560     // Create cnode
1561     if (!CreateCNodeOfKernelGraph(node, graph.get())) {
1562 #ifdef ENABLE_DUMP_IR
1563       DumpIR("construct_kernel_graph_fail.ir", func_graph);
1564 #endif
1565       MS_LOG(EXCEPTION) << "Construct func graph " << func_graph->ToString() << " failed."
1566                         << trace::DumpSourceLines(node);
1567     }
1568   }
1569 
1570   AddParameterToGraphInputs(func_graph->parameters(), graph.get());
1571   FuncGraphManagerPtr manager = MakeManager({graph});
1572   graph->SetInputNodes();
1573   SetInputNodeUsage(graph, manager);
1574   graph->SetExecOrderByDefault();
1575 
1576 #ifndef ENABLE_SECURITY
1577   if (ExistSummaryNode(graph.get())) {
1578     graph->set_summary_node_exist(true);
1579   }
1580 #endif
1581 
1582   all_out_graph->push_back(graph);
1583   return graph;
1584 }
1585 
AddParameterToGraphInputs(const std::vector<AnfNodePtr> & parameters,KernelGraph * graph)1586 void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &parameters, KernelGraph *graph) {
1587   MS_EXCEPTION_IF_NULL(graph);
1588   auto graph_inputs = graph->MutableInputs();
1589   MS_EXCEPTION_IF_NULL(graph_inputs);
1590   graph_inputs->clear();
1591   for (auto &parameter : parameters) {
1592     MS_EXCEPTION_IF_NULL(parameter);
1593     auto backend_parameter = graph->GetBackendAnfByFrontAnf(parameter);
1594     if (backend_parameter == nullptr) {
1595       // for example "def f(x,y,z) {return x + y}", parameter z in unused
1596       auto new_parameter = CreateNewParameter(parameter, graph);
1597       graph_inputs->push_back(new_parameter);
1598       MS_LOG(INFO) << "Can't find parameter:" << parameter->DebugString();
1599       continue;
1600     }
1601     graph_inputs->push_back(backend_parameter);
1602   }
1603 }
1604 
UpdateOutputs(const std::shared_ptr<KernelGraph> & kernel_graph,VectorRef * const outputs,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node) const1605 void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
1606                                  const std::vector<tensor::TensorPtr> &input_tensors,
1607                                  std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const {
1608   MS_EXCEPTION_IF_NULL(kernel_graph);
1609   MS_EXCEPTION_IF_NULL(outputs);
1610   MS_EXCEPTION_IF_NULL(tensor_to_node);
1611   KernelMapTensor node_to_tensor;
1612   auto anf_outputs = kernel_graph->outputs();
1613   for (auto &item : anf_outputs) {
1614     MS_EXCEPTION_IF_NULL(item);
1615     MS_LOG(DEBUG) << "Update output[" << item->DebugString() << "]";
1616     outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
1617   }
1618 
1619   auto ms_context = MsContext::GetInstance();
1620   MS_EXCEPTION_IF_NULL(ms_context);
1621   for (auto &item : *tensor_to_node) {
1622     auto &tensor = item.first;
1623     auto &node = item.second.first;
1624     auto &output_index = item.second.second;
1625     DeviceAddressPtr address = nullptr;
1626     if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
1627         ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1628       address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
1629     } else {
1630       address = AnfAlgo::GetMutableOutputAddr(node, output_index);
1631     }
1632     MS_EXCEPTION_IF_NULL(tensor);
1633     tensor->set_device_address(address);
1634     tensor->SetNeedWait(false);
1635     MS_LOG(DEBUG) << "Debug address: Output tensor obj " << tensor.get() << ", tensor id " << tensor->id()
1636                   << ", device address " << tensor->device_address().get();
1637     if (AnfAlgo::IsDynamicShape(node)) {
1638       const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
1639       ShapeVector int_shape;
1640       (void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
1641       (void)tensor->set_shape(int_shape);
1642     }
1643     if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
1644       tensor->data_sync(false);
1645       tensor->set_sync_status(kNeedSyncHostToDevice);
1646     }
1647   }
1648 }
1649 
UpdateOutputAbstract(const std::shared_ptr<KernelGraph> & kernel_graph,OpRunInfo * op_run_info) const1650 void SessionBasic::UpdateOutputAbstract(const std::shared_ptr<KernelGraph> &kernel_graph,
1651                                         OpRunInfo *op_run_info) const {
1652   MS_EXCEPTION_IF_NULL(kernel_graph);
1653   MS_EXCEPTION_IF_NULL(op_run_info);
1654   const auto &kernels = kernel_graph->execution_order();
1655   for (const auto &kernel : kernels) {
1656     MS_EXCEPTION_IF_NULL(kernel);
1657     if (AnfAlgo::GetCNodeName(kernel) == op_run_info->op_name) {
1658       op_run_info->abstract = kernel->abstract();
1659     }
1660   }
1661 }
1662 
GetInputNeedLockTensors(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs)1663 std::vector<tensor::TensorPtr> SessionBasic::GetInputNeedLockTensors(const GraphId &graph_id,
1664                                                                      const std::vector<tensor::TensorPtr> &inputs) {
1665   auto graph = GetGraph(graph_id);
1666   MS_EXCEPTION_IF_NULL(graph);
1667   if (!graph->has_optimizer()) {
1668     return {};
1669   }
1670   auto input_nodes = graph->inputs();
1671   bool check_monad = false;
1672   if (input_nodes.size() == inputs.size()) {
1673     check_monad = true;
1674   }
1675   std::vector<tensor::TensorPtr> result;
1676   for (size_t i = 0; i < inputs.size(); ++i) {
1677     if (check_monad && HasAbstractMonad(input_nodes[i])) {
1678       continue;
1679     }
1680     auto &tensor = inputs[i];
1681     MS_EXCEPTION_IF_NULL(tensor);
1682     if (!tensor->IsGraphOutput()) {
1683       result.emplace_back(tensor);
1684     }
1685   }
1686   return result;
1687 }
1688 
CreateOutputTensors(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & input_tensors,VectorRef * outputs,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node)1689 void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
1690                                        VectorRef *outputs,
1691                                        std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
1692   auto kernel_graph = GetGraph(graph_id);
1693   MS_EXCEPTION_IF_NULL(kernel_graph);
1694   MS_EXCEPTION_IF_NULL(outputs);
1695   MS_EXCEPTION_IF_NULL(tensor_to_node);
1696   auto anf_outputs = kernel_graph->outputs();
1697   KernelMapTensor node_to_tensor;
1698   for (auto &item : anf_outputs) {
1699     MS_EXCEPTION_IF_NULL(item);
1700     MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
1701     outputs->emplace_back(CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
1702   }
1703   auto ms_context = MsContext::GetInstance();
1704   MS_EXCEPTION_IF_NULL(ms_context);
1705   auto enable_mem_scheduler = ms_context->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
1706   if (enable_mem_scheduler) {
1707     kernel_graph->SetOutputNodeToTensor(node_to_tensor);
1708   }
1709 }
1710 
UpdateOutputTensors(const VectorRef * outputs,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node,std::map<DeviceAddressPtr,DeviceAddressPtr> *)1711 void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
1712                                        const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
1713                                        std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
1714   auto context_ptr = MsContext::GetInstance();
1715   MS_EXCEPTION_IF_NULL(context_ptr);
1716   auto enable_mem_scheduler = context_ptr->get_param<bool>(MS_CTX_ENABLE_MEM_SCHEDULER);
1717   if (enable_mem_scheduler) {
1718     return;
1719   }
1720   MS_EXCEPTION_IF_NULL(outputs);
1721   for (const auto &item : *outputs) {
1722     if (utils::isa<VectorRefPtr>(item)) {
1723       const auto &vector_ref = utils::cast<VectorRef>(item);
1724       std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
1725       UpdateOutputTensors(&vector_ref, tensor_to_node, &new_to_old_device_address);
1726     } else if (utils::isa<tensor::TensorPtr>(item)) {
1727       const auto &tensor = utils::cast<tensor::TensorPtr>(item);
1728       MS_EXCEPTION_IF_NULL(tensor);
1729       const auto &iter = tensor_to_node.find(tensor);
1730       if (iter != tensor_to_node.end()) {
1731         const auto &node = iter->second.first;
1732         const auto &output_index = iter->second.second;
1733         if (!AnfAlgo::OutputAddrExist(node, output_index, true)) {
1734           continue;
1735         }
1736         const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);
1737         tensor->set_device_address(address);
1738 
1739         if (AnfAlgo::IsDynamicShape(node)) {
1740           const auto &updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
1741           ShapeVector int_shape;
1742           (void)std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
1743           (void)tensor->set_shape(int_shape);
1744         }
1745       }
1746       if (tensor->NeedSyncDeviceToHostImmediately()) {
1747         tensor->data_sync(false);
1748         tensor->set_device_address(nullptr);
1749         tensor->set_sync_status(kNeedSyncHostToDevice);
1750       }
1751     }
1752   }
1753 }
1754 
GetModelInputsInfo(uint32_t graph_id,std::vector<tensor::TensorPtr> * inputs,std::vector<std::string> * inputs_name) const1755 void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
1756                                       std::vector<std::string> *inputs_name) const {
1757   MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
1758   auto kernel_graph = GetGraph(graph_id);
1759   MS_EXCEPTION_IF_NULL(kernel_graph);
1760   MS_EXCEPTION_IF_NULL(inputs);
1761   MS_EXCEPTION_IF_NULL(inputs_name);
1762   auto kernel_graph_inputs = kernel_graph->inputs();
1763   // find parameters of graph inputs
1764   for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
1765     if (!kernel_graph_inputs[i]->isa<Parameter>()) {
1766       MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
1767       continue;
1768     }
1769     auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
1770     if (!AnfAlgo::IsParameterWeight(parameter)) {
1771       vector<int64_t> input_shape;
1772       auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
1773       (void)std::transform(parameter_shape.begin(), parameter_shape.end(), std::back_inserter(input_shape),
1774                            [](const size_t dim) { return SizeToLong(dim); });
1775       auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
1776       auto data_type = kernel_build_info->GetOutputDeviceType(0);
1777       auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
1778       inputs->push_back(ms_tensor);
1779       inputs_name->push_back(parameter->name());
1780     }
1781   }
1782 }
1783 
GetModelOutputsInfo(uint32_t graph_id,std::vector<tensor::TensorPtr> * outputs,std::vector<std::string> * output_names) const1784 void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
1785                                        std::vector<std::string> *output_names) const {
1786   std::vector<tensor::TensorPtr> inputs;
1787   std::vector<std::string> input_names;
1788   GetModelInputsInfo(graph_id, &inputs, &input_names);
1789 
1790   auto kernel_graph = GetGraph(graph_id);
1791   MS_EXCEPTION_IF_NULL(kernel_graph);
1792   MS_EXCEPTION_IF_NULL(outputs);
1793   MS_EXCEPTION_IF_NULL(output_names);
1794 
1795   VectorRef vector_outputs;
1796   std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
1797   KernelMapTensor node_to_tensor;
1798   auto anf_outputs = kernel_graph->outputs();
1799   for (auto &item : anf_outputs) {
1800     MS_EXCEPTION_IF_NULL(item);
1801     MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
1802     vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node, &node_to_tensor));
1803   }
1804   *outputs = TransformVectorRefToMultiTensor(vector_outputs);
1805   for (size_t i = 0; i < outputs->size(); i++) {
1806     output_names->push_back("output" + std::to_string(i));
1807   }
1808 }
1809 
1810 #ifndef ENABLE_SECURITY
RegisterSummaryCallBackFunc(const CallBackFunc & callback)1811 void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
1812   MS_EXCEPTION_IF_NULL(callback);
1813   summary_callback_ = callback;
1814 }
1815 
SetSummaryNodes(KernelGraph * graph)1816 void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
1817   MS_LOG(DEBUG) << "Update summary Start";
1818   MS_EXCEPTION_IF_NULL(graph);
1819   if (!graph->summary_node_exist()) {
1820     return;
1821   }
1822   auto summary = graph->summary_nodes();
1823   auto apply_list = TopoSort(graph->get_return());
1824   for (auto &n : apply_list) {
1825     MS_EXCEPTION_IF_NULL(n);
1826     if (IsPrimitiveCNode(n, prim::kPrimScalarSummary) || IsPrimitiveCNode(n, prim::kPrimTensorSummary) ||
1827         IsPrimitiveCNode(n, prim::kPrimImageSummary) || IsPrimitiveCNode(n, prim::kPrimHistogramSummary)) {
1828       auto cnode = n->cast<CNodePtr>();
1829       MS_EXCEPTION_IF_NULL(cnode);
1830       if (cnode->inputs().size() <= kSummaryGetItem) {
1831         MS_LOG(EXCEPTION) << "The node Summary should have 2 inputs at least!";
1832       }
1833       auto node = cnode->input(kSummaryGetItem);
1834       MS_EXCEPTION_IF_NULL(node);
1835       auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, true);
1836       MS_EXCEPTION_IF_NULL(item_with_index.first);
1837       if (!AnfAlgo::IsRealKernel(item_with_index.first)) {
1838         MS_LOG(EXCEPTION) << "Unexpected node:" << item_with_index.first->DebugString();
1839       }
1840       summary[n->fullname_with_scope()] = item_with_index;
1841     }
1842   }
1843   graph->set_summary_nodes(summary);
1844   MS_LOG(DEBUG) << "Update summary end size: " << summary.size();
1845 }
1846 
Summary(KernelGraph * graph)1847 void SessionBasic::Summary(KernelGraph *graph) {
1848   if (summary_callback_ == nullptr) {
1849     return;
1850   }
1851   MS_EXCEPTION_IF_NULL(graph);
1852   bool exist_summary = graph->summary_node_exist();
1853   if (!exist_summary) {
1854     return;
1855   }
1856 
1857   static bool is_first = true;
1858   if (is_first && !IsSupportSummary()) {
1859     is_first = false;
1860     MS_LOG(ERROR) << "The Summary operator can not collect data correctly. Detail: the data sink mode is used and the"
1861                      " sink size(in model.train() python api) is not equal to 1.";
1862   }
1863   SetSummaryNodes(graph);
1864   auto summary_outputs = graph->summary_nodes();
1865   std::map<std::string, tensor::TensorPtr> params_list;
1866   // fetch outputs apply kernel in session & run callback functions
1867   for (auto &output_item : summary_outputs) {
1868     auto node = output_item.second.first;
1869     size_t index = IntToSize(output_item.second.second);
1870     auto address = AnfAlgo::GetOutputAddr(node, index);
1871     auto shape = AnfAlgo::GetOutputInferShape(node, index);
1872     TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
1873     std::vector<int64_t> temp_shape;
1874     (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
1875     tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
1876     MS_EXCEPTION_IF_NULL(address);
1877     if (!address->GetPtr()) {
1878       continue;
1879     }
1880     if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()),
1881                                    tensor->data_type(), tensor->data_c())) {
1882       MS_LOG(ERROR) << "Failed to sync output from device to host.";
1883     }
1884     tensor->set_sync_status(kNoNeedSync);
1885     params_list[output_item.first] = tensor;
1886   }
1887   // call callback function here
1888   summary_callback_(0, params_list);
1889 }
1890 #endif
1891 
1892 namespace {
CNodeFirstInputIsPrimitive(const AnfNodePtr & node)1893 bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) {
1894   if (node == nullptr) {
1895     return false;
1896   }
1897   auto cnode = node->cast<CNodePtr>();
1898   if (cnode == nullptr) {
1899     return false;
1900   }
1901   auto prim = cnode->input(kAnfPrimitiveIndex);
1902   if (prim == nullptr || !IsValueNode<Primitive>(prim)) {
1903     return false;
1904   }
1905   return true;
1906 }
1907 
ExtendNodeUsers(const FuncGraphManagerPtr & front_func_graph_manager,const AnfNodePtr & front_node)1908 std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager,
1909                                         const AnfNodePtr &front_node) {
1910   MS_EXCEPTION_IF_NULL(front_func_graph_manager);
1911   auto &users = front_func_graph_manager->node_users()[front_node];
1912   std::vector<AnfNodePtr> result;
1913   for (auto &user : users) {
1914     if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend) ||
1915         AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimLoad)) {
1916       auto depend_cnode = user.first->cast<CNodePtr>();
1917       if (depend_cnode == nullptr) {
1918         continue;
1919       }
1920       if (front_node != depend_cnode->input(1)) {
1921         continue;
1922       }
1923       auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
1924       result.insert(result.end(), res.begin(), res.end());
1925     } else if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) {
1926       auto res = ExtendNodeUsers(front_func_graph_manager, user.first);
1927       (void)result.insert(result.end(), res.begin(), res.end());
1928     } else {
1929       (void)result.emplace_back(user.first);
1930     }
1931   }
1932   return result;
1933 }
1934 
GetSupportedInternalNode(const AnfNodePtr & front_node)1935 AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) {
1936   MS_EXCEPTION_IF_NULL(front_node);
1937   if (!front_node->isa<CNode>()) {
1938     return nullptr;
1939   }
1940   if (AnfAlgo::IsRealKernel(front_node)) {
1941     return front_node;
1942   }
1943   if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
1944     return front_node;
1945   }
1946   if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimMakeTuple)) {
1947     auto cnode = front_node->cast<CNodePtr>();
1948     MS_EXCEPTION_IF_NULL(cnode);
1949     auto &inputs = cnode->inputs();
1950     if (inputs.size() > 1) {
1951       return GetSupportedInternalNode(inputs[1]);
1952     }
1953   }
1954   if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimDepend)) {
1955     auto cnode = front_node->cast<CNodePtr>();
1956     MS_EXCEPTION_IF_NULL(cnode);
1957     auto &inputs = cnode->inputs();
1958     if (inputs.size() >= kDependInputSize) {
1959       return GetSupportedInternalNode(inputs[kRealInputIndexInDepend]);
1960     }
1961   }
1962   return nullptr;
1963 }
1964 }  // namespace
1965 
1966 constexpr auto kMixTarget = "MixTarget";
1967 constexpr auto kNoTarget = "NoTarget";
AddPartialParametersMap(const AnfNodePtr & partial_node)1968 std::string SessionBasic::AddPartialParametersMap(const AnfNodePtr &partial_node) {
1969   MS_EXCEPTION_IF_NULL(partial_node);
1970   auto iter = partial_target_map_.find(partial_node);
1971   if (iter != partial_target_map_.end()) {
1972     return iter->second;
1973   }
1974   auto partial_cnode = partial_node->cast<CNodePtr>();
1975   MS_EXCEPTION_IF_NULL(partial_cnode);
1976   auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex));
1977   MS_EXCEPTION_IF_NULL(partial_graph);
1978   auto parameters = partial_graph->parameters();
1979   auto partial_inputs = partial_cnode->inputs();
1980   const size_t kNonParameterNum = 2;
1981   if (parameters.size() + kNonParameterNum != partial_inputs.size()) {
1982     return kMixTarget;
1983   }
1984   for (size_t i = 0; i < parameters.size(); ++i) {
1985     partial_parameters_map_[parameters[i]] = partial_inputs[kNonParameterNum + i];
1986   }
1987   auto graph_nodes = TopoSort(partial_graph->get_return());
1988   std::string graph_target = kNoTarget;
1989   for (auto &node : graph_nodes) {
1990     if (!node->isa<CNode>()) {
1991       continue;
1992     }
1993     if (!AnfAlgo::IsRealKernel(node)) {
1994       continue;
1995     }
1996     std::string cur_target = GetCNodeTarget(node);
1997     if (graph_target == kNoTarget) {
1998       graph_target = cur_target;
1999     }
2000     if (graph_target != cur_target) {
2001       graph_target = kMixTarget;
2002       break;
2003     }
2004   }
2005   (void)partial_target_map_.emplace(std::pair<AnfNodePtr, std::string>(partial_node, graph_target));
2006   return graph_target;
2007 }
2008 
HandleInternalOutput(const AnfNodePtr & input_front_node,const AnfNodePtr & backend_node,const FuncGraphManagerPtr & front_func_graph_manager,const std::shared_ptr<KernelGraph> & backend_graph)2009 void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node,
2010                                         const FuncGraphManagerPtr &front_func_graph_manager,
2011                                         const std::shared_ptr<KernelGraph> &backend_graph) {
2012   auto front_node = GetSupportedInternalNode(input_front_node);
2013   if (front_node == nullptr) {
2014     return;
2015   }
2016   auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0);
2017   auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0);
2018   auto backend_real_kernel = backend_real_kernel_pair.first;
2019   if (backend_real_kernel == nullptr || !backend_real_kernel->isa<CNode>()) {
2020     return;
2021   }
2022   auto front_real_kernel = front_real_kernel_pair.first;
2023   std::string kernel_target = GetCNodeTarget(front_real_kernel);
2024   bool internal_output = CNodeFirstInputIsPrimitive(front_real_kernel);
2025   bool unique_target = true;
2026   if (internal_output && opt::IsNopNode(front_real_kernel)) {
2027     auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0);
2028     auto pre_node_target = GetCNodeTarget(pre_node_pair.first);
2029     if (pre_node_target != kernel_target) {
2030       unique_target = false;
2031     }
2032   }
2033   if (internal_output) {
2034     auto users = ExtendNodeUsers(front_func_graph_manager, front_node);
2035     for (auto &user : users) {
2036       if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial) && kernel_target != kGPUDevice &&
2037           !ExistGraphCaller(user)) {
2038         auto partial_target = AddPartialParametersMap(user);
2039         if (partial_target != kNoTarget && partial_target != kernel_target) {
2040           unique_target = false;
2041         }
2042         continue;
2043       }
2044       if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimUpdateState)) {
2045         continue;
2046       }
2047       if (!CNodeFirstInputIsPrimitive(user)) {
2048         internal_output = false;
2049         break;
2050       }
2051       if (!AnfAlgo::IsRealKernel(user)) {
2052         internal_output = false;
2053         break;
2054       }
2055       if (kernel_target != GetCNodeTarget(user)) {
2056         unique_target = false;
2057       }
2058     }
2059   }
2060   if (internal_output) {
2061     MS_LOG(INFO) << "AddInternalOutput: " << front_node->DebugString() << " To " << backend_real_kernel->DebugString()
2062                  << ", unique_target: " << unique_target;
2063     backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target);
2064   }
2065 }
2066 
ConstructOutput(const AnfNodePtrList & outputs,const std::shared_ptr<KernelGraph> & graph)2067 CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
2068   MS_EXCEPTION_IF_NULL(graph);
2069   std::vector<AnfNodePtr> output_args;
2070   for (const auto &output : outputs) {
2071     MS_EXCEPTION_IF_NULL(output);
2072     MS_LOG(INFO) << "Output:" << output->DebugString();
2073   }
2074   auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr {
2075     auto backend_anf = graph->GetBackendAnfByFrontAnf(out);
2076     if (backend_anf != nullptr) {
2077       auto context_ptr = MsContext::GetInstance();
2078       MS_EXCEPTION_IF_NULL(context_ptr);
2079       if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
2080         return backend_anf;
2081       }
2082 
2083       MS_EXCEPTION_IF_NULL(out);
2084       auto out_func_graph = out->func_graph();
2085       MS_EXCEPTION_IF_NULL(out_func_graph);
2086       auto out_func_graph_manager = out_func_graph->manager();
2087       if (out_func_graph_manager == nullptr) {
2088         return backend_anf;
2089       }
2090       HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph);
2091       return backend_anf;
2092     }
2093     MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!";
2094   };
2095   output_args.push_back(NewValueNode(prim::kPrimMakeTuple));
2096   (void)std::transform(outputs.begin(), outputs.end(), std::back_inserter(output_args),
2097                        [&](const AnfNodePtr &out) -> AnfNodePtr { return FindEqu(out); });
2098   return graph->NewCNode(output_args);
2099 }
2100 
CreateOutputNode(const CNodePtr & cnode,const std::shared_ptr<KernelGraph> & graph)2101 void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) {
2102   std::vector<AnfNodePtr> make_tuple_inputs;
2103   make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
2104   MS_EXCEPTION_IF_NULL(graph);
2105   if (AnfRuntimeAlgorithm::GetOutputTensorNum(cnode) > 1) {
2106     for (size_t output_index = 0; output_index < AnfRuntimeAlgorithm::GetOutputTensorNum(cnode); output_index++) {
2107       auto idx = NewValueNode(SizeToLong(output_index));
2108       MS_EXCEPTION_IF_NULL(idx);
2109       auto imm = std::make_shared<Int64Imm>(output_index);
2110       idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
2111       auto getitem = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx});
2112       std::vector<TypeId> types = {AnfAlgo::GetOutputInferDataType(cnode, output_index)};
2113       std::vector<std::vector<size_t>> shapes = {AnfAlgo::GetOutputInferShape(cnode, output_index)};
2114       AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
2115       make_tuple_inputs.push_back(getitem);
2116     }
2117   } else {
2118     make_tuple_inputs.push_back(cnode);
2119   }
2120   // create output
2121   auto g_output = graph->NewCNode(make_tuple_inputs);
2122   graph->set_output(g_output);
2123 }
2124 
ConstructSingleOpGraph(const OpRunInfo & op_run_info,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<int64_t> & tensors_mask,bool is_ascend)2125 std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
2126                                                                   const std::vector<tensor::TensorPtr> &input_tensors,
2127                                                                   const std::vector<int64_t> &tensors_mask,
2128                                                                   bool is_ascend) {
2129   auto graph = std::make_shared<KernelGraph>();
2130   graph->set_graph_id(graph_sum_);
2131   graph_sum_++;
2132   std::vector<AnfNodePtr> inputs;
2133   // set input[0]
2134   PrimitivePtr op_prim = op_run_info.primitive;
2135   MS_EXCEPTION_IF_NULL(op_prim);
2136   // Decoupling of frontend PrimitivePy and backend Primitive
2137   inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(*op_prim)));
2138   // set input parameter
2139   if (input_tensors.size() != tensors_mask.size()) {
2140     MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size "
2141                       << tensors_mask.size();
2142   }
2143   for (size_t i = 0; i < input_tensors.size(); ++i) {
2144     if (tensors_mask[i] == kValueNodeTensorMask) {
2145       auto value_node = graph->NewValueNode(input_tensors[i]);
2146       inputs.push_back(value_node);
2147       continue;
2148     }
2149     auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]);
2150     inputs.push_back(parameter);
2151     auto mutable_inputs = graph->MutableInputs();
2152     MS_EXCEPTION_IF_NULL(mutable_inputs);
2153     mutable_inputs->push_back(parameter);
2154   }
2155   // set execution order
2156   auto cnode = graph->NewCNode(inputs);
2157   MS_EXCEPTION_IF_NULL(cnode);
2158   // set abstract,which include inferred shapes and types
2159   cnode->set_abstract(op_run_info.abstract);
2160   // get output dynamic shape info
2161   AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info.is_dynamic_shape), cnode);
2162   if (op_run_info.is_auto_mixed_precision) {
2163     AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info.next_op_name), cnode);
2164     AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info.next_input_index), cnode);
2165   }
2166   // set execution order
2167   std::vector<CNodePtr> exe_order = {cnode};
2168   graph->set_execution_order(exe_order);
2169   // set output
2170   if (is_ascend) {
2171     graph->set_output(cnode);
2172   } else {
2173     CreateOutputNode(cnode, graph);
2174   }
2175   graph->SetInputNodes();
2176   auto manager = MakeManager({graph});
2177   if (manager != nullptr) {
2178     manager->AddFuncGraph(graph);
2179     graph->set_manager(manager);
2180   }
2181   auto ms_context = MsContext::GetInstance();
2182   MS_EXCEPTION_IF_NULL(ms_context);
2183   if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
2184     UnifyMindIR(graph);
2185   }
2186   graph->UpdateGraphDynamicAttr();
2187   return graph;
2188 }
2189 
NewKernelGraph()2190 KernelGraphPtr SessionBasic::NewKernelGraph() {
2191   auto graph = std::make_shared<KernelGraph>();
2192   graph->set_graph_id(graph_sum_);
2193   graphs_[graph_sum_++] = graph;
2194   return graph;
2195 }
2196 
FindPullNode(const AnfNodePtr & push_node,const std::vector<AnfNodePtr> & node_list)2197 AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) {
2198   MS_EXCEPTION_IF_NULL(push_node);
2199   for (auto &node : node_list) {
2200     if (node != nullptr && node->isa<CNode>()) {
2201       for (auto input : node->cast<CNodePtr>()->inputs()) {
2202         if (push_node == AnfAlgo::VisitKernel(input, 0).first) {
2203           if (AnfAlgo::GetCNodeName(node) != kPullOpName) {
2204             MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
2205           }
2206           return node;
2207         }
2208       }
2209     }
2210   }
2211   return nullptr;
2212 }
2213 
CompileGraph(const GraphSegmentPtr & segment,const AnfNodePtrList & outputs)2214 GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) {
2215   MS_EXCEPTION_IF_NULL(executor_);
2216   return executor_->CompileGraph(shared_from_this(), segment, outputs);
2217 }
2218 
CompileGraph(NotNull<FuncGraphPtr> func_graph)2219 GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
2220   MS_EXCEPTION_IF_NULL(executor_);
2221   return executor_->CompileGraph(shared_from_this(), func_graph);
2222 }
2223 
BuildGraph(GraphId graph_id)2224 void SessionBasic::BuildGraph(GraphId graph_id) {
2225   MS_EXCEPTION_IF_NULL(executor_);
2226   executor_->BuildGraph(shared_from_this(), graph_id);
2227 }
2228 
RunOp(OpRunInfo * op_run_info,const GraphInfo & graph_info,std::vector<tensor::TensorPtr> * input_tensors,VectorRef * outputs,const std::vector<int64_t> & tensors_mask)2229 void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info,
2230                          std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
2231                          const std::vector<int64_t> &tensors_mask) {
2232   MS_EXCEPTION_IF_NULL(executor_);
2233   executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs, tensors_mask);
2234 }
2235 
RunOpsInGraph(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)2236 void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
2237                                  VectorRef *outputs) {
2238   MS_EXCEPTION_IF_NULL(executor_);
2239   executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs);
2240 }
2241 
RunGraph(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)2242 void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
2243   MS_EXCEPTION_IF_NULL(executor_);
2244   executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);
2245 }
2246 
RunGraphAsync(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)2247 void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
2248                                  VectorRef *outputs) {
2249   MS_EXCEPTION_IF_NULL(executor_);
2250   executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
2251 }
2252 
RunGraphImpl(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * const outputs)2253 void SessionBasic::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
2254                                 VectorRef *const outputs) {
2255   MS_LOG(INFO) << "Run graph start, graph id: " << graph_id;
2256   auto kernel_graph = GetGraph(graph_id);
2257   MS_EXCEPTION_IF_NULL(kernel_graph);
2258   // if none of child graph and no anf output exists
2259   if (!kernel_graph->executable()) {
2260     MS_LOG(INFO) << "No child graph has anf output";
2261     return;
2262   }
2263   PreExecuteGraph(kernel_graph, inputs, outputs);
2264   ExecuteGraph(kernel_graph);
2265   PostExecuteGraph(kernel_graph, inputs, outputs);
2266   MS_LOG(INFO) << "Run graph end, graph id: " << graph_id;
2267 }
2268 
RunOpsInGraphImpl(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)2269 void SessionBasic::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
2270                                      VectorRef *outputs) {
2271   MS_LOG(INFO) << "Clean task in Queue";
2272   session::PynativeTaskManager::GetInstance().ExecuteRemainingTasks();
2273   MS_LOG(INFO) << "Start!";
2274   auto kernel_graph = GetGraph(graph_id);
2275   MS_EXCEPTION_IF_NULL(kernel_graph);
2276   std::map<AnfNodePtr, size_t> parameter_index;
2277   GetParameterIndex(kernel_graph.get(), inputs, &parameter_index);
2278   GraphOutputInfo graph_output_info;
2279   graph_output_info.graph_outputs = outputs;
2280   CreateOutputPlaceholder(kernel_graph, inputs, graph_output_info.graph_outputs, &graph_output_info.output_indexes);
2281   std::map<KernelWithIndex, size_t> cnode_refcount;
2282   GetRefCount(kernel_graph.get(), &cnode_refcount);
2283   BuildOpsInGraph(graph_id, parameter_index, inputs, cnode_refcount);
2284 
2285   // Clear bucket resources every step
2286   if (kernel_graph->is_bprop()) {
2287     ClearAllBucket(graph_id);
2288   }
2289 
2290   std::map<KernelWithIndex, tensor::TensorPtr> op_output_map;
2291   for (const auto &kernel : kernel_graph->execution_order()) {
2292     // Generate input tensors, tensor masks and input kernel with index
2293     InputTensorInfo input_tensor_info;
2294     GetOpInputTensors(kernel, op_output_map, parameter_index, inputs, &input_tensor_info);
2295 
2296     // Get OpRunInfo and GraphInfo
2297     OpRunInfo run_info;
2298     GetSingleOpRunInfo(kernel, &run_info);
2299     GraphInfo graph_info = GetSingleOpGraphInfo(kernel, input_tensor_info.input_tensors);
2300 
2301     // Build and run current single op
2302     VectorRef op_outputs;
2303     RunOpImplOrigin(graph_info, &run_info, &input_tensor_info.input_tensors, &op_outputs,
2304                     input_tensor_info.input_tensors_mask);
2305     graph_output_info.graph_output_tensors.clear();
2306     // Handle inputs and outputs of current op
2307     HandleOpInputs(input_tensor_info.input_kernel, &cnode_refcount, &op_output_map);
2308     HandleOpOutputs(kernel, op_outputs, cnode_refcount, &op_output_map, &graph_output_info);
2309     // Save grad node to Bucket
2310     if (kernel_graph->is_bprop()) {
2311       AddGradAddrToBucket(graph_id, graph_output_info.graph_output_tensors);
2312     }
2313   }
2314   MS_LOG(INFO) << "Finish!";
2315 }
2316 
EraseValueNodeTensor(const std::vector<int64_t> & tensors_mask,std::vector<tensor::TensorPtr> * input_tensors) const2317 void SessionBasic::EraseValueNodeTensor(const std::vector<int64_t> &tensors_mask,
2318                                         std::vector<tensor::TensorPtr> *input_tensors) const {
2319   MS_EXCEPTION_IF_NULL(input_tensors);
2320   if (input_tensors->size() != tensors_mask.size()) {
2321     MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
2322                       << tensors_mask.size();
2323   }
2324   std::vector<tensor::TensorPtr> new_input_tensors;
2325   for (size_t index = 0; index < tensors_mask.size(); ++index) {
2326     if (tensors_mask[index] != kValueNodeTensorMask) {
2327       new_input_tensors.emplace_back(input_tensors->at(index));
2328     }
2329   }
2330   *input_tensors = new_input_tensors;
2331 }
2332 
UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> & all_graphs)2333 void SessionBasic::UpdateAllGraphDynamicShapeAttr(const std::vector<KernelGraphPtr> &all_graphs) {
2334   bool is_dynamic = false;
2335   for (const auto &graph : all_graphs) {
2336     UpdateGraphDynamicShapeAttr(NOT_NULL(graph));
2337     is_dynamic = graph->is_dynamic_shape() || is_dynamic;
2338   }
2339   if (is_dynamic && all_graphs.size() > 1) {
2340     MS_LOG(EXCEPTION)
2341       << "Dynamic shape is not supported with control flow(loop control statements and condition control statements).";
2342   }
2343 }
2344 
UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> & root_graph)2345 void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph) {
2346   for (const auto &cnode : root_graph->execution_order()) {
2347     if (AnfAlgo::IsNodeDynamicShape(cnode)) {
2348       AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
2349       MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
2350     }
2351   }
2352   root_graph->UpdateGraphDynamicAttr();
2353 }
2354 
IsGetNextGraph(const std::shared_ptr<KernelGraph> & kernel_graph,std::string * channel_name)2355 bool SessionBasic::IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name) {
2356   MS_EXCEPTION_IF_NULL(kernel_graph);
2357   for (const auto &kernel_node : kernel_graph->execution_order()) {
2358     auto kernel_name = AnfAlgo::GetCNodeName(kernel_node);
2359     if (kernel_name == kGetNextOpName) {
2360       auto prim = AnfAlgo::GetCNodePrimitive(kernel_node);
2361       MS_EXCEPTION_IF_NULL(prim);
2362       *channel_name = GetValue<std::string>(prim->GetAttr("shared_name"));
2363       return true;
2364     }
2365   }
2366   return false;
2367 }
2368 
RunOpRemoveNopNode(const KernelGraphPtr & kernel_graph) const2369 void SessionBasic::RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const {
2370   auto ms_context = MsContext::GetInstance();
2371   MS_EXCEPTION_IF_NULL(ms_context);
2372   if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
2373     opt::RemoveNopNode(kernel_graph.get());
2374   }
2375 }
2376 
RunOpHideNopNode(const KernelGraphPtr & kernel_graph)2377 void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) {
2378   auto ms_context = MsContext::GetInstance();
2379   MS_EXCEPTION_IF_NULL(ms_context);
2380   if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
2381     opt::HideNopNode(kernel_graph.get());
2382   }
2383 }
2384 
GetAllReduceSplitIndex()2385 std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
2386   auto ms_context = MsContext::GetInstance();
2387   MS_EXCEPTION_IF_NULL(ms_context);
2388   std::string group = GetCommWorldGroup();
2389   auto parallel_context = parallel::ParallelContext::GetInstance();
2390   MS_EXCEPTION_IF_NULL(parallel_context);
2391   // PyNative not support multi group allreduce
2392   group += "sum1";
2393   return parallel_context->GetAllReduceFusionSplitIndices(group);
2394 }
2395 
GetBpropGraphGradsCount(const KernelGraphPtr & graph)2396 uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
2397   return AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem}).size();
2398 }
2399 
SetGraphBpropAttr(const KernelGraphPtr & graph)2400 void SetGraphBpropAttr(const KernelGraphPtr &graph) {
2401   auto &execution_orders = graph->execution_order();
2402   if (std::any_of(execution_orders.begin(), execution_orders.end(),
2403                   [](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
2404     graph->set_is_bprop(true);
2405     MS_LOG(INFO) << "Match bprop graph";
2406   } else {
2407     graph->set_is_bprop(false);
2408   }
2409 }
2410 
GenerateBucketSizeList(const KernelGraphPtr & graph,const std::vector<uint32_t> & split_index)2411 std::vector<uint32_t> GenerateBucketSizeList(const KernelGraphPtr &graph, const std::vector<uint32_t> &split_index) {
2412   if (split_index.empty()) {
2413     auto grads_count = GetBpropGraphGradsCount(graph);
2414     if (grads_count == 0) {
2415       MS_LOG(EXCEPTION) << "Bprop graph has no grad";
2416     }
2417     return {grads_count};
2418   }
2419 
2420   std::vector<uint32_t> bucket_size_list;
2421   uint32_t old_index = 0;
2422   for (const auto &index : split_index) {
2423     if (old_index == 0) {
2424       bucket_size_list.emplace_back(index - old_index + 1);
2425     } else {
2426       bucket_size_list.emplace_back(index - old_index);
2427     }
2428     old_index = index;
2429   }
2430   return bucket_size_list;
2431 }
2432 
CheckSplitIndexValid(const vector<uint32_t> & split_index)2433 void CheckSplitIndexValid(const vector<uint32_t> &split_index) {
2434   uint32_t last = 0;
2435   for (size_t i = 0; i < split_index.size(); ++i) {
2436     if (split_index[i] <= last && i != 0) {
2437       MS_LOG(EXCEPTION) << "Invalid split index:" << split_index;
2438     }
2439     last = split_index[i];
2440   }
2441 }
2442 
PreProcessOnSplitIndex(const KernelGraphPtr & graph,vector<uint32_t> * split_index)2443 void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
2444   MS_EXCEPTION_IF_NULL(split_index);
2445   if (split_index->empty()) {
2446     return;
2447   }
2448 
2449   CheckSplitIndexValid(*split_index);
2450   // calculate split index num
2451   auto split_index_num = split_index->back();
2452   // obtain graph output tensor num
2453   auto grads_count = GetBpropGraphGradsCount(graph);
2454   if (split_index_num >= grads_count) {
2455     MS_LOG(WARNING) << "Invalid all_reduce_fusion_config:" << *split_index << " total grads count:" << grads_count
2456                     << ". All AllReduce operators will be fused into one.";
2457     split_index->clear();
2458     split_index->push_back(grads_count - 1);
2459   } else if (split_index_num < grads_count - 1) {
2460     split_index->push_back(grads_count - 1);
2461   }
2462 }
2463 
InitAllBucket(const KernelGraphPtr & graph,const device::DeviceContext * device_context)2464 void SessionBasic::InitAllBucket(const KernelGraphPtr &graph, const device::DeviceContext *device_context) {
2465   MS_EXCEPTION_IF_NULL(graph);
2466   MS_LOG(INFO) << "Init Bucket start, graph_id:" << graph->graph_id();
2467   auto ms_context = MsContext::GetInstance();
2468   MS_EXCEPTION_IF_NULL(ms_context);
2469   const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
2470   auto parallel_context = parallel::ParallelContext::GetInstance();
2471   MS_EXCEPTION_IF_NULL(parallel_context);
2472   auto parallel_mode = parallel_context->parallel_mode();
2473   if (!pynative_mode || parallel_mode != parallel::DATA_PARALLEL) {
2474     return;
2475   }
2476   SetGraphBpropAttr(graph);
2477 
2478   if (!graph->is_bprop()) {
2479     return;
2480   }
2481 
2482   std::vector<std::shared_ptr<device::Bucket>> bucket_list;
2483   // Create bucket for every split allreduce ops
2484   auto split_index = GetAllReduceSplitIndex();
2485   PreProcessOnSplitIndex(graph, &split_index);
2486   auto bucket_size_list = GenerateBucketSizeList(graph, split_index);
2487   uint32_t bucket_id = 0;
2488   for (const auto &bucket_size : bucket_size_list) {
2489     MS_LOG(INFO) << "Create new bucket:" << bucket_id << " size:" << bucket_size;
2490     std::shared_ptr<device::Bucket> bucket = nullptr;
2491     if (device_context != nullptr) {
2492       bucket = device_context->CreateBucket(bucket_id++, bucket_size);
2493     } else {
2494       bucket = CreateBucket(bucket_id++, bucket_size);
2495     }
2496     bucket_list.emplace_back(bucket);
2497   }
2498 
2499   auto bucket_ret = bucket_map_.try_emplace(graph->graph_id(), bucket_list);
2500   if (!bucket_ret.second) {
2501     MS_LOG(EXCEPTION) << "Duplicate bucket_map_ graph key:" << graph->graph_id();
2502   }
2503   // set all free bucket index to 0
2504   auto free_bucket_ret = free_bucket_id_map_.try_emplace(graph->graph_id(), 0);
2505   if (!free_bucket_ret.second) {
2506     MS_LOG(EXCEPTION) << "Duplicate free_bucket_id_map_ graph key:" << graph->graph_id();
2507   }
2508   MS_LOG(INFO) << "Init Bucket finish";
2509 }
2510 
AddGradAddrToBucket(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & grad_tensor)2511 void SessionBasic::AddGradAddrToBucket(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &grad_tensor) {
2512   auto parallel_context = parallel::ParallelContext::GetInstance();
2513   MS_EXCEPTION_IF_NULL(parallel_context);
2514   auto parallel_mode = parallel_context->parallel_mode();
2515   if (parallel_mode != parallel::DATA_PARALLEL) {
2516     return;
2517   }
2518 
2519   auto iter = bucket_map_.find(graph_id);
2520   if (iter == bucket_map_.end()) {
2521     MS_LOG(EXCEPTION) << "unknown graph id:" << graph_id;
2522   }
2523   auto &bucket_list = iter->second;
2524   auto free_bucket_iter = free_bucket_id_map_.find(graph_id);
2525   if (free_bucket_iter == free_bucket_id_map_.end()) {
2526     MS_LOG(EXCEPTION) << "unknown free graph id:" << graph_id;
2527   }
2528 
2529   auto free_bucket_index = free_bucket_iter->second;
2530   for (auto &tensor : grad_tensor) {
2531     if (free_bucket_index >= bucket_list.size()) {
2532       MS_LOG(EXCEPTION) << "Invalid free bucket id:" << free_bucket_iter->second
2533                         << " total bucket num:" << bucket_list.size();
2534     }
2535     auto &free_bucket = bucket_list[free_bucket_index];
2536     free_bucket->AddGradTensor(tensor);
2537     if (free_bucket->full()) {
2538       MS_LOG(INFO) << "bucket is full";
2539       free_bucket->Launch();
2540       free_bucket_index = ++free_bucket_iter->second;
2541       MS_LOG(INFO) << "new free bucket:" << free_bucket_index;
2542     }
2543   }
2544 }
2545 
ClearAllBucket(const GraphId & graph_id)2546 void SessionBasic::ClearAllBucket(const GraphId &graph_id) {
2547   auto iter = bucket_map_.find(graph_id);
2548   if (iter != bucket_map_.end()) {
2549     auto bucket_list = iter->second;
2550     for (auto &bucket : bucket_list) {
2551       MS_LOG(INFO) << "Clear bucket:" << bucket->id();
2552       bucket->Release();
2553     }
2554   }
2555   auto free_iter = free_bucket_id_map_.find(graph_id);
2556   if (free_iter != free_bucket_id_map_.end()) {
2557     free_iter->second = 0;
2558   }
2559 }
2560 
FinalOptimize(const KernelGraphPtr & graph) const2561 void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const {
2562   MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id();
2563   opt::CommonFinalOptimization(graph);
2564   MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id();
2565 }
2566 
DumpGraph(const std::shared_ptr<KernelGraph> & kernel_graph)2567 void SessionBasic::DumpGraph(const std::shared_ptr<KernelGraph> &kernel_graph) {
2568 #ifdef ENABLE_DUMP_IR
2569   auto context_ptr = MsContext::GetInstance();
2570   MS_EXCEPTION_IF_NULL(context_ptr);
2571   bool save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
2572   if (save_graphs) {
2573     DumpIR("graph_build_" + std::to_string(kernel_graph->graph_id()) + ".ir", kernel_graph, true, kWholeStack);
2574     DumpIRProto(kernel_graph, "vm_build_" + std::to_string(kernel_graph->graph_id()));
2575     DumpIR("trace_code_graph", kernel_graph, true, kWholeStack);
2576   }
2577 #endif
2578 }
2579 
UnifyMindIR(const KernelGraphPtr & graph)2580 void SessionBasic::UnifyMindIR(const KernelGraphPtr &graph) { opt::CommonUnifyMindIROptimization(graph); }
2581 
2582 #if ((defined ENABLE_CPU) && (!defined _WIN32))
InitPsWorker(const KernelGraphPtr & kernel_graph)2583 void SessionBasic::InitPsWorker(const KernelGraphPtr &kernel_graph) {
2584   if (!ps::PSContext::instance()->is_worker()) {
2585     return;
2586   }
2587   CheckPSModeConsistence(kernel_graph);
2588   if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
2589     if (!ps::ps_cache_instance.initialized_ps_cache()) {
2590       auto context_ptr = MsContext::GetInstance();
2591       MS_EXCEPTION_IF_NULL(context_ptr);
2592       auto devcie_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2593       auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(devcie_target, device_id_);
2594       MS_EXCEPTION_IF_NULL(runtime_instance);
2595       auto context = runtime_instance->context();
2596       const auto &kernels = kernel_graph->execution_order();
2597       if (kernels.size() > 0 && AnfAlgo::GetCNodeName(kernels[0]) == "InitDataSetQueue") {
2598         GetBatchElements(kernels[0]);
2599         ps::ps_cache_instance.Initialize();
2600       }
2601       ps::ps_cache_instance.DoProcessData(device_id_, context);
2602     }
2603   } else {
2604     // Assign parameter keys.
2605     AssignParamKey(kernel_graph);
2606   }
2607 }
2608 
GetBatchElements(const AnfNodePtr & kernel_node) const2609 void SessionBasic::GetBatchElements(const AnfNodePtr &kernel_node) const {
2610   auto shapes = AnfAlgo::GetNodeAttr<std::vector<std::vector<int64_t>>>(kernel_node, "shapes");
2611   auto types = AnfAlgo::GetNodeAttr<std::vector<TypePtr>>(kernel_node, "types");
2612   if (shapes.size() != types.size() || shapes.size() == 0 || types.size() == 0) {
2613     MS_LOG(EXCEPTION) << "Invalid shapes of op[InitDataSetQueue]: shapes size " << shapes.size() << ", types size "
2614                       << types;
2615   }
2616   size_t batch_elements = 1;
2617   const auto &shape = shapes[0];
2618   for (size_t i = 0; i < shape.size(); ++i) {
2619     batch_elements *= LongToSize(shape[i]);
2620   }
2621   ps::ps_cache_instance.set_batch_elements(batch_elements);
2622 }
2623 
CheckPSModeConsistence(const KernelGraphPtr & kernel_graph) const2624 void SessionBasic::CheckPSModeConsistence(const KernelGraphPtr &kernel_graph) const {
2625   auto input_nodes = kernel_graph->inputs();
2626   for (const auto &input_node : input_nodes) {
2627     if (!input_node->isa<Parameter>()) {
2628       continue;
2629     }
2630     auto pk_node = input_node->cast<ParameterPtr>();
2631     MS_EXCEPTION_IF_NULL(pk_node);
2632     auto param_info_ptr = pk_node->param_info();
2633     const std::string &param_name = pk_node->fullname_with_scope();
2634     if (param_info_ptr != nullptr && param_info_ptr->init_in_server() &&
2635         !ps::ps_cache_instance.IsHashTable(param_name)) {
2636       MS_LOG(EXCEPTION) << "Can not initialize the parameter[" << param_name
2637                         << "] in server, this parameter is used by kernel which executes in device";
2638     }
2639   }
2640 }
2641 
AssignParamKey(const KernelGraphPtr & kernel_graph)2642 void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
2643   MS_EXCEPTION_IF_NULL(kernel_graph);
2644   // PS embeddingLookup cache check.
2645   if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
2646     MS_LOG(EXCEPTION) << "The other parameter can't set ps mode when the embeddingLookup cache is enabled in "
2647                          "parameter server training mode.";
2648   }
2649   std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph->get_return());
2650   for (auto &node : node_list) {
2651     if (node != nullptr && node->isa<CNode>()) {
2652       // Assign key for forward kernel EmbeddingLookup.
2653       // The key will be assigned to embedding table ande Push kernel as well.
2654       if (AnfAlgo::GetCNodeName(node) == kEmbeddingLookupOpName) {
2655         size_t embedding_table_idx = 0;
2656         auto embedding_table = AnfAlgo::GetInputNode(node->cast<CNodePtr>(), embedding_table_idx);
2657         size_t key = ps::Worker::GetInstance().SetParamKey(embedding_table->fullname_with_scope());
2658         AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
2659       } else if (AnfAlgo::GetCNodeName(node) == kPushOpName) {
2660         auto pull_node = FindPullNode(node, node_list);
2661         if (!pull_node) {
2662           MS_LOG(EXCEPTION) << "Assigning parameter key failed: can't find Pull node of the Push node.";
2663         }
2664 
2665         // Second input of Pull node is the trainable parameter.
2666         size_t parameter_index = 1;
2667         auto parameter_node = AnfAlgo::GetInputNode(pull_node->cast<CNodePtr>(), parameter_index);
2668         size_t key = ps::Worker::GetInstance().SetParamKey(parameter_node->fullname_with_scope());
2669         AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), node);
2670         AnfAlgo::SetNodeAttr(kAttrPsKey, MakeValue(key), pull_node);
2671 
2672         std::string optimizer_name = AnfAlgo::GetNodeAttr<std::string>(node, kAttrOptimizerType);
2673         ps::Worker::GetInstance().SetKeyOptimId(key, optimizer_name);
2674       }
2675     }
2676   }
2677 }
2678 
InitPSParamAndOptim(const KernelGraphPtr & kernel_graph,const std::vector<tensor::TensorPtr> & inputs_const)2679 void SessionBasic::InitPSParamAndOptim(const KernelGraphPtr &kernel_graph,
2680                                        const std::vector<tensor::TensorPtr> &inputs_const) {
2681   if (!ps::PSContext::instance()->is_worker()) {
2682     return;
2683   }
2684   std::vector<tensor::TensorPtr> inputs(inputs_const);
2685   MS_EXCEPTION_IF_NULL(kernel_graph);
2686   auto input_nodes = kernel_graph->inputs();
2687   auto ms_context = MsContext::GetInstance();
2688   MS_EXCEPTION_IF_NULL(ms_context);
2689   for (size_t i = 0; i < inputs.size(); ++i) {
2690     auto tensor = inputs[i];
2691     MS_EXCEPTION_IF_NULL(tensor);
2692     auto input_node = input_nodes[i];
2693     MS_EXCEPTION_IF_NULL(input_node);
2694     if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
2695       ps::Worker::GetInstance().InitPSParamAndOptim(input_node, tensor);
2696     }
2697   }
2698 }
2699 #endif
2700 }  // namespace session
DumpGraphExeOrder(const std::string & file_name,const std::string & target_dir,const std::vector<CNodePtr> & execution_order)2701 void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir,
2702                        const std::vector<CNodePtr> &execution_order) {
2703   std::string file_path = target_dir + "/execution_order/" + file_name;
2704   auto realpath = Common::CreatePrefixPath(file_path);
2705   if (!realpath.has_value()) {
2706     MS_LOG(ERROR) << "Failed to get real path: [" << file_path << "] in dump graph execution order.";
2707     return;
2708   }
2709   file_path = realpath.value();
2710 
2711   ChangeFileMode(file_path, S_IWUSR);
2712   // write to csv file
2713   std::ofstream ofs(file_path);
2714   if (!ofs.is_open()) {
2715     MS_LOG(ERROR) << "Failed to open file [" << file_path
2716                   << "] in dump graph execution order, please check the file access permission and whether disk space "
2717                      "is available.";
2718     return;
2719   }
2720   ofs << "NodeExecutionOrder-FullNameWithScope\n";
2721   for (const CNodePtr &node : execution_order) {
2722     ofs << node->fullname_with_scope() << "\n";
2723   }
2724   ofs.close();
2725   // set file mode to read only by user
2726   ChangeFileMode(file_path, S_IRUSR);
2727 }
2728 
GetRankId()2729 uint32_t GetRankId() {
2730   uint32_t rank_id = 0;
2731   auto ms_context = MsContext::GetInstance();
2732   MS_EXCEPTION_IF_NULL(ms_context);
2733 
2734   std::string world_group;
2735   std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2736   if (backend == kAscendDevice) {
2737     world_group = kHcclWorldGroup;
2738   } else if (backend == kGPUDevice) {
2739     world_group = kNcclWorldGroup;
2740   } else {
2741     MS_LOG(ERROR) << "Invalid backend: " << backend;
2742     return rank_id;
2743   }
2744   if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
2745     MS_LOG(INFO) << "Failed to get rank id.";
2746   }
2747   return rank_id;
2748 }
2749 }  // namespace mindspore
2750