1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/session/session_basic.h"
17
18 #include <algorithm>
19 #include <set>
20 #include <queue>
21 #include <utility>
22 #include <functional>
23 #include <unordered_map>
24
25 #include "ops/ascend_op_name.h"
26 #include "ops/structure_op_name.h"
27 #include "ops/framework_op_name.h"
28 #include "ops/sequence_ops.h"
29 #include "utils/hash_map.h"
30 #include "ops/primitive_c.h"
31 #include "ir/manager.h"
32 #include "abstract/utils.h"
33 #include "kernel/common_utils.h"
34 #include "base/base_ref_utils.h"
35 #include "runtime/device/ms_device_shape_transfer.h"
36 #include "include/common/utils/config_manager.h"
37 #include "include/backend/anf_runtime_algorithm.h"
38 #include "include/common/utils/anfalgo.h"
39 #include "backend/common/session/executor_manager.h"
40 #include "backend/common/optimizer/common_backend_optimization.h"
41 #include "include/backend/optimizer/helper.h"
42 #include "include/backend/optimizer/op_adaptation_info_factory.h"
43 #include "runtime/device/kernel_runtime_manager.h"
44 #include "runtime/pynative/op_compiler.h"
45 #include "utils/ms_utils.h"
46 #include "ir/anf.h"
47 #include "ir/func_graph_cloner.h"
48 #include "include/common/utils/utils.h"
49 #include "include/common/debug/anf_ir_dump.h"
50 #include "include/common/debug/dump_proto.h"
51 #include "utils/file_utils.h"
52 #include "utils/trace_base.h"
53 #include "include/common/utils/parallel_context.h"
54 #include "kernel/oplib/oplib.h"
55 #if defined(__linux__) && defined(WITH_BACKEND)
56 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
57 #include "include/backend/distributed/ps/constants.h"
58 #include "include/backend/distributed/ps/util.h"
59 #include "include/backend/distributed/ps/ps_context.h"
60 #include "abstract/abstract_value.h"
61 #endif
62 #include "backend/common/session/session_factory.h"
63 #include "runtime/pynative/op_executor.h"
64 #ifdef ENABLE_DEBUGGER
65 #include "debug/tensor_load.h"
66 #include "debug/debugger/proto_exporter.h"
67 #endif
68 #include "include/backend/debug/debugger/proto_exporter.h"
69 #ifdef ENABLE_DUMP_IR
70 #include "debug/rdr/graph_exec_order_recorder.h"
71 #include "include/common/debug/rdr/recorder_manager.h"
72 #include "debug/rdr/graph_recorder.h"
73 #include "runtime/hardware/device_context_manager.h"
74 #endif
75 #ifndef ENABLE_SECURITY
76 #include "include/backend/debug/data_dump/dump_json_parser.h"
77 #include "include/backend/debug/data_dump/e2e_dump.h"
78 #endif
79
80 namespace mindspore {
81 namespace session {
82 MS_REG_SESSION(kSessionBasic, SessionBasic);
83
84 namespace {
85 constexpr int64_t kInvalidShape = -2;
IsPynativeMode()86 static bool IsPynativeMode() {
87 auto ms_context = MsContext::GetInstance();
88 MS_EXCEPTION_IF_NULL(ms_context);
89 return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
90 }
91
GetNodeOutputTensorFromInputs(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors)92 BaseRef GetNodeOutputTensorFromInputs(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
93 const std::vector<tensor::TensorPtr> &input_tensors) {
94 auto &node = node_output_pair.first;
95 MS_EXCEPTION_IF_NULL(node);
96 if (HasAbstractMonad(node)) {
97 return std::make_shared<tensor::Tensor>(int64_t(0), kBool);
98 }
99 // if node is a value node, no need sync addr from device to host
100 if (node->isa<ValueNode>()) {
101 auto value_node = node->cast<ValueNodePtr>();
102 MS_EXCEPTION_IF_NULL(value_node);
103 return value_node->value();
104 }
105 if (IsPynativeMode()) {
106 return nullptr;
107 }
108 if (!node->isa<Parameter>()) {
109 return nullptr;
110 }
111 MS_EXCEPTION_IF_NULL(graph);
112 auto param_node = node->cast<ParameterPtr>();
113 if (param_node != nullptr && param_node->IsUsedByRealKernelInGraph(graph->graph_id())) {
114 return nullptr;
115 }
116 for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) {
117 if (input_idx >= input_tensors.size()) {
118 MS_LOG(EXCEPTION) << "Input idx:" << input_idx << " is out of range:" << input_tensors.size();
119 }
120 if (graph->inputs()[input_idx] == node) {
121 return input_tensors[input_idx];
122 }
123 }
124 return nullptr;
125 }
126
CreateNodeOutputTensor(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node)127 BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
128 const std::vector<tensor::TensorPtr> &input_tensors,
129 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) {
130 auto &node = node_output_pair.first;
131 size_t output_index = node_output_pair.second;
132 MS_EXCEPTION_IF_NULL(node);
133 MS_EXCEPTION_IF_NULL(graph);
134 auto tensor_from_input = GetNodeOutputTensorFromInputs(node_output_pair, graph, input_tensors);
135 if (tensor_from_input != nullptr) {
136 return tensor_from_input;
137 }
138 TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
139 if (type_id == kTypeUnknown) {
140 type_id = common::AnfAlgo::GetOutputInferDataType(node, output_index);
141 }
142
143 auto shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
144 if (common::AnfAlgo::IsDynamicShape(node)) {
145 auto max_shape = common::AnfAlgo::GetOutputMaxShape(node, output_index);
146 if (abstract::ShapeSize(max_shape) > abstract::ShapeSize(shape)) {
147 shape = max_shape;
148 }
149 }
150 tensor::TensorPtr tensor;
151 bool is_internal_output = graph->IsInternalOutput(node, output_index);
152 if (is_internal_output) {
153 tensor = graph->GetInternalOutputTensor(node, output_index);
154 if (tensor == nullptr) {
155 tensor = std::make_shared<tensor::Tensor>(type_id, shape);
156 graph->AddInternalOutputTensor(node, output_index, tensor);
157 }
158 } else {
159 tensor = std::make_shared<tensor::Tensor>(type_id, shape);
160 }
161 MS_EXCEPTION_IF_NULL(tensor);
162 if (is_internal_output) {
163 tensor->set_sync_status(kNoNeedSync);
164 } else {
165 // if in pynative mode,data only copied to host when user want to print data
166 auto ms_context = MsContext::GetInstance();
167 MS_EXCEPTION_IF_NULL(ms_context);
168 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode &&
169 ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
170 tensor->set_sync_status(kNeedSyncDeviceToHostImmediately);
171 } else {
172 tensor->set_sync_status(kNeedSyncDeviceToHost);
173 }
174 }
175 tensor->SetIsGraphOutput();
176 (*tensor_to_node)[tensor] = node_output_pair;
177 return tensor;
178 }
179
GetOpRunDeviceTarget(const PrimitivePtr & op_prim)180 std::string GetOpRunDeviceTarget(const PrimitivePtr &op_prim) {
181 auto ms_context = MsContext::GetInstance();
182 MS_EXCEPTION_IF_NULL(ms_context);
183 const std::string &device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
184
185 MS_EXCEPTION_IF_NULL(op_prim);
186 const auto &attr_map = op_prim->attrs();
187 auto iter = attr_map.find(kAttrPrimitiveTarget);
188 if (iter != attr_map.end()) {
189 return GetValue<std::string>(iter->second);
190 }
191 return device_target;
192 }
193
194 // Need to discard input tensor properties in heterogeneous scenarios.
195 // For example, the format of device_address in input_tensor is 5D format,
196 // and it's invalid for CPU graph parameter.
NeedDiscardTensorProperties(const std::string & op_device_target,const device::DeviceAddressPtr & tensor_device_address)197 bool NeedDiscardTensorProperties(const std::string &op_device_target,
198 const device::DeviceAddressPtr &tensor_device_address) {
199 if (tensor_device_address == nullptr) {
200 return true;
201 }
202
203 if (op_device_target == device::GetDeviceNameByType(tensor_device_address->GetDeviceType())) {
204 return false;
205 }
206 return true;
207 }
208
ConstructRunOpParameter(const std::shared_ptr<KernelGraph> & graph,const tensor::BaseTensorPtr & input_tensor,const BackendOpRunInfoPtr & op_run_info,InputType input_type)209 ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
210 const tensor::BaseTensorPtr &input_tensor, const BackendOpRunInfoPtr &op_run_info,
211 InputType input_type) {
212 MS_EXCEPTION_IF_NULL(graph);
213 auto param = graph->NewParameter();
214 MS_EXCEPTION_IF_NULL(param);
215 if (input_type == InputType::kParameter) {
216 param->set_default_param(input_tensor);
217 }
218
219 // set the kernel info of parameter
220 auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
221 MS_EXCEPTION_IF_NULL(input_tensor);
222 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(input_tensor->device_address());
223 if (NeedDiscardTensorProperties(op_run_info->base_op_run_info.device_target, device_address)) {
224 kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
225 TypeId param_init_data_type = common::AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type();
226 kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{param_init_data_type});
227 } else {
228 kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{device_address->type_id()});
229 kernel_build_info_builder->SetOutputsReshapeType({device_address->padding_type()});
230 kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{device_address->format()});
231 }
232 if (input_tensor->isa<tensor::MapTensor>()) {
233 auto map_tensor = input_tensor->cast<tensor::MapTensorPtr>();
234 auto map_tensor_abs = std::make_shared<abstract::AbstractMapTensor>(map_tensor);
235 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
236 param->set_abstract(map_tensor_abs);
237 return param;
238 }
239 AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get());
240 // construct abstract of parameter
241 auto type_of_tensor = input_tensor->Dtype();
242 std::shared_ptr<abstract::AbstractTensor> abstract;
243 // Base_shape_ptr is set in dynamic shape scenario, if nullptr, not dynamic shape
244 if (input_tensor->base_shape_ptr() != nullptr) {
245 abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, input_tensor->base_shape_ptr());
246 } else {
247 abstract = std::make_shared<abstract::AbstractTensor>(type_of_tensor, input_tensor->shape());
248 }
249 param->set_abstract(abstract);
250 return param;
251 }
252
DumpGraphOutput(const Any & any,size_t recurse_level=0)253 void DumpGraphOutput(const Any &any, size_t recurse_level = 0) {
254 MS_LOG(INFO) << "Graph outputs:";
255 const size_t max_deep = 10;
256 if (recurse_level > max_deep) {
257 MS_LOG(INFO) << "Recurse too deep";
258 return;
259 }
260 std::string tab_str;
261 for (size_t i = 0; i < recurse_level; i++) {
262 tab_str = tab_str.append(" ");
263 }
264 if (any.is<AnyList>()) {
265 (void)tab_str.append("{");
266 MS_LOG(INFO) << tab_str;
267 auto any_list = any.cast<AnyList>();
268 for (auto &it : any_list) {
269 DumpGraphOutput(it, recurse_level + 1);
270 }
271 (void)tab_str.append("}");
272 MS_LOG(INFO) << tab_str;
273 }
274 (void)tab_str.append(any.ToString());
275 MS_LOG(INFO) << tab_str;
276 }
277
CreateNodeOutputPlaceholder(const session::KernelWithIndex & node_output_pair,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<size_t> & indexes,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)278 BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
279 const std::vector<tensor::TensorPtr> &input_tensors,
280 const std::vector<size_t> &indexes,
281 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
282 auto &node = node_output_pair.first;
283 MS_EXCEPTION_IF_NULL(node);
284 MS_EXCEPTION_IF_NULL(graph);
285 MS_EXCEPTION_IF_NULL(output_indexes);
286 MS_LOG(DEBUG) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second
287 << "]";
288 // if node is a value node, no need sync addr from device to host
289 if (node->isa<ValueNode>()) {
290 auto value_node = node->cast<ValueNodePtr>();
291 MS_EXCEPTION_IF_NULL(value_node);
292 return value_node->value();
293 }
294 if (node->isa<Parameter>()) {
295 const auto &input_nodes = graph->input_nodes();
296 for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) {
297 if (input_idx >= input_tensors.size()) {
298 MS_LOG(EXCEPTION) << "Input idx:" << input_idx << " is out of range:" << input_tensors.size();
299 }
300 if (input_nodes[input_idx] == node) {
301 return input_tensors[input_idx];
302 }
303 }
304 MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr";
305 }
306 (*output_indexes)[node_output_pair].emplace_back(indexes);
307 BaseRef output_placeholder = std::make_shared<BaseRef>();
308 return output_placeholder;
309 }
310
CreateNodeOutputPlaceholder(const AnfNodePtr & anf,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,const std::vector<size_t> & indexes,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes)311 BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph,
312 const std::vector<tensor::TensorPtr> &input_tensors,
313 const std::vector<size_t> &indexes,
314 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) {
315 MS_EXCEPTION_IF_NULL(anf);
316 MS_EXCEPTION_IF_NULL(output_indexes);
317 MS_LOG(DEBUG) << "Create placeholder for output[" << anf->DebugString() << "]";
318 auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(anf, 0);
319 MS_EXCEPTION_IF_NULL(item_with_index.first);
320 MS_LOG(DEBUG) << "Create placeholder for output after visit:" << item_with_index.first->DebugString();
321 // special handle for maketuple
322 if (common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
323 auto cnode = item_with_index.first->cast<CNodePtr>();
324 MS_EXCEPTION_IF_NULL(cnode);
325 VectorRef ret;
326 for (size_t i = 1; i < cnode->size(); ++i) {
327 std::vector<size_t> cur_index = indexes;
328 cur_index.emplace_back(i - 1);
329 auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes);
330 ret.push_back(out);
331 }
332 return ret;
333 }
334 // if is graph return nothing ,the function should return a null anylist
335 size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
336 if (size == 0) {
337 return VectorRef();
338 }
339 return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
340 }
341
CheckInputTensorShape(const tensor::BaseTensorPtr & tensor,const CNodePtr & kernel,size_t input_index)342 void CheckInputTensorShape(const tensor::BaseTensorPtr &tensor, const CNodePtr &kernel, size_t input_index) {
343 MS_EXCEPTION_IF_NULL(tensor);
344 const auto &tensor_shape = tensor->shape();
345 const auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
346 if (tensor_shape.size() != input_shape.size()) {
347 MS_LOG(EXCEPTION) << "The input tensor's shape size: " << tensor_shape.size()
348 << " is not equal to expected size: " << input_shape.size() << " for input[" << input_index
349 << "] of kernel: " << common::AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
350 }
351 for (size_t i = 0; i < tensor_shape.size(); i++) {
352 if (tensor_shape[i] < 0 || (tensor_shape[i] != input_shape[i] && input_shape[i] >= 0)) {
353 MS_LOG(EXCEPTION) << "The input tensor's shape: " << tensor_shape
354 << " is not equal to expected shape: " << input_shape << " for input[" << input_index
355 << "] of kernel: " << common::AnfAlgo::GetCNodeName(kernel) << trace::DumpSourceLines(kernel);
356 }
357 }
358 }
359
is_param_scalar(const size_t & param_shape_size,const size_t & input_shape_size)360 bool is_param_scalar(const size_t ¶m_shape_size, const size_t &input_shape_size) {
361 if (param_shape_size == 1 && input_shape_size == 0) {
362 return true;
363 }
364 if (param_shape_size == 0 && input_shape_size == 1) {
365 return true;
366 }
367 return false;
368 }
369
ConvertVectorRefOutputs(const VectorRef & op_outputs)370 ValuePtrList ConvertVectorRefOutputs(const VectorRef &op_outputs) {
371 ValuePtrList op_ouputs;
372 for (auto value : op_outputs.elements_) {
373 (void)op_ouputs.emplace_back(utils::cast<ValuePtr>(value));
374 }
375 return op_ouputs;
376 }
377 } // namespace
378
CreateNodeOutputTensors(const AnfNodePtr & anf,const KernelGraphPtr & graph,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node,KernelMapTensor * node_to_tensor)379 BaseRef SessionBasic::CreateNodeOutputTensors(const AnfNodePtr &anf, const KernelGraphPtr &graph,
380 const std::vector<tensor::TensorPtr> &input_tensors,
381 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
382 KernelMapTensor *node_to_tensor) {
383 MS_EXCEPTION_IF_NULL(anf);
384 MS_EXCEPTION_IF_NULL(tensor_to_node);
385 MS_EXCEPTION_IF_NULL(node_to_tensor);
386 MS_LOG(DEBUG) << "Create tensor for output[" << anf->DebugString() << "]";
387 auto item_with_index = common::AnfAlgo::VisitKernelWithReturnType(anf, 0);
388 MS_EXCEPTION_IF_NULL(item_with_index.first);
389 MS_LOG(DEBUG) << "Create tensor for output after visit:" << item_with_index.first->DebugString();
390 // special handle for maketuple
391 if (common::AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
392 auto cnode = item_with_index.first->cast<CNodePtr>();
393 MS_EXCEPTION_IF_NULL(cnode);
394 VectorRef ret;
395 for (size_t i = 1; i < cnode->size(); ++i) {
396 auto out = CreateNodeOutputTensors(cnode->input(i), graph, input_tensors, tensor_to_node, node_to_tensor);
397 (void)ret.emplace_back(out);
398 }
399 return ret;
400 }
401 // if is graph return nothing ,the function should return a null anylist
402 size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first);
403 if (size == 0) {
404 return VectorRef();
405 }
406
407 // The outputs of graph may have the same kernel node, no need to create new tensor.
408 const auto &iter = node_to_tensor->find(item_with_index);
409 if (iter != node_to_tensor->end()) {
410 return iter->second;
411 }
412
413 const auto &tensor = CreateNodeOutputTensor(item_with_index, graph, input_tensors, tensor_to_node);
414 (*node_to_tensor)[item_with_index] = tensor;
415 return tensor;
416 }
417
InitExecutor(const std::string & device_name,uint32_t device_id)418 void SessionBasic::InitExecutor(const std::string &device_name, uint32_t device_id) {
419 device_id_ = device_id;
420 context_ = std::make_shared<Context>(device_name, device_id);
421 executor_ = ExecutorManager::Instance().GetExecutor(device_name, device_id);
422 }
423
GetSingleOpRunInfo(const CNodePtr & cnode,const InputInfo & input_info,const GraphOutputInfo * const graph_output_info) const424 BackendOpRunInfoPtr SessionBasic::GetSingleOpRunInfo(const CNodePtr &cnode, const InputInfo &input_info,
425 const GraphOutputInfo *const graph_output_info) const {
426 MS_EXCEPTION_IF_NULL(cnode);
427 auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
428 MS_EXCEPTION_IF_NULL(primitive);
429 const auto &abstract = cnode->abstract();
430 if (abstract == nullptr) {
431 MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString();
432 }
433 const auto &shape = abstract->BuildShape();
434 MS_EXCEPTION_IF_NULL(shape);
435
436 std::vector<size_t> output_indexes;
437 bool is_gradient_out = false;
438 if (graph_output_info != nullptr) {
439 for (auto &item : graph_output_info->output_indexes) {
440 if (item.first.first == cnode) {
441 is_gradient_out = true;
442 (void)output_indexes.emplace_back(item.first.second);
443 }
444 }
445 }
446
447 pynative::BaseOpRunInfo base_op_run_info;
448 base_op_run_info.is_mixed_precision_cast = false;
449 base_op_run_info.has_dynamic_output = shape->IsDynamic();
450 base_op_run_info.op_name = primitive->name();
451 base_op_run_info.next_op_name = std::string();
452 base_op_run_info.device_target = GetOpRunDeviceTarget(primitive);
453 base_op_run_info.next_input_index = 0;
454 base_op_run_info.expanded_input_values.clear();
455 for (auto const &value : input_info.input_values) {
456 base_op_run_info.expanded_input_values.emplace_back(value);
457 }
458 base_op_run_info.input_types = input_info.input_types;
459 base_op_run_info.abstract = abstract;
460 base_op_run_info.output_indexes = output_indexes;
461 return std::make_shared<BackendOpRunInfo>(base_op_run_info, primitive, false, is_gradient_out);
462 }
463
GetParameterIndex(const KernelGraph * graph,const std::vector<tensor::TensorPtr> & inputs,std::map<AnfNodePtr,size_t> * parameter_index) const464 void SessionBasic::GetParameterIndex(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
465 std::map<AnfNodePtr, size_t> *parameter_index) const {
466 MS_EXCEPTION_IF_NULL(graph);
467 MS_EXCEPTION_IF_NULL(parameter_index);
468 size_t index = 0;
469 auto parallel_context = parallel::ParallelContext::GetInstance();
470 MS_EXCEPTION_IF_NULL(parallel_context);
471 auto parallel_mode = parallel_context->parallel_mode();
472 bool is_parallel_forward_jit =
473 !graph->has_flag(kFlagIsPynativeBpropGraph) &&
474 (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel);
475 for (const auto &input_node : graph->input_nodes()) {
476 auto params = common::AnfAlgo::GetAllOutput(input_node);
477 for (const auto ¶m : params) {
478 if (index >= inputs.size()) {
479 MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index
480 << ", input size: " << inputs.size();
481 }
482 const auto &input = inputs[index];
483 MS_EXCEPTION_IF_NULL(input);
484 MS_EXCEPTION_IF_NULL(param);
485 // Check shape of input and parameter
486 const auto &input_shape = input->shape();
487 const auto ¶m_shape = common::AnfAlgo::GetOutputInferShape(param, 0);
488 bool is_dynamic = param->Shape()->IsDynamic();
489 // Dynamic shape feed mode, shape is dynamic but max shape is ()
490 if (!is_dynamic || !param_shape.empty()) {
491 if (!is_parallel_forward_jit && input_shape.size() != param_shape.size()) {
492 // Infer shape is -2, which indicates that the shape cannot be infer currently
493 if (param_shape.size() == 1 && param_shape[0] == kInvalidShape) {
494 parameter_index->emplace(param, index++);
495 continue;
496 }
497 // Input is scalar. param shape will be [1], input shape will be []
498 if (is_param_scalar(param_shape.size(), input_shape.size())) {
499 parameter_index->emplace(param, index++);
500 continue;
501 }
502 MS_LOG(EXCEPTION) << "Shape size of input tensor(" << input_shape << ") and parameter(" << param_shape
503 << ") are different, input index: " << index << ", parameter: " << param->DebugString();
504 }
505 for (size_t i = 0; i < input_shape.size(); i += 1) {
506 if (input_shape[i] < 0 || (!is_parallel_forward_jit && input_shape[i] != param_shape[i] && !is_dynamic)) {
507 MS_LOG(EXCEPTION) << "Input tensor shape(" << input_shape << ") and parameter shape(" << param_shape
508 << ") are different, input index: " << index << ", parameter: " << param->DebugString();
509 }
510 }
511 }
512 parameter_index->emplace(param, index++);
513 }
514 }
515 }
516
CreateOutputPlaceholder(const KernelGraphPtr & kernel_graph,const std::vector<tensor::TensorPtr> & input_tensors,VectorRef * const outputs,std::map<KernelWithIndex,std::vector<std::vector<size_t>>> * output_indexes) const517 void SessionBasic::CreateOutputPlaceholder(
518 const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *const outputs,
519 std::map<KernelWithIndex, std::vector<std::vector<size_t>>> *output_indexes) const {
520 MS_EXCEPTION_IF_NULL(kernel_graph);
521 MS_EXCEPTION_IF_NULL(outputs);
522 MS_EXCEPTION_IF_NULL(output_indexes);
523 auto anf_outputs = kernel_graph->outputs();
524 size_t index = 0;
525 for (auto &item : anf_outputs) {
526 MS_EXCEPTION_IF_NULL(item);
527 std::vector<size_t> indexes{index++};
528 (void)outputs->emplace_back(
529 CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes));
530 }
531 }
532
GetRefCount(const KernelGraph * graph,std::map<KernelWithIndex,size_t> * ref_count) const533 void SessionBasic::GetRefCount(const KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) const {
534 MS_EXCEPTION_IF_NULL(graph);
535 for (const auto &kernel : graph->execution_order()) {
536 for (size_t i = 1; i < kernel->size(); i += 1) {
537 auto input = kernel->inputs()[i];
538 CalculateRefCount(input, ref_count);
539 }
540 }
541 }
542
CalculateRefCount(const AnfNodePtr & node,std::map<KernelWithIndex,size_t> * ref_count) const543 void SessionBasic::CalculateRefCount(const AnfNodePtr &node, std::map<KernelWithIndex, size_t> *ref_count) const {
544 if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
545 auto kernel_with_index = common::AnfAlgo::VisitKernel(node, 0);
546 const auto &real_input = kernel_with_index.first;
547 if (real_input->isa<CNode>()) {
548 (*ref_count)[kernel_with_index] += 1;
549 }
550 return;
551 }
552 auto cnode = node->cast<CNodePtr>();
553 for (size_t i = 1; i < cnode->size(); ++i) {
554 auto input = cnode->input(i);
555 CalculateRefCount(input, ref_count);
556 }
557 }
558
GetForwardOpOutputRefCount(const KernelGraph * graph,const std::vector<tensor::TensorPtr> & inputs,std::map<std::string,size_t> * forward_op_output_tensor_id,const std::map<AnfNodePtr,size_t> & parameter_index) const559 void SessionBasic::GetForwardOpOutputRefCount(const KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs,
560 std::map<std::string, size_t> *forward_op_output_tensor_id,
561 const std::map<AnfNodePtr, size_t> ¶meter_index) const {
562 auto context_ptr = MsContext::GetInstance();
563 MS_EXCEPTION_IF_NULL(context_ptr);
564 // Cpu can not clear device address, because it's device address and host address is the same
565 if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kCPUDevice) {
566 return;
567 }
568 MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
569 for (const auto &kernel : graph->execution_order()) {
570 MS_EXCEPTION_IF_NULL(kernel);
571 const auto input_tensor_num = common::AnfAlgo::GetInputTensorNum(kernel);
572 for (size_t i = 1; i <= input_tensor_num; ++i) {
573 const auto &input = kernel->input(i);
574 auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
575 auto real_input = kernel_with_index.first;
576 MS_EXCEPTION_IF_NULL(real_input);
577 if (real_input->isa<ValueNode>()) {
578 const auto &value = GetValueNodeOutput(real_input, kernel_with_index.second);
579 if (value == nullptr || !value->isa<tensor::Tensor>()) {
580 continue;
581 }
582 auto tensor = value->cast<tensor::TensorPtr>();
583 if (tensor->is_forward_output()) {
584 (*forward_op_output_tensor_id)[tensor->id()] += 1;
585 }
586 } else if (real_input->isa<Parameter>()) {
587 // Forward op output use as sens, so need add reference
588 auto iter = parameter_index.find(real_input);
589 if (iter != parameter_index.end()) {
590 auto tensor = inputs[iter->second];
591 if (tensor->is_forward_output()) {
592 (*forward_op_output_tensor_id)[tensor->id()] += 1;
593 }
594 }
595 }
596 }
597 }
598 MS_LOG(DEBUG) << "Forward op output tensor in bprop graph size " << forward_op_output_tensor_id->size();
599 }
600
ReleaseForwardOpOutput(const std::vector<ValuePtr> & input_values,std::map<std::string,size_t> * forward_op_output_tensor_id) const601 void SessionBasic::ReleaseForwardOpOutput(const std::vector<ValuePtr> &input_values,
602 std::map<std::string, size_t> *forward_op_output_tensor_id) const {
603 MS_EXCEPTION_IF_NULL(forward_op_output_tensor_id);
604 for (const auto &value : input_values) {
605 auto tensor = value->cast<tensor::BaseTensorPtr>();
606 if (tensor == nullptr) {
607 continue;
608 }
609
610 if (!tensor->is_forward_output()) {
611 continue;
612 }
613 auto it = forward_op_output_tensor_id->find(tensor->id());
614 if (it != forward_op_output_tensor_id->end()) {
615 if (--(it->second) == 0) {
616 tensor->set_device_address(nullptr);
617 forward_op_output_tensor_id->erase(it);
618 }
619 }
620 }
621 }
622
HandleOpInputs(const std::set<KernelWithIndex> & input_kernel,std::map<KernelWithIndex,size_t> * ref_count,std::map<KernelWithIndex,tensor::BaseTensorPtr> * op_output_map) const623 void SessionBasic::HandleOpInputs(const std::set<KernelWithIndex> &input_kernel,
624 std::map<KernelWithIndex, size_t> *ref_count,
625 std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map) const {
626 MS_EXCEPTION_IF_NULL(ref_count);
627 MS_EXCEPTION_IF_NULL(op_output_map);
628 for (const auto &kernel_with_index : input_kernel) {
629 if (!kernel_with_index.first->isa<CNode>()) {
630 continue;
631 }
632
633 // Release previous output
634 auto ref_iter = ref_count->find(kernel_with_index);
635 if (ref_iter == ref_count->end()) {
636 MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = "
637 << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
638 }
639 // Reduce reference count number, when it was reduced to zero, release the useless output of pre node.
640 ref_iter->second -= 1;
641 if (ref_iter->second != 0) {
642 continue;
643 }
644 ref_count->erase(ref_iter);
645 auto output_iter = op_output_map->find(kernel_with_index);
646 if (output_iter == op_output_map->end()) {
647 MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = "
648 << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second;
649 }
650 op_output_map->erase(output_iter);
651 }
652 }
653
HandleOpOutputs(const AnfNodePtr & kernel,const VectorRef & op_outputs,const std::map<KernelWithIndex,size_t> & ref_count,std::map<KernelWithIndex,tensor::BaseTensorPtr> * op_output_map,GraphOutputInfo * const graph_output_info) const654 void SessionBasic::HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs,
655 const std::map<KernelWithIndex, size_t> &ref_count,
656 std::map<KernelWithIndex, tensor::BaseTensorPtr> *op_output_map,
657 GraphOutputInfo *const graph_output_info) const {
658 MS_EXCEPTION_IF_NULL(kernel);
659 MS_EXCEPTION_IF_NULL(op_output_map);
660 MS_EXCEPTION_IF_NULL(graph_output_info);
661 MS_EXCEPTION_IF_NULL(graph_output_info->graph_outputs);
662 ValuePtrList output_values;
663 if (common::AnfAlgo::IsBpropCutOpExecInBackend(kernel)) {
664 output_values = ConvertVectorRefOutputs(op_outputs);
665 } else {
666 output_values = common::AnfAlgo::TransformVectorRefToMultiValue(op_outputs);
667 }
668 if (output_values.size() > op_outputs.size()) {
669 MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString();
670 }
671 size_t out_index = 0;
672 for (const auto &output_value : output_values) {
673 auto kernel_with_index = make_pair(kernel, out_index++);
674 auto output_tensor = output_value->cast<tensor::BaseTensorPtr>();
675 bool value_is_tensor = (output_tensor != nullptr);
676 if (ref_count.find(kernel_with_index) != ref_count.end() && value_is_tensor) {
677 (*op_output_map)[kernel_with_index] = output_tensor;
678 }
679 const auto &iter = graph_output_info->output_indexes.find(kernel_with_index);
680 if (iter == graph_output_info->output_indexes.end()) {
681 continue;
682 }
683 const std::vector<std::vector<size_t>> &multiple_ref_indexes = iter->second;
684 for (const auto &ref_indexes : multiple_ref_indexes) {
685 size_t n = 0;
686 const VectorRef *cur_vector_ref = graph_output_info->graph_outputs;
687 for (; n < ref_indexes.size() - 1; n += 1) {
688 size_t index = ref_indexes.at(n);
689 if (index >= cur_vector_ref->size()) {
690 MS_LOG(EXCEPTION) << "Get invalid output ref index: " << index << ", size of vertor ref is "
691 << cur_vector_ref->size();
692 }
693 const BaseRef &base_ref = (*cur_vector_ref)[index];
694 if (!utils::isa<VectorRef>(base_ref)) {
695 MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, index: " << index << "cur n: " << n;
696 }
697 cur_vector_ref = &utils::cast<VectorRef>(base_ref);
698 }
699 BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)];
700 tensor_ref = output_value;
701 if (value_is_tensor) {
702 (void)graph_output_info->graph_output_tensors.emplace_back(output_tensor);
703 }
704 }
705 }
706 }
707
GetValueNodeOutput(const AnfNodePtr & node,size_t output_index) const708 ValuePtr SessionBasic::GetValueNodeOutput(const AnfNodePtr &node, size_t output_index) const {
709 MS_EXCEPTION_IF_NULL(node);
710 if (!node->isa<ValueNode>()) {
711 return nullptr;
712 }
713 auto value_node = node->cast<ValueNodePtr>();
714 MS_EXCEPTION_IF_NULL(value_node);
715 auto value = GetValueNode(value_node);
716 MS_EXCEPTION_IF_NULL(value);
717 if (value->isa<ValueTuple>()) {
718 auto value_tuple = value->cast<ValueTuplePtr>();
719 MS_EXCEPTION_IF_NULL(value_tuple);
720 if (value_tuple->value().empty()) {
721 // empty tuple
722 return value;
723 }
724 if (output_index >= value_tuple->size()) {
725 MS_LOG(EXCEPTION) << "Index " << output_index << "is out of value tuple range";
726 }
727 auto tensor_value = value_tuple->value()[output_index];
728 if (tensor_value->isa<tensor::Tensor>()) {
729 return tensor_value;
730 } else {
731 return value;
732 }
733 } else if (value->isa<tensor::Tensor>()) {
734 if (output_index != 0) {
735 MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << output_index;
736 }
737 return value;
738 } else if (value->isa<StringImm>()) {
739 auto value_string = GetValue<std::string>(value);
740 const ShapeVector shape = {1, SizeToLong(value_string.size())};
741 TensorPtr tensor = std::make_shared<Tensor>(kObjectTypeString, shape, value_string.data(), value_string.size());
742 MS_EXCEPTION_IF_NULL(tensor);
743 tensor->set_sync_status(kNeedSyncHostToDevice);
744 return tensor;
745 } else if (value->isa<tensor::CSRTensor>()) {
746 return value->cast<tensor::CSRTensorPtr>()->GetTensorAt(output_index);
747 } else if (value->isa<tensor::COOTensor>()) {
748 return value->cast<tensor::COOTensorPtr>()->GetTensorAt(output_index);
749 }
750
751 return value;
752 }
753
GetParameterOutputTensor(const AnfNodePtr & node,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs) const754 TensorPtr SessionBasic::GetParameterOutputTensor(const AnfNodePtr &node,
755 const std::map<AnfNodePtr, size_t> ¶meter_index,
756 const std::vector<tensor::TensorPtr> &graph_inputs) const {
757 MS_EXCEPTION_IF_NULL(node);
758 if (!node->isa<Parameter>()) {
759 return nullptr;
760 }
761 const auto &iter = parameter_index.find(node);
762 if (iter == parameter_index.end()) {
763 MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, parameter = " << node->DebugString();
764 }
765 const size_t index = iter->second;
766 if (index >= graph_inputs.size()) {
767 MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " << index
768 << ", input tensor size = " << graph_inputs.size();
769 }
770 return graph_inputs[index];
771 }
772
GetCNodeOutputTensor(const KernelWithIndex & kernel_with_index,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output) const773 tensor::BaseTensorPtr SessionBasic::GetCNodeOutputTensor(
774 const KernelWithIndex &kernel_with_index, const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output) const {
775 const auto &iter = op_output.find(kernel_with_index);
776 if (iter == op_output.end()) {
777 MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << kernel_with_index.first->DebugString();
778 }
779 return iter->second;
780 }
781
GetConstValueDepend(const CNodePtr & cnode,std::set<int64_t> * const_input_attr_index) const782 void SessionBasic::GetConstValueDepend(const CNodePtr &cnode, std::set<int64_t> *const_input_attr_index) const {
783 MS_EXCEPTION_IF_NULL(cnode);
784 MS_EXCEPTION_IF_NULL(const_input_attr_index);
785 auto ms_context = MsContext::GetInstance();
786 MS_EXCEPTION_IF_NULL(ms_context);
787 auto device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
788 if (device_target != kAscendDevice) {
789 return;
790 }
791 *const_input_attr_index = abstract::GetValueDependArgIndices(cnode);
792 if (!const_input_attr_index->empty()) {
793 return;
794 }
795 auto op_name = common::AnfAlgo::GetCNodeName(cnode);
796 auto op_adaptation_info = opt::OpAdaptationInfoRegister::GetOpAdaptationInfo(op_name, kAscendDevice, true);
797 if (op_adaptation_info == nullptr) {
798 return;
799 }
800 if (op_adaptation_info->is_ascend_mindir()) {
801 auto input_to_attr_map = op_adaptation_info->input_attr_map();
802 for (const auto &input_attr_info : input_to_attr_map) {
803 (void)const_input_attr_index->insert(SizeToLong(input_attr_info.first));
804 }
805 }
806 }
807
GetShapeFromTuple(const abstract::AbstractTuplePtr & tuple_abs,const size_t index)808 static inline BaseShapePtr GetShapeFromTuple(const abstract::AbstractTuplePtr &tuple_abs, const size_t index) {
809 MS_EXCEPTION_IF_NULL(tuple_abs);
810 const auto &elements = tuple_abs->elements();
811 if (!elements.empty()) {
812 auto tuple_abs_elem = elements[index];
813 MS_EXCEPTION_IF_NULL(tuple_abs_elem);
814 return tuple_abs_elem->GetShape();
815 }
816 // empty tuple
817 return tuple_abs->GetShape();
818 }
819
GetOpInputTensors(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info) const820 void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
821 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
822 const std::map<AnfNodePtr, size_t> ¶meter_index,
823 const std::vector<tensor::TensorPtr> &graph_inputs, InputInfo *input_info) const {
824 MS_EXCEPTION_IF_NULL(cnode);
825 MS_EXCEPTION_IF_NULL(input_info);
826 auto context = MsContext::GetInstance();
827 MS_EXCEPTION_IF_NULL(context);
828 std::set<int64_t> const_input_attr_index = {};
829 GetConstValueDepend(cnode, &const_input_attr_index);
830 const auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
831 for (size_t i = 1; i <= input_num; i += 1) {
832 const auto &input = cnode->input(i);
833 auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
834 auto real_input = kernel_with_index.first;
835 MS_EXCEPTION_IF_NULL(real_input);
836 ValuePtr input_value = nullptr;
837 if (real_input->isa<ValueNode>()) {
838 input_value = GetValueNodeOutput(real_input, kernel_with_index.second);
839 const auto &value_ptr = GetValueNode(real_input);
840 MS_EXCEPTION_IF_NULL(value_ptr);
841 auto is_value_node = value_ptr->isa<StringImm>();
842 if (!const_input_attr_index.empty()) {
843 is_value_node = (const_input_attr_index.count(SizeToLong(i - 1)) != 0);
844 }
845
846 bool is_forward_output = false;
847 if (value_ptr->isa<tensor::Tensor>()) {
848 auto forward_tensor = value_ptr->cast<tensor::TensorPtr>();
849 if (forward_tensor->is_forward_output()) {
850 is_forward_output = true;
851 }
852 }
853
854 if (common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, cnode)) {
855 auto is_tensor = input_value->isa<tensor::Tensor>();
856 (void)input_info->input_types.emplace_back(
857 ((is_value_node && !is_forward_output) || !is_tensor) ? InputType::kConstant : InputType::kOpOutput);
858 } else {
859 (void)input_info->input_types.emplace_back((is_value_node || !is_forward_output) ? InputType::kConstant
860 : InputType::kOpOutput);
861 }
862 } else if (real_input->isa<Parameter>()) {
863 auto tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
864 MS_EXCEPTION_IF_NULL(tensor);
865 input_value = tensor;
866 input_info->input_types.emplace_back(tensor->is_parameter() ? InputType::kParameter : InputType::kInput);
867 } else if (real_input->isa<CNode>()) {
868 auto tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
869 MS_EXCEPTION_IF_NULL(tensor);
870 input_value = tensor;
871 if (common::AnfAlgo::IsBpropCutOpExecInBackend(real_input)) {
872 CheckInputTensorShape(tensor, cnode, i - 1);
873 }
874 input_info->input_kernel.insert(kernel_with_index);
875 input_info->input_types.emplace_back(tensor->is_parameter() ? InputType::kParameter : InputType::kOpOutput);
876 } else {
877 MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
878 }
879 MS_EXCEPTION_IF_NULL(input_value);
880 MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
881 << real_input->fullname_with_scope() << "-" << kernel_with_index.second;
882 BaseShapePtr base_shape = nullptr;
883 auto real_input_abs = real_input->abstract();
884 MS_EXCEPTION_IF_NULL(real_input_abs);
885 if (real_input_abs->isa<abstract::AbstractTuple>()) {
886 auto tuple_abs = real_input_abs->cast<abstract::AbstractTuplePtr>();
887 base_shape = GetShapeFromTuple(tuple_abs, kernel_with_index.second);
888 } else {
889 base_shape = real_input_abs->BuildShape();
890 }
891 MS_EXCEPTION_IF_NULL(base_shape);
892 if (base_shape->IsDynamic()) {
893 // in this case, input_value must be a Tensor
894 auto tensor = input_value->cast<tensor::TensorPtr>();
895 MS_EXCEPTION_IF_NULL(tensor);
896 tensor->set_base_shape(base_shape);
897 }
898 (void)input_info->input_abs.emplace_back(real_input->abstract());
899 (void)input_info->input_values.emplace_back(input_value);
900 }
901 }
902
GetOpInputTensorsFromCNode(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info) const903 void SessionBasic::GetOpInputTensorsFromCNode(const CNodePtr &cnode,
904 const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
905 const std::map<AnfNodePtr, size_t> ¶meter_index,
906 const std::vector<tensor::TensorPtr> &graph_inputs,
907 InputInfo *input_info) const {
908 MS_EXCEPTION_IF_NULL(cnode);
909 MS_EXCEPTION_IF_NULL(input_info);
910 std::function<ValuePtr(const KernelWithIndex &)> fn = [&](const KernelWithIndex &kernel_with_index) -> ValuePtr {
911 auto real_input = kernel_with_index.first;
912 MS_EXCEPTION_IF_NULL(real_input);
913 ValuePtr input_value = nullptr;
914 if (real_input->isa<CNode>()) {
915 if (IsPrimitiveCNode(real_input, prim::kPrimMakeTuple)) {
916 const auto &c_make_tuple = real_input->cast<CNodePtr>();
917 ValuePtrList v_list;
918 for (size_t j = 1; j < c_make_tuple->size(); ++j) {
919 auto kernel_with_index_input = common::AnfAlgo::VisitKernel(c_make_tuple->input(j), 0);
920 (void)v_list.emplace_back(fn(kernel_with_index_input));
921 input_info->input_kernel.insert(kernel_with_index_input);
922 }
923 input_value = std::make_shared<ValueTuple>(v_list);
924 } else {
925 input_value = GetCNodeOutputTensor(kernel_with_index, op_output);
926 input_info->input_kernel.insert(kernel_with_index);
927 }
928 } else if (real_input->isa<ValueNode>()) {
929 input_value = GetValueNodeOutput(real_input, kernel_with_index.second);
930 } else if (real_input->isa<Parameter>()) {
931 auto tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
932 input_value = tensor;
933 } else {
934 MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
935 }
936 return input_value;
937 };
938
939 const auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
940 input_info->input_values.resize(input_num);
941 input_info->input_abs.resize(input_num);
942 for (size_t i = 1; i <= input_num; ++i) {
943 const auto &input = cnode->input(i);
944 KernelWithIndex kernel_with_index;
945 // Pyboost tuple inputs can not plant, like op concat, addn, filln and so on
946 if (cnode->HasAttr(kAttrIsPyboostTupleInput)) {
947 kernel_with_index = common::AnfAlgo::VisitKernelWithReturnType(input, 0, false, {prim::kPrimMakeTuple});
948 } else {
949 kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
950 }
951 ValuePtr input_value = fn(kernel_with_index);
952 MS_EXCEPTION_IF_NULL(input_value);
953 MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from "
954 << kernel_with_index.first->fullname_with_scope() << "-" << kernel_with_index.second;
955 input_info->input_values[i - 1] = input_value;
956 input_info->input_abs[i - 1] = kernel_with_index.first->abstract();
957 }
958 }
959
GetOpInputTensorByIndex(const CNodePtr & cnode,const std::map<KernelWithIndex,tensor::BaseTensorPtr> & op_output,const std::map<AnfNodePtr,size_t> & parameter_index,const std::vector<tensor::TensorPtr> & graph_inputs,InputInfo * input_info,size_t input_index) const960 tensor::BaseTensorPtr SessionBasic::GetOpInputTensorByIndex(
961 const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::BaseTensorPtr> &op_output,
962 const std::map<AnfNodePtr, size_t> ¶meter_index, const std::vector<tensor::TensorPtr> &graph_inputs,
963 InputInfo *input_info, size_t input_index) const {
964 MS_EXCEPTION_IF_NULL(cnode);
965 MS_EXCEPTION_IF_NULL(input_info);
966 if (input_index >= cnode->size() - 1) {
967 MS_LOG(EXCEPTION) << "Input index is out of range:" << cnode->size() << ",cnode:" << cnode->DebugString();
968 }
969
970 const auto &input = cnode->input(input_index + 1);
971 auto kernel_with_index = common::AnfAlgo::VisitKernel(input, 0);
972 auto real_input = kernel_with_index.first;
973 MS_EXCEPTION_IF_NULL(real_input);
974
975 if (real_input->isa<Parameter>()) {
976 return GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
977 } else if (real_input->isa<CNode>()) {
978 tensor::BaseTensorPtr tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
979 if (common::AnfAlgo::IsBpropCutOpExecInBackend(real_input)) {
980 CheckInputTensorShape(tensor, cnode, input_index);
981 }
982 input_info->input_kernel.insert(kernel_with_index);
983 return tensor;
984 } else {
985 MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
986 }
987 }
988
UpdateOutputs(const std::shared_ptr<KernelGraph> & kernel_graph,VectorRef * const outputs,const std::vector<tensor::TensorPtr> & input_tensors,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node) const989 void SessionBasic::UpdateOutputs(const std::shared_ptr<KernelGraph> &kernel_graph, VectorRef *const outputs,
990 const std::vector<tensor::TensorPtr> &input_tensors,
991 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) const {
992 MS_EXCEPTION_IF_NULL(kernel_graph);
993 MS_EXCEPTION_IF_NULL(outputs);
994 MS_EXCEPTION_IF_NULL(tensor_to_node);
995 KernelMapTensor node_to_tensor;
996 auto anf_outputs = kernel_graph->outputs();
997 for (auto &item : anf_outputs) {
998 MS_EXCEPTION_IF_NULL(item);
999 MS_LOG(DEBUG) << "Update output[" << item->DebugString() << "]";
1000 (void)outputs->emplace_back(
1001 CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, &node_to_tensor));
1002 }
1003
1004 auto ms_context = MsContext::GetInstance();
1005 MS_EXCEPTION_IF_NULL(ms_context);
1006 for (auto &item : *tensor_to_node) {
1007 auto &tensor = item.first;
1008 auto &node = item.second.first;
1009 auto &output_index = item.second.second;
1010 DeviceAddressPtr address = nullptr;
1011 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode &&
1012 ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1013 address = AnfAlgo::GetMutableOutputAddr(node, output_index, false);
1014 } else {
1015 address = AnfAlgo::GetMutableOutputAddr(node, output_index);
1016 }
1017 MS_EXCEPTION_IF_NULL(tensor);
1018 tensor->set_device_address(address);
1019 MS_LOG(DEBUG) << "Debug address: Output tensor obj " << tensor.get() << ", tensor id " << tensor->id()
1020 << ", device address " << tensor->device_address().get();
1021 if (common::AnfAlgo::IsDynamicShape(node)) {
1022 const auto &updated_shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
1023 (void)tensor->set_shape(updated_shape);
1024 }
1025 if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
1026 tensor->data_sync(false);
1027 tensor->set_sync_status(kNeedSyncHostToDevice);
1028 }
1029 }
1030 }
1031
CreateOutputTensors(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & input_tensors,VectorRef * outputs,std::map<tensor::TensorPtr,session::KernelWithIndex> * tensor_to_node,KernelMapTensor * node_to_tensor)1032 void SessionBasic::CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors,
1033 VectorRef *outputs,
1034 std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node,
1035 KernelMapTensor *node_to_tensor) {
1036 auto kernel_graph = GetGraph(graph_id);
1037 MS_EXCEPTION_IF_NULL(kernel_graph);
1038 MS_EXCEPTION_IF_NULL(outputs);
1039 MS_EXCEPTION_IF_NULL(tensor_to_node);
1040 auto anf_outputs = kernel_graph->outputs();
1041 for (auto &item : anf_outputs) {
1042 MS_EXCEPTION_IF_NULL(item);
1043 MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
1044 (void)outputs->emplace_back(
1045 CreateNodeOutputTensors(item, kernel_graph, input_tensors, tensor_to_node, node_to_tensor));
1046 }
1047 }
1048
UpdateOutputTensors(const VectorRef * outputs,const std::map<tensor::TensorPtr,session::KernelWithIndex> & tensor_to_node,std::map<DeviceAddressPtr,DeviceAddressPtr> *)1049 void SessionBasic::UpdateOutputTensors(const VectorRef *outputs,
1050 const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node,
1051 std::map<DeviceAddressPtr, DeviceAddressPtr> *) {
1052 auto context_ptr = MsContext::GetInstance();
1053 MS_EXCEPTION_IF_NULL(context_ptr);
1054 if (device::KernelRuntime::UseMemScheduler()) {
1055 return;
1056 }
1057 MS_EXCEPTION_IF_NULL(outputs);
1058 for (const auto &item : *outputs) {
1059 if (utils::isa<VectorRefPtr>(item)) {
1060 const auto &vector_ref = utils::cast<VectorRef>(item);
1061 std::map<DeviceAddressPtr, DeviceAddressPtr> new_to_old_device_address;
1062 UpdateOutputTensors(&vector_ref, tensor_to_node, &new_to_old_device_address);
1063 } else if (utils::isa<tensor::TensorPtr>(item)) {
1064 const auto &tensor = utils::cast<tensor::TensorPtr>(item);
1065 MS_EXCEPTION_IF_NULL(tensor);
1066 const auto &iter = tensor_to_node.find(tensor);
1067 if (iter != tensor_to_node.end()) {
1068 const auto &node = iter->second.first;
1069 const auto &output_index = iter->second.second;
1070 if (!AnfAlgo::OutputAddrExist(node, output_index, true)) {
1071 continue;
1072 }
1073 const auto &address = AnfAlgo::GetMutableOutputAddr(node, output_index);
1074 tensor->set_device_address(address);
1075
1076 if (common::AnfAlgo::IsDynamicShape(node)) {
1077 const auto &updated_shape = common::AnfAlgo::GetOutputInferShape(node, output_index);
1078 (void)tensor->set_shape(updated_shape);
1079 }
1080 }
1081 if (tensor->NeedSyncDeviceToHostImmediately()) {
1082 tensor->data_sync(false);
1083 tensor->set_device_address(nullptr);
1084 tensor->set_sync_status(kNeedSyncHostToDevice);
1085 }
1086 }
1087 }
1088 }
1089
GetModelInputsInfo(uint32_t graph_id,std::vector<tensor::TensorPtr> * inputs,std::vector<std::string> * inputs_name) const1090 void SessionBasic::GetModelInputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *inputs,
1091 std::vector<std::string> *inputs_name) const {
1092 MS_LOG(INFO) << "Start get model inputs, graph id : " << graph_id;
1093 auto kernel_graph = GetGraph(graph_id);
1094 MS_EXCEPTION_IF_NULL(kernel_graph);
1095 MS_EXCEPTION_IF_NULL(inputs);
1096 MS_EXCEPTION_IF_NULL(inputs_name);
1097 auto kernel_graph_inputs = kernel_graph->inputs();
1098 // find parameters of graph inputs
1099 for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) {
1100 if (!kernel_graph_inputs[i]->isa<Parameter>()) {
1101 MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter.";
1102 continue;
1103 }
1104 auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>();
1105 if (!common::AnfAlgo::IsParameterWeight(parameter)) {
1106 auto input_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0);
1107 auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter);
1108 auto data_type = kernel_build_info->GetOutputDeviceType(0);
1109 auto ms_tensor = std::make_shared<tensor::Tensor>(data_type, input_shape);
1110 (void)inputs->emplace_back(ms_tensor);
1111 (void)inputs_name->emplace_back(parameter->name());
1112 }
1113 }
1114 }
1115
GetModelOutputsInfo(uint32_t graph_id,std::vector<tensor::TensorPtr> * outputs,std::vector<std::string> * output_names) const1116 void SessionBasic::GetModelOutputsInfo(uint32_t graph_id, std::vector<tensor::TensorPtr> *outputs,
1117 std::vector<std::string> *output_names) const {
1118 std::vector<tensor::TensorPtr> inputs;
1119 std::vector<std::string> input_names;
1120 GetModelInputsInfo(graph_id, &inputs, &input_names);
1121
1122 auto kernel_graph = GetGraph(graph_id);
1123 MS_EXCEPTION_IF_NULL(kernel_graph);
1124 MS_EXCEPTION_IF_NULL(outputs);
1125 MS_EXCEPTION_IF_NULL(output_names);
1126
1127 VectorRef vector_outputs;
1128 std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node;
1129 KernelMapTensor node_to_tensor;
1130 auto anf_outputs = kernel_graph->outputs();
1131 for (auto &item : anf_outputs) {
1132 MS_EXCEPTION_IF_NULL(item);
1133 MS_LOG(INFO) << "Create node output[" << item->DebugString() << "]";
1134 vector_outputs.emplace_back(CreateNodeOutputTensors(item, kernel_graph, inputs, &tensor_to_node, &node_to_tensor));
1135 }
1136 *outputs = TransformVectorRefToMultiTensor(vector_outputs);
1137 for (size_t i = 0; i < outputs->size(); i++) {
1138 (void)output_names->emplace_back("output" + std::to_string(i));
1139 }
1140 }
1141
1142 #ifndef ENABLE_SECURITY
RegisterSummaryCallBackFunc(const CallBackFunc & callback)1143 void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) {
1144 MS_EXCEPTION_IF_NULL(callback);
1145 Summary::GetInstance().RegisterSummaryCallBackFunc(callback);
1146 }
1147
RecurseSetSummaryNodesForAllGraphs(KernelGraph * graph)1148 void SessionBasic::RecurseSetSummaryNodesForAllGraphs(KernelGraph *graph) {
1149 MS_EXCEPTION_IF_NULL(graph);
1150 MS_LOG(INFO) << "Recurse set summary nodes for all graphs in graph: " << graph->graph_id() << " start";
1151 Summary::GetInstance().RecurseSetSummaryNodesForAllGraphs(graph);
1152 }
1153
SetSummaryNodes(KernelGraph * graph)1154 void SessionBasic::SetSummaryNodes(KernelGraph *graph) {
1155 MS_LOG(DEBUG) << "Update summary Start";
1156 MS_EXCEPTION_IF_NULL(graph);
1157 Summary::GetInstance().SetSummaryNodes(graph);
1158 }
1159
Summary(KernelGraph * graph)1160 void SessionBasic::Summary(KernelGraph *graph) {
1161 MS_EXCEPTION_IF_NULL(graph);
1162 static bool is_first = true;
1163 if (is_first && !IsSupportSummary()) {
1164 is_first = false;
1165 MS_LOG(WARNING) << "The Summary operator can not collect data correctly. Detail: the data sink mode is used and the"
1166 " sink size(in model.train() python api) is not equal to 1.";
1167 }
1168 Summary::GetInstance().SummaryTensor(graph);
1169 }
1170 #endif
1171
CreateOutputNode(const CNodePtr & cnode,const std::shared_ptr<KernelGraph> & graph) const1172 void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph) const {
1173 MS_EXCEPTION_IF_NULL(cnode);
1174 std::vector<AnfNodePtr> make_tuple_inputs;
1175 (void)make_tuple_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(*prim::kPrimMakeTuple)));
1176 MS_EXCEPTION_IF_NULL(graph);
1177 if (AnfAlgo::GetOutputElementNum(cnode) > 1) {
1178 for (size_t output_index = 0; output_index < AnfAlgo::GetOutputElementNum(cnode); output_index++) {
1179 auto idx = NewValueNode(SizeToLong(output_index));
1180 MS_EXCEPTION_IF_NULL(idx);
1181 auto imm = std::make_shared<Int64Imm>(output_index);
1182 idx->set_abstract(std::make_shared<abstract::AbstractScalar>(imm));
1183 auto getitem = graph->NewCNode({NewValueNode(std::make_shared<Primitive>(*prim::kPrimTupleGetItem)), cnode, idx});
1184 std::vector<TypeId> types = {common::AnfAlgo::GetOutputInferDataType(cnode, output_index)};
1185 auto shapes = {common::AnfAlgo::GetOutputInferShape(cnode, output_index)};
1186 common::AnfAlgo::SetOutputInferTypeAndShape(types, shapes, getitem.get());
1187 (void)make_tuple_inputs.emplace_back(getitem);
1188 }
1189 } else {
1190 (void)make_tuple_inputs.emplace_back(cnode);
1191 }
1192 // create output
1193 auto g_output = graph->NewCNode(make_tuple_inputs);
1194 graph->set_output(g_output);
1195 }
1196
ConstructSingleOpGraph(const BackendOpRunInfoPtr & op_run_info,const std::vector<ValuePtr> & input_values,const std::vector<InputType> & input_type)1197 std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const BackendOpRunInfoPtr &op_run_info,
1198 const std::vector<ValuePtr> &input_values,
1199 const std::vector<InputType> &input_type) {
1200 auto graph = NewPynativeKernelGraph();
1201 std::vector<AnfNodePtr> inputs;
1202 // set input[0]
1203 auto op_prim = op_run_info->op_prim;
1204 MS_EXCEPTION_IF_NULL(op_prim);
1205 // Decoupling of frontend PrimitivePy and backend Primitive
1206 auto new_prim = std::make_shared<Primitive>(*op_prim);
1207 if (op_run_info->base_op_run_info.use_dynamic_shape_process) {
1208 AnfAlgo::SetDynamicAttrToPrim(new_prim);
1209 }
1210 (void)inputs.emplace_back(std::make_shared<ValueNode>(new_prim));
1211 // set input parameter
1212 if (input_values.size() != input_type.size()) {
1213 MS_LOG(EXCEPTION) << "Input tensors size " << input_values.size() << " should be equal to tensors mask size "
1214 << input_type.size();
1215 }
1216 for (size_t i = 0; i < input_values.size(); ++i) {
1217 if (input_type[i] == InputType::kConstant) {
1218 auto value_node = graph->NewValueNode(input_values[i]);
1219 (void)inputs.emplace_back(value_node);
1220 continue;
1221 }
1222 auto parameter =
1223 ConstructRunOpParameter(graph, input_values[i]->cast<tensor::BaseTensorPtr>(), op_run_info, input_type[i]);
1224 (void)inputs.emplace_back(parameter);
1225 auto mutable_inputs = graph->MutableInputs();
1226 MS_EXCEPTION_IF_NULL(mutable_inputs);
1227 (void)mutable_inputs->emplace_back(parameter);
1228 }
1229 // set execution order
1230 auto cnode = graph->NewCNode(inputs);
1231 MS_EXCEPTION_IF_NULL(cnode);
1232 auto is_mutable = common::AnfAlgo::HasNodeAttr(kAttrMutableKernel, cnode);
1233 if (is_mutable) {
1234 graph->set_flag(kAttrMutableKernel, true);
1235 }
1236 // set abstract,which include inferred shapes and types
1237 cnode->set_abstract(op_run_info->base_op_run_info.abstract);
1238 common::AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(op_run_info->base_op_run_info.has_dynamic_output),
1239 cnode);
1240 if (op_run_info->base_op_run_info.is_mixed_precision_cast) {
1241 common::AnfAlgo::SetNodeAttr(kAttrPynativeNextOpName, MakeValue(op_run_info->base_op_run_info.next_op_name), cnode);
1242 common::AnfAlgo::SetNodeAttr(kAttrPynativeNextIndex, MakeValue(op_run_info->base_op_run_info.next_input_index),
1243 cnode);
1244 }
1245 // set execution order
1246 graph->set_execution_order({cnode});
1247 CreateOutputNode(cnode, graph);
1248 graph->SetInputNodes();
1249 auto manager = MakeManager({graph});
1250 if (manager != nullptr) {
1251 manager->AddFuncGraph(graph);
1252 graph->set_manager(manager);
1253 }
1254 auto ms_context = MsContext::GetInstance();
1255 MS_EXCEPTION_IF_NULL(ms_context);
1256 if (ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1257 UnifyMindIR(graph);
1258 }
1259 graph->UpdateGraphDynamicAttr();
1260 return graph;
1261 }
1262
FindPullNode(const AnfNodePtr & push_node,const std::vector<AnfNodePtr> & node_list) const1263 AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list) const {
1264 MS_EXCEPTION_IF_NULL(push_node);
1265 for (auto &node : node_list) {
1266 if (node != nullptr && node->isa<CNode>()) {
1267 for (auto input : node->cast<CNodePtr>()->inputs()) {
1268 if (push_node == common::AnfAlgo::VisitKernel(input, 0).first) {
1269 if (common::AnfAlgo::GetCNodeName(node) != kPullOpName) {
1270 MS_LOG(EXCEPTION) << "The edge between Push and Pull node is invalid.";
1271 }
1272 return node;
1273 }
1274 }
1275 }
1276 }
1277 return nullptr;
1278 }
1279
CompileGraph(const GraphSegmentPtr & segment,const AnfNodePtrList & outputs)1280 GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) {
1281 MS_EXCEPTION_IF_NULL(executor_);
1282 return executor_->CompileGraph(shared_from_this(), segment, outputs);
1283 }
1284
CompileGraph(NotNull<FuncGraphPtr> func_graph)1285 GraphId SessionBasic::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
1286 MS_EXCEPTION_IF_NULL(executor_);
1287 return executor_->CompileGraph(shared_from_this(), func_graph);
1288 }
1289
BuildGraph(GraphId graph_id)1290 void SessionBasic::BuildGraph(GraphId graph_id) {
1291 MS_EXCEPTION_IF_NULL(executor_);
1292 executor_->BuildGraph(shared_from_this(), graph_id);
1293 }
1294
RunGraph(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)1295 void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
1296 MS_EXCEPTION_IF_NULL(executor_);
1297 executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs);
1298 }
1299
RunGraphAsync(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)1300 void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
1301 VectorRef *outputs) {
1302 MS_EXCEPTION_IF_NULL(executor_);
1303 executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
1304 }
1305
RunGraphImpl(const GraphId & graph_id,const std::vector<tensor::TensorPtr> & inputs,VectorRef * outputs)1306 void SessionBasic::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs,
1307 VectorRef *outputs) {
1308 MS_LOG(INFO) << "Status record: start run graph. graph id: " << graph_id;
1309 auto kernel_graph = GetGraph(graph_id);
1310 MS_EXCEPTION_IF_NULL(kernel_graph);
1311 // if none of child graph and no anf output exists
1312 if (!kernel_graph->executable()) {
1313 MS_LOG(INFO) << "No child graph has anf output";
1314 return;
1315 }
1316 PreExecuteGraph(kernel_graph, inputs, outputs);
1317 ExecuteGraph(kernel_graph);
1318 PostExecuteGraph(kernel_graph, inputs, outputs);
1319 MS_LOG(INFO) << "Status record: end run graph. graph id: " << graph_id;
1320 }
1321
ProcessInputTensorsForHeterogeneous(const std::string & cur_target,const std::vector<tensor::TensorPtr> & input_tensors) const1322 void SessionBasic::ProcessInputTensorsForHeterogeneous(const std::string &cur_target,
1323 const std::vector<tensor::TensorPtr> &input_tensors) const {
1324 for (auto &tensor : input_tensors) {
1325 MS_EXCEPTION_IF_NULL(tensor);
1326 auto device_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
1327 if (device_address != nullptr) {
1328 if (device_address->GetDeviceType() != device::GetDeviceTypeByName(cur_target)) {
1329 tensor->data_sync();
1330 tensor->set_device_address(nullptr);
1331 }
1332 }
1333 }
1334 }
1335
EraseValueNodeTensor(const std::vector<InputType> & input_types,std::vector<tensor::TensorPtr> * input_tensors) const1336 void SessionBasic::EraseValueNodeTensor(const std::vector<InputType> &input_types,
1337 std::vector<tensor::TensorPtr> *input_tensors) const {
1338 MS_EXCEPTION_IF_NULL(input_tensors);
1339 if (input_tensors->size() != input_types.size()) {
1340 MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size()
1341 << " should be equal to tensors input type size " << input_types.size();
1342 }
1343 std::vector<tensor::TensorPtr> new_input_tensors;
1344 for (size_t index = 0; index < input_types.size(); ++index) {
1345 if (input_types[index] != InputType::kConstant) {
1346 (void)new_input_tensors.emplace_back(input_tensors->at(index));
1347 }
1348 }
1349 *input_tensors = new_input_tensors;
1350 }
1351
IsGetNextGraph(const std::shared_ptr<KernelGraph> & kernel_graph,std::string * channel_name) const1352 bool SessionBasic::IsGetNextGraph(const std::shared_ptr<KernelGraph> &kernel_graph, std::string *channel_name) const {
1353 MS_EXCEPTION_IF_NULL(kernel_graph);
1354 for (const auto &kernel_node : kernel_graph->execution_order()) {
1355 auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
1356 if (kernel_name == kGetNextOpName) {
1357 auto prim = common::AnfAlgo::GetCNodePrimitive(kernel_node);
1358 MS_EXCEPTION_IF_NULL(prim);
1359 *channel_name = GetValue<std::string>(prim->GetAttr("shared_name"));
1360 return true;
1361 }
1362 }
1363 return false;
1364 }
1365
RunOpRemoveNopNode(const KernelGraphPtr & kernel_graph) const1366 void SessionBasic::RunOpRemoveNopNode(const KernelGraphPtr &kernel_graph) const {
1367 auto ms_context = MsContext::GetInstance();
1368 MS_EXCEPTION_IF_NULL(ms_context);
1369 if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1370 opt::RemoveNopNode(kernel_graph.get());
1371 }
1372 }
1373
RunOpHideNopNode(const KernelGraphPtr & kernel_graph)1374 void SessionBasic::RunOpHideNopNode(const KernelGraphPtr &kernel_graph) {
1375 auto ms_context = MsContext::GetInstance();
1376 MS_EXCEPTION_IF_NULL(ms_context);
1377 if (!ms_context->get_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER)) {
1378 opt::HideNopNode(kernel_graph.get());
1379 }
1380 }
1381
GetAllReduceSplitIndex()1382 std::vector<uint32_t> SessionBasic::GetAllReduceSplitIndex() {
1383 auto ms_context = MsContext::GetInstance();
1384 MS_EXCEPTION_IF_NULL(ms_context);
1385 std::string group = GetCommWorldGroup();
1386 auto parallel_context = parallel::ParallelContext::GetInstance();
1387 MS_EXCEPTION_IF_NULL(parallel_context);
1388 // PyNative not support multi group allreduce
1389 group += "sum1";
1390 return parallel_context->GetAllReduceFusionSplitIndices(group);
1391 }
1392
GetBpropGraphGradsCount(const KernelGraphPtr & graph)1393 uint32_t GetBpropGraphGradsCount(const KernelGraphPtr &graph) {
1394 auto outputs = common::AnfAlgo::GetAllOutput(graph->output(), {prim::kPrimTupleGetItem});
1395 MS_LOG(DEBUG) << "Get total graph output size:" << outputs.size();
1396 // The type of output is CNode or ValueNode.
1397 // There is no need to calculate grad if the type of output is not CNode.
1398 return static_cast<uint32_t>(std::count_if(outputs.begin(), outputs.end(), [](const AnfNodePtr &output) {
1399 return output != nullptr && output->isa<CNode>();
1400 }));
1401 }
1402
SetGraphBpropAttr(const KernelGraphPtr & graph)1403 void SetGraphBpropAttr(const KernelGraphPtr &graph) {
1404 auto &execution_orders = graph->execution_order();
1405 if (std::any_of(execution_orders.begin(), execution_orders.end(),
1406 [](const AnfNodePtr &node) { return node->scope()->name().rfind("Gradient", 0) == 0; })) {
1407 graph->set_flag(kFlagIsPynativeBpropGraph, true);
1408 MS_LOG(INFO) << "Match bprop graph";
1409 }
1410 }
1411
CheckSplitIndexValid(const vector<uint32_t> & split_index)1412 void CheckSplitIndexValid(const vector<uint32_t> &split_index) {
1413 uint32_t last = 0;
1414 for (size_t i = 0; i < split_index.size(); ++i) {
1415 if (split_index[i] <= last && i != 0) {
1416 MS_LOG(EXCEPTION) << "Invalid split index:" << split_index;
1417 }
1418 last = split_index[i];
1419 }
1420 }
1421
PreProcessOnSplitIndex(const KernelGraphPtr & graph,vector<uint32_t> * split_index)1422 void PreProcessOnSplitIndex(const KernelGraphPtr &graph, vector<uint32_t> *split_index) {
1423 MS_EXCEPTION_IF_NULL(split_index);
1424 if (split_index->empty()) {
1425 return;
1426 }
1427
1428 CheckSplitIndexValid(*split_index);
1429 // calculate split index num
1430 auto split_index_num = split_index->back();
1431 // obtain graph output tensor num
1432 auto grads_count = GetBpropGraphGradsCount(graph);
1433 if (split_index_num >= grads_count) {
1434 MS_LOG(WARNING) << "The context configuration all_reduce_fusion_config's upper boundary value should be smaller "
1435 << "than total grads count: " << grads_count << ", but got: " << *split_index
1436 << ". Now all AllReduce operators will be fused into one AllReduce operator.";
1437 split_index->clear();
1438 split_index->push_back(grads_count - 1);
1439 } else if (split_index_num < grads_count - 1) {
1440 split_index->push_back(grads_count - 1);
1441 }
1442 }
1443
FinalOptimize(const KernelGraphPtr & graph) const1444 void SessionBasic::FinalOptimize(const KernelGraphPtr &graph) const {
1445 MS_LOG(INFO) << "Start FinalOptimize for graph: " << graph->graph_id();
1446 opt::CommonFinalOptimization(graph);
1447 MS_LOG(INFO) << "End FinalOptimize for graph: " << graph->graph_id();
1448 }
1449
DumpGraphs(const std::vector<KernelGraphPtr> & graphs) const1450 void SessionBasic::DumpGraphs(const std::vector<KernelGraphPtr> &graphs) const {
1451 #ifdef ENABLE_DUMP_IR
1452 auto context_ptr = MsContext::GetInstance();
1453 MS_EXCEPTION_IF_NULL(context_ptr);
1454 bool save_graphs = context_ptr->CanDump(kIntroductory);
1455 auto &json_parser = DumpJsonParser::GetInstance();
1456 json_parser.Parse();
1457 if (!save_graphs && !json_parser.e2e_dump_enabled() && !json_parser.async_dump_enabled() &&
1458 !mindspore::RecorderManager::Instance().RdrEnable()) {
1459 return;
1460 }
1461 for (auto &graph : graphs) {
1462 MS_EXCEPTION_IF_NULL(graph);
1463
1464 if (graph->memory_managed_by_ge()) {
1465 continue;
1466 }
1467
1468 std::string name = "graph_build." + std::to_string(graph->graph_id());
1469 DumpGraphParams dump_params = {true, static_cast<int>(kWholeStack)};
1470 (void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, name, graph, dump_params, ".ir;.pb");
1471
1472 auto &kernels = graph->execution_order();
1473 std::string exec_order_name = "graph_exec_order." + std::to_string(graph->graph_id());
1474 (void)mindspore::RDR::RecordGraphExecOrder(SUBMODULE_ID, exec_order_name, kernels);
1475 if (save_graphs) {
1476 std::string file_name = "graph_build_" + std::to_string(graph->graph_id()) + ".ir";
1477 DumpIR(file_name, graph, true, kWholeStack);
1478 DumpIRProto(graph, "vm_build_" + std::to_string(graph->graph_id()));
1479 DumpIR("trace_code_graph", graph, true, kWholeStack);
1480 }
1481 std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1482 if (device_target != kAscendDevice) {
1483 // Here dump data only with Ascend.
1484 continue;
1485 }
1486 // If the new runtime is used, get rank_id from context via GetRankID(), else get rank_id from rank_id_.
1487 uint32_t rank_id = rank_id_;
1488 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1489 rank_id = GetRankId();
1490 }
1491 std::string final_graph = "trace_code_graph_" + std::to_string(graph->graph_id());
1492 if (json_parser.e2e_dump_enabled() && context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
1493 std::string root_dir = json_parser.path() + "/rank_" + std::to_string(rank_id);
1494 MS_LOG(INFO) << "Dump graph and exeorder for graph: " << graph->graph_id()
1495 << ", root_graph_id: " << graph->root_graph_id() << ", rank_id: " << rank_id;
1496 std::string target_dir = root_dir + "/graphs";
1497 std::string cst_file_dir = GenerateDumpPath(graph->root_graph_id(), rank_id, true);
1498 std::string ir_file_path = target_dir + "/" + "ms_output_" + final_graph + ".ir";
1499 DumpIRProtoWithSrcInfo(graph, final_graph, target_dir, kDebugWholeStack);
1500 if (!MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1501 // Dump constant data for old runtime ascend.
1502 DumpConstantInfo(graph, cst_file_dir);
1503 }
1504 DumpIR("trace_code_graph", graph, true, kWholeStack, ir_file_path);
1505 DumpGraphExeOrder("ms_execution_order_graph_" + std::to_string(graph->graph_id()) + ".csv", root_dir,
1506 graph->execution_order());
1507 }
1508 }
1509 #endif
1510 }
1511 } // namespace session
DumpGraphExeOrder(const std::string & file_name,const std::string & target_dir,const std::vector<CNodePtr> & execution_order)1512 void DumpGraphExeOrder(const std::string &file_name, const std::string &target_dir,
1513 const std::vector<CNodePtr> &execution_order) {
1514 std::string file_path = target_dir + "/execution_order/" + file_name;
1515 auto realpath = Common::CreatePrefixPath(file_path);
1516 if (!realpath.has_value()) {
1517 MS_LOG(ERROR) << "Failed to get real path: [" << file_path << "] in dump graph execution order.";
1518 return;
1519 }
1520 file_path = realpath.value();
1521
1522 ChangeFileMode(file_path, S_IWUSR);
1523 // write to csv file
1524 std::ofstream ofs(file_path);
1525 if (!ofs.is_open()) {
1526 MS_LOG(ERROR) << "Failed to open file [" << file_path
1527 << "] in dump graph execution order, please check the file access permission and whether disk space "
1528 "is available.";
1529 return;
1530 }
1531 ofs << "NodeExecutionOrder-FullNameWithScope\n";
1532 for (const CNodePtr &node : execution_order) {
1533 ofs << node->fullname_with_scope() << "\n";
1534 }
1535 ofs.close();
1536 // set file mode to read only by user
1537 ChangeFileMode(file_path, S_IRUSR);
1538 }
1539
GetRankId()1540 uint32_t GetRankId() {
1541 uint32_t rank_id = 0;
1542 auto ms_context = MsContext::GetInstance();
1543 MS_EXCEPTION_IF_NULL(ms_context);
1544
1545 std::string world_group;
1546 std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1547 if (backend == kAscendDevice) {
1548 world_group = kHcclWorldGroup;
1549 } else if (backend == kGPUDevice) {
1550 world_group = kNcclWorldGroup;
1551 } else {
1552 MS_LOG(ERROR) << "Invalid backend: " << backend;
1553 return rank_id;
1554 }
1555 auto env_rank_id = common::GetEnv("RANK_ID");
1556 if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL) && !env_rank_id.empty()) {
1557 if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
1558 MS_LOG(INFO) << "Failed to get rank id.";
1559 }
1560 }
1561 return rank_id;
1562 }
1563 } // namespace mindspore
1564