1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
18 #include <memory>
19 #include <vector>
20 #include <utility>
21 #include "include/errorcode.h"
22 #include "tools/anf_exporter/fetch_content.h"
23 #include "tools/converter/ops/ops_def.h"
24 #include "ops/depend.h"
25 #include "ops/fusion/pad_fusion.h"
26 #include "ops/op_utils.h"
27 #include "nnacl/op_base.h"
28
29 namespace mindspore::opt {
30 namespace {
ProcessInputIsMonad(const FuncGraphPtr & func_graph,const CNodePtr & cnode)31 int ProcessInputIsMonad(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
32 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
33 auto first_input = cnode->input(1);
34 MS_ASSERT(first_input != nullptr);
35 if (CheckPrimitiveType(first_input, prim::kPrimTranspose)) {
36 first_input = cnode->input(1)->cast<CNodePtr>()->input(1);
37 MS_CHECK_TRUE_MSG(first_input != nullptr, RET_ERROR, "first_input is nullptr");
38 }
39 auto second_input = cnode->input(kInputIndexTwo);
40 MS_ASSERT(seconde_input != nullptr);
41 if (CheckPrimitiveType(second_input, prim::kPrimTranspose)) {
42 second_input = cnode->input(kInputIndexTwo)->cast<CNodePtr>()->input(1);
43 MS_CHECK_TRUE_MSG(second_input != nullptr, RET_ERROR, "second_input is nullptr");
44 }
45 AnfNodePtr must_monad = nullptr;
46 AnfNodePtr not_must_monad = nullptr;
47 if (utils::isa<ValueNode>(first_input)) {
48 auto value_node = first_input->cast<ValueNodePtr>();
49 MS_ASSERT(value_node->value() != nullptr);
50 if (utils::isa<Monad>(value_node->value())) {
51 must_monad = first_input;
52 not_must_monad = second_input;
53 }
54 }
55 if (utils::isa<ValueNode>(second_input)) {
56 auto value_node = second_input->cast<ValueNodePtr>();
57 MS_ASSERT(value_node->value() != nullptr);
58 if (utils::isa<Monad>(value_node->value())) {
59 must_monad = second_input;
60 not_must_monad = first_input;
61 }
62 }
63 if (must_monad == nullptr) {
64 return lite::RET_NO_CHANGE;
65 }
66 auto manager = func_graph->manager();
67 MS_ASSERT(manager != nullptr);
68 if (!utils::isa<CNode>(not_must_monad) || CheckIsAllInputsParam(not_must_monad)) {
69 manager->Replace(cnode, must_monad);
70 } else {
71 manager->Replace(cnode, not_must_monad);
72 }
73 return lite::RET_OK;
74 }
75
ProcessDependencyWithTwoNodes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool pre_node_is_first)76 int ProcessDependencyWithTwoNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, bool pre_node_is_first) {
77 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
78 AnfNodePtr pre_node = cnode->input(1);
79 AnfNodePtr post_node = cnode->input(kInputIndexTwo);
80 MS_ASSERT(pre_node != nullptr);
81 MS_ASSERT(post_node != nullptr);
82 if (!pre_node_is_first) {
83 pre_node = cnode->input(kInputIndexTwo);
84 post_node = cnode->input(1);
85 }
86 if (CheckPrimitiveType(pre_node, prim::kPrimTranspose)) {
87 pre_node = cnode->input(1)->cast<CNodePtr>()->input(1);
88 MS_CHECK_TRUE_MSG(pre_node != nullptr, RET_ERROR, "pre_node is nullptr");
89 }
90 if (CheckPrimitiveType(post_node, prim::kPrimTranspose)) {
91 post_node = cnode->input(kInputIndexTwo)->cast<CNodePtr>()->input(1);
92 MS_CHECK_TRUE_MSG(post_node != nullptr, RET_ERROR, "post_node is nullptr");
93 }
94 auto manager = func_graph->manager();
95 MS_ASSERT(manager != nullptr);
96 auto node_users = manager->node_users()[pre_node];
97 auto iter =
98 std::find_if(node_users.begin(), node_users.end(),
99 [&post_node](const std::pair<AnfNodePtr, int> &post_pair) { return post_pair.first == post_node; });
100 if (iter == node_users.end()) {
101 return lite::RET_NO_CHANGE;
102 }
103 auto tr = manager->Transact();
104 tr.SetEdge(post_node, iter->second, NewValueNode(std::make_shared<UMonad>()));
105 tr.Commit();
106 auto depend_prim = std::make_shared<ops::Depend>();
107 auto depend_node = func_graph->NewCNode(depend_prim, {post_node, pre_node});
108 MS_CHECK_TRUE_MSG(depend_prim != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
109 MS_CHECK_TRUE_MSG(depend_node != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
110 depend_node->set_fullname_with_scope(cnode->fullname_with_scope());
111 manager->Replace(cnode, depend_node);
112 return lite::RET_OK;
113 }
114
ProcessInputHaveDependency(const FuncGraphPtr & func_graph,const CNodePtr & cnode)115 int ProcessInputHaveDependency(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
116 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
117 if (ProcessDependencyWithTwoNodes(func_graph, cnode, true) == lite::RET_OK) {
118 return lite::RET_OK;
119 }
120 if (ProcessDependencyWithTwoNodes(func_graph, cnode, false) == lite::RET_OK) {
121 return lite::RET_OK;
122 }
123 auto make_tuple_prim = NewValueNode(std::make_shared<lite::MakeTuple>());
124 auto manager = func_graph->manager();
125 MS_CHECK_TRUE_MSG(make_tuple_prim != nullptr, lite::RET_NULL_PTR, "NewCNode Failed");
126 MS_ASSERT(manager != nullptr);
127 if (CheckPrimitiveType(cnode->input(0), prim::kPrimTranspose)) {
128 manager->Replace(cnode->input(0)->cast<CNodePtr>()->input(0), make_tuple_prim);
129 return RET_OK;
130 }
131 manager->Replace(cnode->input(0), make_tuple_prim);
132 return lite::RET_OK;
133 }
134 } // namespace
135
ReplaceOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)136 int RemoveRedundantOpPass::ReplaceOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
137 MS_CHECK_TRUE_MSG(anf_node != nullptr, RET_ERROR, "anf_node is nullptr");
138 MS_CHECK_TRUE_MSG(manager != nullptr, RET_ERROR, "manager is nullptr");
139 if (!utils::isa<CNodePtr>(anf_node)) {
140 MS_LOG(DEBUG) << "anf node is node a cnode.";
141 return lite::RET_NO_CHANGE;
142 }
143 auto cnode = anf_node->cast<CNodePtr>();
144 MS_ASSERT(cnode != nullptr);
145 if (CheckPrimitiveType(anf_node, kPrimIdentity)) {
146 if (cnode->size() != kInputSizeTwo) {
147 MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
148 remove_cnode_.insert(anf_node);
149 return lite::RET_NO_CHANGE;
150 }
151 }
152 if (CheckPrimitiveType(anf_node, prim::kPrimDepend)) {
153 if (cnode->size() != kInputSizeTwo) {
154 MS_LOG(DEBUG) << "The node inputs size is bigger than 1";
155 remove_cnode_.insert(anf_node);
156 return lite::RET_NO_CHANGE;
157 }
158 }
159 if (CheckPrimitiveType(anf_node, prim::kPrimTranspose)) {
160 if (cnode->size() != kInputSizeThree) {
161 MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
162 remove_cnode_.insert(anf_node);
163 return lite::RET_NO_CHANGE;
164 }
165 }
166
167 bool replace_succ = manager->Replace(anf_node, cnode->input(1));
168 if (!replace_succ) {
169 MS_LOG(ERROR) << "replace redundant op failed.";
170 return lite::RET_ERROR;
171 }
172 return RET_OK;
173 }
174
ReplaceUpdateStateOp(const FuncGraphPtr & func_graph,const AnfNodePtr & anf_node)175 int RemoveRedundantOpPass::ReplaceUpdateStateOp(const FuncGraphPtr &func_graph, const AnfNodePtr &anf_node) {
176 if (!utils::isa<CNodePtr>(anf_node)) {
177 MS_LOG(DEBUG) << "anf node is node a cnode.";
178 return lite::RET_NO_CHANGE;
179 }
180 auto cnode = anf_node->cast<CNodePtr>();
181 MS_ASSERT(cnode != nullptr);
182 if (ProcessInputIsMonad(func_graph, cnode) == lite::RET_OK) {
183 return lite::RET_OK;
184 }
185 // both of two inputs are not monad, but have dependency.
186 return ProcessInputHaveDependency(func_graph, cnode);
187 }
188
ReplaceTupleGetItem(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)189 int RemoveRedundantOpPass::ReplaceTupleGetItem(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
190 if (!utils::isa<CNodePtr>(anf_node)) {
191 MS_LOG(DEBUG) << "anf node is node a cnode.";
192 return lite::RET_NO_CHANGE;
193 }
194 if (!CheckPrimitiveType(anf_node, prim::kPrimTupleGetItem)) {
195 return lite::RET_NO_CHANGE;
196 }
197 auto cnode = anf_node->cast<CNodePtr>();
198 MS_ASSERT(cnode != nullptr);
199 if (cnode->inputs().size() != kInputSizeThree) {
200 MS_LOG(ERROR) << "TupleGetItem should have 3 inputs, got " << cnode->inputs().size();
201 return RET_ERROR;
202 }
203 if (!CheckPrimitiveType(cnode->input(1), kPrimIdentity)) {
204 return lite::RET_NO_CHANGE;
205 }
206 auto get_item_input_cnode = cnode->input(1)->cast<CNodePtr>();
207 auto index_vnode = cnode->input(kInputIndexTwo);
208 if (!utils::isa<ValueNode>(index_vnode)) {
209 MS_LOG(ERROR) << "TupleGetItem's input 2 is not valuenode";
210 return lite::RET_ERROR;
211 }
212 MS_CHECK_TRUE_MSG(!CastToInt(index_vnode->cast<ValueNodePtr>()->value()).empty(), RET_ERROR, "value is empty");
213 int index = CastToInt(index_vnode->cast<ValueNodePtr>()->value()).front();
214 int input_cnode_inputs_size = get_item_input_cnode->inputs().size();
215 if ((index + 1) >= input_cnode_inputs_size) {
216 MS_LOG(ERROR) << "value node index is out of range.";
217 return lite::RET_ERROR;
218 }
219 bool replace_succ = manager->Replace(anf_node, get_item_input_cnode->input(index + 1));
220 if (!replace_succ) {
221 MS_LOG(ERROR) << "replace identity failed.";
222 return lite::RET_ERROR;
223 }
224 return lite::RET_OK;
225 }
226
RemoveDropoutOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)227 int RemoveRedundantOpPass::RemoveDropoutOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
228 MS_ASSERT(anf_node != nullptr);
229 MS_ASSERT(manager != nullptr);
230 if (!utils::isa<CNodePtr>(anf_node)) {
231 MS_LOG(DEBUG) << "anf node is node a cnode.";
232 return lite::RET_NO_CHANGE;
233 }
234 auto cnode = anf_node->cast<CNodePtr>();
235 MS_ASSERT(cnode != nullptr);
236 if (cnode->size() > kInputSizeTwo) {
237 MS_LOG(ERROR) << "dropout input invalid.";
238 return lite::RET_ERROR;
239 }
240 if (!utils::isa<abstract::AbstractTuplePtr>(anf_node->abstract())) {
241 MS_LOG(DEBUG) << "dropout output size is one.";
242 manager->Replace(anf_node, cnode->input(1));
243 } else {
244 auto node_users = manager->node_users()[anf_node];
245 for (auto &node_user : node_users) {
246 auto node = node_user.first;
247 if (!CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
248 MS_LOG(ERROR) << "dropout out node is invalid.";
249 return lite::RET_ERROR;
250 }
251 auto get_index_node = node->cast<CNodePtr>()->input(kInputIndexTwo)->cast<ValueNodePtr>();
252 if (get_index_node == nullptr) {
253 MS_LOG(ERROR) << "tuple get item node is invalid.";
254 return lite::RET_ERROR;
255 }
256 auto get_index = CastToInt(get_index_node->value()).front();
257 if (get_index > 0 && !manager->node_users()[node].empty()) {
258 MS_LOG(ERROR) << "dropout's second output is useful.";
259 return lite::RET_ERROR;
260 }
261 manager->Replace(node, cnode->input(1));
262 }
263 }
264 return lite::RET_OK;
265 }
266
GetConstDataFromInputNode(const CNodePtr & cnode,lite::DataInfo * data_info)267 int RemoveRedundantOpPass::GetConstDataFromInputNode(const CNodePtr &cnode, lite::DataInfo *data_info) {
268 MS_ASSERT(cnode != nullptr);
269 MS_ASSERT(data_info != nullptr);
270 auto padding_node = cnode->input(kInputIndexTwo);
271 MS_ASSERT(padding_node != nullptr);
272 if (utils::isa<Parameter>(padding_node)) {
273 auto status = lite::FetchDataFromParameterNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
274 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
275 MS_LOG(ERROR) << "fetch data from parameter node failed.";
276 return lite::RET_ERROR;
277 }
278 } else if (utils::isa<ValueNode>(padding_node)) {
279 auto status = lite::FetchDataFromValueNode(cnode, 2, converter::kFmkTypeMs, false, data_info);
280 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
281 MS_LOG(ERROR) << "fetch data from value node failed.";
282 return lite::RET_ERROR;
283 }
284 }
285 return lite::RET_OK;
286 }
287
RemoveInvalidPadOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)288 int RemoveRedundantOpPass::RemoveInvalidPadOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
289 if (!utils::isa<CNodePtr>(anf_node)) {
290 MS_LOG(DEBUG) << "anf node is node a cnode.";
291 return lite::RET_NO_CHANGE;
292 }
293 auto cnode = anf_node->cast<CNodePtr>();
294 MS_ASSERT(cnode != nullptr);
295 auto primitive = GetValueNode<mindspore::PrimitivePtr>(cnode->input(0));
296 if (primitive == nullptr) {
297 MS_LOG(ERROR) << "primitive is nullptr:" << cnode->fullname_with_scope();
298 return lite::RET_NO_CHANGE;
299 }
300 auto is_invalid = true;
301 if (cnode->size() > kInputSizeTwo) {
302 lite::DataInfo data_info;
303 if (GetConstDataFromInputNode(cnode, &data_info) != RET_OK) {
304 MS_LOG(ERROR) << "Get pad data failed.";
305 return lite::RET_ERROR;
306 }
307 if (!data_info.data_.empty()) {
308 auto pad_data = reinterpret_cast<int *>(data_info.data_.data());
309 size_t num = data_info.data_.size() / sizeof(int);
310 for (size_t i = 0; i < num; ++i) {
311 if (pad_data[i] != 0) {
312 is_invalid = false;
313 break;
314 }
315 }
316 } else {
317 is_invalid = false;
318 }
319 } else {
320 auto pad_prim = utils::cast<std::shared_ptr<mindspore::ops::PadFusion>>(primitive);
321 MS_ASSERT(pad_prim != nullptr);
322 MS_CHECK_TRUE_RET(pad_prim->GetAttr(ops::kPadding) != nullptr, lite::RET_ERROR);
323 auto pad_data = pad_prim->get_paddings();
324 for (size_t i = 0; i < pad_data.size(); i++) {
325 for (size_t j = 0; j < pad_data[i].size(); j++) {
326 if (pad_data[i][j] != 0) {
327 is_invalid = false;
328 break;
329 }
330 }
331 if (is_invalid == false) {
332 break;
333 }
334 }
335 }
336 if (is_invalid) {
337 return ReplaceOp(anf_node, manager);
338 }
339 return lite::RET_OK;
340 }
341
RemoveInvalidTransposeOp(const AnfNodePtr & anf_node,const FuncGraphManagerPtr & manager)342 int RemoveRedundantOpPass::RemoveInvalidTransposeOp(const AnfNodePtr &anf_node, const FuncGraphManagerPtr &manager) {
343 auto cnode = anf_node->cast<CNodePtr>();
344 MS_ASSERT(cnode != nullptr);
345 if (cnode->size() != kInputSizeThree) {
346 MS_LOG(DEBUG) << "The node inputs size is bigger than 2";
347 return lite::RET_NO_CHANGE;
348 }
349 auto index_node = cnode->inputs()[kInputIndexTwo]->cast<ParameterPtr>();
350 if (index_node == nullptr) {
351 return RET_OK;
352 }
353 auto tensor_info = std::dynamic_pointer_cast<tensor::Tensor>(index_node->default_param());
354 MS_ASSERT(tensor_info != nullptr);
355 if (tensor_info->Size() != 0) {
356 return RET_OK;
357 }
358 return ReplaceOp(anf_node, manager);
359 }
360
Run(const FuncGraphPtr & func_graph)361 bool RemoveRedundantOpPass::Run(const FuncGraphPtr &func_graph) {
362 MS_ASSERT(func_graph != nullptr);
363 auto manager = func_graph->manager();
364 MS_ASSERT(manager != nullptr);
365 auto node_list = TopoSort(func_graph->get_return());
366 int status = RET_OK;
367 for (auto &node : node_list) {
368 if (!utils::isa<CNodePtr>(node)) {
369 continue;
370 }
371 if (CheckPrimitiveType(node, kPrimIdentity)) {
372 status = ReplaceOp(node, manager);
373 }
374 if (CheckPrimitiveType(node, prim::kPrimLoad)) {
375 status = ReplaceOp(node, manager);
376 }
377 if (CheckPrimitiveType(node, prim::kPrimUpdateState)) {
378 status = ReplaceUpdateStateOp(func_graph, node);
379 }
380 if (CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
381 status = ReplaceTupleGetItem(node, manager);
382 }
383 if (!is_train_model_ && CheckPrimitiveType(node, prim::kPrimDropout)) {
384 status = RemoveDropoutOp(node, manager);
385 }
386 if (CheckPrimitiveType(node, prim::kPrimPadFusion)) {
387 status = RemoveInvalidPadOp(node, manager);
388 }
389 if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
390 status = RemoveInvalidTransposeOp(node, manager);
391 }
392 if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) {
393 auto sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(1));
394 if (sub_func_graph == nullptr) {
395 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
396 return false;
397 }
398 (void)Run(sub_func_graph);
399 sub_func_graph = GetValueNode<FuncGraphPtr>(node->cast<CNodePtr>()->input(2));
400 if (sub_func_graph == nullptr) {
401 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
402 return false;
403 }
404 (void)Run(sub_func_graph);
405 }
406 if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
407 MS_LOG(ERROR) << "remove identity pass is failed.";
408 return false;
409 }
410 }
411 for (auto &node : remove_cnode_) {
412 func_graph->DropNode(node);
413 }
414 return true;
415 }
416 } // namespace mindspore::opt
417