1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "common/graph_output_name_keeper.h"
18 #include <map>
19 #include <string>
20 #include <vector>
21 #include "common/anf_util.h"
22 #include "common/check_base.h"
23 #include "common/op_enum.h"
24 #include "ops/make_tuple.h"
25 #include "ops/depend.h"
26 #include "include/registry/converter_context.h"
27
28 namespace mindspore {
29 namespace dpico {
GetInstance()30 GraphOutputNameKeeper *GraphOutputNameKeeper::GetInstance() {
31 static GraphOutputNameKeeper instance;
32 return &instance;
33 }
34
SaveOriginalOutputs(const api::FuncGraphPtr & func_graph)35 int GraphOutputNameKeeper::SaveOriginalOutputs(const api::FuncGraphPtr &func_graph) {
36 MS_CHECK_TRUE_MSG(func_graph != nullptr, RET_ERROR, "func's input parameter is a nullptr.");
37 auto return_cnode = func_graph->get_return();
38 MS_CHECK_TRUE_MSG(return_cnode != nullptr, RET_ERROR, "func_graph's return node is invalid.");
39 original_outputs_ = return_cnode->inputs();
40 original_outputs_.erase(original_outputs_.begin());
41 bool change{true};
42 while (change) {
43 change = false;
44 std::vector<api::AnfNodePtr> outputs_tmp;
45 for (size_t j = 0; j < original_outputs_.size(); ++j) {
46 auto output_node = original_outputs_[j];
47 MS_CHECK_TRUE_MSG(output_node != nullptr, RET_ERROR, "existing node is a nullptr.");
48 if (dpico::CheckPrimitiveType(output_node, api::MakeShared<ops::MakeTuple>())) {
49 auto make_tuple_cnode = output_node->cast<api::CNodePtr>();
50 MS_CHECK_TRUE_MSG(make_tuple_cnode != nullptr, RET_ERROR, "make tuple node is invalid.");
51 auto make_tuple_inputs = make_tuple_cnode->inputs();
52 outputs_tmp.insert(outputs_tmp.end(), make_tuple_inputs.begin() + 1, make_tuple_inputs.end());
53 change = true;
54 continue;
55 }
56 if (dpico::CheckPrimitiveType(output_node, api::MakeShared<ops::Depend>())) {
57 auto depend_cnode = output_node->cast<api::CNodePtr>();
58 MS_CHECK_TRUE_MSG(depend_cnode != nullptr, RET_ERROR, "depend node is invalid.");
59 MS_CHECK_TRUE_MSG(depend_cnode->size() == kInputIndex3, RET_ERROR, "depend node's input size should be 3.");
60 outputs_tmp.push_back(depend_cnode->input(1));
61 change = true;
62 continue;
63 }
64 outputs_tmp.push_back(output_node);
65 }
66 original_outputs_ = outputs_tmp;
67 }
68
69 auto origin_outputs_name = converter::ConverterContext::GetGraphOutputTensorNames();
70 for (auto &output_name : origin_outputs_name) {
71 const std::string top_name_suffix = "duplicate";
72 const size_t max_loop = 1000;
73 for (size_t i = 0; i < max_loop; i++) {
74 std::string top_name_tmp = output_name + "_" + top_name_suffix + std::to_string(i);
75 auto attr = func_graph->get_attr(top_name_tmp);
76 if (attr != nullptr) {
77 auto op_name = api::GetValue<std::string>(attr);
78 ori_output_info_[op_name] = output_name;
79 } else {
80 break;
81 }
82 }
83 }
84 return RET_OK;
85 }
86
DetermineOmOpInputName(const api::AnfNodePtr & in_node,std::string * input_name)87 int GraphOutputNameKeeper::DetermineOmOpInputName(const api::AnfNodePtr &in_node, std::string *input_name) {
88 if (original_outputs_.empty()) {
89 return RET_OK;
90 }
91 MS_CHECK_TRUE_MSG(in_node != nullptr && input_name != nullptr, RET_ERROR, "func's input parameter is a nullptr.");
92 if (ori_output_info_.find(in_node->fullname_with_scope()) != ori_output_info_.end()) {
93 *input_name = ori_output_info_[in_node->fullname_with_scope()];
94 return RET_OK;
95 }
96 auto iter = std::find(original_outputs_.begin(), original_outputs_.end(), in_node);
97 if (iter == original_outputs_.end()) {
98 return RET_OK;
99 }
100 auto index = iter - original_outputs_.begin();
101 auto origin_outputs_name = converter::ConverterContext::GetGraphOutputTensorNames();
102 if (origin_outputs_name.size() <= static_cast<size_t>(index)) {
103 return RET_OK;
104 }
105 input_name->swap(origin_outputs_name[index]);
106 return RET_OK;
107 }
108
DetermineOmOpOutputName(const api::AnfNodePtr & node,std::string * output_name,bool is_subgraph_input)109 int GraphOutputNameKeeper::DetermineOmOpOutputName(const api::AnfNodePtr &node, std::string *output_name,
110 bool is_subgraph_input) {
111 MS_CHECK_TRUE_MSG(node != nullptr && output_name != nullptr, RET_ERROR, "func's input parameter is a nullptr.");
112 MS_CHECK_TRUE_MSG(!original_outputs_.empty(), RET_ERROR, "has no outputs.");
113 if (ori_output_info_.find(node->fullname_with_scope()) != ori_output_info_.end()) {
114 *output_name = ori_output_info_[node->fullname_with_scope()];
115 return RET_OK;
116 }
117 auto iter = std::find(original_outputs_.begin(), original_outputs_.end(), node);
118 if (iter == original_outputs_.end()) {
119 return RET_OK;
120 }
121 auto index = iter - original_outputs_.begin();
122 auto origin_outputs_name = converter::ConverterContext::GetGraphOutputTensorNames();
123 if (origin_outputs_name.size() <= static_cast<size_t>(index)) {
124 return RET_OK;
125 }
126 if (!is_subgraph_input) {
127 MS_CHECK_TRUE_MSG(om_to_anf_mapper_.find(origin_outputs_name[index]) == om_to_anf_mapper_.end(), RET_ERROR,
128 "find the output has been existed.");
129 om_to_anf_mapper_.emplace(origin_outputs_name[index], *output_name);
130 }
131 output_name->swap(origin_outputs_name[index]);
132 return RET_OK;
133 }
134
CanKeepOutputNames(const std::vector<std::string> & om_outputs)135 bool GraphOutputNameKeeper::CanKeepOutputNames(const std::vector<std::string> &om_outputs) {
136 size_t has_find{0};
137 for (const auto &output : om_outputs) {
138 if (om_to_anf_mapper_.find(output) != om_to_anf_mapper_.end()) {
139 ++has_find;
140 }
141 }
142 return has_find == om_to_anf_mapper_.size();
143 }
144
GetAnfOutputNameFromOm(const std::string & om_out_name)145 std::string GraphOutputNameKeeper::GetAnfOutputNameFromOm(const std::string &om_out_name) {
146 if (om_to_anf_mapper_.find(om_out_name) != om_to_anf_mapper_.end()) {
147 return om_to_anf_mapper_[om_out_name];
148 }
149 return om_out_name;
150 }
151 } // namespace dpico
152 } // namespace mindspore
153