• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/optimizer/graph/transpose_strategy.h"
18 #include <algorithm>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <vector>
23 #include <string>
24 #include <utility>
25 #include "ops/crop.h"
26 #include "ops/fusion/activation.h"
27 #include "ops/fusion/slice_fusion.h"
28 #include "ops/op_utils.h"
29 #include "tools/anf_exporter/fetch_content.h"
30 #include "nnacl/op_base.h"
31 
32 namespace mindspore {
33 namespace opt {
34 namespace {
35 constexpr size_t kFirstInput = 1;
36 constexpr size_t kHalfDivisor = 2;
37 constexpr size_t kOnnxStridedSlice = 6;
38 constexpr int kPaddingListLength = 8;
GetPostNodes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,std::vector<AnfNodePtr> * out_nodes)39 STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) {
40   MS_ASSERT(func_graph != nullptr && cnode != nullptr && out_nodes != nullptr);
41   auto manager = func_graph->manager();
42   if (manager == nullptr) {
43     manager = Manage(func_graph, true);
44   }
45   if (manager == nullptr) {
46     MS_LOG(ERROR) << "manager is nullptr.";
47     return lite::RET_ERROR;
48   }
49   auto node_users = manager->node_users()[cnode];
50   if (node_users.empty()) {
51     MS_LOG(ERROR) << "cnode is isolated.";
52     return lite::RET_ERROR;
53   }
54   std::transform(node_users.begin(), node_users.end(), std::back_inserter(*out_nodes),
55                  [](const std::pair<AnfNodePtr, int> &node_user) { return node_user.first; });
56   return lite::RET_OK;
57 }
58 
JudgeIs4DInput(NodeInferShape * node_infer_shape,const CNodePtr & cnode)59 bool JudgeIs4DInput(NodeInferShape *node_infer_shape, const CNodePtr &cnode) {
60   MS_ASSERT(node_infer_shape != nullptr && cnode != nullptr);
61   auto shape = node_infer_shape->GetInputShape(cnode, 1);
62   if (shape.size() != kInputSizeFour) {
63     if (cnode->size() > kInputSizeTwo) {
64       shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
65       if (shape.size() != kInputSizeFour && !shape.empty()) {
66         return false;
67       }
68     } else {
69       return false;
70     }
71   }
72   return true;
73 }
74 
TransformOpAxesAttr(const std::vector<int> & origin_axes,FormatTransNodeType trans_type)75 std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type) {
76   std::vector<int> cur_axes;
77   for (size_t i = 0; i < origin_axes.size(); ++i) {
78     int axis = origin_axes[i];
79     if (axis < 0) {
80       axis += kInputSizeFour;
81     }
82     MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
83     int cur_axis = kNH2NC[axis];
84     if (trans_type == kNHWC2NCHW) {
85       cur_axis = kNC2NH[axis];
86     }
87     cur_axes.push_back(cur_axis);
88   }
89   std::sort(cur_axes.begin(), cur_axes.end());
90   return cur_axes;
91 }
92 
TransformAttrByAxes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,const std::vector<int> & axes,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)93 int TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
94                         const std::vector<int> &axes, FormatTransNodeType trans_type,
95                         NodeInferShape *node_infer_shape) {
96   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
97   if (input_index >= cnode->size() || axes.empty()) {
98     return lite::RET_ERROR;
99   }
100   auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index);
101   if (origin_input.size() != axes.size()) {
102     return lite::RET_ERROR;
103   }
104   std::vector<int> cur_input;
105   for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
106     for (size_t index = 0; index < axes.size(); ++index) {
107       int axis = axes[index];
108       if (axis < 0) {
109         axis += kInputSizeFour;
110       }
111       MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
112       int cur_axis = kNH2NC[axis];
113       if (trans_type == kNHWC2NCHW) {
114         cur_axis = kNC2NH[axis];
115       }
116       if (cur_axis == dim) {
117         cur_input.push_back(origin_input[index]);
118       }
119     }
120   }
121   auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
122   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "BuildIntVecParameterNode failed");
123   func_graph->manager()->Replace(cnode->input(input_index), param_node);
124   return lite::RET_OK;
125 }
126 
ChangeCommonOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)127 STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
128                       NodeInferShape *node_infer_shape) {
129   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
130   if (trans_type == kNONE) {
131     MS_LOG(ERROR) << "trans_type is invalid.";
132     return lite::RET_ERROR;
133   }
134   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
135   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
136   if (prim->GetAttr(ops::kAxis) == nullptr) {
137     return lite::RET_NOT_SUPPORT;
138   }
139   MS_CHECK_TRUE_MSG(prim->GetAttr(ops::kAxis) != nullptr, lite::RET_NULL_PTR, "GetAttr Failed.");
140   auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
141   if (axis < 0) {
142     axis += kInputSizeFour;
143   }
144   auto new_axis = kNH2NC[axis];
145   if (trans_type == kNHWC2NCHW) {
146     new_axis = kNC2NH[axis];
147   }
148   prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
149   return lite::RET_OK;
150 }
151 
ChangeOpCrop(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)152 STATUS ChangeOpCrop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
153                     NodeInferShape *node_infer_shape) {
154   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
155   if (trans_type == kNONE) {
156     MS_LOG(ERROR) << "trans_type is invalid.";
157     return lite::RET_ERROR;
158   }
159   auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
160   if (crop_prim == nullptr) {
161     MS_LOG(ERROR) << "cnode is invalid.";
162     return lite::RET_ERROR;
163   }
164   MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kAxis) != nullptr, lite::RET_ERROR);
165   auto axis = crop_prim->get_axis();
166   if (axis < 0) {
167     axis += kInputSizeFour;
168   }
169   MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
170   MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kOffsets) != nullptr, lite::RET_ERROR);
171   auto offsets = crop_prim->get_offsets();
172   if (trans_type == kNCHW2NHWC) {
173     auto new_axis = kNH2NC[axis];
174     if (new_axis == 0) {
175       MS_CHECK_GE(offsets.size(), kInputIndexFour, lite::RET_ERROR);
176       offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
177     } else if (new_axis == kInputIndexThree) {
178       MS_CHECK_GE(offsets.size(), kInputIndexThree, lite::RET_ERROR);
179       offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
180     } else {
181       offsets.push_back(0);
182     }
183     crop_prim->set_axis(new_axis);
184     crop_prim->set_offsets(offsets);
185   } else {
186     auto new_axis = kNC2NH[axis];
187     if (new_axis == 0) {
188       offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]};
189     } else if (new_axis == kInputIndexThree) {
190       offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]};
191     } else {
192       offsets.pop_back();
193     }
194     crop_prim->set_axis(new_axis);
195     crop_prim->set_offsets(offsets);
196   }
197   return lite::RET_OK;
198 }
199 
ChangeOpPad(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)200 STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
201                    NodeInferShape *node_infer_shape) {
202   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
203   if (trans_type == kNONE) {
204     MS_LOG(ERROR) << "trans_type is invalid.";
205     return lite::RET_ERROR;
206   }
207   if (cnode->size() < kInputSizeThree) {
208     MS_LOG(ERROR) << "pad op need three inputs.";
209     return lite::RET_INPUT_TENSOR_ERROR;
210   }
211   auto second_input = cnode->input(kInputIndexTwo);
212   lite::DataInfo data_info;
213   int status;
214   if (utils::isa<Parameter>(second_input)) {
215     status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
216   } else if (utils::isa<ValueNode>(second_input)) {
217     status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
218   } else {
219     return lite::RET_NOT_SUPPORT;
220   }
221   if (status != lite::RET_OK) {
222     MS_LOG(ERROR) << "get paddings failed.";
223     return status;
224   }
225   if (std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<int>()) !=
226       kPaddingListLength) {
227     return lite::RET_OK;
228   }
229   std::vector<std::vector<int32_t>> padding_list(kInputSizeFour, std::vector<int32_t>(kInputSizeTwo));
230   auto data = reinterpret_cast<int32_t *>(data_info.data_.data());
231   for (int i = 0; i < kPaddingListLength; ++i) {
232     padding_list[i / kInputIndexTwo][i % kInputIndexTwo] = *data;
233     data += 1;
234   }
235   if (trans_type == kNCHW2NHWC) {
236     auto chanel_pad = padding_list[1];
237     padding_list.erase(padding_list.begin() + 1);
238     padding_list.push_back(chanel_pad);
239   } else {
240     auto chanel_pad = padding_list.back();
241     padding_list.pop_back();
242     padding_list.insert(padding_list.begin() + 1, chanel_pad);
243   }
244   auto param_node =
245     BuildIntVec2DParameterNode(func_graph, padding_list, cnode->input(kInputIndexTwo)->fullname_with_scope());
246   MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_NULL_PTR, "BuildParameterNode Failed");
247   auto manager = func_graph->manager();
248   MS_ASSERT(manager != nullptr);
249   manager->Replace(cnode->input(kInputIndexTwo), param_node);
250   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
251   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
252   if (prim->GetAttr(ops::kPaddings) != nullptr) {
253     std::vector<std::vector<int64_t>> padding_attr;
254     (void)std::transform(padding_list.begin(), padding_list.end(), std::back_inserter(padding_attr),
255                          [](const std::vector<int> &val) { return std::vector<int64_t>(val.begin(), val.end()); });
256     prim->AddAttr(ops::kPaddings, MakeValue(padding_attr));
257   }
258   return lite::RET_OK;
259 }
260 
ChangeOpSlice(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)261 STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
262                      NodeInferShape *node_infer_shape) {
263   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
264   if (trans_type == kNONE) {
265     MS_LOG(ERROR) << "trans_type is invalid.";
266     return lite::RET_ERROR;
267   }
268   for (size_t i = 2; i < cnode->size(); ++i) {
269     if (utils::isa<CNodePtr>(cnode->input(i))) {
270       return lite::RET_NOT_SUPPORT;
271     }
272   }
273   auto shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
274   if (shape.empty()) {
275     return lite::RET_NOT_SUPPORT;
276   }
277   int element_num = shape.front();
278   auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0));
279   MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
280   std::vector<int> axes;
281   if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
282     for (int index = 0; index < element_num; ++index) {
283       axes.push_back(index);
284     }
285   } else {
286     auto origin_axes = prim->get_axes();
287     std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes),
288                    [](int64_t v) { return static_cast<int>(v); });
289   }
290   for (size_t i = 2; i < cnode->size(); ++i) {
291     if (TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape) != RET_OK) {
292       MS_LOG(ERROR) << "Transform axes failed.";
293       return RET_ERROR;
294     }
295   }
296   auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
297   std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
298   prim->set_axes(new_axes);
299   return lite::RET_OK;
300 }
301 
ChangeOpStrideSlice(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)302 STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
303                            NodeInferShape *node_infer_shape) {
304   MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
305   if (trans_type == kNONE) {
306     MS_LOG(ERROR) << "trans_type is invalid.";
307     return lite::RET_ERROR;
308   }
309   if (cnode->size() != kOnnxStridedSlice) {
310     return lite::RET_NOT_SUPPORT;
311   }
312   for (size_t i = 2; i < cnode->size(); ++i) {
313     if (utils::isa<CNodePtr>(cnode->input(i))) {
314       return lite::RET_NOT_SUPPORT;
315     }
316   }
317   std::vector<int> axes = node_infer_shape->GetIntVecInput(cnode, kInputIndexFour);
318   if (axes.empty()) {
319     MS_LOG(ERROR) << "strided slice input invalid.";
320     return lite::RET_ERROR;
321   }
322   for (size_t index = 2; index < cnode->size(); ++index) {
323     if (index == kInputIndexFour) {
324       continue;
325     }
326     if (TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape) != RET_OK) {
327       MS_LOG(ERROR) << "transform axes failed.";
328       return lite::RET_ERROR;
329     }
330   }
331   auto cur_axes = TransformOpAxesAttr(axes, trans_type);
332   auto param_node =
333     BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
334   MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed");
335   auto manager = func_graph->manager();
336   MS_ASSERT(manager != nullptr);
337   manager->Replace(cnode->input(kInputIndexFour), param_node);
338   return lite::RET_OK;
339 }
340 }  // namespace
341 
TransposePairFuseWhenInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)342 AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
343                                                           const std::vector<int> &perm, bool before, size_t index) {
344   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
345   AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
346   // judge pair transpose after insert.
347   if (CheckPrimitiveType(trans_input_node, prim::kPrimTranspose)) {
348     std::vector<int> trans_perm;
349     auto input_cnode = trans_input_node->cast<CNodePtr>();
350     if (input_cnode == nullptr) {
351       MS_LOG(ERROR) << "input node is invalid.";
352       return nullptr;
353     }
354     if (GetTransposePerm(input_cnode, &trans_perm) != lite::RET_OK) {
355       MS_LOG(ERROR) << "transpose perm get failed.";
356       return nullptr;
357     }
358     if ((perm == kNH2NC && trans_perm == kNC2NH) || (perm == kNC2NH && trans_perm == kNH2NC)) {
359       return input_cnode->input(kFirstInput);
360     }
361   }
362   // insert depend on shape
363   return TransposeDependOnShape(func_graph, cnode, perm, before, index);
364 }
365 
TransposeDependOnShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)366 AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
367                                                      const std::vector<int> &perm, bool before, size_t index) {
368   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
369   AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
370   auto status = TransposeInsertDependOnShape(func_graph, cnode, before, index);
371   if (status == lite::RET_ERROR) {
372     return nullptr;
373   } else if (status == lite::RET_NO_CHANGE) {
374     return before ? cnode->input(index) : cnode;
375   }
376   // insert tranpsoe
377   std::string trans_name =
378     before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
379   auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
380   return trans_insert_node;
381 }
382 
CanFusionIfInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,TransTypePair * trans_info,TransTypePair * trans_insert_info)383 bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
384                                           TransTypePair *trans_info, TransTypePair *trans_insert_info) {
385   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
386   MS_ASSERT(pre_type != nullptr && post_type != nullptr);
387   size_t trans_count = 0;
388   std::vector<AnfNodePtr> in_nodes;
389   auto graph_inputs = func_graph->get_inputs();
390   for (size_t i = 1; i < cnode->size(); ++i) {
391     if (utils::isa<CNodePtr>(cnode->input(i)) ||
392         std::find(graph_inputs.begin(), graph_inputs.end(), cnode->input(i)) != graph_inputs.end()) {
393       in_nodes.push_back(cnode->input(i));
394     }
395   }
396   if (!IsInOutCanFuison(in_nodes, &trans_count, &trans_info->pre_)) {
397     return false;
398   }
399   std::vector<AnfNodePtr> out_nodes;
400   if (GetPostNodes(func_graph, cnode, &out_nodes) != lite::RET_OK) {
401     return false;
402   }
403   if (!IsInOutCanFuison(out_nodes, &trans_count, &trans_info->post_)) {
404     return false;
405   }
406   if (trans_info->pre_ == trans_info->post_) {
407     return false;
408   }
409   auto total_node_count = in_nodes.size() + out_nodes.size();
410   bool can_insert = trans_count > total_node_count / kHalfDivisor;
411   if (CheckPrimitiveType(cnode, prim::kPrimActivation)) {
412     auto prim_act = GetValueNode<std::shared_ptr<ops::Activation>>(cnode->input(0));
413     MS_CHECK_TRUE_MSG(prim_act != nullptr, false, "GetValueNode Failed");
414     if (prim_act->get_activation_type() == mindspore::ActivationType::LEAKY_RELU) {
415       can_insert = trans_count >= total_node_count / kHalfDivisor;
416     }
417   }
418   if (CheckPrimitiveType(cnode, prim::kPrimSplit) || CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
419     can_insert = trans_count >= total_node_count / kHalfDivisor;
420   }
421   if (!can_insert) {
422     return can_insert;
423   }
424   DecidePreAndPostTransType(trans_info, trans_insert_info);
425   return can_insert;
426 }
427 
CanChangeOpAxis(const CNodePtr & cnode)428 bool TransposeStrategy::CanChangeOpAxis(const CNodePtr &cnode) {
429   MS_ASSERT(cnode != nullptr);
430   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
431   MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode Failed");
432   if (!IsDynamicFormatOp(prim->name())) {
433     return false;
434   }
435   if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
436     return false;
437   }
438   if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice) ||
439       CheckPrimitiveType(cnode, prim::kPrimPadFusion)) {
440     for (size_t i = 2; i < cnode->size(); ++i) {
441       if (utils::isa<CNodePtr>(cnode->input(i))) {
442         return false;
443       }
444       if (utils::isa<Parameter>(cnode->input(i)) && !cnode->input(i)->cast<ParameterPtr>()->has_default()) {
445         return false;
446       }
447     }
448     if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) {
449       return false;
450     }
451   } else if (IsDynamicFormatOpWithAxis(prim->name())) {
452     if (prim->GetAttr(ops::kAxis) == nullptr) {
453       return false;
454     }
455   }
456   return true;
457 }
458 
ChangeOpAxis(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type)459 STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
460                                        FormatTransNodeType trans_type) {
461   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
462   auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
463   MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
464   if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
465     return lite::RET_NOT_SUPPORT;
466   }
467   std::map<std::string,
468            std::function<STATUS(const FuncGraphPtr &, const CNodePtr &, FormatTransNodeType, NodeInferShape *)>>
469     process_funcs = {
470       {prim::kPrimConcat->name(), ChangeCommonOp},     {prim::kPrimSplit->name(), ChangeCommonOp},
471       {prim::kPrimCrop->name(), ChangeOpCrop},         {prim::kPrimPadFusion->name(), ChangeOpPad},
472       {prim::kPrimSliceFusion->name(), ChangeOpSlice}, {prim::kPrimStridedSlice->name(), ChangeOpStrideSlice}};
473   auto iter = process_funcs.find(prim->name());
474   if (iter != process_funcs.end()) {
475     return iter->second(func_graph, cnode, trans_type, &node_infer_shape_);
476   }
477   return lite::RET_OK;
478 }
479 
TransposeInsertDependOnShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool before,size_t index)480 STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
481                                                        bool before, size_t index) {
482   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
483   auto manager = func_graph->manager();
484   if (manager == nullptr) {
485     manager = Manage(func_graph, true);
486   }
487   if (manager == nullptr) {
488     MS_LOG(ERROR) << "manager is nullptr.";
489     return lite::RET_ERROR;
490   }
491   auto node_users = manager->node_users()[cnode];
492   if (node_users.empty()) {
493     MS_LOG(ERROR) << "cnode is isolated.";
494     return lite::RET_ERROR;
495   }
496   if (!utils::isa<CNodePtr>(node_users.front().first)) {
497     return lite::RET_ERROR;
498   }
499   CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
500   MS_ASSERT(base_node != nullptr);
501   size_t input_index = before ? index : node_users.front().second;
502   auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
503   if (!shape.empty() && shape.size() != kNH2NC.size()) {
504     return lite::RET_NO_CHANGE;
505   }
506   return lite::RET_OK;
507 }
508 
IsInOutCanFuison(const std::vector<AnfNodePtr> & nodes,size_t * trans_count,FormatTransNodeType * trans_type)509 bool TransposeStrategy::IsInOutCanFuison(const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
510                                          FormatTransNodeType *trans_type) {
511   MS_ASSERT(trans_count != nullptr && trans_type != nullptr);
512   for (auto &node : nodes) {
513     if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
514       FormatTransNodeType cur_type;
515       std::vector<int> perm;
516       auto cnode = node->cast<CNodePtr>();
517       if (cnode == nullptr) {
518         return false;
519       }
520       if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
521         return false;
522       }
523       if (perm == kNH2NC) {
524         cur_type = kNHWC2NCHW;
525       } else if (perm == kNC2NH) {
526         cur_type = kNCHW2NHWC;
527       } else {
528         return false;
529       }
530       if (*trans_type == kNONE) {
531         *trans_type = cur_type;
532       } else if (*trans_type != cur_type) {
533         return false;
534       }
535       *trans_count += 1;
536     }
537   }
538   return true;
539 }
540 
DecidePreAndPostTransType(TransTypePair * trans_info,TransTypePair * trans_insert_info) const541 void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) const {
542   if (trans_info->pre_ == trans_info->post_) {
543     return;
544   }
545   if (trans_info->pre_ != kNONE && trans_info->post_ != kNONE) {
546     trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
547     trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
548   } else if (trans_info->pre_ == kNONE) {
549     trans_insert_info->pre_ = trans_info->post_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
550     trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
551   } else {
552     trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
553     trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
554   }
555 }
556 }  // namespace opt
557 }  // namespace mindspore
558