1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/common/func_graph_utils.h"
18 #include <algorithm>
19 #include <memory>
20 #include "tools/common/graph_util.h"
21 #include "tools/converter/converter_context.h"
22 namespace mindspore {
GetAbstractFromNode(const std::pair<AnfNodePtr,int64_t> & node)23 AbstractBasePtr FuncGraphUtils::GetAbstractFromNode(const std::pair<AnfNodePtr, int64_t> &node) {
24 auto anfnode = node.first;
25 MS_EXCEPTION_IF_NULL(anfnode);
26 AbstractBasePtr abstract = anfnode->abstract();
27 if (abstract == nullptr) {
28 return nullptr;
29 }
30 auto index = static_cast<size_t>(node.second);
31
32 if (utils::isa<abstract::AbstractSequencePtr>(abstract)) {
33 auto abstract_tuple = utils::cast<abstract::AbstractSequencePtr>(abstract);
34 MS_EXCEPTION_IF_NULL(abstract_tuple);
35 auto abstract_list = abstract_tuple->elements();
36 if (abstract_list.size() <= index) {
37 MS_LOG(WARNING) << "AbstractSequence's size[" << abstract_list.size() << "] is smaller than index " << index
38 << "]";
39 return nullptr;
40 }
41 abstract = abstract_list[index];
42 }
43 return abstract;
44 }
45
GetOutputName(const std::pair<AnfNodePtr,int64_t> & node_index)46 std::string FuncGraphUtils::GetOutputName(const std::pair<AnfNodePtr, int64_t> &node_index) {
47 auto node = node_index.first;
48 auto idx = node_index.second;
49 MS_EXCEPTION_IF_NULL(node);
50 AbstractBasePtr abstract = GetAbstractFromNode(node_index);
51 MS_EXCEPTION_IF_NULL(abstract);
52
53 std::string output_name;
54 if (!abstract->name().empty()) {
55 output_name = abstract->name();
56 } else if (idx >= 0) {
57 output_name = node->fullname_with_scope() + "_" + std::to_string(idx);
58 } else {
59 output_name = node->fullname_with_scope();
60 }
61
62 return output_name;
63 }
64
SetOutputName(const std::pair<AnfNodePtr,int64_t> & node,const std::string & name)65 void FuncGraphUtils::SetOutputName(const std::pair<AnfNodePtr, int64_t> &node, const std::string &name) {
66 AbstractBasePtr abstract = GetAbstractFromNode(node);
67 if (abstract != nullptr) {
68 abstract->set_name(name);
69 }
70 }
71
GetFuncGraphOutputNames(const FuncGraphPtr & func_graph)72 std::vector<std::string> FuncGraphUtils::GetFuncGraphOutputNames(const FuncGraphPtr &func_graph) {
73 std::vector<std::string> output_names;
74 // the 3rd model will save the tensor name to ConverterInnerContext
75 output_names = lite::ConverterInnerContext::GetInstance()->GetGraphOutputTensorNames();
76 if (!output_names.empty()) {
77 return output_names;
78 }
79 std::vector<std::pair<AnfNodePtr, int64_t>> outputs;
80 std::vector<std::string> tmp_names;
81 std::vector<std::vector<int64_t>> tmp_dims;
82 auto ret = lite::GetFuncGraphOutputsInfo(func_graph, &outputs, &tmp_names, &tmp_dims);
83 MS_EXCEPTION_IF_CHECK_FAIL((ret == lite::RET_OK), "Get outputs info of funcgraph failed");
84
85 output_names.resize(outputs.size());
86 std::transform(outputs.begin(), outputs.end(), output_names.begin(), GetOutputName);
87 return output_names;
88 }
89
SetFuncGraphOutputNames(const FuncGraphPtr & func_graph,const std::vector<std::string> & output_names)90 void FuncGraphUtils::SetFuncGraphOutputNames(const FuncGraphPtr &func_graph,
91 const std::vector<std::string> &output_names) {
92 std::vector<std::pair<AnfNodePtr, int64_t>> outputs;
93 std::vector<std::string> tmp_names;
94 std::vector<std::vector<int64_t>> tmp_dims;
95 auto ret = lite::GetFuncGraphOutputsInfo(func_graph, &outputs, &tmp_names, &tmp_dims);
96 MS_EXCEPTION_IF_CHECK_FAIL((ret == lite::RET_OK), "Get outputs info of funcgraph failed");
97 // the control flow model may be not equal, it will be updated by metagraph
98 if (outputs.size() != output_names.size()) {
99 MS_LOG(INFO)
100 << "the size of output nodes is not equal to the size of output names, it will be updated by metagraph";
101 return;
102 }
103
104 for (size_t i = 0; i < output_names.size(); ++i) {
105 SetOutputName(outputs[i], output_names[i]);
106 }
107 return;
108 }
109
GetParameterConstValue(const AnfNodePtr & anf_node)110 tensor::TensorPtr FuncGraphUtils::GetParameterConstValue(const AnfNodePtr &anf_node) {
111 if (anf_node == nullptr) {
112 MS_LOG(ERROR) << "Input argument anf node is nullptr";
113 return nullptr;
114 }
115 auto parameter = anf_node->cast<ParameterPtr>();
116 if (parameter == nullptr) {
117 MS_LOG(ERROR) << "Node " << anf_node->fullname_with_scope() << " is not a Parameter";
118 return nullptr;
119 }
120 auto default_param = parameter->default_param();
121 if (default_param == nullptr) {
122 MS_LOG(ERROR) << "Parameter " << anf_node->fullname_with_scope() << " has not default value";
123 return nullptr;
124 }
125 if (!default_param->isa<tensor::Tensor>()) {
126 MS_LOG(ERROR) << "Parameter " << anf_node->fullname_with_scope()
127 << " default value is not a tensor::Tensor, real type " << default_param->type_name();
128 return nullptr;
129 }
130 auto tensor = default_param->cast<std::shared_ptr<tensor::Tensor>>();
131 if (tensor == nullptr) {
132 MS_LOG(ERROR) << "Parameter " << anf_node->fullname_with_scope() << " tensor value is nullptr";
133 return nullptr;
134 }
135 return tensor;
136 }
137 } // namespace mindspore
138