• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/converter/parser/parser_utils.h"
17 #include <algorithm>
18 #include <memory>
19 #include <set>
20 #include <string>
21 #include <vector>
22 #include <unordered_map>
23 #include "ops/transpose.h"
24 #include "tools/common/tensor_util.h"
25 #include "tools/converter/parser/conv1d_inout_adjust.h"
26 #include "tools/converter/parser/inputs_adjust.h"
27 #include "tools/converter/parser/tf_bidirection_gru_cf_fusion.h"
28 #include "tools/converter/parser/unused_node_remove_pass.h"
29 #include "tools/converter/quant_param_holder.h"
30 #include "tools/optimizer/common/gllo_utils.h"
31 #include "tools/optimizer/format/to_format_base.h"
32 #include "nnacl/op_base.h"
33 
34 namespace mindspore::lite {
35 namespace {
36 std::unordered_map<std::string, size_t> weight_indexs = {{ops::kNameConv2DFusion, 2},
37                                                          {ops::kNameConv2DBackpropInputFusion, 2},
38                                                          {ops::kNameConv2dTransposeFusion, 2},
39                                                          {ops::kNameApplyMomentum, 1},
40                                                          {ops::kNameSGD, 1},
41                                                          {ops::kNameAdam, 1}};
42 }  // namespace
43 
GetAllFuncGraph(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * all_func_graphs)44 void GetAllFuncGraph(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
45   MS_ASSERT(all_func_graphs);
46   MS_ASSERT(func_graph);
47   if (all_func_graphs->find(func_graph) == all_func_graphs->end()) {
48     all_func_graphs->insert(func_graph);
49   } else {
50     return;
51   }
52   auto nodes = func_graph->nodes();
53   for (auto &node : nodes) {
54     if (IsValueNode<FuncGraph>(node)) {
55       MS_ASSERT(node->cast<ValueNodePtr>() != nullptr);
56       MS_ASSERT(node->cast<ValueNodePtr>()->value() != nullptr);
57       MS_ASSERT((node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>() != nullptr);
58       auto new_fg = (node->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
59       GetAllFuncGraph(new_fg, all_func_graphs);
60     }
61     if (utils::isa<CNodePtr>(node)) {
62       auto cnode = node->cast<CNodePtr>();
63       MS_ASSERT(cnode != nullptr);
64       for (auto &input : cnode->inputs()) {
65         if (input->isa<ValueNode>()) {
66           if (IsValueNode<FuncGraph>(input)) {
67             MS_ASSERT(input->cast<ValueNodePtr>() != nullptr);
68             MS_ASSERT(input->cast<ValueNodePtr>()->value() != nullptr);
69             MS_ASSERT((input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>() != nullptr);
70             auto new_fg = (input->cast<ValueNodePtr>()->value())->cast<FuncGraphPtr>();
71             GetAllFuncGraph(new_fg, all_func_graphs);
72           }
73         }
74       }
75     }
76   }
77 }
78 
CommonAnfAdjust(const std::set<FuncGraphPtr> & all_func_graphs)79 int CommonAnfAdjust(const std::set<FuncGraphPtr> &all_func_graphs) {
80   for (auto func_graph : all_func_graphs) {
81     {
82       auto asylic_optimizer = std::make_shared<opt::GraphOptimizer>();
83       MS_CHECK_TRUE_MSG(asylic_optimizer != nullptr, RET_NULL_PTR, "asylic_optimizer is nullptr.");
84       auto asylic_pm = std::make_shared<opt::PassManager>("asylic pass manager", false);
85       MS_CHECK_TRUE_MSG(asylic_pm != nullptr, RET_NULL_PTR, "asylic_pm is nullptr.");
86 
87       // fuse tf1.x bidirection_gru into GRU, must be placed here because graph is cyclic
88       asylic_pm->AddPass(std::make_shared<opt::TfBidirectionGruCfFusion>());
89       // remove remaining cyclic nodes
90       asylic_pm->AddPass(std::make_shared<opt::UnusedNodeRemovePass>());
91       asylic_optimizer->AddPassManager(asylic_pm);
92       if (!asylic_optimizer->Optimize(func_graph)) {
93         MS_LOG(ERROR) << "gru cf fusion pass failed.";
94         ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_ERROR);
95         return RET_ERROR;
96       }
97     }
98     auto adjust_input = std::make_shared<InputAdjust>();
99     MS_CHECK_TRUE_MSG(adjust_input != nullptr, RET_NULL_PTR, "adjust_input is nullptr.");
100     if (!adjust_input->Run(func_graph)) {
101       MS_LOG(ERROR) << "adjust input failed.";
102       return RET_ERROR;
103     }
104     // adjust for conv1d
105     auto conv1d_adjust = std::make_shared<Conv1DInOutAdjust>();
106     MS_CHECK_TRUE_MSG(conv1d_adjust != nullptr, RET_NULL_PTR, "conv1d_adjust is nullptr.");
107     if (!conv1d_adjust->Run(func_graph)) {
108       MS_LOG(ERROR) << "adjust conv1d failed.";
109       return RET_ERROR;
110     }
111   }
112   return RET_OK;
113 }
114 
GetTransposePerm(schema::Format src_format,schema::Format dst_format,std::vector<int> * perm)115 int GetTransposePerm(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
116   MS_CHECK_TRUE_MSG(perm != nullptr, RET_NULL_PTR, "perm is nullptr.");
117   auto src_format_str = std::string(schema::EnumNameFormat(src_format));
118   auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
119   if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
120     MS_LOG(ERROR) << "src_format or dst_format is error.";
121     return lite::RET_ERROR;
122   }
123   for (size_t i = 0; i < src_format_str.size(); ++i) {
124     auto pos = src_format_str.find(dst_format_str[i]);
125     if (pos == std::string::npos) {
126       MS_LOG(ERROR) << "src_format and dst_format don't match.";
127       return lite::RET_ERROR;
128     }
129     perm->push_back(static_cast<int>(pos));
130   }
131   return lite::RET_OK;
132 }
133 
GetTransposePermSharing(schema::Format src_format,schema::Format dst_format,std::vector<int> * perm)134 int GetTransposePermSharing(schema::Format src_format, schema::Format dst_format, std::vector<int> *perm) {
135   MS_CHECK_TRUE_MSG(perm != nullptr, RET_NULL_PTR, "perm is nullptr.");
136   auto src_format_str = std::string(schema::EnumNameFormat(src_format));
137   auto dst_format_str = std::string(schema::EnumNameFormat(dst_format));
138   if (src_format_str.empty() || dst_format_str.empty() || src_format_str.size() != dst_format_str.size()) {
139     MS_LOG(ERROR) << "src_format or dst_format is error.";
140     return lite::RET_ERROR;
141   }
142   for (size_t i = 0; i < src_format_str.size(); ++i) {
143     auto pos = dst_format_str.find(src_format_str[i]);
144     if (pos == std::string::npos) {
145       MS_LOG(ERROR) << "src_format and dst_format don't match.";
146       return lite::RET_ERROR;
147     }
148     perm->push_back(static_cast<int>(pos));
149   }
150   return lite::RET_OK;
151 }
152 
GetRealConvWeightNode(const FuncGraphPtr & graph,const CNodePtr & cnode,size_t index)153 AnfNodePtr GetRealConvWeightNode(const FuncGraphPtr &graph, const CNodePtr &cnode, size_t index) {
154   MS_CHECK_TRUE_MSG(graph != nullptr, nullptr, "graph is nullptr.");
155   MS_CHECK_TRUE_MSG(cnode != nullptr, nullptr, "cnode is nullptr.");
156   auto weight_node = cnode->input(index);
157   MS_CHECK_TRUE_MSG(weight_node != nullptr, nullptr, "weight_node is nullptr.");
158   bool is_real_weight =
159     !opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) && !opt::CheckPrimitiveType(weight_node, prim::kPrimLoad);
160   while (!is_real_weight) {
161     if (!utils::isa<CNode>(weight_node)) {
162       MS_LOG(ERROR) << "weight node is invalid.";
163       return nullptr;
164     }
165     auto weight_cnode = weight_node->cast<CNodePtr>();
166     weight_node = weight_cnode->input(1);
167     MS_CHECK_TRUE_MSG(weight_node != nullptr, nullptr, "weight_node is nullptr.");
168     is_real_weight = !opt::CheckPrimitiveType(weight_node, opt::kPrimIdentity) &&
169                      !opt::CheckPrimitiveType(weight_node, prim::kPrimLoad);
170   }
171   auto manager = Manage(graph);
172   MS_CHECK_TRUE_MSG(manager != nullptr, nullptr, "manager is nullptr.");
173   manager->Replace(cnode->input(index), weight_node);
174   return weight_node;
175 }
176 
UnifyConvWeightFormat(const FuncGraphPtr & graph,const CNodePtr & cnode,schema::Format src_format,schema::Format dst_format,std::set<AnfNodePtr> * has_visited)177 int UnifyConvWeightFormat(const FuncGraphPtr &graph, const CNodePtr &cnode, schema::Format src_format,
178                           schema::Format dst_format, std::set<AnfNodePtr> *has_visited) {
179   MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
180   MS_CHECK_TRUE_MSG(cnode != nullptr, RET_NULL_PTR, "cnode is nullptr.");
181   MS_CHECK_TRUE_MSG(has_visited != nullptr, RET_NULL_PTR, "has_visited is nullptr.");
182   if (src_format == dst_format) {
183     return lite::RET_OK;
184   }
185   auto primitive_ptr = GetValueNode<PrimitivePtr>(cnode->input(0));
186   auto primitive_name = primitive_ptr->name();
187   if (weight_indexs.find(primitive_name) == weight_indexs.end()) {
188     MS_LOG(ERROR) << primitive_name << " is not a member of convolution's family.";
189     return RET_ERROR;
190   }
191   size_t index = weight_indexs[primitive_name];
192   if (GetRealConvWeightNode(graph, cnode, index) == nullptr) {
193     MS_LOG(ERROR) << "current conv node is invalid, node name is " << cnode->fullname_with_scope();
194     return RET_ERROR;
195   }
196   bool is_const_weight = true;
197   auto weight_node = cnode->input(index);
198   MS_CHECK_TRUE_MSG(weight_node != nullptr, RET_NULL_PTR, "weight_node is nullptr.");
199   if (utils::isa<CNode>(weight_node)) {
200     is_const_weight = false;
201   } else if (utils::isa<Parameter>(weight_node)) {
202     auto weight_param_node = weight_node->cast<ParameterPtr>();
203     MS_CHECK_TRUE_MSG(weight_param_node != nullptr, RET_NULL_PTR, "weight_param_node is nullptr.");
204     if (!weight_param_node->has_default()) {
205       is_const_weight = false;
206     }
207   }
208   int status;
209   if (is_const_weight) {
210     status = UnifyConstConvWeight(graph, weight_node, src_format, dst_format, has_visited);
211   } else {
212     status = UnifyVariableConvWeight(graph, weight_node, src_format, dst_format, has_visited);
213   }
214   if (status != RET_OK) {
215     MS_LOG(ERROR) << "unfiy coneight failed, cnode name is " << cnode->fullname_with_scope();
216   }
217   return status;
218 }
219 
UnifyVariableConvWeight(const FuncGraphPtr & graph,const AnfNodePtr & weight_node,schema::Format src_format,schema::Format dst_format,std::set<AnfNodePtr> * has_visited)220 int UnifyVariableConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
221                             schema::Format dst_format, std::set<AnfNodePtr> *has_visited) {
222   MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
223   MS_CHECK_TRUE_MSG(weight_node != nullptr, RET_NULL_PTR, "weight_node is nullptr.");
224   MS_CHECK_TRUE_MSG(has_visited != nullptr, RET_NULL_PTR, "has_visited is nullptr.");
225   if (src_format == dst_format) {
226     return lite::RET_OK;
227   }
228   std::vector<int> perm;
229   auto status = GetTransposePerm(src_format, dst_format, &perm);
230   if (status != lite::RET_OK) {
231     MS_LOG(ERROR) << "get perm failed.";
232     return status;
233   }
234   auto manager = Manage(graph);
235   MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
236   CNodePtr trans_cnode = nullptr;
237   auto weight_node_users = manager->node_users()[weight_node];
238   for (auto &weight_node_user : weight_node_users) {
239     auto post_node = weight_node_user.first;
240     if (!utils::isa<CNodePtr>(post_node)) {
241       MS_LOG(ERROR) << "post node is invalid.";
242       return RET_ERROR;
243     }
244     if (!opt::ToFormatBase::IsWeightNodeSensitive(post_node)) {
245       continue;
246     }
247     has_visited->insert(post_node);
248     if (trans_cnode == nullptr) {
249       trans_cnode = opt::GenTransposeNode(graph, weight_node, perm, weight_node->fullname_with_scope() + "_post_perm");
250       MS_CHECK_TRUE_MSG(trans_cnode != nullptr, RET_NULL_PTR, "trans_cnode is nullptr.");
251       auto abstract = weight_node->abstract();
252       ShapeVector shape;
253       if (abstract != nullptr) {
254         ShapeVector weight_shape;
255         if (opt::FetchShapeFromAbstract(abstract, &weight_shape) != RET_OK) {
256           MS_LOG(ERROR) << "fetch shape from abstract failed.";
257           return RET_ERROR;
258         }
259         if (!weight_shape.empty()) {
260           if (weight_shape.size() != opt::kInputSizeFour) {
261             MS_LOG(ERROR) << "conv weight shape is invalid, which is not 4D, now is " << weight_shape.size();
262             return RET_ERROR;
263           }
264           std::transform(perm.begin(), perm.end(), std::back_inserter(shape),
265                          [&weight_shape](const int index) { return weight_shape[index]; });
266         }
267         abstract = abstract->Clone();
268       } else {
269         abstract = CreateTensorAbstract(shape, TypeId::kNumberTypeFloat32);
270         MS_CHECK_TRUE_MSG(abstract != nullptr, RET_NULL_PTR, "abstract is nullptr.");
271       }
272       auto shape_ptr = std::make_shared<abstract::Shape>(shape);
273       MS_CHECK_TRUE_MSG(shape_ptr != nullptr, RET_NULL_PTR, "shape_ptr is nullptr.");
274       abstract->set_shape(shape_ptr);
275       trans_cnode->set_abstract(abstract);
276     }
277     auto post_cnode = post_node->cast<CNodePtr>();
278     MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr.");
279     auto tr = manager->Transact();
280     tr.SetEdge(post_cnode, weight_node_user.second, trans_cnode);
281     tr.Commit();
282   }
283   return RET_OK;
284 }
285 
UnifyConstConvWeight(const FuncGraphPtr & graph,const AnfNodePtr & weight_node,schema::Format src_format,schema::Format dst_format,std::set<AnfNodePtr> * has_visited)286 int UnifyConstConvWeight(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
287                          schema::Format dst_format, std::set<AnfNodePtr> *has_visited) {
288   MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
289   MS_CHECK_TRUE_MSG(weight_node != nullptr, RET_NULL_PTR, "weight_node is nullptr.");
290   MS_CHECK_TRUE_MSG(has_visited != nullptr, RET_NULL_PTR, "has_visited is nullptr.");
291   if (src_format == dst_format) {
292     return lite::RET_OK;
293   }
294   auto weight_value = opt::GetTensorInfo(weight_node);
295   if (weight_value == nullptr) {
296     MS_LOG(ERROR) << "conv weight is non-const.";
297     return RET_ERROR;
298   }
299   if (weight_value->shape().size() != kShape4dDims) {
300     return lite::RET_OK;
301   }
302   auto status = opt::TransFilterFormat(weight_value, src_format, dst_format);
303   if (status != RET_OK) {
304     MS_LOG(ERROR) << "TransFilter " << EnumNameFormat(src_format) << "To" << EnumNameFormat(dst_format)
305                   << " failed, node : " << weight_node->fullname_with_scope();
306     return RET_ERROR;
307   }
308   auto type_id = static_cast<TypeId>(weight_value->data_type());
309   auto shape = weight_value->shape();
310   auto abstract = CreateTensorAbstract(shape, type_id);
311   if (abstract == nullptr) {
312     MS_LOG(ERROR) << "Create tensor abstarct failed";
313     return RET_ERROR;
314   }
315   weight_node->set_abstract(abstract);
316   if (HandleConstConvWeightShared(graph, weight_node, src_format, dst_format, has_visited) != RET_OK) {
317     MS_LOG(ERROR) << "handle const conv weight-shared failed, node name is " << weight_node->fullname_with_scope();
318     return RET_ERROR;
319   }
320   return RET_OK;
321 }
322 
HandleConstConvWeightShared(const FuncGraphPtr & graph,const AnfNodePtr & weight_node,schema::Format src_format,schema::Format dst_format,std::set<AnfNodePtr> * has_visited)323 int HandleConstConvWeightShared(const FuncGraphPtr &graph, const AnfNodePtr &weight_node, schema::Format src_format,
324                                 schema::Format dst_format, std::set<AnfNodePtr> *has_visited) {
325   MS_CHECK_TRUE_MSG(graph != nullptr, RET_NULL_PTR, "graph is nullptr.");
326   MS_CHECK_TRUE_MSG(weight_node != nullptr, RET_NULL_PTR, "weight_node is nullptr.");
327   MS_CHECK_TRUE_MSG(has_visited != nullptr, RET_NULL_PTR, "has_visited is nullptr.");
328   if (src_format == dst_format) {
329     return RET_OK;
330   }
331   std::vector<int> perm;
332   auto status = GetTransposePermSharing(src_format, dst_format, &perm);
333   if (status != RET_OK) {
334     MS_LOG(ERROR) << "get perm failed.";
335     return status;
336   }
337   auto manager = Manage(graph);
338   MS_CHECK_TRUE_MSG(manager != nullptr, RET_NULL_PTR, "manager is nullptr.");
339   CNodePtr trans_cnode = nullptr;
340   auto weight_node_users = manager->node_users()[weight_node];
341   for (auto &weight_node_user : weight_node_users) {
342     auto post_node = weight_node_user.first;
343     if (!utils::isa<CNodePtr>(post_node)) {
344       MS_LOG(ERROR) << "post node is invalid.";
345       return RET_ERROR;
346     }
347     if (opt::ToFormatBase::IsWeightNodeSensitive(post_node)) {
348       has_visited->insert(post_node);
349       continue;
350     }
351     if (trans_cnode == nullptr) {
352       trans_cnode = opt::GenTransposeNode(graph, weight_node, perm, weight_node->fullname_with_scope() + "_post_perm");
353       MS_CHECK_TRUE_MSG(trans_cnode != nullptr, RET_NULL_PTR, "trans_cnode is nullptr.");
354       auto prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0));
355       MS_CHECK_TRUE_MSG(prim != nullptr, RET_NULL_PTR, "prim is nullptr.");
356       prim->AddAttr(ops::kFormat, MakeValue<int64_t>(dst_format));
357       auto weight_value = opt::GetTensorInfo(weight_node);
358       MS_CHECK_TRUE_MSG(weight_value != nullptr, RET_NULL_PTR, "weight_value is nullptr.");
359 
360       auto weight_shape = weight_value->shape();
361       ShapeVector shape;
362       if (!weight_shape.empty()) {
363         if (weight_shape.size() != opt::kInputSizeFour) {
364           MS_LOG(ERROR) << "conv weight shape is invalid, which is not 4D, now is " << weight_shape.size();
365           return RET_ERROR;
366         }
367         std::transform(perm.begin(), perm.end(), std::back_inserter(shape),
368                        [&weight_shape](const int index) { return weight_shape[index]; });
369       }
370       auto abstract = weight_node->abstract();
371       MS_CHECK_TRUE_MSG(abstract != nullptr, RET_NULL_PTR, "abstract is nullptr.");
372       abstract = abstract->Clone();
373       auto shape_ptr = std::make_shared<abstract::Shape>(shape);
374       MS_CHECK_TRUE_MSG(shape_ptr != nullptr, RET_NULL_PTR, "shape_ptr is nullptr.");
375       abstract->set_shape(shape_ptr);
376       trans_cnode->set_abstract(abstract);
377     }
378     auto post_cnode = post_node->cast<CNodePtr>();
379     MS_CHECK_TRUE_MSG(post_cnode != nullptr, RET_NULL_PTR, "post_cnode is nullptr.");
380     auto tr = manager->Transact();
381     tr.SetEdge(post_cnode, weight_node_user.second, trans_cnode);
382     tr.Commit();
383   }
384   return RET_OK;
385 }
386 }  // namespace mindspore::lite
387