1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/reshape_transpose_fusion.h"
19 #include <numeric>
20 #include <vector>
21 #include <unordered_map>
22 #include "mindspore/core/ops/array_ops.h"
23 #include "ops/op_utils.h"
24 #include "ops/auto_generate/gen_lite_ops.h"
25 #include "tools/lite_exporter/fetch_content.h"
26 #include "tools/optimizer/common/format_utils.h"
27 #include "nnacl/op_base.h"
28
29 namespace mindspore::opt {
30 namespace {
31 const auto &p1 = std::placeholders::_1;
32 } // namespace
33
DefineReshapeTransposePattern() const34 VectorRef ReshapeTransposeFusion::DefineReshapeTransposePattern() const {
35 auto input = std::make_shared<Var>();
36 MS_CHECK_TRUE_RET(input != nullptr, {});
37 auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
38 MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
39 auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
40 MS_CHECK_TRUE_RET(is_const != nullptr, {});
41 auto reshape = VectorRef({is_reshape, input, is_const});
42 auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
43 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
44 auto is_const_perm = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
45 MS_CHECK_TRUE_RET(is_const_perm != nullptr, {});
46 return VectorRef({is_transpose, reshape, is_const_perm});
47 }
48
DefineTransposeReshapePattern() const49 VectorRef ReshapeTransposeFusion::DefineTransposeReshapePattern() const {
50 auto input = std::make_shared<Var>();
51 MS_CHECK_TRUE_RET(input != nullptr, {});
52 auto is_transpose = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose));
53 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
54 auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
55 MS_CHECK_TRUE_RET(is_const != nullptr, {});
56 auto transpose = VectorRef({is_transpose, input, is_const});
57 auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape));
58 MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
59 auto is_const_shape = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
60 MS_CHECK_TRUE_RET(is_const_shape != nullptr, {});
61 return VectorRef({is_reshape, transpose, is_const_shape});
62 }
63
DefinePatterns() const64 std::unordered_map<std::string, VectorRef> ReshapeTransposeFusion::DefinePatterns() const {
65 std::unordered_map<std::string, VectorRef> patterns;
66 patterns["ReshapeTranspose"] = DefineReshapeTransposePattern();
67 patterns["TransposeReshape"] = DefineTransposeReshapePattern();
68 return patterns;
69 }
70
CheckTransposeCanFused(const FuncGraphPtr & func_graph,const CNodePtr & transpose,const std::vector<int> & perm)71 bool CheckTransposeCanFused(const FuncGraphPtr &func_graph, const CNodePtr &transpose, const std::vector<int> &perm) {
72 MS_ASSERT(func_graph != nullptr && transpose != nullptr);
73 MS_CHECK_TRUE_RET(transpose->size() == kInputSizeThree, false);
74 if (perm == kNH2NC || perm == kNC2NH) {
75 return false;
76 }
77 auto input_abstract = GetCNodeInputAbstract(transpose, 1);
78 MS_CHECK_TRUE_RET(input_abstract != nullptr, false);
79 ShapeVector input_shape;
80 if (FetchShapeFromAbstract(input_abstract, &input_shape) != lite::RET_OK) {
81 MS_LOG(ERROR) << "Get shape from abstract failed.";
82 return false;
83 }
84 auto output_abstract = transpose->abstract();
85 MS_CHECK_TRUE_RET(output_abstract != nullptr, false);
86 ShapeVector output_shape;
87 if (FetchShapeFromAbstract(output_abstract, &output_shape) != lite::RET_OK) {
88 MS_LOG(ERROR) << "Get shape from abstract failed.";
89 return false;
90 }
91 if (input_shape.empty() || std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end() ||
92 output_shape.empty() || std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end()) {
93 MS_LOG(DEBUG) << "The input shape or output shape of transpose is invalid.";
94 return false;
95 }
96 int dim_size = static_cast<int>(input_shape.size());
97 std::vector<int> in_dim_index_valid;
98 for (int i = 0; i < dim_size; ++i) {
99 if (input_shape[i] > 1) {
100 in_dim_index_valid.push_back(i);
101 }
102 }
103 std::vector<int> out_dim_index_valid;
104 for (size_t i = 0; i < perm.size(); ++i) {
105 if (perm[i] < 0 || perm[i] >= dim_size) {
106 return false;
107 }
108 if (input_shape[perm[i]] > 1) {
109 out_dim_index_valid.push_back(perm[i]);
110 }
111 }
112 return in_dim_index_valid == out_dim_index_valid;
113 }
114
GetShapeOfReshape(const CNodePtr & reshape_cnode,bool * changed=nullptr)115 std::vector<int> GetShapeOfReshape(const CNodePtr &reshape_cnode, bool *changed = nullptr) {
116 MS_ASSERT(reshape_cnode != nullptr);
117 lite::DataInfo data_info;
118 if (lite::FetchConstData(reshape_cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true) != lite::RET_OK) {
119 return {};
120 }
121 MS_CHECK_TRUE_RET(data_info.data_type_ == kNumberTypeInt || data_info.data_type_ == kNumberTypeInt32, {});
122 std::vector<int> shape(data_info.data_.size() / C4NUM);
123 if (memcpy_s(shape.data(), shape.size() * sizeof(int), data_info.data_.data(), data_info.data_.size()) != EOK) {
124 return {};
125 }
126 auto abstract = GetCNodeInputAbstract(reshape_cnode, 1);
127 MS_CHECK_TRUE_RET(abstract != nullptr, {});
128 ShapeVector input_shape;
129 if (FetchShapeFromAbstract(abstract, &input_shape) != lite::RET_OK) {
130 MS_LOG(ERROR) << "Get shape from abstract failed.";
131 return {};
132 }
133 for (size_t i = 0; i < input_shape.size() && i < shape.size(); i++) {
134 if (changed != nullptr && !(*changed) && shape[i] == 0) {
135 *changed = true;
136 }
137 shape[i] = shape[i] == 0 ? input_shape[i] : shape[i];
138 }
139 return shape;
140 }
141
ReshapeTransFusion(const FuncGraphPtr & func_graph,const CNodePtr & transpose) const142 AnfNodePtr ReshapeTransposeFusion::ReshapeTransFusion(const FuncGraphPtr &func_graph, const CNodePtr &transpose) const {
143 MS_ASSERT(func_graph != nullptr && transpose != nullptr);
144 auto reshape = transpose->input(1);
145 MS_CHECK_TRUE_RET(reshape != nullptr, nullptr);
146 auto reshape_cnode = reshape->cast<CNodePtr>();
147 MS_CHECK_TRUE_RET(reshape_cnode != nullptr && reshape_cnode->size() == kInputSizeThree, nullptr);
148 if (CheckPrimitiveType(reshape_cnode->input(1), prim::kPrimTranspose)) {
149 return TransReshapeTransFusion(func_graph, transpose);
150 }
151
152 if (IsMultiOutputTensors(func_graph, reshape)) {
153 return nullptr;
154 }
155 std::vector<int> perm;
156 if (GetTransposePerm(transpose, &perm) != RET_OK) {
157 MS_LOG(ERROR) << "fetch transpose's perm failed.";
158 return nullptr;
159 }
160 if (!CheckTransposeCanFused(func_graph, transpose, perm)) {
161 return nullptr;
162 }
163 auto shape = GetShapeOfReshape(reshape_cnode);
164 MS_CHECK_TRUE_RET(shape.size() == perm.size(), nullptr);
165 std::vector<int> new_shape(shape.size());
166 for (size_t i = 0; i < perm.size(); i++) {
167 MS_CHECK_TRUE_RET(perm.at(i) >= 0 && static_cast<size_t>(perm.at(i)) < shape.size(), nullptr);
168 new_shape.at(i) = shape.at(perm.at(i));
169 }
170 auto new_shape_param = BuildIntVecParameterNode(func_graph, new_shape, reshape->fullname_with_scope() + "_transpose");
171 MS_CHECK_TRUE_RET(new_shape_param != nullptr, nullptr);
172 auto manager = func_graph->manager();
173 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
174 if (transpose->abstract() != nullptr) {
175 reshape->set_abstract(transpose->abstract()->Clone());
176 }
177 manager->SetEdge(reshape, kInputIndexTwo, new_shape_param);
178 return reshape;
179 }
180
TransReshapeFusion(const FuncGraphPtr & func_graph,const CNodePtr & reshape_cnode) const181 AnfNodePtr ReshapeTransposeFusion::TransReshapeFusion(const FuncGraphPtr &func_graph,
182 const CNodePtr &reshape_cnode) const {
183 MS_ASSERT(func_graph != nullptr && reshape_cnode != nullptr);
184 MS_CHECK_TRUE_RET(reshape_cnode->size() == kInputSizeThree, nullptr);
185 auto transpose = reshape_cnode->input(1);
186 MS_CHECK_TRUE_RET(transpose != nullptr, nullptr);
187 auto transpose_cnode = transpose->cast<CNodePtr>();
188 MS_CHECK_TRUE_RET(transpose_cnode != nullptr, nullptr);
189 std::vector<int> perm;
190 if (GetTransposePerm(transpose_cnode, &perm) != RET_OK) {
191 MS_LOG(ERROR) << "fetch transpose's perm failed.";
192 return nullptr;
193 }
194 if (!CheckTransposeCanFused(func_graph, transpose_cnode, perm)) {
195 return nullptr;
196 }
197 auto manager = func_graph->manager();
198 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
199 bool changed = false;
200 auto shape = GetShapeOfReshape(reshape_cnode, &changed);
201 if (changed) {
202 MS_CHECK_TRUE_RET(!shape.empty(), nullptr);
203 auto new_shape_param =
204 BuildIntVecParameterNode(func_graph, shape, reshape_cnode->fullname_with_scope() + "_new_shape");
205 MS_CHECK_TRUE_RET(new_shape_param != nullptr, nullptr);
206 manager->SetEdge(reshape_cnode, kInputIndexTwo, new_shape_param);
207 }
208
209 MS_CHECK_TRUE_RET(transpose_cnode->size() == kInputSizeThree, nullptr);
210 manager->SetEdge(reshape_cnode, 1, transpose_cnode->input(1));
211
212 return reshape_cnode;
213 }
214
FindFixedPositionOfReshape(const ShapeVector & input_shape,const ShapeVector & shape,const std::vector<int> & pre_perm,const std::vector<int> & post_perm,std::vector<size_t> * in_pos,std::vector<size_t> * out_pos)215 int FindFixedPositionOfReshape(const ShapeVector &input_shape, const ShapeVector &shape,
216 const std::vector<int> &pre_perm, const std::vector<int> &post_perm,
217 std::vector<size_t> *in_pos, std::vector<size_t> *out_pos) {
218 size_t i = 0;
219 size_t j = 0;
220 std::vector<size_t> tmp_in_pos;
221 std::vector<size_t> tmp_out_pos;
222 while (i < input_shape.size() && j < shape.size()) {
223 if (input_shape.at(i) == shape.at(j)) {
224 tmp_in_pos.push_back(i++);
225 tmp_out_pos.push_back(j++);
226 } else {
227 size_t in_num = input_shape.at(i++);
228 size_t out_num = shape.at(j++);
229 while (in_num != out_num) {
230 if (in_num < out_num) {
231 MS_CHECK_TRUE_RET(i < input_shape.size(), lite::RET_ERROR);
232 in_num = in_num * input_shape.at(i++);
233 } else {
234 MS_CHECK_TRUE_RET(j < shape.size(), lite::RET_ERROR);
235 out_num = out_num * shape.at(j++);
236 }
237 }
238 }
239 }
240 for (auto ele : tmp_in_pos) {
241 MS_CHECK_TRUE_RET(ele < pre_perm.size(), lite::RET_ERROR);
242 in_pos->push_back(pre_perm.at(ele));
243 }
244 for (auto ele : tmp_out_pos) {
245 auto itr = std::find(post_perm.begin(), post_perm.end(), ele);
246 MS_CHECK_TRUE_RET(itr != post_perm.end(), lite::RET_ERROR);
247 out_pos->push_back(itr - post_perm.begin());
248 }
249 return lite::RET_OK;
250 }
251
CheckPermAndShape(const std::vector<int> & input_shape,const std::vector<int> & output_shape,const std::vector<int> & pre_perm,const std::vector<int> & post_perm,const std::vector<size_t> & in_fixed_pos,const std::vector<size_t> & out_fixed_pos)252 bool CheckPermAndShape(const std::vector<int> &input_shape, const std::vector<int> &output_shape,
253 const std::vector<int> &pre_perm, const std::vector<int> &post_perm,
254 const std::vector<size_t> &in_fixed_pos, const std::vector<size_t> &out_fixed_pos) {
255 if (in_fixed_pos.empty() || out_fixed_pos.empty()) {
256 return false;
257 }
258 for (size_t i = 0; i < in_fixed_pos.size() || i < out_fixed_pos.size(); i++) {
259 size_t pre_num = 1;
260 auto in_begin = i < in_fixed_pos.size() ? in_fixed_pos.at(i) : input_shape.size() - 1;
261 auto in_end = i < in_fixed_pos.size() - 1 ? in_fixed_pos.at(i + 1) : input_shape.size();
262 auto itr = std::find(pre_perm.begin(), pre_perm.end(), in_begin + 1) - 1;
263 for (auto j = in_begin + 1; j < in_end && j < input_shape.size(); j++) {
264 auto tmp_itr = std::find(pre_perm.begin(), pre_perm.end(), j);
265 if (tmp_itr - itr != 1) {
266 return false;
267 }
268 itr = tmp_itr;
269
270 MS_CHECK_INT_MUL_NOT_OVERFLOW(static_cast<int>(pre_num), static_cast<int>(input_shape.at(j)), false);
271 pre_num *= input_shape.at(j);
272 }
273
274 size_t post_num = 1;
275 auto out_begin = i < out_fixed_pos.size() ? out_fixed_pos.at(i) : output_shape.size() - 1;
276 auto out_end = i < out_fixed_pos.size() - 1 ? out_fixed_pos.at(i + 1) : output_shape.size();
277 auto pos = out_begin + 1 < post_perm.size() ? post_perm.at(out_begin + 1) - 1 : 0;
278 for (auto j = out_begin + 1; j < out_end && j < output_shape.size(); j++) {
279 auto tmp_pos = post_perm.at(j);
280 if (tmp_pos - pos != 1) {
281 return false;
282 }
283 pos = tmp_pos;
284
285 MS_CHECK_INT_MUL_NOT_OVERFLOW(static_cast<int>(post_num), static_cast<int>(output_shape.at(j)), false);
286 post_num *= output_shape.at(j);
287 }
288 if (pre_num != post_num) {
289 return false;
290 }
291 }
292 return true;
293 }
294
CheckTransReshapeTransCanFused(const ShapeVector & input_shape,const ShapeVector & output_shape,const std::vector<int> & pre_perm,const std::vector<int> & post_perm)295 bool CheckTransReshapeTransCanFused(const ShapeVector &input_shape, const ShapeVector &output_shape,
296 const std::vector<int> &pre_perm, const std::vector<int> &post_perm) {
297 if (input_shape.size() != pre_perm.size() || output_shape.size() != post_perm.size()) {
298 return false;
299 }
300 std::vector<size_t> in_fixed_pos;
301 std::vector<size_t> out_fixed_pos;
302 if (FindFixedPositionOfReshape(input_shape, output_shape, pre_perm, post_perm, &in_fixed_pos, &out_fixed_pos) !=
303 lite::RET_OK) {
304 MS_LOG(ERROR) << "Find fixed position of reshape failed.";
305 return false;
306 }
307
308 std::vector<int> pre_trans_in_shape;
309 for (int i = 0; i < static_cast<int>(pre_perm.size()); i++) {
310 auto itr = std::find(pre_perm.begin(), pre_perm.end(), i);
311 MS_CHECK_TRUE_RET(itr != pre_perm.end(), false);
312 pre_trans_in_shape.push_back(input_shape.at(itr - pre_perm.begin()));
313 }
314 std::vector<int> trans_out_shape;
315 for (auto dim : post_perm) {
316 MS_CHECK_TRUE_RET(static_cast<size_t>(dim) < output_shape.size(), false);
317 trans_out_shape.push_back(output_shape.at(static_cast<size_t>(dim)));
318 }
319
320 return CheckPermAndShape(pre_trans_in_shape, trans_out_shape, pre_perm, post_perm, in_fixed_pos, out_fixed_pos);
321 }
322
DealReshapeWithMultiOutputs(const FuncGraphPtr & func_graph,const CNodePtr & reshape,const CNodePtr & transpose,const std::vector<int> & post_perm)323 STATUS DealReshapeWithMultiOutputs(const FuncGraphPtr &func_graph, const CNodePtr &reshape, const CNodePtr &transpose,
324 const std::vector<int> &post_perm) {
325 MS_ASSERT(func_graph != nullptr && reshape != nullptr && transpose != nullptr);
326 std::vector<int> new_perm;
327 std::vector<int> tmp_perm(post_perm.size());
328 std::iota(tmp_perm.begin(), tmp_perm.end(), 0);
329 for (auto ele : tmp_perm) {
330 auto itr = std::find(post_perm.begin(), post_perm.end(), ele);
331 MS_CHECK_TRUE_RET(itr != post_perm.end(), lite::RET_ERROR);
332 new_perm.push_back(itr - post_perm.begin());
333 }
334 auto insert_trans_perm_param = BuildIntVecParameterNode(func_graph, new_perm, "transpose_perm");
335 MS_CHECK_TRUE_RET(insert_trans_perm_param != nullptr, lite::RET_ERROR);
336 auto transpose_prim = std::make_shared<ops::Transpose>();
337 if (transpose_prim == nullptr) {
338 MS_LOG(ERROR) << "Build reshape primitive failed.";
339 return lite::RET_ERROR;
340 }
341 auto transpose_prim_c = transpose_prim->GetPrim();
342 MS_CHECK_TRUE_RET(transpose_prim_c != nullptr, lite::RET_ERROR);
343 auto value_node = NewValueNode(transpose_prim_c);
344 MS_CHECK_TRUE_RET(value_node != nullptr, lite::RET_ERROR);
345 auto new_trans = func_graph->NewCNode({value_node, transpose, insert_trans_perm_param});
346 MS_CHECK_TRUE_RET(new_trans != nullptr, lite::RET_ERROR);
347 auto output_node_list = GetRealNodeUsedList(func_graph, reshape);
348
349 auto manager = func_graph->manager();
350 MS_CHECK_TRUE_RET(manager != nullptr, lite::RET_ERROR);
351 for (auto output_node_pair : *output_node_list) {
352 if (output_node_pair.first != transpose) {
353 manager->SetEdge(output_node_pair.first, output_node_pair.second, new_trans);
354 }
355 }
356 return lite::RET_OK;
357 }
358
TransReshapeTransFusion(const FuncGraphPtr & func_graph,const CNodePtr & trans_cnode) const359 AnfNodePtr ReshapeTransposeFusion::TransReshapeTransFusion(const FuncGraphPtr &func_graph,
360 const CNodePtr &trans_cnode) const {
361 MS_ASSERT(func_graph != nullptr && trans_cnode != nullptr);
362 MS_CHECK_TRUE_RET(trans_cnode->size() == kInputSizeThree, nullptr);
363 auto reshape = trans_cnode->input(1);
364 MS_CHECK_TRUE_RET(reshape != nullptr, nullptr);
365
366 auto reshape_cnode = reshape->cast<CNodePtr>();
367 MS_CHECK_TRUE_RET(reshape_cnode != nullptr && reshape_cnode->size() == kInputSizeThree, nullptr);
368 auto pre_trans = reshape_cnode->input(1);
369 MS_CHECK_TRUE_RET(pre_trans != nullptr, nullptr);
370 if (IsMultiOutputTensors(func_graph, pre_trans)) {
371 return nullptr;
372 }
373
374 auto abstract = pre_trans->abstract();
375 MS_CHECK_TRUE_RET(abstract != nullptr, nullptr);
376 ShapeVector input_shape;
377 if (FetchShapeFromAbstract(abstract, &input_shape) != lite::RET_OK) {
378 MS_LOG(ERROR) << "Get shape from abstract failed.";
379 return nullptr;
380 }
381 auto reshape_abstract = reshape->abstract();
382 MS_CHECK_TRUE_RET(reshape_abstract != nullptr, nullptr);
383 ShapeVector output_shape;
384 if (FetchShapeFromAbstract(reshape_abstract, &output_shape) != lite::RET_OK) {
385 MS_LOG(ERROR) << "Get shape from abstract failed.";
386 return nullptr;
387 }
388 if (input_shape.empty() || std::find(input_shape.begin(), input_shape.end(), -1) != input_shape.end() ||
389 output_shape.empty() || std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end()) {
390 MS_LOG(DEBUG) << "The input shape or output shape of reshape is invalid.";
391 return nullptr;
392 }
393
394 std::vector<int> pre_perm;
395 auto pre_trans_cnode = pre_trans->cast<CNodePtr>();
396 MS_CHECK_TRUE_RET(pre_trans_cnode != nullptr, nullptr);
397 if (GetTransposePerm(pre_trans_cnode, &pre_perm) != RET_OK) {
398 MS_LOG(ERROR) << "fetch transpose's perm failed.";
399 return nullptr;
400 }
401 std::vector<int> post_perm;
402 if (GetTransposePerm(trans_cnode, &post_perm) != RET_OK) {
403 MS_LOG(ERROR) << "fetch transpose's perm failed.";
404 return nullptr;
405 }
406
407 if (!CheckTransReshapeTransCanFused(input_shape, output_shape, pre_perm, post_perm)) {
408 return nullptr;
409 }
410
411 if (IsMultiOutputTensors(func_graph, reshape) &&
412 DealReshapeWithMultiOutputs(func_graph, reshape_cnode, trans_cnode, post_perm) != lite::RET_OK) {
413 MS_LOG(ERROR) << "deal with multi-output reshape failed.";
414 return nullptr;
415 }
416 std::vector<int> new_shape;
417 for (auto ele : post_perm) {
418 MS_CHECK_TRUE_RET(static_cast<size_t>(ele) < output_shape.size(), nullptr);
419 new_shape.push_back(output_shape.at(ele));
420 }
421 auto new_shape_param =
422 BuildIntVecParameterNode(func_graph, new_shape, reshape_cnode->fullname_with_scope() + "new_shape");
423 MS_CHECK_TRUE_RET(new_shape_param != nullptr, nullptr);
424 if (trans_cnode->abstract() != nullptr) {
425 reshape->set_abstract(trans_cnode->abstract()->Clone());
426 }
427 auto manager = func_graph->manager();
428 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
429 manager->SetEdge(reshape, 1, pre_trans_cnode->input(1));
430 manager->SetEdge(reshape, kInputIndexTwo, new_shape_param);
431 return reshape;
432 }
433
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const434 AnfNodePtr ReshapeTransposeFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
435 const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
436 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
437 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
438 return nullptr;
439 }
440 auto cnode = node->cast<CNodePtr>();
441 MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
442 if (IsMarkedTrainOp(cnode)) {
443 return nullptr;
444 }
445 if (pattern_name == "ReshapeTranspose") {
446 return ReshapeTransFusion(func_graph, cnode);
447 } else if (pattern_name == "TransposeReshape") {
448 return TransReshapeFusion(func_graph, cnode);
449 }
450
451 return nullptr;
452 }
453 } // namespace mindspore::opt
454