• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/reshape_transpose_fusion.h"
19 #include <numeric>
20 #include <vector>
21 #include <unordered_map>
22 #include "mindspore/core/ops/array_ops.h"
23 #include "ops/op_utils.h"
24 #include "ops/auto_generate/gen_lite_ops.h"
25 #include "tools/lite_exporter/fetch_content.h"
26 #include "tools/optimizer/common/format_utils.h"
27 #include "nnacl/op_base.h"
28 
29 namespace mindspore::opt {
30 namespace {
31 const auto &p1 = std::placeholders::_1;
32 }  // namespace
33 
DefineReshapeTransposePattern() const34 VectorRef ReshapeTransposeFusion::DefineReshapeTransposePattern() const {
35   auto input = std::make_shared<Var>();
36   MS_CHECK_TRUE_RET(input != nullptr, {});
37   auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
38   MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
39   auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
40   MS_CHECK_TRUE_RET(is_const != nullptr, {});
41   auto reshape = VectorRef({is_reshape, input, is_const});
42   auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
43   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
44   auto is_const_perm = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
45   MS_CHECK_TRUE_RET(is_const_perm != nullptr, {});
46   return VectorRef({is_transpose, reshape, is_const_perm});
47 }
48 
DefineTransposeReshapePattern() const49 VectorRef ReshapeTransposeFusion::DefineTransposeReshapePattern() const {
50   auto input = std::make_shared<Var>();
51   MS_CHECK_TRUE_RET(input != nullptr, {});
52   auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
53   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
54   auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
55   MS_CHECK_TRUE_RET(is_const != nullptr, {});
56   auto transpose = VectorRef({is_transpose, input, is_const});
57   auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
58   MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
59   auto is_const_shape = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
60   MS_CHECK_TRUE_RET(is_const_shape != nullptr, {});
61   return VectorRef({is_reshape, transpose, is_const_shape});
62 }
63 
DefinePatterns() const64 std::unordered_map<std::string, VectorRef> ReshapeTransposeFusion::DefinePatterns() const {
65   std::unordered_map<std::string, VectorRef> patterns;
66   patterns["ReshapeTranspose"] = DefineReshapeTransposePattern();
67   patterns["TransposeReshape"] = DefineTransposeReshapePattern();
68   return patterns;
69 }
70 
CheckTransposeCanFused(const FuncGraphPtr & func_graph,const CNodePtr & transpose,const std::vector<int> & perm)71 bool CheckTransposeCanFused(const FuncGraphPtr &func_graph, const CNodePtr &transpose, const std::vector<int> &perm) {
72   MS_ASSERT(func_graph != nullptr && transpose != nullptr);
73   MS_CHECK_TRUE_RET(transpose->size() == kInputSizeThree, false);
74   if (perm == kNH2NC || perm == kNC2NH) {
75     return false;
76   }
77   auto input_abstract = GetCNodeInputAbstract(transpose, 1);
78   MS_CHECK_TRUE_RET(input_abstract != nullptr, false);
79   ShapeVector input_shape;
80   if (FetchShapeFromAbstract(input_abstract, &input_shape) != lite::RET_OK) {
81     MS_LOG(ERROR) << "Get shape from abstract failed.";
82     return false;
83   }
84   auto output_abstract = transpose->abstract();
85   MS_CHECK_TRUE_RET(output_abstract != nullptr, false);
86   ShapeVector output_shape;
87   if (FetchShapeFromAbstract(output_abstract, &output_shape) != lite::RET_OK) {
88     MS_LOG(ERROR) << "Get shape from abstract failed.";
89     return false;
90   }
91   if (input_shape.empty() || std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end() ||
92       output_shape.empty() || std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end()) {
93     MS_LOG(DEBUG) << "The input shape or output shape of transpose is invalid.";
94     return false;
95   }
96   int dim_size = static_cast<int>(input_shape.size());
97   std::vector<int> in_dim_index_valid;
98   for (int i = 0; i < dim_size; ++i) {
99     if (input_shape[i] > 1) {
100       in_dim_index_valid.push_back(i);
101     }
102   }
103   std::vector<int> out_dim_index_valid;
104   for (size_t i = 0; i < perm.size(); ++i) {
105     if (perm[i] < 0 || perm[i] >= dim_size) {
106       return false;
107     }
108     if (input_shape[perm[i]] > 1) {
109       out_dim_index_valid.push_back(perm[i]);
110     }
111   }
112   return in_dim_index_valid == out_dim_index_valid;
113 }
114 
GetShapeOfReshape(const CNodePtr & reshape_cnode,bool * changed=nullptr)115 std::vector<int> GetShapeOfReshape(const CNodePtr &reshape_cnode, bool *changed = nullptr) {
116   MS_ASSERT(reshape_cnode != nullptr);
117   lite::DataInfo data_info;
118   if (lite::FetchConstData(reshape_cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true) != lite::RET_OK) {
119     return {};
120   }
121   MS_CHECK_TRUE_RET(data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32, {});
122   std::vector<int> shape(data_info.data_.size() / C4NUM);
123   if (memcpy_s(shape.data(), shape.size() * sizeof(int), data_info.data_.data(), data_info.data_.size()) != EOK) {
124     return {};
125   }
126   auto abstract = GetCNodeInputAbstract(reshape_cnode, 1);
127   MS_CHECK_TRUE_RET(abstract != nullptr, {});
128   ShapeVector input_shape;
129   if (FetchShapeFromAbstract(abstract, &input_shape) != lite::RET_OK) {
130     MS_LOG(ERROR) << "Get shape from abstract failed.";
131     return {};
132   }
133   for (size_t i = 0; i < input_shape.size() && i < shape.size(); i++) {
134     if (changed != nullptr && !(*changed) && shape[i] == 0) {
135       *changed = true;
136     }
137     shape[i] = shape[i] == 0 ? input_shape[i] : shape[i];
138   }
139   return shape;
140 }
141 
ReshapeTransFusion(const FuncGraphPtr & func_graph,const CNodePtr & transpose) const142 AnfNodePtr ReshapeTransposeFusion::ReshapeTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &transpose) const {
143   MS_ASSERT(func_graph != nullptr && transpose != nullptr);
144   auto reshape = transpose->input(1);
145   MS_CHECK_TRUE_RET(reshape != nullptr, nullptr);
146   auto reshape_cnode = reshape->cast<CNodePtr>();
147   MS_CHECK_TRUE_RET(reshape_cnode != nullptr && reshape_cnode->size() == kInputSizeThree, nullptr);
148   if (CheckPrimitiveType(reshape_cnode->input(1), prim::kPrimTranspose)) {
149     return TransReshapeTransFusion(func_graph, transpose);
150   }
151 
152   if (IsMultiOutputTensors(func_graph, reshape)) {
153     return nullptr;
154   }
155   std::vector<int> perm;
156   if (GetTransposePerm(transpose, &perm) != RET_OK) {
157     MS_LOG(ERROR) << "fetch transpose's perm failed.";
158     return nullptr;
159   }
160   if (!CheckTransposeCanFused(func_graph, transpose, perm)) {
161     return nullptr;
162   }
163   auto shape = GetShapeOfReshape(reshape_cnode);
164   MS_CHECK_TRUE_RET(shape.size() == perm.size(), nullptr);
165   std::vector<int> new_shape(shape.size());
166   for (size_t i = 0; i < perm.size(); i++) {
167     MS_CHECK_TRUE_RET(perm.at(i) >= 0 && static_cast<size_t>(perm.at(i)) < shape.size(), nullptr);
168     new_shape.at(i) = shape.at(perm.at(i));
169   }
170   auto new_shape_param = BuildIntVecParameterNode(func_graph, new_shape, reshape->fullname_with_scope() + "_transpose");
171   MS_CHECK_TRUE_RET(new_shape_param != nullptr, nullptr);
172   auto manager = func_graph->manager();
173   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
174   if (transpose->abstract() != nullptr) {
175     reshape->set_abstract(transpose->abstract()->Clone());
176   }
177   manager->SetEdge(reshape, kInputIndexTwo, new_shape_param);
178   return reshape;
179 }
180 
TransReshapeFusion(const FuncGraphPtr & func_graph,const CNodePtr & reshape_cnode) const181 AnfNodePtr ReshapeTransposeFusion::TransReshapeFusion(const FuncGraphPtr &func_graph,
182                                                       const CNodePtr &reshape_cnode) const {
183   MS_ASSERT(func_graph != nullptr && reshape_cnode != nullptr);
184   MS_CHECK_TRUE_RET(reshape_cnode->size() == kInputSizeThree, nullptr);
185   auto transpose = reshape_cnode->input(1);
186   MS_CHECK_TRUE_RET(transpose != nullptr, nullptr);
187   auto transpose_cnode = transpose->cast<CNodePtr>();
188   MS_CHECK_TRUE_RET(transpose_cnode != nullptr, nullptr);
189   std::vector<int> perm;
190   if (GetTransposePerm(transpose_cnode, &perm) != RET_OK) {
191     MS_LOG(ERROR) << "fetch transpose's perm failed.";
192     return nullptr;
193   }
194   if (!CheckTransposeCanFused(func_graph, transpose_cnode, perm)) {
195     return nullptr;
196   }
197   auto manager = func_graph->manager();
198   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
199   bool changed = false;
200   auto shape = GetShapeOfReshape(reshape_cnode, &changed);
201   if (changed) {
202     MS_CHECK_TRUE_RET(!shape.empty(), nullptr);
203     auto new_shape_param =
204       BuildIntVecParameterNode(func_graph, shape, reshape_cnode->fullname_with_scope() + "_new_shape");
205     MS_CHECK_TRUE_RET(new_shape_param != nullptr, nullptr);
206     manager->SetEdge(reshape_cnode, kInputIndexTwo, new_shape_param);
207   }
208 
209   MS_CHECK_TRUE_RET(transpose_cnode->size() == kInputSizeThree, nullptr);
210   manager->SetEdge(reshape_cnode, 1, transpose_cnode->input(1));
211 
212   return reshape_cnode;
213 }
214 
FindFixedPositionOfReshape(const ShapeVector & input_shape,const ShapeVector & shape,const std::vector<int> & pre_perm,const std::vector<int> & post_perm,std::vector<size_t> * in_pos,std::vector<size_t> * out_pos)215 int FindFixedPositionOfReshape(const ShapeVector &input_shape, const ShapeVector &shape,
216                                const std::vector<int> &pre_perm, const std::vector<int> &post_perm,
217                                std::vector<size_t> *in_pos, std::vector<size_t> *out_pos) {
218   size_t i = 0;
219   size_t j = 0;
220   std::vector<size_t> tmp_in_pos;
221   std::vector<size_t> tmp_out_pos;
222   while (i < input_shape.size() && j < shape.size()) {
223     if (input_shape.at(i) == shape.at(j)) {
224       tmp_in_pos.push_back(i++);
225       tmp_out_pos.push_back(j++);
226     } else {
227       size_t in_num = input_shape.at(i++);
228       size_t out_num = shape.at(j++);
229       while (in_num != out_num) {
230         if (in_num < out_num) {
231           MS_CHECK_TRUE_RET(i < input_shape.size(), lite::RET_ERROR);
232           in_num = in_num * input_shape.at(i++);
233         } else {
234           MS_CHECK_TRUE_RET(j < shape.size(), lite::RET_ERROR);
235           out_num = out_num * shape.at(j++);
236         }
237       }
238     }
239   }
240   for (auto ele : tmp_in_pos) {
241     MS_CHECK_TRUE_RET(ele < pre_perm.size(), lite::RET_ERROR);
242     in_pos->push_back(pre_perm.at(ele));
243   }
244   for (auto ele : tmp_out_pos) {
245     auto itr = std::find(post_perm.begin(), post_perm.end(), ele);
246     MS_CHECK_TRUE_RET(itr != post_perm.end(), lite::RET_ERROR);
247     out_pos->push_back(itr - post_perm.begin());
248   }
249   return lite::RET_OK;
250 }
251 
CheckPermAndShape(const std::vector<int> & input_shape,const std::vector<int> & output_shape,const std::vector<int> & pre_perm,const std::vector<int> & post_perm,const std::vector<size_t> & in_fixed_pos,const std::vector<size_t> & out_fixed_pos)252 bool CheckPermAndShape(const std::vector<int> &input_shape, const std::vector<int> &output_shape,
253                        const std::vector<int> &pre_perm, const std::vector<int> &post_perm,
254                        const std::vector<size_t> &in_fixed_pos, const std::vector<size_t> &out_fixed_pos) {
255   if (in_fixed_pos.empty() || out_fixed_pos.empty()) {
256     return false;
257   }
258   for (size_t i = 0; i < in_fixed_pos.size() || i < out_fixed_pos.size(); i++) {
259     size_t pre_num = 1;
260     auto in_begin = i < in_fixed_pos.size() ? in_fixed_pos.at(i) : input_shape.size() - 1;
261     auto in_end = i < in_fixed_pos.size() - 1 ? in_fixed_pos.at(i + 1) : input_shape.size();
262     auto itr = std::find(pre_perm.begin(), pre_perm.end(), in_begin + 1) - 1;
263     for (auto j = in_begin + 1; j < in_end && j < input_shape.size(); j++) {
264       auto tmp_itr = std::find(pre_perm.begin(), pre_perm.end(), j);
265       if (tmp_itr - itr != 1) {
266         return false;
267       }
268       itr = tmp_itr;
269 
270       MS_CHECK_INT_MUL_NOT_OVERFLOW(static_cast<int>(pre_num), static_cast<int>(input_shape.at(j)), false);
271       pre_num *= input_shape.at(j);
272     }
273 
274     size_t post_num = 1;
275     auto out_begin = i < out_fixed_pos.size() ? out_fixed_pos.at(i) : output_shape.size() - 1;
276     auto out_end = i < out_fixed_pos.size() - 1 ? out_fixed_pos.at(i + 1) : output_shape.size();
277     auto pos = out_begin + 1 < post_perm.size() ? post_perm.at(out_begin + 1) - 1 : 0;
278     for (auto j = out_begin + 1; j < out_end && j < output_shape.size(); j++) {
279       auto tmp_pos = post_perm.at(j);
280       if (tmp_pos - pos != 1) {
281         return false;
282       }
283       pos = tmp_pos;
284 
285       MS_CHECK_INT_MUL_NOT_OVERFLOW(static_cast<int>(post_num), static_cast<int>(output_shape.at(j)), false);
286       post_num *= output_shape.at(j);
287     }
288     if (pre_num != post_num) {
289       return false;
290     }
291   }
292   return true;
293 }
294 
CheckTransReshapeTransCanFused(const ShapeVector & input_shape,const ShapeVector & output_shape,const std::vector<int> & pre_perm,const std::vector<int> & post_perm)295 bool CheckTransReshapeTransCanFused(const ShapeVector &input_shape, const ShapeVector &output_shape,
296                                     const std::vector<int> &pre_perm, const std::vector<int> &post_perm) {
297   if (input_shape.size() != pre_perm.size() || output_shape.size() != post_perm.size()) {
298     return false;
299   }
300   std::vector<size_t> in_fixed_pos;
301   std::vector<size_t> out_fixed_pos;
302   if (FindFixedPositionOfReshape(input_shape, output_shape, pre_perm, post_perm, &in_fixed_pos, &out_fixed_pos) !=
303       lite::RET_OK) {
304     MS_LOG(ERROR) << "Find fixed position of reshape failed.";
305     return false;
306   }
307 
308   std::vector<int> pre_trans_in_shape;
309   for (int i = 0; i < static_cast<int>(pre_perm.size()); i++) {
310     auto itr = std::find(pre_perm.begin(), pre_perm.end(), i);
311     MS_CHECK_TRUE_RET(itr != pre_perm.end(), false);
312     pre_trans_in_shape.push_back(input_shape.at(itr - pre_perm.begin()));
313   }
314   std::vector<int> trans_out_shape;
315   for (auto dim : post_perm) {
316     MS_CHECK_TRUE_RET(static_cast<size_t>(dim) < output_shape.size(), false);
317     trans_out_shape.push_back(output_shape.at(static_cast<size_t>(dim)));
318   }
319 
320   return CheckPermAndShape(pre_trans_in_shape, trans_out_shape, pre_perm, post_perm, in_fixed_pos, out_fixed_pos);
321 }
322 
DealReshapeWithMultiOutputs(const FuncGraphPtr & func_graph,const CNodePtr & reshape,const CNodePtr & transpose,const std::vector<int> & post_perm)323 STATUS DealReshapeWithMultiOutputs(const FuncGraphPtr &func_graph, const CNodePtr &reshape, const CNodePtr &transpose,
324                                    const std::vector<int> &post_perm) {
325   MS_ASSERT(func_graph != nullptr && reshape != nullptr && transpose != nullptr);
326   std::vector<int> new_perm;
327   std::vector<int> tmp_perm(post_perm.size());
328   std::iota(tmp_perm.begin(), tmp_perm.end(), 0);
329   for (auto ele : tmp_perm) {
330     auto itr = std::find(post_perm.begin(), post_perm.end(), ele);
331     MS_CHECK_TRUE_RET(itr != post_perm.end(), lite::RET_ERROR);
332     new_perm.push_back(itr - post_perm.begin());
333   }
334   auto insert_trans_perm_param = BuildIntVecParameterNode(func_graph, new_perm, "transpose_perm");
335   MS_CHECK_TRUE_RET(insert_trans_perm_param != nullptr, lite::RET_ERROR);
336   auto transpose_prim = std::make_shared<ops::Transpose>();
337   if (transpose_prim == nullptr) {
338     MS_LOG(ERROR) << "Build reshape primitive failed.";
339     return lite::RET_ERROR;
340   }
341   auto transpose_prim_c = transpose_prim->GetPrim();
342   MS_CHECK_TRUE_RET(transpose_prim_c != nullptr, lite::RET_ERROR);
343   auto value_node = NewValueNode(transpose_prim_c);
344   MS_CHECK_TRUE_RET(value_node != nullptr, lite::RET_ERROR);
345   auto new_trans = func_graph->NewCNode({value_node, transpose, insert_trans_perm_param});
346   MS_CHECK_TRUE_RET(new_trans != nullptr, lite::RET_ERROR);
347   auto output_node_list = GetRealNodeUsedList(func_graph, reshape);
348 
349   auto manager = func_graph->manager();
350   MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR);
351   for (auto output_node_pair : *output_node_list) {
352     if (output_node_pair.first != transpose) {
353       manager->SetEdge(output_node_pair.first, output_node_pair.second, new_trans);
354     }
355   }
356   return lite::RET_OK;
357 }
358 
TransReshapeTransFusion(const FuncGraphPtr & func_graph,const CNodePtr & trans_cnode) const359 AnfNodePtr ReshapeTransposeFusion::TransReshapeTransFusion(const FuncGraphPtr &func_graph,
360                                                            const CNodePtr &trans_cnode) const {
361   MS_ASSERT(func_graph != nullptr && trans_cnode != nullptr);
362   MS_CHECK_TRUE_RET(trans_cnode->size() == kInputSizeThree, nullptr);
363   auto reshape = trans_cnode->input(1);
364   MS_CHECK_TRUE_RET(reshape != nullptr, nullptr);
365 
366   auto reshape_cnode = reshape->cast<CNodePtr>();
367   MS_CHECK_TRUE_RET(reshape_cnode != nullptr && reshape_cnode->size() == kInputSizeThree, nullptr);
368   auto pre_trans = reshape_cnode->input(1);
369   MS_CHECK_TRUE_RET(pre_trans != nullptr, nullptr);
370   if (IsMultiOutputTensors(func_graph, pre_trans)) {
371     return nullptr;
372   }
373 
374   auto abstract = pre_trans->abstract();
375   MS_CHECK_TRUE_RET(abstract != nullptr, nullptr);
376   ShapeVector input_shape;
377   if (FetchShapeFromAbstract(abstract, &input_shape) != lite::RET_OK) {
378     MS_LOG(ERROR) << "Get shape from abstract failed.";
379     return nullptr;
380   }
381   auto reshape_abstract = reshape->abstract();
382   MS_CHECK_TRUE_RET(reshape_abstract != nullptr, nullptr);
383   ShapeVector output_shape;
384   if (FetchShapeFromAbstract(reshape_abstract, &output_shape) != lite::RET_OK) {
385     MS_LOG(ERROR) << "Get shape from abstract failed.";
386     return nullptr;
387   }
388   if (input_shape.empty() || std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end() ||
389       output_shape.empty() || std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end()) {
390     MS_LOG(DEBUG) << "The input shape or output shape of reshape is invalid.";
391     return nullptr;
392   }
393 
394   std::vector<int> pre_perm;
395   auto pre_trans_cnode = pre_trans->cast<CNodePtr>();
396   MS_CHECK_TRUE_RET(pre_trans_cnode != nullptr, nullptr);
397   if (GetTransposePerm(pre_trans_cnode, &pre_perm) != RET_OK) {
398     MS_LOG(ERROR) << "fetch transpose's perm failed.";
399     return nullptr;
400   }
401   std::vector<int> post_perm;
402   if (GetTransposePerm(trans_cnode, &post_perm) != RET_OK) {
403     MS_LOG(ERROR) << "fetch transpose's perm failed.";
404     return nullptr;
405   }
406 
407   if (!CheckTransReshapeTransCanFused(input_shape, output_shape, pre_perm, post_perm)) {
408     return nullptr;
409   }
410 
411   if (IsMultiOutputTensors(func_graph, reshape) &&
412       DealReshapeWithMultiOutputs(func_graph, reshape_cnode, trans_cnode, post_perm) != lite::RET_OK) {
413     MS_LOG(ERROR) << "deal with multi-output reshape failed.";
414     return nullptr;
415   }
416   std::vector<int> new_shape;
417   for (auto ele : post_perm) {
418     MS_CHECK_TRUE_RET(static_cast<size_t>(ele) < output_shape.size(), nullptr);
419     new_shape.push_back(output_shape.at(ele));
420   }
421   auto new_shape_param =
422     BuildIntVecParameterNode(func_graph, new_shape, reshape_cnode->fullname_with_scope() + "new_shape");
423   MS_CHECK_TRUE_RET(new_shape_param != nullptr, nullptr);
424   if (trans_cnode->abstract() != nullptr) {
425     reshape->set_abstract(trans_cnode->abstract()->Clone());
426   }
427   auto manager = func_graph->manager();
428   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
429   manager->SetEdge(reshape, 1, pre_trans_cnode->input(1));
430   manager->SetEdge(reshape, kInputIndexTwo, new_shape_param);
431   return reshape;
432 }
433 
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const434 AnfNodePtr ReshapeTransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
435                                            const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
436   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
437     lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
438     return nullptr;
439   }
440   auto cnode = node->cast<CNodePtr>();
441   MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
442   if (IsMarkedTrainOp(cnode)) {
443     return nullptr;
444   }
445   if (pattern_name == "ReshapeTranspose") {
446     return ReshapeTransFusion(func_graph, cnode);
447   } else if (pattern_name == "TransposeReshape") {
448     return TransReshapeFusion(func_graph, cnode);
449   }
450 
451   return nullptr;
452 }
453 }  // namespace mindspore::opt
454