• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/common/func_graph_utils.h"
18 #include <algorithm>
19 #include <memory>
20 #include "tools/common/graph_util.h"
21 #include "tools/converter/converter_context.h"
22 namespace mindspore {
GetAbstractFromNode(const std::pair<AnfNodePtr,int64_t> & node)23 AbstractBasePtr FuncGraphUtils::GetAbstractFromNode(const std::pair<AnfNodePtr, int64_t> &node) {
24   auto anfnode = node.first;
25   MS_EXCEPTION_IF_NULL(anfnode);
26   AbstractBasePtr abstract = anfnode->abstract();
27   if (abstract == nullptr) {
28     return nullptr;
29   }
30   auto index = static_cast<size_t>(node.second);
31 
32   if (utils::isa<abstract::AbstractSequencePtr>(abstract)) {
33     auto abstract_tuple = utils::cast<abstract::AbstractSequencePtr>(abstract);
34     MS_EXCEPTION_IF_NULL(abstract_tuple);
35     auto abstract_list = abstract_tuple->elements();
36     if (abstract_list.size() <= index) {
37       MS_LOG(WARNING) << "AbstractSequence's size[" << abstract_list.size() << "] is smaller than index " << index
38                       << "]";
39       return nullptr;
40     }
41     abstract = abstract_list[index];
42   }
43   return abstract;
44 }
45 
GetOutputName(const std::pair<AnfNodePtr,int64_t> & node_index)46 std::string FuncGraphUtils::GetOutputName(const std::pair<AnfNodePtr, int64_t> &node_index) {
47   auto node = node_index.first;
48   auto idx = node_index.second;
49   MS_EXCEPTION_IF_NULL(node);
50   AbstractBasePtr abstract = GetAbstractFromNode(node_index);
51   MS_EXCEPTION_IF_NULL(abstract);
52 
53   std::string output_name;
54   if (!abstract->name().empty()) {
55     output_name = abstract->name();
56   } else if (idx >= 0) {
57     output_name = node->fullname_with_scope() + "_" + std::to_string(idx);
58   } else {
59     output_name = node->fullname_with_scope();
60   }
61 
62   return output_name;
63 }
64 
SetOutputName(const std::pair<AnfNodePtr,int64_t> & node,const std::string & name)65 void FuncGraphUtils::SetOutputName(const std::pair<AnfNodePtr, int64_t> &node, const std::string &name) {
66   AbstractBasePtr abstract = GetAbstractFromNode(node);
67   if (abstract != nullptr) {
68     abstract->set_name(name);
69   }
70 }
71 
GetFuncGraphOutputNames(const FuncGraphPtr & func_graph)72 std::vector<std::string> FuncGraphUtils::GetFuncGraphOutputNames(const FuncGraphPtr &func_graph) {
73   std::vector<std::string> output_names;
74   // the 3rd model will save the tensor name to ConverterInnerContext
75   output_names = lite::ConverterInnerContext::GetInstance()->GetGraphOutputTensorNames();
76   if (!output_names.empty()) {
77     return output_names;
78   }
79   std::vector<std::pair<AnfNodePtr, int64_t>> outputs;
80   std::vector<std::string> tmp_names;
81   std::vector<std::vector<int64_t>> tmp_dims;
82   auto ret = lite::GetFuncGraphOutputsInfo(func_graph, &outputs, &tmp_names, &tmp_dims);
83   MS_EXCEPTION_IF_CHECK_FAIL((ret == lite::RET_OK), "Get outputs info of funcgraph failed");
84 
85   output_names.resize(outputs.size());
86   std::transform(outputs.begin(), outputs.end(), output_names.begin(), GetOutputName);
87   return output_names;
88 }
89 
SetFuncGraphOutputNames(const FuncGraphPtr & func_graph,const std::vector<std::string> & output_names)90 void FuncGraphUtils::SetFuncGraphOutputNames(const FuncGraphPtr &func_graph,
91                                              const std::vector<std::string> &output_names) {
92   std::vector<std::pair<AnfNodePtr, int64_t>> outputs;
93   std::vector<std::string> tmp_names;
94   std::vector<std::vector<int64_t>> tmp_dims;
95   auto ret = lite::GetFuncGraphOutputsInfo(func_graph, &outputs, &tmp_names, &tmp_dims);
96   MS_EXCEPTION_IF_CHECK_FAIL((ret == lite::RET_OK), "Get outputs info of funcgraph failed");
97   // the control flow model may be not equal, it will be updated by metagraph
98   if (outputs.size() != output_names.size()) {
99     MS_LOG(INFO)
100       << "the size of output nodes is not equal to the size of output names, it will be updated by metagraph";
101     return;
102   }
103 
104   for (size_t i = 0; i < output_names.size(); ++i) {
105     SetOutputName(outputs[i], output_names[i]);
106   }
107   return;
108 }
109 
GetParameterConstValue(const AnfNodePtr & anf_node)110 tensor::TensorPtr FuncGraphUtils::GetParameterConstValue(const AnfNodePtr &anf_node) {
111   if (anf_node == nullptr) {
112     MS_LOG(ERROR) << "Input argument anf node is nullptr";
113     return nullptr;
114   }
115   auto parameter = anf_node->cast<ParameterPtr>();
116   if (parameter == nullptr) {
117     MS_LOG(ERROR) << "Node " << anf_node->fullname_with_scope() << " is not a Parameter";
118     return nullptr;
119   }
120   auto default_param = parameter->default_param();
121   if (default_param == nullptr) {
122     MS_LOG(ERROR) << "Parameter " << anf_node->fullname_with_scope() << " has not default value";
123     return nullptr;
124   }
125   if (!default_param->isa<tensor::Tensor>()) {
126     MS_LOG(ERROR) << "Parameter " << anf_node->fullname_with_scope()
127                   << " default value is not a tensor::Tensor, real type " << default_param->type_name();
128     return nullptr;
129   }
130   auto tensor = default_param->cast<std::shared_ptr<tensor::Tensor>>();
131   if (tensor == nullptr) {
132     MS_LOG(ERROR) << "Parameter " << anf_node->fullname_with_scope() << " tensor value is nullptr";
133     return nullptr;
134   }
135   return tensor;
136 }
137 }  // namespace mindspore
138