• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
18 #include <memory>
19 #include <vector>
20 #include <utility>
21 #include "include/errorcode.h"
22 #include "tools/anf_exporter/fetch_content.h"
23 #include "tools/converter/ops/ops_def.h"
24 #include "ops/depend.h"
25 #include "ops/fusion/pad_fusion.h"
26 #include "ops/op_utils.h"
27 #include "nnacl/op_base.h"
28 
29 namespace mindspore::opt {
30 namespace {
ProcessInputIsMonad(const FuncGraphPtr & func_graph,const CNodePtr & cnode)31 int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
32   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
33   auto first_input = cnode->input(1);
34   MS_ASSERT(first_input != nullptr);
35   if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
36     first_input = cnode->input(1)->cast<CNodePtr>()->input(1);
37     MS_CHECK_TRUE_MSG(first_input != nullptr, RET_ERROR, "first_input is nullptr");
38   }
39   auto second_input = cnode->input(kInputIndexTwo);
40   MS_ASSERT(seconde_input != nullptr);
41   if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
42     second_input = cnode->input(kInputIndexTwo)->cast<CNodePtr>()->input(1);
43     MS_CHECK_TRUE_MSG(second_input != nullptr, RET_ERROR, "second_input is nullptr");
44   }
45   AnfNodePtr must_monad = nullptr;
46   AnfNodePtr not_must_monad = nullptr;
47   if (utils::isa<ValueNode>(first_input)) {
48     auto value_node = first_input->cast<ValueNodePtr>();
49     MS_ASSERT(value_node->value() != nullptr);
50     if (utils::isa<Monad>(value_node->value())) {
51       must_monad = first_input;
52       not_must_monad = second_input;
53     }
54   }
55   if (utils::isa<ValueNode>(second_input)) {
56     auto value_node = second_input->cast<ValueNodePtr>();
57     MS_ASSERT(value_node->value() != nullptr);
58     if (utils::isa<Monad>(value_node->value())) {
59       must_monad = second_input;
60       not_must_monad = first_input;
61     }
62   }
63   if (must_monad == nullptr) {
64     return lite::RET_NO_CHANGE;
65   }
66   auto manager = func_graph->manager();
67   MS_ASSERT(manager != nullptr);
68   if (!utils::isa<CNode>(not_must_monad) || CheckIsAllInputsParam(not_must_monad)) {
69     manager->Replace(cnode, must_monad);
70   } else {
71     manager->Replace(cnode, not_must_monad);
72   }
73   return lite::RET_OK;
74 }
75 
ProcessDependencyWithTwoNodes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool pre_node_is_first)76 int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool pre_node_is_first) {
77   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
78   AnfNodePtr pre_node = cnode->input(1);
79   AnfNodePtr post_node = cnode->input(kInputIndexTwo);
80   MS_ASSERT(pre_node != nullptr);
81   MS_ASSERT(post_node != nullptr);
82   if (!pre_node_is_first) {
83     pre_node = cnode->input(kInputIndexTwo);
84     post_node = cnode->input(1);
85   }
86   if (CheckPrimitiveType(pre_node, prim::kPrimTranspose)) {
87     pre_node = cnode->input(1)->cast<CNodePtr>()->input(1);
88     MS_CHECK_TRUE_MSG(pre_node != nullptr, RET_ERROR, "pre_node is nullptr");
89   }
90   if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
91     post_node = cnode->input(kInputIndexTwo)->cast<CNodePtr>()->input(1);
92     MS_CHECK_TRUE_MSG(post_node != nullptr, RET_ERROR, "post_node is nullptr");
93   }
94   auto manager = func_graph->manager();
95   MS_ASSERT(manager != nullptr);
96   auto node_users = manager->node_users()[pre_node];
97   auto iter =
98     std::find_if(node_users.begin(), node_users.end(),
99                  [&post_node](const std::pair<AnfNodePtr, int> &post_pair) { return post_pair.first == post_node; });
100   if (iter == node_users.end()) {
101     return lite::RET_NO_CHANGE;
102   }
103   auto tr = manager->Transact();
104   tr.SetEdge(post_node, iter->second, NewValueNode(std::make_shared<UMonad>()));
105   tr.Commit();
106   auto depend_prim = std::make_shared<ops::Depend>();
107   auto depend_node = func_graph->NewCNode(depend_prim, {post_node, pre_node});
108   MS_CHECK_TRUE_MSG(depend_prim != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
109   MS_CHECK_TRUE_MSG(depend_node != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
110   depend_node->set_fullname_with_scope(cnode->fullname_with_scope());
111   manager->Replace(cnode, depend_node);
112   return lite::RET_OK;
113 }
114 
ProcessInputHaveDependency(const FuncGraphPtr & func_graph,const CNodePtr & cnode)115 int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
116   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
117   if (ProcessDependencyWithTwoNodes(func_graph, cnode, true) == lite::RET_OK) {
118     return lite::RET_OK;
119   }
120   if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
121     return lite::RET_OK;
122   }
123   auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>());
124   auto manager = func_graph->manager();
125   MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
126   MS_ASSERT(manager != nullptr);
127   if (CheckPrimitiveType(cnode->input(0), prim::kPrimTranspose)) {
128     manager->Replace(cnode->input(0)->cast<CNodePtr>()->input(0), make_tuple_prim);
129     return RET_OK;
130   }
131   manager->Replace(cnode->input(0), make_tuple_prim);
132   return lite::RET_OK;
133 }
134 }  // namespace
135 
ReplaceOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)136 int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
137   MS_CHECK_TRUE_MSG(anf_node != nullptr, RET_ERROR, "anf_node is nullptr");
138   MS_CHECK_TRUE_MSG(manager != nullptr, RET_ERROR, "manager is nullptr");
139   if (!utils::isa<CNodePtr>(anf_node)) {
140     MS_LOG(DEBUG) << "anf node is node a cnode.";
141     return lite::RET_NO_CHANGE;
142   }
143   auto cnode = anf_node->cast<CNodePtr>();
144   MS_ASSERT(cnode != nullptr);
145   if (CheckPrimitiveType(anf_node, kPrimIdentity)) {
146     if (cnode->size() != kInputSizeTwo) {
147       MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
148       remove_cnode_.insert(anf_node);
149       return lite::RET_NO_CHANGE;
150     }
151   }
152   if (CheckPrimitiveType(anf_node, prim::kPrimDepend)) {
153     if (cnode->size() != kInputSizeTwo) {
154       MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
155       remove_cnode_.insert(anf_node);
156       return lite::RET_NO_CHANGE;
157     }
158   }
159   if (CheckPrimitiveType(anf_node, prim::kPrimTranspose)) {
160     if (cnode->size() != kInputSizeThree) {
161       MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
162       remove_cnode_.insert(anf_node);
163       return lite::RET_NO_CHANGE;
164     }
165   }
166 
167   bool replace_succ = manager->Replace(anf_node, cnode->input(1));
168   if (!replace_succ) {
169     MS_LOG(ERROR) << "replace redundant op failed.";
170     return lite::RET_ERROR;
171   }
172   return RET_OK;
173 }
174 
ReplaceUpdateStateOp(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node)175 int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
176   if (!utils::isa<CNodePtr>(anf_node)) {
177     MS_LOG(DEBUG) << "anf node is node a cnode.";
178     return lite::RET_NO_CHANGE;
179   }
180   auto cnode = anf_node->cast<CNodePtr>();
181   MS_ASSERT(cnode != nullptr);
182   if (ProcessInputIsMonad(func_graph, cnode) == lite::RET_OK) {
183     return lite::RET_OK;
184   }
185   // both of two inputs are not monad, but have dependency.
186   return ProcessInputHaveDependency(func_graph, cnode);
187 }
188 
ReplaceTupleGetItem(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)189 int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
190   if (!utils::isa<CNodePtr>(anf_node)) {
191     MS_LOG(DEBUG) << "anf node is node a cnode.";
192     return lite::RET_NO_CHANGE;
193   }
194   if (!CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
195     return lite::RET_NO_CHANGE;
196   }
197   auto cnode = anf_node->cast<CNodePtr>();
198   MS_ASSERT(cnode != nullptr);
199   if (cnode->inputs().size() != kInputSizeThree) {
200     MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size();
201     return RET_ERROR;
202   }
203   if (!CheckPrimitiveType(cnode->input(1), kPrimIdentity)) {
204     return lite::RET_NO_CHANGE;
205   }
206   auto get_item_input_cnode = cnode->input(1)->cast<CNodePtr>();
207   auto index_vnode = cnode->input(kInputIndexTwo);
208   if (!utils::isa<ValueNode>(index_vnode)) {
209     MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
210     return lite::RET_ERROR;
211   }
212   MS_CHECK_TRUE_MSG(!CastToInt(index_vnode->cast<ValueNodePtr>()->value()).empty(), RET_ERROR, "value is empty");
213   int index = CastToInt(index_vnode->cast<ValueNodePtr>()->value()).front();
214   int input_cnode_inputs_size = get_item_input_cnode->inputs().size();
215   if ((index + 1) >= input_cnode_inputs_size) {
216     MS_LOG(ERROR) << "value node index is out of range.";
217     return lite::RET_ERROR;
218   }
219   bool replace_succ = manager->Replace(anf_node, get_item_input_cnode->input(index + 1));
220   if (!replace_succ) {
221     MS_LOG(ERROR) << "replace identity failed.";
222     return lite::RET_ERROR;
223   }
224   return lite::RET_OK;
225 }
226 
RemoveDropoutOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)227 int RemoveRedundantOpPass::RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
228   MS_ASSERT(anf_node != nullptr);
229   MS_ASSERT(manager != nullptr);
230   if (!utils::isa<CNodePtr>(anf_node)) {
231     MS_LOG(DEBUG) << "anf node is node a cnode.";
232     return lite::RET_NO_CHANGE;
233   }
234   auto cnode = anf_node->cast<CNodePtr>();
235   MS_ASSERT(cnode != nullptr);
236   if (cnode->size() > kInputSizeTwo) {
237     MS_LOG(ERROR) << "dropout input invalid.";
238     return lite::RET_ERROR;
239   }
240   if (!utils::isa<abstract::AbstractTuplePtr>(anf_node->abstract())) {
241     MS_LOG(DEBUG) << "dropout output size is one.";
242     manager->Replace(anf_node, cnode->input(1));
243   } else {
244     auto node_users = manager->node_users()[anf_node];
245     for (auto &node_user : node_users) {
246       auto node = node_user.first;
247       if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
248         MS_LOG(ERROR) << "dropout out node is invalid.";
249         return lite::RET_ERROR;
250       }
251       auto get_index_node = node->cast<CNodePtr>()->input(kInputIndexTwo)->cast<ValueNodePtr>();
252       if (get_index_node == nullptr) {
253         MS_LOG(ERROR) << "tuple get item node is invalid.";
254         return lite::RET_ERROR;
255       }
256       auto get_index = CastToInt(get_index_node->value()).front();
257       if (get_index > 0 && !manager->node_users()[node].empty()) {
258         MS_LOG(ERROR) << "dropout's second output is useful.";
259         return lite::RET_ERROR;
260       }
261       manager->Replace(node, cnode->input(1));
262     }
263   }
264   return lite::RET_OK;
265 }
266 
GetConstDataFromInputNode(const CNodePtr & cnode,lite::DataInfo * data_info)267 int RemoveRedundantOpPass::GetConstDataFromInputNode(const CNodePtr &cnode, lite::DataInfo *data_info) {
268   MS_ASSERT(cnode != nullptr);
269   MS_ASSERT(data_info != nullptr);
270   auto padding_node = cnode->input(kInputIndexTwo);
271   MS_ASSERT(padding_node != nullptr);
272   if (utils::isa<Parameter>(padding_node)) {
273     auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
274     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
275       MS_LOG(ERROR) << "fetch data from parameter node failed.";
276       return lite::RET_ERROR;
277     }
278   } else if (utils::isa<ValueNode>(padding_node)) {
279     auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
280     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
281       MS_LOG(ERROR) << "fetch data from value node failed.";
282       return lite::RET_ERROR;
283     }
284   }
285   return lite::RET_OK;
286 }
287 
RemoveInvalidPadOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)288 int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
289   if (!utils::isa<CNodePtr>(anf_node)) {
290     MS_LOG(DEBUG) << "anf node is node a cnode.";
291     return lite::RET_NO_CHANGE;
292   }
293   auto cnode = anf_node->cast<CNodePtr>();
294   MS_ASSERT(cnode != nullptr);
295   auto primitive = GetValueNode<mindspore::PrimitivePtr>(cnode->input(0));
296   if (primitive == nullptr) {
297     MS_LOG(ERROR) << "primitive is nullptr:" << cnode->fullname_with_scope();
298     return lite::RET_NO_CHANGE;
299   }
300   auto is_invalid = true;
301   if (cnode->size() > kInputSizeTwo) {
302     lite::DataInfo data_info;
303     if (GetConstDataFromInputNode(cnode, &data_info) != RET_OK) {
304       MS_LOG(ERROR) << "Get pad data failed.";
305       return lite::RET_ERROR;
306     }
307     if (!data_info.data_.empty()) {
308       auto pad_data = reinterpret_cast<int *>(data_info.data_.data());
309       size_t num = data_info.data_.size() / sizeof(int);
310       for (size_t i = 0; i < num; ++i) {
311         if (pad_data[i] != 0) {
312           is_invalid = false;
313           break;
314         }
315       }
316     } else {
317       is_invalid = false;
318     }
319   } else {
320     auto pad_prim = utils::cast<std::shared_ptr<mindspore::ops::PadFusion>>(primitive);
321     MS_ASSERT(pad_prim != nullptr);
322     MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPadding) != nullptr, lite::RET_ERROR);
323     auto pad_data = pad_prim->get_paddings();
324     for (size_t i = 0; i < pad_data.size(); i++) {
325       for (size_t j = 0; j < pad_data[i].size(); j++) {
326         if (pad_data[i][j] != 0) {
327           is_invalid = false;
328           break;
329         }
330       }
331       if (is_invalid == false) {
332         break;
333       }
334     }
335   }
336   if (is_invalid) {
337     return ReplaceOp(anf_node, manager);
338   }
339   return lite::RET_OK;
340 }
341 
RemoveInvalidTransposeOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)342 int RemoveRedundantOpPass::RemoveInvalidTransposeOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
343   auto cnode = anf_node->cast<CNodePtr>();
344   MS_ASSERT(cnode != nullptr);
345   if (cnode->size() != kInputSizeThree) {
346     MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
347     return lite::RET_NO_CHANGE;
348   }
349   auto index_node = cnode->inputs()[kInputIndexTwo]->cast<ParameterPtr>();
350   if (index_node == nullptr) {
351     return RET_OK;
352   }
353   auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(index_node->default_param());
354   MS_ASSERT(tensor_info != nullptr);
355   if (tensor_info->Size() != 0) {
356     return RET_OK;
357   }
358   return ReplaceOp(anf_node, manager);
359 }
360 
Run(const FuncGraphPtr & func_graph)361 bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
362   MS_ASSERT(func_graph != nullptr);
363   auto manager = func_graph->manager();
364   MS_ASSERT(manager != nullptr);
365   auto node_list = TopoSort(func_graph->get_return());
366   int status = RET_OK;
367   for (auto &node : node_list) {
368     if (!utils::isa<CNodePtr>(node)) {
369       continue;
370     }
371     if (CheckPrimitiveType(node, kPrimIdentity)) {
372       status = ReplaceOp(node, manager);
373     }
374     if (CheckPrimitiveType(node, prim::kPrimLoad)) {
375       status = ReplaceOp(node, manager);
376     }
377     if (CheckPrimitiveType(node, prim::kPrimUpdateState)) {
378       status = ReplaceUpdateStateOp(func_graph, node);
379     }
380     if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
381       status = ReplaceTupleGetItem(node, manager);
382     }
383     if (!is_train_model_ && CheckPrimitiveType(node, prim::kPrimDropout)) {
384       status = RemoveDropoutOp(node, manager);
385     }
386     if (CheckPrimitiveType(node, prim::kPrimPadFusion)) {
387       status = RemoveInvalidPadOp(node, manager);
388     }
389     if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
390       status = RemoveInvalidTransposeOp(node, manager);
391     }
392     if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
393       auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
394       if (sub_func_graph == nullptr) {
395         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
396         return false;
397       }
398       (void)Run(sub_func_graph);
399       sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(2));
400       if (sub_func_graph == nullptr) {
401         lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
402         return false;
403       }
404       (void)Run(sub_func_graph);
405     }
406     if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
407       MS_LOG(ERROR) << "remove identity pass is failed.";
408       return false;
409     }
410   }
411   for (auto &node : remove_cnode_) {
412     func_graph->DropNode(node);
413   }
414   return true;
415 }
416 }  // namespace mindspore::opt
417