• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/transpose_strategy.h"
19 #include <algorithm>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <vector>
24 #include <string>
25 #include <utility>
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/lite_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "ops/crop.h"
31 #include "src/common/utils.h"
32 #include "ops/fusion/activation.h"
33 #include "ops/fusion/slice_fusion.h"
34 #include "ops/op_utils.h"
35 #include "tools/lite_exporter/fetch_content.h"
36 #include "nnacl/op_base.h"
37 
38 namespace mindspore {
39 namespace opt {
40 namespace {
41 constexpr size_t kFirstInput = 1;
42 constexpr size_t kHalfDivisor = 2;
43 constexpr size_t kOnnxStridedSlice = 6;
44 constexpr int kPaddingListLength = 8;
GetPostNodes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,std::vector<AnfNodePtr> * out_nodes)45 STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) {
46   MS_ASSERT(func_graph != nullptr && cnode != nullptr && out_nodes != nullptr);
47   auto manager = func_graph->manager();
48   if (manager == nullptr) {
49     manager = Manage(func_graph, true);
50   }
51   if (manager == nullptr) {
52     MS_LOG(ERROR) << "manager is nullptr.";
53     return lite::RET_ERROR;
54   }
55   auto node_users = manager->node_users()[cnode];
56   if (node_users.empty()) {
57     MS_LOG(ERROR) << "cnode is isolated.";
58     return lite::RET_ERROR;
59   }
60   std::transform(node_users.begin(), node_users.end(), std::back_inserter(*out_nodes),
61                  [](const std::pair<AnfNodePtr, int> &node_user) { return node_user.first; });
62   return lite::RET_OK;
63 }
64 
JudgeIs4DInput(NodeInferShape * node_infer_shape,const CNodePtr & cnode)65 bool JudgeIs4DInput(NodeInferShape *node_infer_shape, const CNodePtr &cnode) {
66   MS_ASSERT(node_infer_shape != nullptr && cnode != nullptr);
67   auto shape = node_infer_shape->GetInputShape(cnode, 1);
68   if (shape.size() != kInputSizeFour) {
69     if (cnode->size() > kInputSizeTwo) {
70       shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
71       if (shape.size() != kInputSizeFour && !lite::JudgeDynamicShape(shape)) {
72         return false;
73       }
74     } else {
75       return false;
76     }
77   }
78   return true;
79 }
80 
TransformOpAxesAttr(const std::vector<int> & origin_axes,FormatTransNodeType trans_type)81 std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type) {
82   std::vector<int> cur_axes;
83   for (size_t i = 0; i < origin_axes.size(); ++i) {
84     int axis = origin_axes[i];
85     if (axis < 0) {
86       axis += kInputSizeFour;
87     }
88     MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
89     int cur_axis = kNH2NC[axis];
90     if (trans_type == kNHWC2NCHW) {
91       cur_axis = kNC2NH[axis];
92     }
93     cur_axes.push_back(cur_axis);
94   }
95   std::sort(cur_axes.begin(), cur_axes.end());
96   return cur_axes;
97 }
98 
TransformAttrByAxes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,const std::vector<int> & axes,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)99 int TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
100                         const std::vector<int> &axes, FormatTransNodeType trans_type,
101                         NodeInferShape *node_infer_shape) {
102   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
103   if (input_index >= cnode->size() || axes.empty()) {
104     return lite::RET_ERROR;
105   }
106   auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index);
107   if (origin_input.size() != axes.size()) {
108     return lite::RET_ERROR;
109   }
110   std::vector<int> cur_input;
111   for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
112     for (size_t index = 0; index < axes.size(); ++index) {
113       int axis = axes[index];
114       if (axis < 0) {
115         axis += kInputSizeFour;
116       }
117       MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
118       int cur_axis = kNH2NC[axis];
119       if (trans_type == kNHWC2NCHW) {
120         cur_axis = kNC2NH[axis];
121       }
122       if (cur_axis == dim) {
123         cur_input.push_back(origin_input[index]);
124       }
125     }
126   }
127   auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
128   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "BuildIntVecParameterNode failed");
129   func_graph->manager()->SetEdge(cnode, input_index, param_node);
130   return lite::RET_OK;
131 }
132 
ChangeCommonOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)133 STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
134                       NodeInferShape *node_infer_shape) {
135   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
136   if (trans_type == kNONE) {
137     MS_LOG(ERROR) << "trans_type is invalid.";
138     return lite::RET_ERROR;
139   }
140   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
141   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
142   if (prim->GetAttr(ops::kAxis) == nullptr) {
143     return lite::RET_NOT_SUPPORT;
144   }
145   MS_CHECK_TRUE_MSG(prim->GetAttr(ops::kAxis) != nullptr, lite::RET_NULL_PTR, "GetAttr Failed.");
146   auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
147   if (axis < 0) {
148     axis += kInputSizeFour;
149   }
150   auto new_axis = kNH2NC[axis];
151   if (trans_type == kNHWC2NCHW) {
152     new_axis = kNC2NH[axis];
153   }
154   prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
155   return lite::RET_OK;
156 }
157 
ChangeOpCrop(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)158 STATUS ChangeOpCrop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
159                     NodeInferShape *node_infer_shape) {
160   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
161   if (trans_type == kNONE) {
162     MS_LOG(ERROR) << "trans_type is invalid.";
163     return lite::RET_ERROR;
164   }
165   auto crop_prim = ops::GetOperator<ops::Crop>(cnode->input(0));
166   if (crop_prim == nullptr) {
167     MS_LOG(ERROR) << "cnode is invalid.";
168     return lite::RET_ERROR;
169   }
170   MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kAxis) != nullptr, lite::RET_ERROR);
171   auto axis = crop_prim->get_axis();
172   if (axis < 0) {
173     axis += kInputSizeFour;
174   }
175   MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
176   MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kOffsets) != nullptr, lite::RET_ERROR);
177   auto offsets = crop_prim->get_offsets();
178   if (trans_type == kNCHW2NHWC) {
179     auto new_axis = kNH2NC[axis];
180     if (new_axis == 0) {
181       MS_CHECK_GE(offsets.size(), kInputIndexFour, lite::RET_ERROR);
182       offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
183     } else if (new_axis == kInputIndexThree) {
184       MS_CHECK_GE(offsets.size(), kInputIndexThree, lite::RET_ERROR);
185       offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
186     } else {
187       offsets.push_back(0);
188     }
189     crop_prim->set_axis(new_axis);
190     crop_prim->set_offsets(offsets);
191   } else {
192     auto new_axis = kNC2NH[axis];
193     if (new_axis == 0) {
194       offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]};
195     } else if (new_axis == kInputIndexThree) {
196       offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]};
197     } else {
198       offsets.pop_back();
199     }
200     crop_prim->set_axis(new_axis);
201     crop_prim->set_offsets(offsets);
202   }
203   return lite::RET_OK;
204 }
205 
ChangeOpPad(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)206 STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
207                    NodeInferShape *node_infer_shape) {
208   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
209   if (trans_type == kNONE) {
210     MS_LOG(ERROR) << "trans_type is invalid.";
211     return lite::RET_ERROR;
212   }
213   if (cnode->size() < kInputSizeThree) {
214     MS_LOG(ERROR) << "pad op need three inputs.";
215     return lite::RET_INPUT_TENSOR_ERROR;
216   }
217   auto second_input = cnode->input(kInputIndexTwo);
218   lite::DataInfo data_info;
219   int status;
220   if (utils::isa<Parameter>(second_input)) {
221     status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true);
222   } else if (utils::isa<ValueNode>(second_input)) {
223     status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
224   } else {
225     return lite::RET_NOT_SUPPORT;
226   }
227   if (status != lite::RET_OK) {
228     MS_LOG(ERROR) << "get paddings failed.";
229     return status;
230   }
231   if (std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<int>()) !=
232       kPaddingListLength) {
233     return lite::RET_OK;
234   }
235   std::vector<std::vector<int32_t>> padding_list(kInputSizeFour, std::vector<int32_t>(kInputSizeTwo));
236   auto data = reinterpret_cast<int32_t *>(data_info.data_.data());
237   for (int i = 0; i < kPaddingListLength; ++i) {
238     padding_list[i / kInputIndexTwo][i % kInputIndexTwo] = *data;
239     data += 1;
240   }
241   if (trans_type == kNCHW2NHWC) {
242     auto chanel_pad = padding_list[1];
243     padding_list.erase(padding_list.begin() + 1);
244     padding_list.push_back(chanel_pad);
245   } else {
246     auto chanel_pad = padding_list.back();
247     padding_list.pop_back();
248     padding_list.insert(padding_list.begin() + 1, chanel_pad);
249   }
250   auto param_node =
251     BuildIntVec2DParameterNode(func_graph, padding_list, cnode->input(kInputIndexTwo)->fullname_with_scope());
252   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_NULL_PTR, "BuildParameterNode Failed");
253   auto manager = func_graph->manager();
254   MS_ASSERT(manager != nullptr);
255   manager->Replace(cnode->input(kInputIndexTwo), param_node);
256   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
257   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
258   if (prim->GetAttr(ops::kPaddings) != nullptr) {
259     std::vector<std::vector<int64_t>> padding_attr;
260     (void)std::transform(padding_list.begin(), padding_list.end(), std::back_inserter(padding_attr),
261                          [](const std::vector<int> &val) { return std::vector<int64_t>(val.begin(), val.end()); });
262     prim->AddAttr(ops::kPaddings, MakeValue(padding_attr));
263   }
264   return lite::RET_OK;
265 }
266 
ChangeOpSlice(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)267 STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
268                      NodeInferShape *node_infer_shape) {
269   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
270   if (trans_type == kNONE) {
271     MS_LOG(ERROR) << "trans_type is invalid.";
272     return lite::RET_ERROR;
273   }
274   for (size_t i = 2; i < cnode->size(); ++i) {
275     if (utils::isa<CNodePtr>(cnode->input(i))) {
276       return lite::RET_NOT_SUPPORT;
277     }
278   }
279   auto shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
280   if (lite::JudgeDynamicShape(shape)) {
281     return lite::RET_NOT_SUPPORT;
282   }
283   int element_num = shape.front();
284   auto prim = ops::GetOperator<ops::SliceFusion>(cnode->input(0));
285   MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
286   std::vector<int> axes;
287   if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
288     for (int index = 0; index < element_num; ++index) {
289       axes.push_back(index);
290     }
291   } else {
292     auto origin_axes = prim->get_axes();
293     std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes),
294                    [](int64_t v) { return static_cast<int>(v); });
295   }
296   for (size_t i = 2; i < cnode->size(); ++i) {
297     if (TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape) != RET_OK) {
298       MS_LOG(ERROR) << "Transform axes failed.";
299       return RET_ERROR;
300     }
301   }
302   auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
303   std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
304   prim->set_axes(new_axes);
305   return lite::RET_OK;
306 }
307 
ChangeOpStrideSlice(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)308 STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
309                            NodeInferShape *node_infer_shape) {
310   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
311   if (trans_type == kNONE) {
312     MS_LOG(ERROR) << "trans_type is invalid.";
313     return lite::RET_ERROR;
314   }
315   if (cnode->size() != kOnnxStridedSlice) {
316     return lite::RET_NOT_SUPPORT;
317   }
318   for (size_t i = 2; i < cnode->size(); ++i) {
319     if (utils::isa<CNodePtr>(cnode->input(i))) {
320       return lite::RET_NOT_SUPPORT;
321     }
322   }
323   std::vector<int> axes = node_infer_shape->GetIntVecInput(cnode, kInputIndexFour);
324   if (axes.empty()) {
325     MS_LOG(ERROR) << "strided slice input invalid.";
326     return lite::RET_ERROR;
327   }
328   for (size_t index = 2; index < cnode->size(); ++index) {
329     if (index == kInputIndexFour) {
330       continue;
331     }
332     if (TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape) != RET_OK) {
333       MS_LOG(ERROR) << "transform axes failed.";
334       return lite::RET_ERROR;
335     }
336   }
337   auto cur_axes = TransformOpAxesAttr(axes, trans_type);
338   auto param_node =
339     BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
340   MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed");
341   auto manager = func_graph->manager();
342   MS_ASSERT(manager != nullptr);
343   manager->SetEdge(cnode, kInputIndexFour, param_node);
344   return lite::RET_OK;
345 }
346 }  // namespace
347 
TransposePairFuseWhenInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)348 AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
349                                                           const std::vector<int> &perm, bool before, size_t index) {
350   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
351   AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
352   // judge pair transpose after insert.
353   if (CheckPrimitiveType(trans_input_node, prim::kPrimTranspose)) {
354     std::vector<int> trans_perm;
355     auto input_cnode = trans_input_node->cast<CNodePtr>();
356     if (input_cnode == nullptr) {
357       MS_LOG(ERROR) << "input node is invalid.";
358       return nullptr;
359     }
360     if (GetTransposePerm(input_cnode, &trans_perm) != lite::RET_OK) {
361       MS_LOG(ERROR) << "transpose perm get failed.";
362       return nullptr;
363     }
364     if ((perm == kNH2NC && trans_perm == kNC2NH) || (perm == kNC2NH && trans_perm == kNH2NC)) {
365       return input_cnode->input(kFirstInput);
366     }
367   }
368   // insert depend on shape
369   return TransposeDependOnShape(func_graph, cnode, perm, before, index);
370 }
371 
TransposeDependOnShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)372 AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
373                                                      const std::vector<int> &perm, bool before, size_t index) {
374   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
375   AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
376   auto status = TransposeInsertDependOnShape(func_graph, cnode, before, index);
377   if (status == lite::RET_ERROR) {
378     return nullptr;
379   } else if (status == lite::RET_NO_CHANGE) {
380     return before ? cnode->input(index) : cnode;
381   }
382   // insert tranpsoe
383   std::string trans_name =
384     before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
385   auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
386   return trans_insert_node;
387 }
388 
CanFusionIfInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,TransTypePair * trans_info,TransTypePair * trans_insert_info)389 bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
390                                           TransTypePair *trans_info, TransTypePair *trans_insert_info) {
391   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
392   MS_ASSERT(pre_type != nullptr && post_type != nullptr);
393   size_t trans_count = 0;
394   std::vector<AnfNodePtr> in_nodes;
395   auto graph_inputs = func_graph->get_inputs();
396   for (size_t i = 1; i < cnode->size(); ++i) {
397     if (utils::isa<CNodePtr>(cnode->input(i)) ||
398         std::find(graph_inputs.begin(), graph_inputs.end(), cnode->input(i)) != graph_inputs.end()) {
399       in_nodes.push_back(cnode->input(i));
400     }
401   }
402   if (!IsInOutCanFuison(in_nodes, &trans_count, &trans_info->pre_)) {
403     return false;
404   }
405   std::vector<AnfNodePtr> out_nodes;
406   if (GetPostNodes(func_graph, cnode, &out_nodes) != lite::RET_OK) {
407     return false;
408   }
409   if (!IsInOutCanFuison(out_nodes, &trans_count, &trans_info->post_)) {
410     return false;
411   }
412   if (trans_info->pre_ == trans_info->post_) {
413     return false;
414   }
415   auto total_node_count = in_nodes.size() + out_nodes.size();
416   bool can_insert = trans_count > total_node_count / kHalfDivisor;
417   if (CheckPrimitiveType(cnode, prim::kPrimActivation)) {
418     auto prim_act = ops::GetOperator<ops::Activation>(cnode->input(0));
419     MS_CHECK_TRUE_MSG(prim_act != nullptr, false, "GetValueNode Failed");
420     if (prim_act->get_activation_type() == mindspore::ActivationType::LEAKY_RELU) {
421       can_insert = trans_count >= total_node_count / kHalfDivisor;
422     }
423   }
424   if (CheckPrimitiveType(cnode, prim::kPrimSplit) || CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
425     can_insert = trans_count >= total_node_count / kHalfDivisor;
426   }
427   if (!can_insert) {
428     return can_insert;
429   }
430   DecidePreAndPostTransType(trans_info, trans_insert_info);
431   return can_insert;
432 }
433 
CanChangeOpAxis(const CNodePtr & cnode)434 bool TransposeStrategy::CanChangeOpAxis(const CNodePtr &cnode) {
435   MS_ASSERT(cnode != nullptr);
436   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
437   MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode Failed");
438   if (!IsDynamicFormatOp(prim->name())) {
439     return false;
440   }
441   if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
442     return false;
443   }
444   if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice) ||
445       CheckPrimitiveType(cnode, prim::kPrimPadFusion)) {
446     for (size_t i = 2; i < cnode->size(); ++i) {
447       if (utils::isa<CNodePtr>(cnode->input(i))) {
448         return false;
449       }
450       if (utils::isa<Parameter>(cnode->input(i)) && !cnode->input(i)->cast<ParameterPtr>()->has_default()) {
451         return false;
452       }
453     }
454     if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) {
455       return false;
456     }
457   } else if (CheckPrimitiveType(cnode, prim::kPrimScaleFusion)) {
458     MS_CHECK_TRUE_RET(cnode->size() >= kInputSizeThree, false);
459     auto weight_param = cnode->input(kInputIndexTwo);
460     MS_CHECK_TRUE_RET(weight_param != nullptr, false);
461     std::vector<int64_t> weight_shape;
462     if (FetchShapeFromAbstract(weight_param->abstract(), &weight_shape) != lite::RET_OK) {
463       MS_LOG(ERROR) << "Get shape from abstract failed.";
464       return false;
465     }
466     if (weight_shape.size() != 1) {
467       return false;
468     }
469   } else if (IsDynamicFormatOpWithAxis(prim->name())) {
470     if (prim->GetAttr(ops::kAxis) == nullptr) {
471       return false;
472     }
473   }
474   return true;
475 }
476 
ChangeOpAxis(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type)477 STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
478                                        FormatTransNodeType trans_type) {
479   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
480   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
481   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
482   if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
483     return lite::RET_NOT_SUPPORT;
484   }
485   std::map<std::string,
486            std::function<STATUS(const FuncGraphPtr &, const CNodePtr &, FormatTransNodeType, NodeInferShape *)>>
487     process_funcs = {
488       {prim::kPrimConcat->name(), ChangeCommonOp},     {prim::kPrimSplit->name(), ChangeCommonOp},
489       {prim::kPrimCrop->name(), ChangeOpCrop},         {prim::kPrimPadFusion->name(), ChangeOpPad},
490       {prim::kPrimSliceFusion->name(), ChangeOpSlice}, {prim::kPrimStridedSlice->name(), ChangeOpStrideSlice},
491       {prim::kPrimScaleFusion->name(), ChangeCommonOp}};
492   auto iter = process_funcs.find(prim->name());
493   if (iter != process_funcs.end()) {
494     return iter->second(func_graph, cnode, trans_type, &node_infer_shape_);
495   }
496   return lite::RET_OK;
497 }
498 
TransposeInsertDependOnShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool before,size_t index)499 STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
500                                                        bool before, size_t index) {
501   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
502   auto manager = func_graph->manager();
503   if (manager == nullptr) {
504     manager = Manage(func_graph, true);
505   }
506   if (manager == nullptr) {
507     MS_LOG(ERROR) << "manager is nullptr.";
508     return lite::RET_ERROR;
509   }
510   auto node_users = manager->node_users()[cnode];
511   if (node_users.empty()) {
512     MS_LOG(ERROR) << "cnode is isolated.";
513     return lite::RET_ERROR;
514   }
515   if (!utils::isa<CNodePtr>(node_users.front().first)) {
516     return lite::RET_ERROR;
517   }
518   CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
519   MS_ASSERT(base_node != nullptr);
520   size_t input_index = before ? index : static_cast<size_t>(node_users.front().second);
521   auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
522   if (!lite::JudgeDynamicShape(shape) && shape.size() != kNH2NC.size()) {
523     return lite::RET_NO_CHANGE;
524   }
525   return lite::RET_OK;
526 }
527 
IsInOutCanFuison(const std::vector<AnfNodePtr> & nodes,size_t * trans_count,FormatTransNodeType * trans_type)528 bool TransposeStrategy::IsInOutCanFuison(const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
529                                          FormatTransNodeType *trans_type) {
530   MS_ASSERT(trans_count != nullptr && trans_type != nullptr);
531   for (auto &node : nodes) {
532     if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
533       FormatTransNodeType cur_type;
534       std::vector<int> perm;
535       auto cnode = node->cast<CNodePtr>();
536       if (cnode == nullptr) {
537         return false;
538       }
539       if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
540         return false;
541       }
542       if (perm == kNH2NC) {
543         cur_type = kNHWC2NCHW;
544       } else if (perm == kNC2NH) {
545         cur_type = kNCHW2NHWC;
546       } else {
547         return false;
548       }
549       if (*trans_type == kNONE) {
550         *trans_type = cur_type;
551       } else if (*trans_type != cur_type) {
552         return false;
553       }
554       *trans_count += 1;
555     }
556   }
557   return true;
558 }
559 
DecidePreAndPostTransType(const TransTypePair * trans_info,TransTypePair * trans_insert_info) const560 void TransposeStrategy::DecidePreAndPostTransType(const TransTypePair *trans_info,
561                                                   TransTypePair *trans_insert_info) const {
562   if (trans_info->pre_ == trans_info->post_) {
563     return;
564   }
565   if (trans_info->pre_ != kNONE && trans_info->post_ != kNONE) {
566     trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
567     trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
568   } else if (trans_info->pre_ == kNONE) {
569     trans_insert_info->pre_ = trans_info->post_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
570     trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
571   } else {
572     trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
573     trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
574   }
575 }
576 }  // namespace opt
577 }  // namespace mindspore
578