• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/specify_graph_input_format.h"
19 #include <memory>
20 #include <vector>
21 #include <stack>
22 #include <set>
23 #include "mindspore/core/ops/array_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "tools/converter/parser/parser_utils.h"
26 #include "tools/optimizer/common/format_utils.h"
27 #include "src/common/log_adapter.h"
28 #include "nnacl/op_base.h"
29 #include "ops/op_utils.h"
30 #include "ops/auto_generate/gen_lite_ops.h"
31 
32 namespace mindspore {
33 namespace opt {
Run(const FuncGraphPtr & graph)34 bool SpecifyGraphInputFormat::Run(const FuncGraphPtr &graph) {
35   MS_ASSERT(graph != nullptr);
36   if (exp_graph_input_format_ == cur_graph_input_format_) {
37     return true;
38   }
39   if ((exp_graph_input_format_ != mindspore::NHWC && exp_graph_input_format_ != mindspore::NCHW) ||
40       (cur_graph_input_format_ != mindspore::NHWC && cur_graph_input_format_ != mindspore::NCHW)) {
41     MS_LOG(ERROR) << "this pass only support to transfer graph input format between nhwc with nchw.";
42     return false;
43   }
44   auto manager = Manage(graph);
45   MS_CHECK_TRUE_MSG(manager != nullptr, false, "manager is nullptr.");
46   if (HandleGraphInput(graph) != lite::RET_OK) {
47     MS_LOG(ERROR) << "Specify graph-input format failed.";
48     return false;
49   }
50   return true;
51 }
52 
HandleGraphInput(const FuncGraphPtr & graph)53 STATUS SpecifyGraphInputFormat::HandleGraphInput(const FuncGraphPtr &graph) {
54   MS_ASSERT(graph != nullptr);
55   auto manager = graph->manager();
56   MS_ASSERT(manager != nullptr);
57   auto graph_inputs = graph->get_inputs();
58   for (const auto &input : graph_inputs) {
59     auto input_node = input->cast<ParameterPtr>();
60     MS_ASSERT(input_node != nullptr);
61     auto abstract = input_node->abstract();
62     MS_CHECK_TRUE_MSG(abstract != nullptr, lite::RET_NULL_PTR, "abstract is nullptr");
63 
64     ShapeVector shape;
65     if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
66       MS_LOG(ERROR) << "fetch shape failed." << input->fullname_with_scope();
67       return lite::RET_ERROR;
68     }
69     if (shape.size() != kInputSizeFour) {
70       continue;
71     }
72     ShapeVector transfer_shape;
73     if (exp_graph_input_format_ == mindspore::NCHW) {
74       transfer_shape = {shape[0], shape[kInputIndexThree], shape[1], shape[kInputIndexTwo]};
75     } else {
76       transfer_shape = {shape[0], shape[kInputIndexTwo], shape[kInputIndexThree], shape[1]};
77     }
78     CNodePtr trans_cnode;
79     if (exp_graph_input_format_ == mindspore::NCHW) {
80       trans_cnode = opt::GenTransposeNode(graph, input, kNC2NH, input->fullname_with_scope() + "_nc2nh");
81     } else {
82       trans_cnode = opt::GenTransposeNode(graph, input, kNH2NC, input->fullname_with_scope() + "_nh2nc");
83     }
84     if (trans_cnode == nullptr) {
85       MS_LOG(ERROR) << "create transpose cnode failed.";
86       return lite::RET_ERROR;
87     }
88     auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
89     MS_CHECK_TRUE_MSG(trans_prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
90     if (exp_graph_input_format_ == mindspore::NCHW) {
91       trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NCHW));
92     } else {
93       trans_prim->AddAttr(ops::kFormat, MakeValue<int64_t>(NHWC));
94     }
95     trans_cnode->set_abstract(abstract->Clone());
96     abstract->set_shape(std::make_shared<abstract::Shape>(transfer_shape));
97     (void)manager->Replace(input, trans_cnode);
98   }
99   return lite::RET_OK;
100 }
101 
CheckInputsFormatNHWC(const FuncGraphPtr & func_graph)102 bool CheckInputsFormatNHWC(const FuncGraphPtr &func_graph) {
103   MS_ASSERT(func_graph != nullptr);
104   auto manager = func_graph->manager();
105   if (manager == nullptr) {
106     manager = Manage(func_graph, true);
107     MS_CHECK_TRUE_RET(manager != nullptr, {});
108     std::set<FuncGraphPtr> all_func_graphs;
109     lite::GetAllFuncGraph(func_graph, &all_func_graphs);
110     for (auto &graph : all_func_graphs) {
111       manager->AddFuncGraph(graph);
112     }
113   }
114 
115   auto node_users = manager->node_users();
116   std::vector<AnfNodePtr> nodes;
117   auto inputs = func_graph->get_inputs();
118   (void)std::for_each(inputs.begin(), inputs.end(), [&nodes](const AnfNodePtr &input) {
119     if (opt::GetAnfNodeOutputShape(input, 0).size() == DIMENSION_4D) {
120       nodes.push_back(input);
121     }
122   });
123   for (auto input : nodes) {
124     auto itr = node_users.find(input);
125     for (auto pair : itr->second) {
126       auto used_node = pair.first;
127       MS_CHECK_TRUE_RET(used_node != nullptr && used_node->isa<CNode>(), false);
128       if (!opt::CheckPrimitiveType(used_node, prim::kPrimTranspose)) {
129         return false;
130       }
131       std::vector<int> perm;
132       if (GetTransposePerm(used_node->cast<CNodePtr>(), &perm) != RET_OK) {
133         MS_LOG(ERROR) << "fetch transpose perm failed.";
134         return false;
135       }
136       if (perm != kNH2NC) {
137         return false;
138       }
139     }
140   }
141   return true;
142 }
143 
GetTracedCnodes(const FuncGraphPtr & func_graph)144 std::vector<AnfNodePtr> GetTracedCnodes(const FuncGraphPtr &func_graph) {
145   MS_ASSERT(func_graph != nullptr);
146   auto manager = func_graph->manager();
147   MS_CHECK_TRUE_RET(manager != nullptr, {});
148   auto node_users = manager->node_users();
149   auto nhwc_ops = GetNHWCOpMap();
150   std::stack<AnfNodePtr> nodes;
151   for (auto input : func_graph->get_inputs()) {
152     if (opt::GetAnfNodeOutputShape(input, 0).size() == DIMENSION_4D) {
153       nodes.push(input);
154     }
155   }
156 
157   std::vector<AnfNodePtr> traced_nodes;
158   std::vector<AnfNodePtr> checked_nodes;
159   while (!nodes.empty()) {
160     auto node = nodes.top();
161     nodes.pop();
162     if (std::find(checked_nodes.begin(), checked_nodes.end(), node) != checked_nodes.end() ||
163         opt::CheckPrimitiveType(node, prim::kPrimReturn)) {
164       continue;
165     }
166     if (node->isa<CNode>()) {
167       auto cnode = node->cast<CNodePtr>();
168       MS_CHECK_TRUE_RET(cnode != nullptr, {});
169       MS_CHECK_TRUE_RET(cnode->size() > 0, {});
170       if (cnode->size() > 1) {
171         auto input_node = cnode->input(1);
172         auto itr = std::find(traced_nodes.begin(), traced_nodes.end(), input_node);
173         if (itr != traced_nodes.end()) {
174           (void)traced_nodes.erase(itr + 1, traced_nodes.end());
175         }
176       }
177       auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
178       if (prim != nullptr && nhwc_ops.find(prim->name()) != nhwc_ops.end()) {
179         return traced_nodes;
180       }
181       traced_nodes.push_back(node);
182     }
183     auto itr = node_users.find(node);
184     MS_CHECK_TRUE_RET(itr != node_users.end(), {});
185     for (auto &pair : itr->second) {
186       nodes.push(pair.first);
187     }
188     checked_nodes.push_back(node);
189   }
190   return {};
191 }
192 
GetCurGraphInputFormat(const FuncGraphPtr & func_graph,converter::FmkType fmk_type,mindspore::Format * input_format)193 bool SpecifyGraphInputFormat::GetCurGraphInputFormat(const FuncGraphPtr &func_graph, converter::FmkType fmk_type,
194                                                      mindspore::Format *input_format) {
195   MS_ASSERT(func_graph != nullptr);
196   MS_ASSERT(input_format != nullptr);
197   if (fmk_type == converter::kFmkTypeTf || fmk_type == converter::kFmkTypeTflite) {
198     *input_format = NHWC;
199   } else {
200     *input_format = NCHW;
201   }
202 
203   if (CheckInputsFormatNHWC(func_graph)) {
204     *input_format = NHWC;
205     return true;
206   }
207   auto traced_nodes = GetTracedCnodes(func_graph);
208   for (auto node : traced_nodes) {
209     if (opt::CheckPrimitiveType(node, prim::kPrimTranspose)) {
210       auto cnode = node->cast<CNodePtr>();
211       MS_CHECK_TRUE_RET(cnode != nullptr, false);
212       std::vector<int> perm;
213       if (GetTransposePerm(cnode, &perm) != RET_OK) {
214         MS_LOG(ERROR) << "fetch transpose perm failed.";
215         return false;
216       }
217       if (perm == kNC2NH) {
218         *input_format = NCHW;
219         return true;
220       } else if (perm == kNH2NC) {
221         *input_format = NHWC;
222         return true;
223       }
224     }
225   }
226   return true;
227 }
228 }  // namespace opt
229 }  // namespace mindspore
230