• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/session/session_basic.h"
17 
18 #include <algorithm>
19 #include <set>
20 #include <queue>
21 #include <utility>
22 #include <functional>
23 #include <unordered_map>
24 
25 #include "ops/ascend_op_name.h"
26 #include "ops/structure_op_name.h"
27 #include "ops/framework_op_name.h"
28 #include "ops/sequence_ops.h"
29 #include "utils/hash_map.h"
30 #include "ops/primitive_c.h"
31 #include "ir/manager.h"
32 #include "abstract/utils.h"
33 #include "kernel/common_utils.h"
34 #include "base/base_ref_utils.h"
35 #include "runtime/device/ms_device_shape_transfer.h"
36 #include "include/common/utils/config_manager.h"
37 #include "include/backend/anf_runtime_algorithm.h"
38 #include "include/common/utils/anfalgo.h"
39 #include "backend/common/session/executor_manager.h"
40 #include "backend/common/optimizer/common_backend_optimization.h"
41 #include "include/backend/optimizer/helper.h"
42 #include "include/backend/optimizer/op_adaptation_info_factory.h"
43 #include "runtime/device/kernel_runtime_manager.h"
44 #include "runtime/pynative/op_compiler.h"
45 #include "utils/ms_utils.h"
46 #include "ir/anf.h"
47 #include "ir/func_graph_cloner.h"
48 #include "include/common/utils/utils.h"
49 #include "include/common/debug/anf_ir_dump.h"
50 #include "include/common/debug/dump_proto.h"
51 #include "utils/file_utils.h"
52 #include "utils/trace_base.h"
53 #include "include/common/utils/parallel_context.h"
54 #include "kernel/oplib/oplib.h"
55 #if defined(__linux__) && defined(WITH_BACKEND)
56 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
57 #include "include/backend/distributed/ps/constants.h"
58 #include "include/backend/distributed/ps/util.h"
59 #include "include/backend/distributed/ps/ps_context.h"
60 #include "abstract/abstract_value.h"
61 #endif
62 #include "backend/common/session/session_factory.h"
63 #include "runtime/pynative/op_executor.h"
64 #ifdef ENABLE_DEBUGGER
65 #include "debug/tensor_load.h"
66 #include "debug/debugger/proto_exporter.h"
67 #endif
68 #include "include/backend/debug/debugger/proto_exporter.h"
69 #ifdef ENABLE_DUMP_IR
70 #include "debug/rdr/graph_exec_order_recorder.h"
71 #include "include/common/debug/rdr/recorder_manager.h"
72 #include "debug/rdr/graph_recorder.h"
73 #include "runtime/hardware/device_context_manager.h"
74 #endif
75 #ifndef ENABLE_SECURITY
76 #include "include/backend/debug/data_dump/dump_json_parser.h"
77 #include "include/backend/debug/data_dump/e2e_dump.h"
78 #endif
79 
80 namespace mindspore {
81 namespace session {
82 MS_REG_SESSION(kSessionBasic, SessionBasic);
83 
84 namespace {
85 constexpr int64_t kInvalidShape = -2;
IsPynativeMode()86 static bool IsPynativeMode() {
87   auto ms_context = MsContext::GetInstance();
88   MS_EXCEPTION_IF_NULL(ms_context);
89   return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
90 }
91 
GetNodeOutputTensorFromInputs(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors)92 BaseRef GetNodeOutputTensorFromInputs(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
93                                       const std::vector<tensor::TensorPtr> &input_tensors) {
94   auto &node = node_output_pair.first;
95   MS_EXCEPTION_IF_NULL(node);
96   if (HasAbstractMonad(node)) {
97     return std::make_shared<tensor::Tensor>(int64_t(0), kBool);
98   }
99   // if node is a value node, no need sync addr from device to host
100   if (node->isa<ValueNode>()) {
101     auto value_node = node->cast<ValueNodePtr>();
102     MS_EXCEPTION_IF_NULL(value_node);
103     return value_node->value();
104   }
105   if (IsPynativeMode()) {
106     return nullptr;
107   }
108   if (!node->isa<Parameter>()) {
109     return nullptr;
110   }
111   MS_EXCEPTION_IF_NULL(graph);
112   auto param_node = node->cast<ParameterPtr>();
113   if (param_node != nullptr && param_node->IsUsedByRealKernelInGraph(graph->graph_id())) {
114     return nullptr;
115   }
116   for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
117     if (input_idx >= input_tensors.size()) {
118       MS_LOG(EXCEPTION) << "Input idx:" << input_idx << " is out of range:" << input_tensors.size();
119     }
120     if (graph->inputs()[input_idx] == node) {
121       return input_tensors[input_idx];
122     }
123   }
124   return nullptr;
125 }
126 
CreateNodeOutputTensor(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node)127 BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
128                                const std::vector<tensor::TensorPtr> &input_tensors,
129                                std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
130   auto &node = node_output_pair.first;
131   size_t output_index = node_output_pair.second;
132   MS_EXCEPTION_IF_NULL(node);
133   MS_EXCEPTION_IF_NULL(graph);
134   auto tensor_from_input = GetNodeOutputTensorFromInputs(node_output_pair, graph, input_tensors);
135   if (tensor_from_input != nullptr) {
136     return tensor_from_input;
137   }
138   TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
139   if (type_id == kTypeUnknown) {
140     type_id = common::AnfAlgo::GetOutputInferDataType(node, output_index);
141   }
142 
143   auto shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
144   if (common::AnfAlgo::IsDynamicShape(node)) {
145     auto max_shape = common::AnfAlgo::GetOutputMaxShape(node, output_index);
146     if (abstract::ShapeSize(max_shape) > abstract::ShapeSize(shape)) {
147       shape = max_shape;
148     }
149   }
150   tensor::TensorPtr tensor;
151   bool is_internal_output = graph->IsInternalOutput(node, output_index);
152   if (is_internal_output) {
153     tensor = graph->GetInternalOutputTensor(node, output_index);
154     if (tensor == nullptr) {
155       tensor = std::make_shared<tensor::Tensor>(type_id, shape);
156       graph->AddInternalOutputTensor(node, output_index, tensor);
157     }
158   } else {
159     tensor = std::make_shared<tensor::Tensor>(type_id, shape);
160   }
161   MS_EXCEPTION_IF_NULL(tensor);
162   if (is_internal_output) {
163     tensor->set_sync_status(kNoNeedSync);
164   } else {
165     // if in pynative mode,data only copied to host when user want to print data
166     auto ms_context = MsContext::GetInstance();
167     MS_EXCEPTION_IF_NULL(ms_context);
168     if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
169         ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
170       tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
171     } else {
172       tensor->set_sync_status(kNeedSyncDeviceToHost);
173     }
174   }
175   tensor->SetIsGraphOutput();
176   (*tensor_to_node)[tensor] = node_output_pair;
177   return tensor;
178 }
179 
GetOpRunDeviceTarget(const PrimitivePtr & op_prim)180 std::string GetOpRunDeviceTarget(const PrimitivePtr &op_prim) {
181   auto ms_context = MsContext::GetInstance();
182   MS_EXCEPTION_IF_NULL(ms_context);
183   const std::string &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
184 
185   MS_EXCEPTION_IF_NULL(op_prim);
186   const auto &attr_map = op_prim->attrs();
187   auto iter = attr_map.find(kAttrPrimitiveTarget);
188   if (iter != attr_map.end()) {
189     return GetValue<std::string>(iter->second);
190   }
191   return device_target;
192 }
193 
194 // Need to discard input tensor properties in heterogeneous scenarios.
195 // For example, the format of device_address in input_tensor is 5D format,
196 // and it's invalid for CPU graph parameter.
NeedDiscardTensorProperties(const std::string & op_device_target,const device::DeviceAddressPtr & tensor_device_address)197 bool NeedDiscardTensorProperties(const std::string &op_device_target,
198                                  const device::DeviceAddressPtr &tensor_device_address) {
199   if (tensor_device_address == nullptr) {
200     return true;
201   }
202 
203   if (op_device_target == device::GetDeviceNameByType(tensor_device_address->GetDeviceType())) {
204     return false;
205   }
206   return true;
207 }
208 
ConstructRunOpParameter(const std::shared_ptr<KernelGraph> & graph,const tensor::BaseTensorPtr & input_tensor,const BackendOpRunInfoPtr & op_run_info,InputType input_type)209 ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
210                                      const tensor::BaseTensorPtr &input_tensor, const BackendOpRunInfoPtr &op_run_info,
211                                      InputType input_type) {
212   MS_EXCEPTION_IF_NULL(graph);
213   auto param = graph->NewParameter();
214   MS_EXCEPTION_IF_NULL(param);
215   if (input_type == InputType::kParameter) {
216     param->set_default_param(input_tensor);
217   }
218 
219   // set the kernel info of parameter
220   auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
221   MS_EXCEPTION_IF_NULL(input_tensor);
222   auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
223   if (NeedDiscardTensorProperties(op_run_info->base_op_run_info.device_target, device_address)) {
224     kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
225     TypeId param_init_data_type = common::AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
226     kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
227   } else {
228     kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
229     kernel_build_info_builder->SetOutputsReshapeType({device_address->padding_type()});
230     kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
231   }
232   if (input_tensor->isa<tensor::MapTensor>()) {
233     auto map_tensor = input_tensor->cast<tensor::MapTensorPtr>();
234     auto map_tensor_abs = std::make_shared<abstract::AbstractMapTensor>(map_tensor);
235     AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
236     param->set_abstract(map_tensor_abs);
237     return param;
238   }
239   AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
240   // construct abstract of parameter
241   auto type_of_tensor = input_tensor->Dtype();
242   std::shared_ptr<abstract::AbstractTensor> abstract;
243   // Base_shape_ptr is set in dynamic shape scenario, if nullptr, not dynamic shape
244   if (input_tensor->base_shape_ptr() != nullptr) {
245     abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, input_tensor->base_shape_ptr());
246   } else {
247     abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, input_tensor->shape());
248   }
249   param->set_abstract(abstract);
250   return param;
251 }
252 
DumpGraphOutput(const Any & any,size_t recurse_level=0)253 void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
254   MS_LOG(INFO) << "Graph outputs:";
255   const size_t max_deep = 10;
256   if (recurse_level > max_deep) {
257     MS_LOG(INFO) << "Recurse too deep";
258     return;
259   }
260   std::string tab_str;
261   for (size_t i = 0; i < recurse_level; i++) {
262     tab_str = tab_str.append("  ");
263   }
264   if (any.is<AnyList>()) {
265     (void)tab_str.append("{");
266     MS_LOG(INFO) << tab_str;
267     auto any_list = any.cast<AnyList>();
268     for (auto &it : any_list) {
269       DumpGraphOutput(it, recurse_level + 1);
270     }
271     (void)tab_str.append("}");
272     MS_LOG(INFO) << tab_str;
273   }
274   (void)tab_str.append(any.ToString());
275   MS_LOG(INFO) << tab_str;
276 }
277 
CreateNodeOutputPlaceholder(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<size_t> & indexes,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)278 BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
279                                     const std::vector<tensor::TensorPtr> &input_tensors,
280                                     const std::vector<size_t> &indexes,
281                                     std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
282   auto &node = node_output_pair.first;
283   MS_EXCEPTION_IF_NULL(node);
284   MS_EXCEPTION_IF_NULL(graph);
285   MS_EXCEPTION_IF_NULL(output_indexes);
286   MS_LOG(DEBUG) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second
287                 << "]";
288   // if node is a value node, no need sync addr from device to host
289   if (node->isa<ValueNode>()) {
290     auto value_node = node->cast<ValueNodePtr>();
291     MS_EXCEPTION_IF_NULL(value_node);
292     return value_node->value();
293   }
294   if (node->isa<Parameter>()) {
295     const auto &input_nodes = graph->input_nodes();
296     for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
297       if (input_idx >= input_tensors.size()) {
298         MS_LOG(EXCEPTION) << "Input idx:" << input_idx << " is out of range:" << input_tensors.size();
299       }
300       if (input_nodes[input_idx] == node) {
301         return input_tensors[input_idx];
302       }
303     }
304     MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
305   }
306   (*output_indexes)[node_output_pair].emplace_back(indexes);
307   BaseRef output_placeholder = std::make_shared<BaseRef>();
308   return output_placeholder;
309 }
310 
CreateNodeOutputPlaceholder(const AnfNodePtr & anf,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<size_t> & indexes,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)311 BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
312                                     const std::vector<tensor::TensorPtr> &input_tensors,
313                                     const std::vector<size_t> &indexes,
314                                     std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
315   MS_EXCEPTION_IF_NULL(anf);
316   MS_EXCEPTION_IF_NULL(output_indexes);
317   MS_LOG(DEBUG) << "Create placeholder for output[" << anf->DebugString() << "]";
318   auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(anf, 0);
319   MS_EXCEPTION_IF_NULL(item_with_index.first);
320   MS_LOG(DEBUG) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
321   // special handle for maketuple
322   if (common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
323     auto cnode = item_with_index.first->cast<CNodePtr>();
324     MS_EXCEPTION_IF_NULL(cnode);
325     VectorRef ret;
326     for (size_t i = 1; i < cnode->size(); ++i) {
327       std::vector<size_t> cur_index = indexes;
328       cur_index.emplace_back(i - 1);
329       auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes);
330       ret.push_back(out);
331     }
332     return ret;
333   }
334   // if is graph return nothing ,the function should return a null anylist
335   size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
336   if (size == 0) {
337     return VectorRef();
338   }
339   return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
340 }
341 
CheckInputTensorShape(const tensor::BaseTensorPtr & tensor,const CNodePtr & kernel,size_t input_index)342 void CheckInputTensorShape(const tensor::BaseTensorPtr &tensor, const CNodePtr &kernel, size_t input_index) {
343   MS_EXCEPTION_IF_NULL(tensor);
344   const auto &tensor_shape = tensor->shape();
345   const auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
346   if (tensor_shape.size() != input_shape.size()) {
347     MS_LOG(EXCEPTION) << "The input tensor's shape size: " << tensor_shape.size()
348                       << " is not equal to expected size: " << input_shape.size() << " for input[" << input_index
349                       << "] of kernel: " << common::AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
350   }
351   for (size_t i = 0; i < tensor_shape.size(); i++) {
352     if (tensor_shape[i] < 0 || (tensor_shape[i] != input_shape[i] && input_shape[i] >= 0)) {
353       MS_LOG(EXCEPTION) << "The input tensor's shape: " << tensor_shape
354                         << " is not equal to expected shape: " << input_shape << " for input[" << input_index
355                         << "] of kernel: " << common::AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
356     }
357   }
358 }
359 
is_param_scalar(const size_t & param_shape_size,const size_t & input_shape_size)360 bool is_param_scalar(const size_t &param_shape_size, const size_t &input_shape_size) {
361   if (param_shape_size == 1 && input_shape_size == 0) {
362     return true;
363   }
364   if (param_shape_size == 0 && input_shape_size == 1) {
365     return true;
366   }
367   return false;
368 }
369 
ConvertVectorRefOutputs(const VectorRef & op_outputs)370 ValuePtrList ConvertVectorRefOutputs(const VectorRef &op_outputs) {
371   ValuePtrList op_ouputs;
372   for (auto value : op_outputs.elements_) {
373     (void)op_ouputs.emplace_back(utils::cast<ValuePtr>(value));
374   }
375   return op_ouputs;
376 }
377 }  // namespace
378 
CreateNodeOutputTensors(const AnfNodePtr & anf,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node,KernelMapTensor * node_to_tensor)379 BaseRef SessionBasic::CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph,
380                                               const std::vector<tensor::TensorPtr> &input_tensors,
381                                               std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
382                                               KernelMapTensor *node_to_tensor) {
383   MS_EXCEPTION_IF_NULL(anf);
384   MS_EXCEPTION_IF_NULL(tensor_to_node);
385   MS_EXCEPTION_IF_NULL(node_to_tensor);
386   MS_LOG(DEBUG) << "Create tensor for output[" << anf->DebugString() << "]";
387   auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(anf, 0);
388   MS_EXCEPTION_IF_NULL(item_with_index.first);
389   MS_LOG(DEBUG) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
390   // special handle for maketuple
391   if (common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
392     auto cnode = item_with_index.first->cast<CNodePtr>();
393     MS_EXCEPTION_IF_NULL(cnode);
394     VectorRef ret;
395     for (size_t i = 1; i < cnode->size(); ++i) {
396       auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node, node_to_tensor);
397       (void)ret.emplace_back(out);
398     }
399     return ret;
400   }
401   // if is graph return nothing ,the function should return a null anylist
402   size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
403   if (size == 0) {
404     return VectorRef();
405   }
406 
407   //  The outputs of graph may have the same kernel node, no need to create new tensor.
408   const auto &iter = node_to_tensor->find(item_with_index);
409   if (iter != node_to_tensor->end()) {
410     return iter->second;
411   }
412 
413   const auto &tensor = CreateNodeOutputTensor(item_with_index, graph, input_tensors, tensor_to_node);
414   (*node_to_tensor)[item_with_index] = tensor;
415   return tensor;
416 }
417 
InitExecutor(const std::string & device_name,uint32_t device_id)418 void SessionBasic::InitExecutor(const std::string &device_name, uint32_t device_id) {
419   device_id_ = device_id;
420   context_ = std::make_shared<Context>(device_name, device_id);
421   executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);
422 }
423 
GetSingleOpRunInfo(const CNodePtr & cnode,const InputInfo & input_info,const GraphOutputInfo * const graph_output_info) const424 BackendOpRunInfoPtr SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const InputInfo &input_info,
425                                                      const GraphOutputInfo *const graph_output_info) const {
426   MS_EXCEPTION_IF_NULL(cnode);
427   auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
428   MS_EXCEPTION_IF_NULL(primitive);
429   const auto &abstract = cnode->abstract();
430   if (abstract == nullptr) {
431     MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
432   }
433   const auto &shape = abstract->BuildShape();
434   MS_EXCEPTION_IF_NULL(shape);
435 
436   std::vector<size_t> output_indexes;
437   bool is_gradient_out = false;
438   if (graph_output_info != nullptr) {
439     for (auto &item : graph_output_info->output_indexes) {
440       if (item.first.first == cnode) {
441         is_gradient_out = true;
442         (void)output_indexes.emplace_back(item.first.second);
443       }
444     }
445   }
446 
447   pynative::BaseOpRunInfo base_op_run_info;
448   base_op_run_info.is_mixed_precision_cast = false;
449   base_op_run_info.has_dynamic_output = shape->IsDynamic();
450   base_op_run_info.op_name = primitive->name();
451   base_op_run_info.next_op_name = std::string();
452   base_op_run_info.device_target = GetOpRunDeviceTarget(primitive);
453   base_op_run_info.next_input_index = 0;
454   base_op_run_info.expanded_input_values.clear();
455   for (auto const &value : input_info.input_values) {
456     base_op_run_info.expanded_input_values.emplace_back(value);
457   }
458   base_op_run_info.input_types = input_info.input_types;
459   base_op_run_info.abstract = abstract;
460   base_op_run_info.output_indexes = output_indexes;
461   return std::make_shared<BackendOpRunInfo>(base_op_run_info, primitive, false, is_gradient_out);
462 }
463 
GetParameterIndex(const KernelGraph * graph,const std::vector<tensor::TensorPtr> & inputs,std::map<AnfNodePtr,size_t> * parameter_index) const464 void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
465                                      std::map<AnfNodePtr, size_t> *parameter_index) const {
466   MS_EXCEPTION_IF_NULL(graph);
467   MS_EXCEPTION_IF_NULL(parameter_index);
468   size_t index = 0;
469   auto parallel_context = parallel::ParallelContext::GetInstance();
470   MS_EXCEPTION_IF_NULL(parallel_context);
471   auto parallel_mode = parallel_context->parallel_mode();
472   bool is_parallel_forward_jit =
473     !graph->has_flag(kFlagIsPynativeBpropGraph) &&
474     (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
475   for (const auto &input_node : graph->input_nodes()) {
476     auto params = common::AnfAlgo::GetAllOutput(input_node);
477     for (const auto &param : params) {
478       if (index >= inputs.size()) {
479         MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
480                           << ", input size: " << inputs.size();
481       }
482       const auto &input = inputs[index];
483       MS_EXCEPTION_IF_NULL(input);
484       MS_EXCEPTION_IF_NULL(param);
485       // Check shape of input and parameter
486       const auto &input_shape = input->shape();
487       const auto &param_shape = common::AnfAlgo::GetOutputInferShape(param, 0);
488       bool is_dynamic = param->Shape()->IsDynamic();
489       // Dynamic shape feed mode, shape is dynamic but max shape is ()
490       if (!is_dynamic || !param_shape.empty()) {
491         if (!is_parallel_forward_jit && input_shape.size() != param_shape.size()) {
492           // Infer shape is -2, which indicates that the shape cannot be infer currently
493           if (param_shape.size() == 1 && param_shape[0] == kInvalidShape) {
494             parameter_index->emplace(param, index++);
495             continue;
496           }
497           // Input is scalar. param shape will be [1], input shape will be []
498           if (is_param_scalar(param_shape.size(), input_shape.size())) {
499             parameter_index->emplace(param, index++);
500             continue;
501           }
502           MS_LOG(EXCEPTION) << "Shape size of input tensor(" << input_shape << ") and parameter(" << param_shape
503                             << ") are different, input index: " << index << ", parameter: " << param->DebugString();
504         }
505         for (size_t i = 0; i < input_shape.size(); i += 1) {
506           if (input_shape[i] < 0 || (!is_parallel_forward_jit && input_shape[i] != param_shape[i] && !is_dynamic)) {
507             MS_LOG(EXCEPTION) << "Input tensor shape(" << input_shape << ") and parameter shape(" << param_shape
508                               << ") are different, input index: " << index << ", parameter: " << param->DebugString();
509           }
510         }
511       }
512       parameter_index->emplace(param, index++);
513     }
514   }
515 }
516 
CreateOutputPlaceholder(const KernelGraphPtr & kernel_graph,const std::vector<tensor::TensorPtr> & input_tensors,VectorRef * const outputs,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes) const517 void SessionBasic::CreateOutputPlaceholder(
518   const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *const outputs,
519   std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) const {
520   MS_EXCEPTION_IF_NULL(kernel_graph);
521   MS_EXCEPTION_IF_NULL(outputs);
522   MS_EXCEPTION_IF_NULL(output_indexes);
523   auto anf_outputs = kernel_graph->outputs();
524   size_t index = 0;
525   for (auto &item : anf_outputs) {
526     MS_EXCEPTION_IF_NULL(item);
527     std::vector<size_t> indexes{index++};
528     (void)outputs->emplace_back(
529       CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes));
530   }
531 }
532 
GetRefCount(const KernelGraph * graph,std::map<KernelWithIndex,size_t> * ref_count) const533 void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) const {
534   MS_EXCEPTION_IF_NULL(graph);
535   for (const auto &kernel : graph->execution_order()) {
536     for (size_t i = 1; i < kernel->size(); i += 1) {
537       auto input = kernel->inputs()[i];
538       CalculateRefCount(input, ref_count);
539     }
540   }
541 }
542 
CalculateRefCount(const AnfNodePtr & node,std::map<KernelWithIndex,size_t> * ref_count) const543 void SessionBasic::CalculateRefCount(const AnfNodePtr &node, std::map<KernelWithIndex, size_t> *ref_count) const {
544   if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
545     auto kernel_with_index = common::AnfAlgo::VisitKernel(node, 0);
546     const auto &real_input = kernel_with_index.first;
547     if (real_input->isa<CNode>()) {
548       (*ref_count)[kernel_with_index] += 1;
549     }
550     return;
551   }
552   auto cnode = node->cast<CNodePtr>();
553   for (size_t i = 1; i < cnode->size(); ++i) {
554     auto input = cnode->input(i);
555     CalculateRefCount(input, ref_count);
556   }
557 }
558 
GetForwardOpOutputRefCount(const KernelGraph * graph,const std::vector<tensor::TensorPtr> & inputs,std::map<std::string,size_t> * forward_op_output_tensor_id,const std::map<AnfNodePtr,size_t> & parameter_index) const559 void SessionBasic::GetForwardOpOutputRefCount(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
560                                               std::map<std::string, size_t> *forward_op_output_tensor_id,
561                                               const std::map<AnfNodePtr, size_t> &parameter_index) const {
562   auto context_ptr = MsContext::GetInstance();
563   MS_EXCEPTION_IF_NULL(context_ptr);
564   // Cpu can not clear device address, because it's device address and host address is the same
565   if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice) {
566     return;
567   }
568   MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
569   for (const auto &kernel : graph->execution_order()) {
570     MS_EXCEPTION_IF_NULL(kernel);
571     const auto input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel);
572     for (size_t i = 1; i <= input_tensor_num; ++i) {
573       const auto &input = kernel->input(i);
574       auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
575       auto real_input = kernel_with_index.first;
576       MS_EXCEPTION_IF_NULL(real_input);
577       if (real_input->isa<ValueNode>()) {
578         const auto &value = GetValueNodeOutput(real_input, kernel_with_index.second);
579         if (value == nullptr || !value->isa<tensor::Tensor>()) {
580           continue;
581         }
582         auto tensor = value->cast<tensor::TensorPtr>();
583         if (tensor->is_forward_output()) {
584           (*forward_op_output_tensor_id)[tensor->id()] += 1;
585         }
586       } else if (real_input->isa<Parameter>()) {
587         // Forward op output use as sens, so need add reference
588         auto iter = parameter_index.find(real_input);
589         if (iter != parameter_index.end()) {
590           auto tensor = inputs[iter->second];
591           if (tensor->is_forward_output()) {
592             (*forward_op_output_tensor_id)[tensor->id()] += 1;
593           }
594         }
595       }
596     }
597   }
598   MS_LOG(DEBUG) << "Forward op output tensor in bprop graph size " << forward_op_output_tensor_id->size();
599 }
600 
ReleaseForwardOpOutput(const std::vector<ValuePtr> & input_values,std::map<std::string,size_t> * forward_op_output_tensor_id) const601 void SessionBasic::ReleaseForwardOpOutput(const std::vector<ValuePtr> &input_values,
602                                           std::map<std::string, size_t> *forward_op_output_tensor_id) const {
603   MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
604   for (const auto &value : input_values) {
605     auto tensor = value->cast<tensor::BaseTensorPtr>();
606     if (tensor == nullptr) {
607       continue;
608     }
609 
610     if (!tensor->is_forward_output()) {
611       continue;
612     }
613     auto it = forward_op_output_tensor_id->find(tensor->id());
614     if (it != forward_op_output_tensor_id->end()) {
615       if (--(it->second) == 0) {
616         tensor->set_device_address(nullptr);
617         forward_op_output_tensor_id->erase(it);
618       }
619     }
620   }
621 }
622 
HandleOpInputs(const std::set<KernelWithIndex> & input_kernel,std::map<KernelWithIndex,size_t> * ref_count,std::map<KernelWithIndex,tensor::BaseTensorPtr> * op_output_map) const623 void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
624                                   std::map<KernelWithIndex, size_t> *ref_count,
625                                   std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map) const {
626   MS_EXCEPTION_IF_NULL(ref_count);
627   MS_EXCEPTION_IF_NULL(op_output_map);
628   for (const auto &kernel_with_index : input_kernel) {
629     if (!kernel_with_index.first->isa<CNode>()) {
630       continue;
631     }
632 
633     // Release previous output
634     auto ref_iter = ref_count->find(kernel_with_index);
635     if (ref_iter == ref_count->end()) {
636       MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
637                         << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
638     }
639     // Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
640     ref_iter->second -= 1;
641     if (ref_iter->second != 0) {
642       continue;
643     }
644     ref_count->erase(ref_iter);
645     auto output_iter = op_output_map->find(kernel_with_index);
646     if (output_iter == op_output_map->end()) {
647       MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
648                         << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
649     }
650     op_output_map->erase(output_iter);
651   }
652 }
653 
HandleOpOutputs(const AnfNodePtr & kernel,const VectorRef & op_outputs,const std::map<KernelWithIndex,size_t> & ref_count,std::map<KernelWithIndex,tensor::BaseTensorPtr> * op_output_map,GraphOutputInfo * const graph_output_info) const654 void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
655                                    const std::map<KernelWithIndex, size_t> &ref_count,
656                                    std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map,
657                                    GraphOutputInfo *const graph_output_info) const {
658   MS_EXCEPTION_IF_NULL(kernel);
659   MS_EXCEPTION_IF_NULL(op_output_map);
660   MS_EXCEPTION_IF_NULL(graph_output_info);
661   MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
662   ValuePtrList output_values;
663   if (common::AnfAlgo::IsBpropCutOpExecInBackend(kernel)) {
664     output_values = ConvertVectorRefOutputs(op_outputs);
665   } else {
666     output_values = common::AnfAlgo::TransformVectorRefToMultiValue(op_outputs);
667   }
668   if (output_values.size() > op_outputs.size()) {
669     MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
670   }
671   size_t out_index = 0;
672   for (const auto &output_value : output_values) {
673     auto kernel_with_index = make_pair(kernel, out_index++);
674     auto output_tensor = output_value->cast<tensor::BaseTensorPtr>();
675     bool value_is_tensor = (output_tensor != nullptr);
676     if (ref_count.find(kernel_with_index) != ref_count.end() && value_is_tensor) {
677       (*op_output_map)[kernel_with_index] = output_tensor;
678     }
679     const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
680     if (iter == graph_output_info->output_indexes.end()) {
681       continue;
682     }
683     const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
684     for (const auto &ref_indexes : multiple_ref_indexes) {
685       size_t n = 0;
686       const VectorRef *cur_vector_ref = graph_output_info->graph_outputs;
687       for (; n < ref_indexes.size() - 1; n += 1) {
688         size_t index = ref_indexes.at(n);
689         if (index >= cur_vector_ref->size()) {
690           MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
691                             << cur_vector_ref->size();
692         }
693         const BaseRef &base_ref = (*cur_vector_ref)[index];
694         if (!utils::isa<VectorRef>(base_ref)) {
695           MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
696         }
697         cur_vector_ref = &utils::cast<VectorRef>(base_ref);
698       }
699       BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
700       tensor_ref = output_value;
701       if (value_is_tensor) {
702         (void)graph_output_info->graph_output_tensors.emplace_back(output_tensor);
703       }
704     }
705   }
706 }
707 
GetValueNodeOutput(const AnfNodePtr & node,size_t output_index) const708 ValuePtr SessionBasic::GetValueNodeOutput(const AnfNodePtr &node, size_t output_index) const {
709   MS_EXCEPTION_IF_NULL(node);
710   if (!node->isa<ValueNode>()) {
711     return nullptr;
712   }
713   auto value_node = node->cast<ValueNodePtr>();
714   MS_EXCEPTION_IF_NULL(value_node);
715   auto value = GetValueNode(value_node);
716   MS_EXCEPTION_IF_NULL(value);
717   if (value->isa<ValueTuple>()) {
718     auto value_tuple = value->cast<ValueTuplePtr>();
719     MS_EXCEPTION_IF_NULL(value_tuple);
720     if (value_tuple->value().empty()) {
721       // empty tuple
722       return value;
723     }
724     if (output_index >= value_tuple->size()) {
725       MS_LOG(EXCEPTION) << "Index " << output_index << "is out of value tuple range";
726     }
727     auto tensor_value = value_tuple->value()[output_index];
728     if (tensor_value->isa<tensor::Tensor>()) {
729       return tensor_value;
730     } else {
731       return value;
732     }
733   } else if (value->isa<tensor::Tensor>()) {
734     if (output_index != 0) {
735       MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << output_index;
736     }
737     return value;
738   } else if (value->isa<StringImm>()) {
739     auto value_string = GetValue<std::string>(value);
740     const ShapeVector shape = {1, SizeToLong(value_string.size())};
741     TensorPtr tensor = std::make_shared<Tensor>(kObjectTypeString, shape, value_string.data(), value_string.size());
742     MS_EXCEPTION_IF_NULL(tensor);
743     tensor->set_sync_status(kNeedSyncHostToDevice);
744     return tensor;
745   } else if (value->isa<tensor::CSRTensor>()) {
746     return value->cast<tensor::CSRTensorPtr>()->GetTensorAt(output_index);
747   } else if (value->isa<tensor::COOTensor>()) {
748     return value->cast<tensor::COOTensorPtr>()->GetTensorAt(output_index);
749   }
750 
751   return value;
752 }
753 
GetParameterOutputTensor(const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs) const754 TensorPtr SessionBasic::GetParameterOutputTensor(const AnfNodePtr &node,
755                                                  const std::map<AnfNodePtr, size_t> &parameter_index,
756                                                  const std::vector<tensor::TensorPtr> &graph_inputs) const {
757   MS_EXCEPTION_IF_NULL(node);
758   if (!node->isa<Parameter>()) {
759     return nullptr;
760   }
761   const auto &iter = parameter_index.find(node);
762   if (iter == parameter_index.end()) {
763     MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, parameter = " << node->DebugString();
764   }
765   const size_t index = iter->second;
766   if (index >= graph_inputs.size()) {
767     MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " << index
768                       << ", input tensor size = " << graph_inputs.size();
769   }
770   return graph_inputs[index];
771 }
772 
GetCNodeOutputTensor(const KernelWithIndex & kernel_with_index,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output) const773 tensor::BaseTensorPtr SessionBasic::GetCNodeOutputTensor(
774   const KernelWithIndex &kernel_with_index, const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output) const {
775   const auto &iter = op_output.find(kernel_with_index);
776   if (iter == op_output.end()) {
777     MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << kernel_with_index.first->DebugString();
778   }
779   return iter->second;
780 }
781 
GetConstValueDepend(const CNodePtr & cnode,std::set<int64_t> * const_input_attr_index) const782 void SessionBasic::GetConstValueDepend(const CNodePtr &cnode, std::set<int64_t> *const_input_attr_index) const {
783   MS_EXCEPTION_IF_NULL(cnode);
784   MS_EXCEPTION_IF_NULL(const_input_attr_index);
785   auto ms_context = MsContext::GetInstance();
786   MS_EXCEPTION_IF_NULL(ms_context);
787   auto device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
788   if (device_target != kAscendDevice) {
789     return;
790   }
791   *const_input_attr_index = abstract::GetValueDependArgIndices(cnode);
792   if (!const_input_attr_index->empty()) {
793     return;
794   }
795   auto op_name = common::AnfAlgo::GetCNodeName(cnode);
796   auto op_adaptation_info = opt::OpAdaptationInfoRegister::GetOpAdaptationInfo(op_name, kAscendDevice, true);
797   if (op_adaptation_info == nullptr) {
798     return;
799   }
800   if (op_adaptation_info->is_ascend_mindir()) {
801     auto input_to_attr_map = op_adaptation_info->input_attr_map();
802     for (const auto &input_attr_info : input_to_attr_map) {
803       (void)const_input_attr_index->insert(SizeToLong(input_attr_info.first));
804     }
805   }
806 }
807 
GetShapeFromTuple(const abstract::AbstractTuplePtr & tuple_abs,const size_t index)808 static inline BaseShapePtr GetShapeFromTuple(const abstract::AbstractTuplePtr &tuple_abs, const size_t index) {
809   MS_EXCEPTION_IF_NULL(tuple_abs);
810   const auto &elements = tuple_abs->elements();
811   if (!elements.empty()) {
812     auto tuple_abs_elem = elements[index];
813     MS_EXCEPTION_IF_NULL(tuple_abs_elem);
814     return tuple_abs_elem->GetShape();
815   }
816   // empty tuple
817   return tuple_abs->GetShape();
818 }
819 
GetOpInputTensors(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info) const820 void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
821                                      const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
822                                      const std::map<AnfNodePtr, size_t> &parameter_index,
823                                      const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info) const {
824   MS_EXCEPTION_IF_NULL(cnode);
825   MS_EXCEPTION_IF_NULL(input_info);
826   auto context = MsContext::GetInstance();
827   MS_EXCEPTION_IF_NULL(context);
828   std::set<int64_t> const_input_attr_index = {};
829   GetConstValueDepend(cnode, &const_input_attr_index);
830   const auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
831   for (size_t i = 1; i <= input_num; i += 1) {
832     const auto &input = cnode->input(i);
833     auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
834     auto real_input = kernel_with_index.first;
835     MS_EXCEPTION_IF_NULL(real_input);
836     ValuePtr input_value = nullptr;
837     if (real_input->isa<ValueNode>()) {
838       input_value = GetValueNodeOutput(real_input, kernel_with_index.second);
839       const auto &value_ptr = GetValueNode(real_input);
840       MS_EXCEPTION_IF_NULL(value_ptr);
841       auto is_value_node = value_ptr->isa<StringImm>();
842       if (!const_input_attr_index.empty()) {
843         is_value_node = (const_input_attr_index.count(SizeToLong(i - 1)) != 0);
844       }
845 
846       bool is_forward_output = false;
847       if (value_ptr->isa<tensor::Tensor>()) {
848         auto forward_tensor = value_ptr->cast<tensor::TensorPtr>();
849         if (forward_tensor->is_forward_output()) {
850           is_forward_output = true;
851         }
852       }
853 
854       if (common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, cnode)) {
855         auto is_tensor = input_value->isa<tensor::Tensor>();
856         (void)input_info->input_types.emplace_back(
857           ((is_value_node && !is_forward_output) || !is_tensor) ? InputType::kConstant : InputType::kOpOutput);
858       } else {
859         (void)input_info->input_types.emplace_back((is_value_node || !is_forward_output) ? InputType::kConstant
860                                                                                          : InputType::kOpOutput);
861       }
862     } else if (real_input->isa<Parameter>()) {
863       auto tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
864       MS_EXCEPTION_IF_NULL(tensor);
865       input_value = tensor;
866       input_info->input_types.emplace_back(tensor->is_parameter() ? InputType::kParameter : InputType::kInput);
867     } else if (real_input->isa<CNode>()) {
868       auto tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
869       MS_EXCEPTION_IF_NULL(tensor);
870       input_value = tensor;
871       if (common::AnfAlgo::IsBpropCutOpExecInBackend(real_input)) {
872         CheckInputTensorShape(tensor, cnode, i - 1);
873       }
874       input_info->input_kernel.insert(kernel_with_index);
875       input_info->input_types.emplace_back(tensor->is_parameter() ? InputType::kParameter : InputType::kOpOutput);
876     } else {
877       MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
878     }
879     MS_EXCEPTION_IF_NULL(input_value);
880     MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
881                   << real_input->fullname_with_scope() << "-" << kernel_with_index.second;
882     BaseShapePtr base_shape = nullptr;
883     auto real_input_abs = real_input->abstract();
884     MS_EXCEPTION_IF_NULL(real_input_abs);
885     if (real_input_abs->isa<abstract::AbstractTuple>()) {
886       auto tuple_abs = real_input_abs->cast<abstract::AbstractTuplePtr>();
887       base_shape = GetShapeFromTuple(tuple_abs, kernel_with_index.second);
888     } else {
889       base_shape = real_input_abs->BuildShape();
890     }
891     MS_EXCEPTION_IF_NULL(base_shape);
892     if (base_shape->IsDynamic()) {
893       // in this case, input_value must be a Tensor
894       auto tensor = input_value->cast<tensor::TensorPtr>();
895       MS_EXCEPTION_IF_NULL(tensor);
896       tensor->set_base_shape(base_shape);
897     }
898     (void)input_info->input_abs.emplace_back(real_input->abstract());
899     (void)input_info->input_values.emplace_back(input_value);
900   }
901 }
902 
GetOpInputTensorsFromCNode(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info) const903 void SessionBasic::GetOpInputTensorsFromCNode(const CNodePtr &cnode,
904                                               const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
905                                               const std::map<AnfNodePtr, size_t> &parameter_index,
906                                               const std::vector<tensor::TensorPtr> &graph_inputs,
907                                               InputInfo *input_info) const {
908   MS_EXCEPTION_IF_NULL(cnode);
909   MS_EXCEPTION_IF_NULL(input_info);
910   std::function<ValuePtr(const KernelWithIndex &)> fn = [&](const KernelWithIndex &kernel_with_index) -> ValuePtr {
911     auto real_input = kernel_with_index.first;
912     MS_EXCEPTION_IF_NULL(real_input);
913     ValuePtr input_value = nullptr;
914     if (real_input->isa<CNode>()) {
915       if (IsPrimitiveCNode(real_input, prim::kPrimMakeTuple)) {
916         const auto &c_make_tuple = real_input->cast<CNodePtr>();
917         ValuePtrList v_list;
918         for (size_t j = 1; j < c_make_tuple->size(); ++j) {
919           auto kernel_with_index_input = common::AnfAlgo::VisitKernel(c_make_tuple->input(j), 0);
920           (void)v_list.emplace_back(fn(kernel_with_index_input));
921           input_info->input_kernel.insert(kernel_with_index_input);
922         }
923         input_value = std::make_shared<ValueTuple>(v_list);
924       } else {
925         input_value = GetCNodeOutputTensor(kernel_with_index, op_output);
926         input_info->input_kernel.insert(kernel_with_index);
927       }
928     } else if (real_input->isa<ValueNode>()) {
929       input_value = GetValueNodeOutput(real_input, kernel_with_index.second);
930     } else if (real_input->isa<Parameter>()) {
931       auto tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
932       input_value = tensor;
933     } else {
934       MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
935     }
936     return input_value;
937   };
938 
939   const auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
940   input_info->input_values.resize(input_num);
941   input_info->input_abs.resize(input_num);
942   for (size_t i = 1; i <= input_num; ++i) {
943     const auto &input = cnode->input(i);
944     KernelWithIndex kernel_with_index;
945     // Pyboost tuple inputs can not plant, like op concat, addn, filln and so on
946     if (cnode->HasAttr(kAttrIsPyboostTupleInput)) {
947       kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(input, 0, false, {prim::kPrimMakeTuple});
948     } else {
949       kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
950     }
951     ValuePtr input_value = fn(kernel_with_index);
952     MS_EXCEPTION_IF_NULL(input_value);
953     MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
954                   << kernel_with_index.first->fullname_with_scope() << "-" << kernel_with_index.second;
955     input_info->input_values[i - 1] = input_value;
956     input_info->input_abs[i - 1] = kernel_with_index.first->abstract();
957   }
958 }
959 
GetOpInputTensorByIndex(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info,size_t input_index) const960 tensor::BaseTensorPtr SessionBasic::GetOpInputTensorByIndex(
961   const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
962   const std::map<AnfNodePtr, size_t> &parameter_index, const std::vector<tensor::TensorPtr> &graph_inputs,
963   InputInfo *input_info, size_t input_index) const {
964   MS_EXCEPTION_IF_NULL(cnode);
965   MS_EXCEPTION_IF_NULL(input_info);
966   if (input_index >= cnode->size() - 1) {
967     MS_LOG(EXCEPTION) << "Input index is out of range:" << cnode->size() << ",cnode:" << cnode->DebugString();
968   }
969 
970   const auto &input = cnode->input(input_index + 1);
971   auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
972   auto real_input = kernel_with_index.first;
973   MS_EXCEPTION_IF_NULL(real_input);
974 
975   if (real_input->isa<Parameter>()) {
976     return GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
977   } else if (real_input->isa<CNode>()) {
978     tensor::BaseTensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
979     if (common::AnfAlgo::IsBpropCutOpExecInBackend(real_input)) {
980       CheckInputTensorShape(tensor, cnode, input_index);
981     }
982     input_info->input_kernel.insert(kernel_with_index);
983     return tensor;
984   } else {
985     MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
986   }
987 }
988 
UpdateOutputs(const std::shared_ptr<KernelGraph> & kernel_graph,VectorRef * const outputs,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node) const989 void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
990                                  const std::vector<tensor::TensorPtr> &input_tensors,
991                                  std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const {
992   MS_EXCEPTION_IF_NULL(kernel_graph);
993   MS_EXCEPTION_IF_NULL(outputs);
994   MS_EXCEPTION_IF_NULL(tensor_to_node);
995   KernelMapTensor node_to_tensor;
996   auto anf_outputs = kernel_graph->outputs();
997   for (auto &item : anf_outputs) {
998     MS_EXCEPTION_IF_NULL(item);
999     MS_LOG(DEBUG) << "Update output[" << item->DebugString() << "]";
1000     (void)outputs->emplace_back(
1001       CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
1002   }
1003 
1004   auto ms_context = MsContext::GetInstance();
1005   MS_EXCEPTION_IF_NULL(ms_context);
1006   for (auto &item : *tensor_to_node) {
1007     auto &tensor = item.first;
1008     auto &node = item.second.first;
1009     auto &output_index = item.second.second;
1010     DeviceAddressPtr address = nullptr;
1011     if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
1012         ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1013       address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
1014     } else {
1015       address = AnfAlgo::GetMutableOutputAddr(node, output_index);
1016     }
1017     MS_EXCEPTION_IF_NULL(tensor);
1018     tensor->set_device_address(address);
1019     MS_LOG(DEBUG) << "Debug address: Output tensor obj " << tensor.get() << ", tensor id " << tensor->id()
1020                   << ", device address " << tensor->device_address().get();
1021     if (common::AnfAlgo::IsDynamicShape(node)) {
1022       const auto &updated_shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
1023       (void)tensor->set_shape(updated_shape);
1024     }
1025     if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
1026       tensor->data_sync(false);
1027       tensor->set_sync_status(kNeedSyncHostToDevice);
1028     }
1029   }
1030 }
1031 
CreateOutputTensors(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & input_tensors,VectorRef * outputs,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node,KernelMapTensor * node_to_tensor)1032 void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
1033                                        VectorRef *outputs,
1034                                        std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
1035                                        KernelMapTensor *node_to_tensor) {
1036   auto kernel_graph = GetGraph(graph_id);
1037   MS_EXCEPTION_IF_NULL(kernel_graph);
1038   MS_EXCEPTION_IF_NULL(outputs);
1039   MS_EXCEPTION_IF_NULL(tensor_to_node);
1040   auto anf_outputs = kernel_graph->outputs();
1041   for (auto &item : anf_outputs) {
1042     MS_EXCEPTION_IF_NULL(item);
1043     MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
1044     (void)outputs->emplace_back(
1045       CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, node_to_tensor));
1046   }
1047 }
1048 
UpdateOutputTensors(const VectorRef * outputs,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node,std::map<DeviceAddressPtr,DeviceAddressPtr> *)1049 void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
1050                                        const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
1051                                        std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
1052   auto context_ptr = MsContext::GetInstance();
1053   MS_EXCEPTION_IF_NULL(context_ptr);
1054   if (device::KernelRuntime::UseMemScheduler()) {
1055     return;
1056   }
1057   MS_EXCEPTION_IF_NULL(outputs);
1058   for (const auto &item : *outputs) {
1059     if (utils::isa<VectorRefPtr>(item)) {
1060       const auto &vector_ref = utils::cast<VectorRef>(item);
1061       std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
1062       UpdateOutputTensors(&vector_ref, tensor_to_node, &new_to_old_device_address);
1063     } else if (utils::isa<tensor::TensorPtr>(item)) {
1064       const auto &tensor = utils::cast<tensor::TensorPtr>(item);
1065       MS_EXCEPTION_IF_NULL(tensor);
1066       const auto &iter = tensor_to_node.find(tensor);
1067       if (iter != tensor_to_node.end()) {
1068         const auto &node = iter->second.first;
1069         const auto &output_index = iter->second.second;
1070         if (!AnfAlgo::OutputAddrExist(node, output_index, true)) {
1071           continue;
1072         }
1073         const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);
1074         tensor->set_device_address(address);
1075 
1076         if (common::AnfAlgo::IsDynamicShape(node)) {
1077           const auto &updated_shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
1078           (void)tensor->set_shape(updated_shape);
1079         }
1080       }
1081       if (tensor->NeedSyncDeviceToHostImmediately()) {
1082         tensor->data_sync(false);
1083         tensor->set_device_address(nullptr);
1084         tensor->set_sync_status(kNeedSyncHostToDevice);
1085       }
1086     }
1087   }
1088 }
1089 
GetModelInputsInfo(uint32_t graph_id,std::vector<tensor::TensorPtr> * inputs,std::vector<std::string> * inputs_name) const1090 void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
1091                                       std::vector<std::string> *inputs_name) const {
1092   MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
1093   auto kernel_graph = GetGraph(graph_id);
1094   MS_EXCEPTION_IF_NULL(kernel_graph);
1095   MS_EXCEPTION_IF_NULL(inputs);
1096   MS_EXCEPTION_IF_NULL(inputs_name);
1097   auto kernel_graph_inputs = kernel_graph->inputs();
1098   // find parameters of graph inputs
1099   for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
1100     if (!kernel_graph_inputs[i]->isa<Parameter>()) {
1101       MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
1102       continue;
1103     }
1104     auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
1105     if (!common::AnfAlgo::IsParameterWeight(parameter)) {
1106       auto input_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
1107       auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
1108       auto data_type = kernel_build_info->GetOutputDeviceType(0);
1109       auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
1110       (void)inputs->emplace_back(ms_tensor);
1111       (void)inputs_name->emplace_back(parameter->name());
1112     }
1113   }
1114 }
1115 
GetModelOutputsInfo(uint32_t graph_id,std::vector<tensor::TensorPtr> * outputs,std::vector<std::string> * output_names) const1116 void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
1117                                        std::vector<std::string> *output_names) const {
1118   std::vector<tensor::TensorPtr> inputs;
1119   std::vector<std::string> input_names;
1120   GetModelInputsInfo(graph_id, &inputs, &input_names);
1121 
1122   auto kernel_graph = GetGraph(graph_id);
1123   MS_EXCEPTION_IF_NULL(kernel_graph);
1124   MS_EXCEPTION_IF_NULL(outputs);
1125   MS_EXCEPTION_IF_NULL(output_names);
1126 
1127   VectorRef vector_outputs;
1128   std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
1129   KernelMapTensor node_to_tensor;
1130   auto anf_outputs = kernel_graph->outputs();
1131   for (auto &item : anf_outputs) {
1132     MS_EXCEPTION_IF_NULL(item);
1133     MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
1134     vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node, &node_to_tensor));
1135   }
1136   *outputs = TransformVectorRefToMultiTensor(vector_outputs);
1137   for (size_t i = 0; i < outputs->size(); i++) {
1138     (void)output_names->emplace_back("output" + std::to_string(i));
1139   }
1140 }
1141 
1142 #ifndef ENABLE_SECURITY
RegisterSummaryCallBackFunc(const CallBackFunc & callback)1143 void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
1144   MS_EXCEPTION_IF_NULL(callback);
1145   Summary::GetInstance().RegisterSummaryCallBackFunc(callback);
1146 }
1147 
RecurseSetSummaryNodesForAllGraphs(KernelGraph * graph)1148 void SessionBasic::RecurseSetSummaryNodesForAllGraphs(KernelGraph *graph) {
1149   MS_EXCEPTION_IF_NULL(graph);
1150   MS_LOG(INFO) << "Recurse set summary nodes for all graphs in graph: " << graph->graph_id() << " start";
1151   Summary::GetInstance().RecurseSetSummaryNodesForAllGraphs(graph);
1152 }
1153 
SetSummaryNodes(KernelGraph * graph)1154 void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
1155   MS_LOG(DEBUG) << "Update summary Start";
1156   MS_EXCEPTION_IF_NULL(graph);
1157   Summary::GetInstance().SetSummaryNodes(graph);
1158 }
1159 
Summary(KernelGraph * graph)1160 void SessionBasic::Summary(KernelGraph *graph) {
1161   MS_EXCEPTION_IF_NULL(graph);
1162   static bool is_first = true;
1163   if (is_first && !IsSupportSummary()) {
1164     is_first = false;
1165     MS_LOG(WARNING) << "The Summary operator can not collect data correctly. Detail: the data sink mode is used and the"
1166                        " sink size(in model.train() python api) is not equal to 1.";
1167   }
1168   Summary::GetInstance().SummaryTensor(graph);
1169 }
1170 #endif
1171 
CreateOutputNode(const CNodePtr & cnode,const std::shared_ptr<KernelGraph> & graph) const1172 void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) const {
1173   MS_EXCEPTION_IF_NULL(cnode);
1174   std::vector<AnfNodePtr> make_tuple_inputs;
1175   (void)make_tuple_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(*prim::kPrimMakeTuple)));
1176   MS_EXCEPTION_IF_NULL(graph);
1177   if (AnfAlgo::GetOutputElementNum(cnode) > 1) {
1178     for (size_t output_index = 0; output_index < AnfAlgo::GetOutputElementNum(cnode); output_index++) {
1179       auto idx = NewValueNode(SizeToLong(output_index));
1180       MS_EXCEPTION_IF_NULL(idx);
1181       auto imm = std::make_shared<Int64Imm>(output_index);
1182       idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
1183       auto getitem = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(*prim::kPrimTupleGetItem)), cnode, idx});
1184       std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(cnode, output_index)};
1185       auto shapes = {common::AnfAlgo::GetOutputInferShape(cnode, output_index)};
1186       common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
1187       (void)make_tuple_inputs.emplace_back(getitem);
1188     }
1189   } else {
1190     (void)make_tuple_inputs.emplace_back(cnode);
1191   }
1192   // create output
1193   auto g_output = graph->NewCNode(make_tuple_inputs);
1194   graph->set_output(g_output);
1195 }
1196 
ConstructSingleOpGraph(const BackendOpRunInfoPtr & op_run_info,const std::vector<ValuePtr> & input_values,const std::vector<InputType> & input_type)1197 std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const BackendOpRunInfoPtr &op_run_info,
1198                                                                   const std::vector<ValuePtr> &input_values,
1199                                                                   const std::vector<InputType> &input_type) {
1200   auto graph = NewPynativeKernelGraph();
1201   std::vector<AnfNodePtr> inputs;
1202   // set input[0]
1203   auto op_prim = op_run_info->op_prim;
1204   MS_EXCEPTION_IF_NULL(op_prim);
1205   // Decoupling of frontend PrimitivePy and backend Primitive
1206   auto new_prim = std::make_shared<Primitive>(*op_prim);
1207   if (op_run_info->base_op_run_info.use_dynamic_shape_process) {
1208     AnfAlgo::SetDynamicAttrToPrim(new_prim);
1209   }
1210   (void)inputs.emplace_back(std::make_shared<ValueNode>(new_prim));
1211   // set input parameter
1212   if (input_values.size() != input_type.size()) {
1213     MS_LOG(EXCEPTION) << "Input tensors size " << input_values.size() << " should be equal to tensors mask size "
1214                       << input_type.size();
1215   }
1216   for (size_t i = 0; i < input_values.size(); ++i) {
1217     if (input_type[i] == InputType::kConstant) {
1218       auto value_node = graph->NewValueNode(input_values[i]);
1219       (void)inputs.emplace_back(value_node);
1220       continue;
1221     }
1222     auto parameter =
1223       ConstructRunOpParameter(graph, input_values[i]->cast<tensor::BaseTensorPtr>(), op_run_info, input_type[i]);
1224     (void)inputs.emplace_back(parameter);
1225     auto mutable_inputs = graph->MutableInputs();
1226     MS_EXCEPTION_IF_NULL(mutable_inputs);
1227     (void)mutable_inputs->emplace_back(parameter);
1228   }
1229   // set execution order
1230   auto cnode = graph->NewCNode(inputs);
1231   MS_EXCEPTION_IF_NULL(cnode);
1232   auto is_mutable = common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, cnode);
1233   if (is_mutable) {
1234     graph->set_flag(kAttrMutableKernel, true);
1235   }
1236   // set abstract,which include inferred shapes and types
1237   cnode->set_abstract(op_run_info->base_op_run_info.abstract);
1238   common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info->base_op_run_info.has_dynamic_output),
1239                                cnode);
1240   if (op_run_info->base_op_run_info.is_mixed_precision_cast) {
1241     common::AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info->base_op_run_info.next_op_name), cnode);
1242     common::AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info->base_op_run_info.next_input_index),
1243                                  cnode);
1244   }
1245   // set execution order
1246   graph->set_execution_order({cnode});
1247   CreateOutputNode(cnode, graph);
1248   graph->SetInputNodes();
1249   auto manager = MakeManager({graph});
1250   if (manager != nullptr) {
1251     manager->AddFuncGraph(graph);
1252     graph->set_manager(manager);
1253   }
1254   auto ms_context = MsContext::GetInstance();
1255   MS_EXCEPTION_IF_NULL(ms_context);
1256   if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1257     UnifyMindIR(graph);
1258   }
1259   graph->UpdateGraphDynamicAttr();
1260   return graph;
1261 }
1262 
FindPullNode(const AnfNodePtr & push_node,const std::vector<AnfNodePtr> & node_list) const1263 AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) const {
1264   MS_EXCEPTION_IF_NULL(push_node);
1265   for (auto &node : node_list) {
1266     if (node != nullptr && node->isa<CNode>()) {
1267       for (auto input : node->cast<CNodePtr>()->inputs()) {
1268         if (push_node == common::AnfAlgo::VisitKernel(input, 0).first) {
1269           if (common::AnfAlgo::GetCNodeName(node) != kPullOpName) {
1270             MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
1271           }
1272           return node;
1273         }
1274       }
1275     }
1276   }
1277   return nullptr;
1278 }
1279 
CompileGraph(const GraphSegmentPtr & segment,const AnfNodePtrList & outputs)1280 GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) {
1281   MS_EXCEPTION_IF_NULL(executor_);
1282   return executor_->CompileGraph(shared_from_this(), segment, outputs);
1283 }
1284 
CompileGraph(NotNull<FuncGraphPtr> func_graph)1285 GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
1286   MS_EXCEPTION_IF_NULL(executor_);
1287   return executor_->CompileGraph(shared_from_this(), func_graph);
1288 }
1289 
BuildGraph(GraphId graph_id)1290 void SessionBasic::BuildGraph(GraphId graph_id) {
1291   MS_EXCEPTION_IF_NULL(executor_);
1292   executor_->BuildGraph(shared_from_this(), graph_id);
1293 }
1294 
RunGraph(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)1295 void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
1296   MS_EXCEPTION_IF_NULL(executor_);
1297   executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);
1298 }
1299 
RunGraphAsync(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)1300 void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
1301                                  VectorRef *outputs) {
1302   MS_EXCEPTION_IF_NULL(executor_);
1303   executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
1304 }
1305 
RunGraphImpl(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)1306 void SessionBasic::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
1307                                 VectorRef *outputs) {
1308   MS_LOG(INFO) << "Status record: start run graph. graph id: " << graph_id;
1309   auto kernel_graph = GetGraph(graph_id);
1310   MS_EXCEPTION_IF_NULL(kernel_graph);
1311   // if none of child graph and no anf output exists
1312   if (!kernel_graph->executable()) {
1313     MS_LOG(INFO) << "No child graph has anf output";
1314     return;
1315   }
1316   PreExecuteGraph(kernel_graph, inputs, outputs);
1317   ExecuteGraph(kernel_graph);
1318   PostExecuteGraph(kernel_graph, inputs, outputs);
1319   MS_LOG(INFO) << "Status record: end run graph. graph id: " << graph_id;
1320 }
1321 
ProcessInputTensorsForHeterogeneous(const std::string & cur_target,const std::vector<tensor::TensorPtr> & input_tensors) const1322 void SessionBasic::ProcessInputTensorsForHeterogeneous(const std::string &cur_target,
1323                                                        const std::vector<tensor::TensorPtr> &input_tensors) const {
1324   for (auto &tensor : input_tensors) {
1325     MS_EXCEPTION_IF_NULL(tensor);
1326     auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1327     if (device_address != nullptr) {
1328       if (device_address->GetDeviceType() != device::GetDeviceTypeByName(cur_target)) {
1329         tensor->data_sync();
1330         tensor->set_device_address(nullptr);
1331       }
1332     }
1333   }
1334 }
1335 
EraseValueNodeTensor(const std::vector<InputType> & input_types,std::vector<tensor::TensorPtr> * input_tensors) const1336 void SessionBasic::EraseValueNodeTensor(const std::vector<InputType> &input_types,
1337                                         std::vector<tensor::TensorPtr> *input_tensors) const {
1338   MS_EXCEPTION_IF_NULL(input_tensors);
1339   if (input_tensors->size() != input_types.size()) {
1340     MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size()
1341                       << " should be equal to tensors input type size " << input_types.size();
1342   }
1343   std::vector<tensor::TensorPtr> new_input_tensors;
1344   for (size_t index = 0; index < input_types.size(); ++index) {
1345     if (input_types[index] != InputType::kConstant) {
1346       (void)new_input_tensors.emplace_back(input_tensors->at(index));
1347     }
1348   }
1349   *input_tensors = new_input_tensors;
1350 }
1351 
IsGetNextGraph(const std::shared_ptr<KernelGraph> & kernel_graph,std::string * channel_name) const1352 bool SessionBasic::IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name) const {
1353   MS_EXCEPTION_IF_NULL(kernel_graph);
1354   for (const auto &kernel_node : kernel_graph->execution_order()) {
1355     auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
1356     if (kernel_name == kGetNextOpName) {
1357       auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
1358       MS_EXCEPTION_IF_NULL(prim);
1359       *channel_name = GetValue<std::string>(prim->GetAttr("shared_name"));
1360       return true;
1361     }
1362   }
1363   return false;
1364 }
1365 
RunOpRemoveNopNode(const KernelGraphPtr & kernel_graph) const1366 void SessionBasic::RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const {
1367   auto ms_context = MsContext::GetInstance();
1368   MS_EXCEPTION_IF_NULL(ms_context);
1369   if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1370     opt::RemoveNopNode(kernel_graph.get());
1371   }
1372 }
1373 
RunOpHideNopNode(const KernelGraphPtr & kernel_graph)1374 void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) {
1375   auto ms_context = MsContext::GetInstance();
1376   MS_EXCEPTION_IF_NULL(ms_context);
1377   if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1378     opt::HideNopNode(kernel_graph.get());
1379   }
1380 }
1381 
GetAllReduceSplitIndex()1382 std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
1383   auto ms_context = MsContext::GetInstance();
1384   MS_EXCEPTION_IF_NULL(ms_context);
1385   std::string group = GetCommWorldGroup();
1386   auto parallel_context = parallel::ParallelContext::GetInstance();
1387   MS_EXCEPTION_IF_NULL(parallel_context);
1388   // PyNative not support multi group allreduce
1389   group += "sum1";
1390   return parallel_context->GetAllReduceFusionSplitIndices(group);
1391 }
1392 
GetBpropGraphGradsCount(const KernelGraphPtr & graph)1393 uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
1394   auto outputs = common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
1395   MS_LOG(DEBUG) << "Get total graph output size:" << outputs.size();
1396   // The type of output is CNode or ValueNode.
1397   // There is no need to calculate grad if the type of output is not CNode.
1398   return static_cast<uint32_t>(std::count_if(outputs.begin(), outputs.end(), [](const AnfNodePtr &output) {
1399     return output != nullptr && output->isa<CNode>();
1400   }));
1401 }
1402 
SetGraphBpropAttr(const KernelGraphPtr & graph)1403 void SetGraphBpropAttr(const KernelGraphPtr &graph) {
1404   auto &execution_orders = graph->execution_order();
1405   if (std::any_of(execution_orders.begin(), execution_orders.end(),
1406                   [](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
1407     graph->set_flag(kFlagIsPynativeBpropGraph, true);
1408     MS_LOG(INFO) << "Match bprop graph";
1409   }
1410 }
1411 
CheckSplitIndexValid(const vector<uint32_t> & split_index)1412 void CheckSplitIndexValid(const vector<uint32_t> &split_index) {
1413   uint32_t last = 0;
1414   for (size_t i = 0; i < split_index.size(); ++i) {
1415     if (split_index[i] <= last && i != 0) {
1416       MS_LOG(EXCEPTION) << "Invalid split index:" << split_index;
1417     }
1418     last = split_index[i];
1419   }
1420 }
1421 
PreProcessOnSplitIndex(const KernelGraphPtr & graph,vector<uint32_t> * split_index)1422 void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
1423   MS_EXCEPTION_IF_NULL(split_index);
1424   if (split_index->empty()) {
1425     return;
1426   }
1427 
1428   CheckSplitIndexValid(*split_index);
1429   // calculate split index num
1430   auto split_index_num = split_index->back();
1431   // obtain graph output tensor num
1432   auto grads_count = GetBpropGraphGradsCount(graph);
1433   if (split_index_num >= grads_count) {
1434     MS_LOG(WARNING) << "The context configuration all_reduce_fusion_config's upper boundary value should be smaller "
1435                     << "than total grads count: " << grads_count << ", but got: " << *split_index
1436                     << ". Now all AllReduce operators will be fused into one AllReduce operator.";
1437     split_index->clear();
1438     split_index->push_back(grads_count - 1);
1439   } else if (split_index_num < grads_count - 1) {
1440     split_index->push_back(grads_count - 1);
1441   }
1442 }
1443 
FinalOptimize(const KernelGraphPtr & graph) const1444 void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const {
1445   MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id();
1446   opt::CommonFinalOptimization(graph);
1447   MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id();
1448 }
1449 
DumpGraphs(const std::vector<KernelGraphPtr> & graphs) const1450 void SessionBasic::DumpGraphs(const std::vector<KernelGraphPtr> &graphs) const {
1451 #ifdef ENABLE_DUMP_IR
1452   auto context_ptr = MsContext::GetInstance();
1453   MS_EXCEPTION_IF_NULL(context_ptr);
1454   bool save_graphs = context_ptr->CanDump(kIntroductory);
1455   auto &json_parser = DumpJsonParser::GetInstance();
1456   json_parser.Parse();
1457   if (!save_graphs && !json_parser.e2e_dump_enabled() && !json_parser.async_dump_enabled() &&
1458       !mindspore::RecorderManager::Instance().RdrEnable()) {
1459     return;
1460   }
1461   for (auto &graph : graphs) {
1462     MS_EXCEPTION_IF_NULL(graph);
1463 
1464     if (graph->memory_managed_by_ge()) {
1465       continue;
1466     }
1467 
1468     std::string name = "graph_build." + std::to_string(graph->graph_id());
1469     DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
1470     (void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph, dump_params, ".ir;.pb");
1471 
1472     auto &kernels = graph->execution_order();
1473     std::string exec_order_name = "graph_exec_order." + std::to_string(graph->graph_id());
1474     (void)mindspore::RDR::RecordGraphExecOrder(SUBMODULE_ID, exec_order_name, kernels);
1475     if (save_graphs) {
1476       std::string file_name = "graph_build_" + std::to_string(graph->graph_id()) + ".ir";
1477       DumpIR(file_name, graph, true, kWholeStack);
1478       DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id()));
1479       DumpIR("trace_code_graph", graph, true, kWholeStack);
1480     }
1481     std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1482     if (device_target != kAscendDevice) {
1483       // Here dump data only with Ascend.
1484       continue;
1485     }
1486     // If the new runtime is used, get rank_id from context via GetRankID(), else get rank_id from rank_id_.
1487     uint32_t rank_id = rank_id_;
1488     if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1489       rank_id = GetRankId();
1490     }
1491     std::string final_graph = "trace_code_graph_" + std::to_string(graph->graph_id());
1492     if (json_parser.e2e_dump_enabled() && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
1493       std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
1494       MS_LOG(INFO) << "Dump graph and exeorder for graph: " << graph->graph_id()
1495                    << ", root_graph_id: " << graph->root_graph_id() << ", rank_id: " << rank_id;
1496       std::string target_dir = root_dir + "/graphs";
1497       std::string cst_file_dir = GenerateDumpPath(graph->root_graph_id(), rank_id, true);
1498       std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
1499       DumpIRProtoWithSrcInfo(graph, final_graph, target_dir, kDebugWholeStack);
1500       if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1501         // Dump constant data for old runtime ascend.
1502         DumpConstantInfo(graph, cst_file_dir);
1503       }
1504       DumpIR("trace_code_graph", graph, true, kWholeStack, ir_file_path);
1505       DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(graph->graph_id()) + ".csv", root_dir,
1506                         graph->execution_order());
1507     }
1508   }
1509 #endif
1510 }
1511 }  // namespace session
DumpGraphExeOrder(const std::string & file_name,const std::string & target_dir,const std::vector<CNodePtr> & execution_order)1512 void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir,
1513                        const std::vector<CNodePtr> &execution_order) {
1514   std::string file_path = target_dir + "/execution_order/" + file_name;
1515   auto realpath = Common::CreatePrefixPath(file_path);
1516   if (!realpath.has_value()) {
1517     MS_LOG(ERROR) << "Failed to get real path: [" << file_path << "] in dump graph execution order.";
1518     return;
1519   }
1520   file_path = realpath.value();
1521 
1522   ChangeFileMode(file_path, S_IWUSR);
1523   // write to csv file
1524   std::ofstream ofs(file_path);
1525   if (!ofs.is_open()) {
1526     MS_LOG(ERROR) << "Failed to open file [" << file_path
1527                   << "] in dump graph execution order, please check the file access permission and whether disk space "
1528                      "is available.";
1529     return;
1530   }
1531   ofs << "NodeExecutionOrder-FullNameWithScope\n";
1532   for (const CNodePtr &node : execution_order) {
1533     ofs << node->fullname_with_scope() << "\n";
1534   }
1535   ofs.close();
1536   // set file mode to read only by user
1537   ChangeFileMode(file_path, S_IRUSR);
1538 }
1539 
GetRankId()1540 uint32_t GetRankId() {
1541   uint32_t rank_id = 0;
1542   auto ms_context = MsContext::GetInstance();
1543   MS_EXCEPTION_IF_NULL(ms_context);
1544 
1545   std::string world_group;
1546   std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1547   if (backend == kAscendDevice) {
1548     world_group = kHcclWorldGroup;
1549   } else if (backend == kGPUDevice) {
1550     world_group = kNcclWorldGroup;
1551   } else {
1552     MS_LOG(ERROR) << "Invalid backend: " << backend;
1553     return rank_id;
1554   }
1555   auto env_rank_id = common::GetEnv("RANK_ID");
1556   if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
1557     if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
1558       MS_LOG(INFO) << "Failed to get rank id.";
1559     }
1560   }
1561   return rank_id;
1562 }
1563 }  // namespace mindspore
1564