1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/specify_graph_input_format.h"
19 #include <memory>
20 #include <vector>
21 #include <stack>
22 #include <set>
23 #include "mindspore/core/ops/array_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "tools/converter/parser/parser_utils.h"
26 #include "tools/optimizer/common/format_utils.h"
27 #include "src/common/log_adapter.h"
28 #include "nnacl/op_base.h"
29 #include "ops/op_utils.h"
30 #include "ops/auto_generate/gen_lite_ops.h"
31
32 namespace mindspore {
33 namespace opt {
Run(const FuncGraphPtr & graph)34 bool SpecifyGraphInputFormat::Run(const FuncGraphPtr &graph) {
35 MS_ASSERT(graph != nullptr);
36 if (exp_graph_input_format_ == cur_graph_input_format_) {
37 return true;
38 }
39 if ((exp_graph_input_format_ != mindspore::NHWC && exp_graph_input_format_ != mindspore::NCHW) ||
40 (cur_graph_input_format_ != mindspore::NHWC && cur_graph_input_format_ != mindspore::NCHW)) {
41 MS_LOG(ERROR) << "this pass only support to transfer graph input format between nhwc with nchw.";
42 return false;
43 }
44 auto manager = Manage(graph);
45 MS_CHECK_TRUE_MSG(manager != nullptr, false, "manager is nullptr.");
46 if (HandleGraphInput(graph) != lite::RET_OK) {
47 MS_LOG(ERROR) << "Specify graph-input format failed.";
48 return false;
49 }
50 return true;
51 }
52
HandleGraphInput(const FuncGraphPtr & graph)53 STATUS SpecifyGraphInputFormat::HandleGraphInput(const FuncGraphPtr &graph) {
54 MS_ASSERT(graph != nullptr);
55 auto manager = graph->manager();
56 MS_ASSERT(manager != nullptr);
57 auto graph_inputs = graph->get_inputs();
58 for (const auto &input : graph_inputs) {
59 auto input_node = input->cast<ParameterPtr>();
60 MS_ASSERT(input_node != nullptr);
61 auto abstract = input_node->abstract();
62 MS_CHECK_TRUE_MSG(abstract != nullptr, lite::RET_NULL_PTR, "abstract is nullptr");
63
64 ShapeVector shape;
65 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
66 MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
67 return lite::RET_ERROR;
68 }
69 if (shape.size() != kInputSizeFour) {
70 continue;
71 }
72 ShapeVector transfer_shape;
73 if (exp_graph_input_format_ == mindspore::NCHW) {
74 transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
75 } else {
76 transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
77 }
78 CNodePtr trans_cnode;
79 if (exp_graph_input_format_ == mindspore::NCHW) {
80 trans_cnode = opt::GenTransposeNode(graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
81 } else {
82 trans_cnode = opt::GenTransposeNode(graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
83 }
84 if (trans_cnode == nullptr) {
85 MS_LOG(ERROR) << "create transpose cnode failed.";
86 return lite::RET_ERROR;
87 }
88 auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
89 MS_CHECK_TRUE_MSG(trans_prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
90 if (exp_graph_input_format_ == mindspore::NCHW) {
91 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
92 } else {
93 trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
94 }
95 trans_cnode->set_abstract(abstract->Clone());
96 abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
97 (void)manager->Replace(input, trans_cnode);
98 }
99 return lite::RET_OK;
100 }
101
CheckInputsFormatNHWC(const FuncGraphPtr & func_graph)102 bool CheckInputsFormatNHWC(const FuncGraphPtr &func_graph) {
103 MS_ASSERT(func_graph != nullptr);
104 auto manager = func_graph->manager();
105 if (manager == nullptr) {
106 manager = Manage(func_graph, true);
107 MS_CHECK_TRUE_RET(manager != nullptr, {});
108 std::set<FuncGraphPtr> all_func_graphs;
109 lite::GetAllFuncGraph(func_graph, &all_func_graphs);
110 for (auto &graph : all_func_graphs) {
111 manager->AddFuncGraph(graph);
112 }
113 }
114
115 auto node_users = manager->node_users();
116 std::vector<AnfNodePtr> nodes;
117 auto inputs = func_graph->get_inputs();
118 (void)std::for_each(inputs.begin(), inputs.end(), [&nodes](const AnfNodePtr &input) {
119 if (opt::GetAnfNodeOutputShape(input, 0).size() == DIMENSION_4D) {
120 nodes.push_back(input);
121 }
122 });
123 for (auto input : nodes) {
124 auto itr = node_users.find(input);
125 for (auto pair : itr->second) {
126 auto used_node = pair.first;
127 MS_CHECK_TRUE_RET(used_node != nullptr && used_node->isa<CNode>(), false);
128 if (!opt::CheckPrimitiveType(used_node, prim::kPrimTranspose)) {
129 return false;
130 }
131 std::vector<int> perm;
132 if (GetTransposePerm(used_node->cast<CNodePtr>(), &perm) != RET_OK) {
133 MS_LOG(ERROR) << "fetch transpose perm failed.";
134 return false;
135 }
136 if (perm != kNH2NC) {
137 return false;
138 }
139 }
140 }
141 return true;
142 }
143
GetTracedCnodes(const FuncGraphPtr & func_graph)144 std::vector<AnfNodePtr> GetTracedCnodes(const FuncGraphPtr &func_graph) {
145 MS_ASSERT(func_graph != nullptr);
146 auto manager = func_graph->manager();
147 MS_CHECK_TRUE_RET(manager != nullptr, {});
148 auto node_users = manager->node_users();
149 auto nhwc_ops = GetNHWCOpMap();
150 std::stack<AnfNodePtr> nodes;
151 for (auto input : func_graph->get_inputs()) {
152 if (opt::GetAnfNodeOutputShape(input, 0).size() == DIMENSION_4D) {
153 nodes.push(input);
154 }
155 }
156
157 std::vector<AnfNodePtr> traced_nodes;
158 std::vector<AnfNodePtr> checked_nodes;
159 while (!nodes.empty()) {
160 auto node = nodes.top();
161 nodes.pop();
162 if (std::find(checked_nodes.begin(), checked_nodes.end(), node) != checked_nodes.end() ||
163 opt::CheckPrimitiveType(node, prim::kPrimReturn)) {
164 continue;
165 }
166 if (node->isa<CNode>()) {
167 auto cnode = node->cast<CNodePtr>();
168 MS_CHECK_TRUE_RET(cnode != nullptr, {});
169 MS_CHECK_TRUE_RET(cnode->size() > 0, {});
170 if (cnode->size() > 1) {
171 auto input_node = cnode->input(1);
172 auto itr = std::find(traced_nodes.begin(), traced_nodes.end(), input_node);
173 if (itr != traced_nodes.end()) {
174 (void)traced_nodes.erase(itr + 1, traced_nodes.end());
175 }
176 }
177 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
178 if (prim != nullptr && nhwc_ops.find(prim->name()) != nhwc_ops.end()) {
179 return traced_nodes;
180 }
181 traced_nodes.push_back(node);
182 }
183 auto itr = node_users.find(node);
184 MS_CHECK_TRUE_RET(itr != node_users.end(), {});
185 for (auto &pair : itr->second) {
186 nodes.push(pair.first);
187 }
188 checked_nodes.push_back(node);
189 }
190 return {};
191 }
192
GetCurGraphInputFormat(const FuncGraphPtr & func_graph,converter::FmkType fmk_type,mindspore::Format * input_format)193 bool SpecifyGraphInputFormat::GetCurGraphInputFormat(const FuncGraphPtr &func_graph, converter::FmkType fmk_type,
194 mindspore::Format *input_format) {
195 MS_ASSERT(func_graph != nullptr);
196 MS_ASSERT(input_format != nullptr);
197 if (fmk_type == converter::kFmkTypeTf || fmk_type == converter::kFmkTypeTflite) {
198 *input_format = NHWC;
199 } else {
200 *input_format = NCHW;
201 }
202
203 if (CheckInputsFormatNHWC(func_graph)) {
204 *input_format = NHWC;
205 return true;
206 }
207 auto traced_nodes = GetTracedCnodes(func_graph);
208 for (auto node : traced_nodes) {
209 if (opt::CheckPrimitiveType(node, prim::kPrimTranspose)) {
210 auto cnode = node->cast<CNodePtr>();
211 MS_CHECK_TRUE_RET(cnode != nullptr, false);
212 std::vector<int> perm;
213 if (GetTransposePerm(cnode, &perm) != RET_OK) {
214 MS_LOG(ERROR) << "fetch transpose perm failed.";
215 return false;
216 }
217 if (perm == kNC2NH) {
218 *input_format = NCHW;
219 return true;
220 } else if (perm == kNH2NC) {
221 *input_format = NHWC;
222 return true;
223 }
224 }
225 }
226 return true;
227 }
228 } // namespace opt
229 } // namespace mindspore
230