1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/converter/parser/parser_utils.h"
17 #include <algorithm>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <vector>
22 #include <unordered_map>
23 #include "ops/transpose.h"
24 #include "tools/common/tensor_util.h"
25 #include "tools/converter/parser/conv1d_inout_adjust.h"
26 #include "tools/converter/parser/inputs_adjust.h"
27 #include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
28 #include "tools/converter/parser/unused_node_remove_pass.h"
29 #include "tools/converter/quant_param_holder.h"
30 #include "tools/optimizer/common/gllo_utils.h"
31 #include "tools/optimizer/format/to_format_base.h"
32 #include "nnacl/op_base.h"
33
34 namespace mindspore::lite {
35 namespace {
36 std::unordered_map<std::string, size_t> weight_indexs = {{ops::kNameConv2DFusion, 2},
37 {ops::kNameConv2DBackpropInputFusion, 2},
38 {ops::kNameConv2dTransposeFusion, 2},
39 {ops::kNameApplyMomentum, 1},
40 {ops::kNameSGD, 1},
41 {ops::kNameAdam, 1}};
42 } // namespace
43
GetAllFuncGraph(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * all_func_graphs)44 void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
45 MS_ASSERT(all_func_graphs);
46 MS_ASSERT(func_graph);
47 if (all_func_graphs->find(func_graph) == all_func_graphs->end()) {
48 all_func_graphs->insert(func_graph);
49 } else {
50 return;
51 }
52 auto nodes = func_graph->nodes();
53 for (auto &node : nodes) {
54 if (IsValueNode<FuncGraph>(node)) {
55 MS_ASSERT(node->cast<ValueNodePtr>() != nullptr);
56 MS_ASSERT(node->cast<ValueNodePtr>()->value() != nullptr);
57 MS_ASSERT((node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>() != nullptr);
58 auto new_fg = (node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
59 GetAllFuncGraph(new_fg, all_func_graphs);
60 }
61 if (utils::isa<CNodePtr>(node)) {
62 auto cnode = node->cast<CNodePtr>();
63 MS_ASSERT(cnode != nullptr);
64 for (auto &input : cnode->inputs()) {
65 if (input->isa<ValueNode>()) {
66 if (IsValueNode<FuncGraph>(input)) {
67 MS_ASSERT(input->cast<ValueNodePtr>() != nullptr);
68 MS_ASSERT(input->cast<ValueNodePtr>()->value() != nullptr);
69 MS_ASSERT((input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>() != nullptr);
70 auto new_fg = (input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
71 GetAllFuncGraph(new_fg, all_func_graphs);
72 }
73 }
74 }
75 }
76 }
77 }
78
CommonAnfAdjust(const std::set<FuncGraphPtr> & all_func_graphs)79 int CommonAnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
80 for (auto func_graph : all_func_graphs) {
81 {
82 auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>();
83 MS_CHECK_TRUE_MSG(asylic_optimizer != nullptr, RET_NULL_PTR, "asylic_optimizer is nullptr.");
84 auto asylic_pm = std::make_shared<opt::PassManager>("asylic pass manager", false);
85 MS_CHECK_TRUE_MSG(asylic_pm != nullptr, RET_NULL_PTR, "asylic_pm is nullptr.");
86
87 // fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic
88 asylic_pm->AddPass(std::make_shared<opt::TfBidirectionGruCfFusion>());
89 // remove remaining cyclic nodes
90 asylic_pm->AddPass(std::make_shared<opt::UnusedNodeRemovePass>());
91 asylic_optimizer->AddPassManager(asylic_pm);
92 if (!asylic_optimizer->Optimize(func_graph)) {
93 MS_LOG(ERROR) << "gru cf fusion pass failed.";
94 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
95 return RET_ERROR;
96 }
97 }
98 auto adjust_input = std::make_shared<InputAdjust>();
99 MS_CHECK_TRUE_MSG(adjust_input != nullptr, RET_NULL_PTR, "adjust_input is nullptr.");
100 if (!adjust_input->Run(func_graph)) {
101 MS_LOG(ERROR) << "adjust input failed.";
102 return RET_ERROR;
103 }
104 // adjust for conv1d
105 auto conv1d_adjust = std::make_shared<Conv1DInOutAdjust>();
106 MS_CHECK_TRUE_MSG(conv1d_adjust != nullptr, RET_NULL_PTR, "conv1d_adjust is nullptr.");
107 if (!conv1d_adjust->Run(func_graph)) {
108 MS_LOG(ERROR) << "adjust conv1d failed.";
109 return RET_ERROR;
110 }
111 }
112 return RET_OK;
113 }
114
GetTransposePerm(schema::Format src_format,schema::Format dst_format,std::vector<int> * perm)115 int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
116 MS_CHECK_TRUE_MSG(perm != nullptr, RET_NULL_PTR, "perm is nullptr.");
117 auto src_format_str = std::string(schema::EnumNameFormat(src_format));
118 auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
119 if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
120 MS_LOG(ERROR) << "src_format or dst_format is error.";
121 return lite::RET_ERROR;
122 }
123 for (size_t i = 0; i < src_format_str.size(); ++i) {
124 auto pos = src_format_str.find(dst_format_str[i]);
125 if (pos == std::string::npos) {
126 MS_LOG(ERROR) << "src_format and dst_format don't match.";
127 return lite::RET_ERROR;
128 }
129 perm->push_back(static_cast<int>(pos));
130 }
131 return lite::RET_OK;
132 }
133
GetTransposePermSharing(schema::Format src_format,schema::Format dst_format,std::vector<int> * perm)134 int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
135 MS_CHECK_TRUE_MSG(perm != nullptr, RET_NULL_PTR, "perm is nullptr.");
136 auto src_format_str = std::string(schema::EnumNameFormat(src_format));
137 auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
138 if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
139 MS_LOG(ERROR) << "src_format or dst_format is error.";
140 return lite::RET_ERROR;
141 }
142 for (size_t i = 0; i < src_format_str.size(); ++i) {
143 auto pos = dst_format_str.find(src_format_str[i]);
144 if (pos == std::string::npos) {
145 MS_LOG(ERROR) << "src_format and dst_format don't match.";
146 return lite::RET_ERROR;
147 }
148 perm->push_back(static_cast<int>(pos));
149 }
150 return lite::RET_OK;
151 }
152
GetRealConvWeightNode(const FuncGraphPtr & graph,const CNodePtr & cnode,size_t index)153 AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index) {
154 MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "graph is nullptr.");
155 MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "cnode is nullptr.");
156 auto weight_node = cnode->input(index);
157 MS_CHECK_TRUE_MSG(weight_node != nullptr, nullptr, "weight_node is nullptr.");
158 bool is_real_weight =
159 !opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) && !opt::CheckPrimitiveType(weight_node, prim::kPrimLoad);
160 while (!is_real_weight) {
161 if (!utils::isa<CNode>(weight_node)) {
162 MS_LOG(ERROR) << "weight node is invalid.";
163 return nullptr;
164 }
165 auto weight_cnode = weight_node->cast<CNodePtr>();
166 weight_node = weight_cnode->input(1);
167 MS_CHECK_TRUE_MSG(weight_node != nullptr, nullptr, "weight_node is nullptr.");
168 is_real_weight = !opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) &&
169 !opt::CheckPrimitiveType(weight_node, prim::kPrimLoad);
170 }
171 auto manager = Manage(graph);
172 MS_CHECK_TRUE_MSG(manager != nullptr, nullptr, "manager is nullptr.");
173 manager->Replace(cnode->input(index), weight_node);
174 return weight_node;
175 }
176
UnifyConvWeightFormat(const FuncGraphPtr & graph,const CNodePtr & cnode,schema::Format src_format,schema::Format dst_format,std::set<AnfNodePtr> * has_visited)177 int UnifyConvWeightFormat(const FuncGraphPtr &graph, const CNodePtr &cnode, schema::Format src_format,
178 schema::Format dst_format, std::set<AnfNodePtr> *has_visited) {
179 MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
180 MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr.");
181 MS_CHECK_TRUE_MSG(has_visited != nullptr, RET_NULL_PTR, "has_visited is nullptr.");
182 if (src_format == dst_format) {
183 return lite::RET_OK;
184 }
185 auto primitive_ptr = GetValueNode<PrimitivePtr>(cnode->input(0));
186 auto primitive_name = primitive_ptr->name();
187 if (weight_indexs.find(primitive_name) == weight_indexs.end()) {
188 MS_LOG(ERROR) << primitive_name << " is not a member of convolution's family.";
189 return RET_ERROR;
190 }
191 size_t index = weight_indexs[primitive_name];
192 if (GetRealConvWeightNode(graph, cnode, index) == nullptr) {
193 MS_LOG(ERROR) << "current conv node is invalid, node name is " << cnode->fullname_with_scope();
194 return RET_ERROR;
195 }
196 bool is_const_weight = true;
197 auto weight_node = cnode->input(index);
198 MS_CHECK_TRUE_MSG(weight_node != nullptr, RET_NULL_PTR, "weight_node is nullptr.");
199 if (utils::isa<CNode>(weight_node)) {
200 is_const_weight = false;
201 } else if (utils::isa<Parameter>(weight_node)) {
202 auto weight_param_node = weight_node->cast<ParameterPtr>();
203 MS_CHECK_TRUE_MSG(weight_param_node != nullptr, RET_NULL_PTR, "weight_param_node is nullptr.");
204 if (!weight_param_node->has_default()) {
205 is_const_weight = false;
206 }
207 }
208 int status;
209 if (is_const_weight) {
210 status = UnifyConstConvWeight(graph, weight_node, src_format, dst_format, has_visited);
211 } else {
212 status = UnifyVariableConvWeight(graph, weight_node, src_format, dst_format, has_visited);
213 }
214 if (status != RET_OK) {
215 MS_LOG(ERROR) << "unfiy coneight failed, cnode name is " << cnode->fullname_with_scope();
216 }
217 return status;
218 }
219
UnifyVariableConvWeight(const FuncGraphPtr & graph,const AnfNodePtr & weight_node,schema::Format src_format,schema::Format dst_format,std::set<AnfNodePtr> * has_visited)220 int UnifyVariableConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
221 schema::Format dst_format, std::set<AnfNodePtr> *has_visited) {
222 MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
223 MS_CHECK_TRUE_MSG(weight_node != nullptr, RET_NULL_PTR, "weight_node is nullptr.");
224 MS_CHECK_TRUE_MSG(has_visited != nullptr, RET_NULL_PTR, "has_visited is nullptr.");
225 if (src_format == dst_format) {
226 return lite::RET_OK;
227 }
228 std::vector<int> perm;
229 auto status = GetTransposePerm(src_format, dst_format, &perm);
230 if (status != lite::RET_OK) {
231 MS_LOG(ERROR) << "get perm failed.";
232 return status;
233 }
234 auto manager = Manage(graph);
235 MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
236 CNodePtr trans_cnode = nullptr;
237 auto weight_node_users = manager->node_users()[weight_node];
238 for (auto &weight_node_user : weight_node_users) {
239 auto post_node = weight_node_user.first;
240 if (!utils::isa<CNodePtr>(post_node)) {
241 MS_LOG(ERROR) << "post node is invalid.";
242 return RET_ERROR;
243 }
244 if (!opt::ToFormatBase::IsWeightNodeSensitive(post_node)) {
245 continue;
246 }
247 has_visited->insert(post_node);
248 if (trans_cnode == nullptr) {
249 trans_cnode = opt::GenTransposeNode(graph, weight_node, perm, weight_node->fullname_with_scope() + "_post_perm");
250 MS_CHECK_TRUE_MSG(trans_cnode != nullptr, RET_NULL_PTR, "trans_cnode is nullptr.");
251 auto abstract = weight_node->abstract();
252 ShapeVector shape;
253 if (abstract != nullptr) {
254 ShapeVector weight_shape;
255 if (opt::FetchShapeFromAbstract(abstract, &weight_shape) != RET_OK) {
256 MS_LOG(ERROR) << "fetch shape from abstract failed.";
257 return RET_ERROR;
258 }
259 if (!weight_shape.empty()) {
260 if (weight_shape.size() != opt::kInputSizeFour) {
261 MS_LOG(ERROR) << "conv weight shape is invalid, which is not 4D, now is " << weight_shape.size();
262 return RET_ERROR;
263 }
264 std::transform(perm.begin(), perm.end(), std::back_inserter(shape),
265 [&weight_shape](const int index) { return weight_shape[index]; });
266 }
267 abstract = abstract->Clone();
268 } else {
269 abstract = CreateTensorAbstract(shape, TypeId::kNumberTypeFloat32);
270 MS_CHECK_TRUE_MSG(abstract != nullptr, RET_NULL_PTR, "abstract is nullptr.");
271 }
272 auto shape_ptr = std::make_shared<abstract::Shape>(shape);
273 MS_CHECK_TRUE_MSG(shape_ptr != nullptr, RET_NULL_PTR, "shape_ptr is nullptr.");
274 abstract->set_shape(shape_ptr);
275 trans_cnode->set_abstract(abstract);
276 }
277 auto post_cnode = post_node->cast<CNodePtr>();
278 MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr.");
279 auto tr = manager->Transact();
280 tr.SetEdge(post_cnode, weight_node_user.second, trans_cnode);
281 tr.Commit();
282 }
283 return RET_OK;
284 }
285
UnifyConstConvWeight(const FuncGraphPtr & graph,const AnfNodePtr & weight_node,schema::Format src_format,schema::Format dst_format,std::set<AnfNodePtr> * has_visited)286 int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
287 schema::Format dst_format, std::set<AnfNodePtr> *has_visited) {
288 MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
289 MS_CHECK_TRUE_MSG(weight_node != nullptr, RET_NULL_PTR, "weight_node is nullptr.");
290 MS_CHECK_TRUE_MSG(has_visited != nullptr, RET_NULL_PTR, "has_visited is nullptr.");
291 if (src_format == dst_format) {
292 return lite::RET_OK;
293 }
294 auto weight_value = opt::GetTensorInfo(weight_node);
295 if (weight_value == nullptr) {
296 MS_LOG(ERROR) << "conv weight is non-const.";
297 return RET_ERROR;
298 }
299 if (weight_value->shape().size() != kShape4dDims) {
300 return lite::RET_OK;
301 }
302 auto status = opt::TransFilterFormat(weight_value, src_format, dst_format);
303 if (status != RET_OK) {
304 MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(src_format) << "To" << EnumNameFormat(dst_format)
305 << " failed, node : " << weight_node->fullname_with_scope();
306 return RET_ERROR;
307 }
308 auto type_id = static_cast<TypeId>(weight_value->data_type());
309 auto shape = weight_value->shape();
310 auto abstract = CreateTensorAbstract(shape, type_id);
311 if (abstract == nullptr) {
312 MS_LOG(ERROR) << "Create tensor abstarct failed";
313 return RET_ERROR;
314 }
315 weight_node->set_abstract(abstract);
316 if (HandleConstConvWeightShared(graph, weight_node, src_format, dst_format, has_visited) != RET_OK) {
317 MS_LOG(ERROR) << "handle const conv weight-shared failed, node name is " << weight_node->fullname_with_scope();
318 return RET_ERROR;
319 }
320 return RET_OK;
321 }
322
HandleConstConvWeightShared(const FuncGraphPtr & graph,const AnfNodePtr & weight_node,schema::Format src_format,schema::Format dst_format,std::set<AnfNodePtr> * has_visited)323 int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
324 schema::Format dst_format, std::set<AnfNodePtr> *has_visited) {
325 MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
326 MS_CHECK_TRUE_MSG(weight_node != nullptr, RET_NULL_PTR, "weight_node is nullptr.");
327 MS_CHECK_TRUE_MSG(has_visited != nullptr, RET_NULL_PTR, "has_visited is nullptr.");
328 if (src_format == dst_format) {
329 return RET_OK;
330 }
331 std::vector<int> perm;
332 auto status = GetTransposePermSharing(src_format, dst_format, &perm);
333 if (status != RET_OK) {
334 MS_LOG(ERROR) << "get perm failed.";
335 return status;
336 }
337 auto manager = Manage(graph);
338 MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
339 CNodePtr trans_cnode = nullptr;
340 auto weight_node_users = manager->node_users()[weight_node];
341 for (auto &weight_node_user : weight_node_users) {
342 auto post_node = weight_node_user.first;
343 if (!utils::isa<CNodePtr>(post_node)) {
344 MS_LOG(ERROR) << "post node is invalid.";
345 return RET_ERROR;
346 }
347 if (opt::ToFormatBase::IsWeightNodeSensitive(post_node)) {
348 has_visited->insert(post_node);
349 continue;
350 }
351 if (trans_cnode == nullptr) {
352 trans_cnode = opt::GenTransposeNode(graph, weight_node, perm, weight_node->fullname_with_scope() + "_post_perm");
353 MS_CHECK_TRUE_MSG(trans_cnode != nullptr, RET_NULL_PTR, "trans_cnode is nullptr.");
354 auto prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
355 MS_CHECK_TRUE_MSG(prim != nullptr, RET_NULL_PTR, "prim is nullptr.");
356 prim->AddAttr(ops::kFormat, MakeValue<int64_t>(dst_format));
357 auto weight_value = opt::GetTensorInfo(weight_node);
358 MS_CHECK_TRUE_MSG(weight_value != nullptr, RET_NULL_PTR, "weight_value is nullptr.");
359
360 auto weight_shape = weight_value->shape();
361 ShapeVector shape;
362 if (!weight_shape.empty()) {
363 if (weight_shape.size() != opt::kInputSizeFour) {
364 MS_LOG(ERROR) << "conv weight shape is invalid, which is not 4D, now is " << weight_shape.size();
365 return RET_ERROR;
366 }
367 std::transform(perm.begin(), perm.end(), std::back_inserter(shape),
368 [&weight_shape](const int index) { return weight_shape[index]; });
369 }
370 auto abstract = weight_node->abstract();
371 MS_CHECK_TRUE_MSG(abstract != nullptr, RET_NULL_PTR, "abstract is nullptr.");
372 abstract = abstract->Clone();
373 auto shape_ptr = std::make_shared<abstract::Shape>(shape);
374 MS_CHECK_TRUE_MSG(shape_ptr != nullptr, RET_NULL_PTR, "shape_ptr is nullptr.");
375 abstract->set_shape(shape_ptr);
376 trans_cnode->set_abstract(abstract);
377 }
378 auto post_cnode = post_node->cast<CNodePtr>();
379 MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr.");
380 auto tr = manager->Transact();
381 tr.SetEdge(post_cnode, weight_node_user.second, trans_cnode);
382 tr.Commit();
383 }
384 return RET_OK;
385 }
386 } // namespace mindspore::lite
387