1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/optimizer/graph/transpose_strategy.h"
18 #include <algorithm>
19 #include <functional>
20 #include <map>
21 #include <memory>
22 #include <vector>
23 #include <string>
24 #include <utility>
25 #include "ops/crop.h"
26 #include "ops/fusion/activation.h"
27 #include "ops/fusion/slice_fusion.h"
28 #include "ops/op_utils.h"
29 #include "tools/anf_exporter/fetch_content.h"
30 #include "nnacl/op_base.h"
31
32 namespace mindspore {
33 namespace opt {
34 namespace {
35 constexpr size_t kFirstInput = 1;
36 constexpr size_t kHalfDivisor = 2;
37 constexpr size_t kOnnxStridedSlice = 6;
38 constexpr int kPaddingListLength = 8;
GetPostNodes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,std::vector<AnfNodePtr> * out_nodes)39 STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) {
40 MS_ASSERT(func_graph != nullptr && cnode != nullptr && out_nodes != nullptr);
41 auto manager = func_graph->manager();
42 if (manager == nullptr) {
43 manager = Manage(func_graph, true);
44 }
45 if (manager == nullptr) {
46 MS_LOG(ERROR) << "manager is nullptr.";
47 return lite::RET_ERROR;
48 }
49 auto node_users = manager->node_users()[cnode];
50 if (node_users.empty()) {
51 MS_LOG(ERROR) << "cnode is isolated.";
52 return lite::RET_ERROR;
53 }
54 std::transform(node_users.begin(), node_users.end(), std::back_inserter(*out_nodes),
55 [](const std::pair<AnfNodePtr, int> &node_user) { return node_user.first; });
56 return lite::RET_OK;
57 }
58
JudgeIs4DInput(NodeInferShape * node_infer_shape,const CNodePtr & cnode)59 bool JudgeIs4DInput(NodeInferShape *node_infer_shape, const CNodePtr &cnode) {
60 MS_ASSERT(node_infer_shape != nullptr && cnode != nullptr);
61 auto shape = node_infer_shape->GetInputShape(cnode, 1);
62 if (shape.size() != kInputSizeFour) {
63 if (cnode->size() > kInputSizeTwo) {
64 shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
65 if (shape.size() != kInputSizeFour && !shape.empty()) {
66 return false;
67 }
68 } else {
69 return false;
70 }
71 }
72 return true;
73 }
74
TransformOpAxesAttr(const std::vector<int> & origin_axes,FormatTransNodeType trans_type)75 std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type) {
76 std::vector<int> cur_axes;
77 for (size_t i = 0; i < origin_axes.size(); ++i) {
78 int axis = origin_axes[i];
79 if (axis < 0) {
80 axis += kInputSizeFour;
81 }
82 MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
83 int cur_axis = kNH2NC[axis];
84 if (trans_type == kNHWC2NCHW) {
85 cur_axis = kNC2NH[axis];
86 }
87 cur_axes.push_back(cur_axis);
88 }
89 std::sort(cur_axes.begin(), cur_axes.end());
90 return cur_axes;
91 }
92
TransformAttrByAxes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,const std::vector<int> & axes,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)93 int TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
94 const std::vector<int> &axes, FormatTransNodeType trans_type,
95 NodeInferShape *node_infer_shape) {
96 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
97 if (input_index >= cnode->size() || axes.empty()) {
98 return lite::RET_ERROR;
99 }
100 auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index);
101 if (origin_input.size() != axes.size()) {
102 return lite::RET_ERROR;
103 }
104 std::vector<int> cur_input;
105 for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
106 for (size_t index = 0; index < axes.size(); ++index) {
107 int axis = axes[index];
108 if (axis < 0) {
109 axis += kInputSizeFour;
110 }
111 MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
112 int cur_axis = kNH2NC[axis];
113 if (trans_type == kNHWC2NCHW) {
114 cur_axis = kNC2NH[axis];
115 }
116 if (cur_axis == dim) {
117 cur_input.push_back(origin_input[index]);
118 }
119 }
120 }
121 auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
122 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "BuildIntVecParameterNode failed");
123 func_graph->manager()->Replace(cnode->input(input_index), param_node);
124 return lite::RET_OK;
125 }
126
ChangeCommonOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)127 STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
128 NodeInferShape *node_infer_shape) {
129 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
130 if (trans_type == kNONE) {
131 MS_LOG(ERROR) << "trans_type is invalid.";
132 return lite::RET_ERROR;
133 }
134 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
135 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
136 if (prim->GetAttr(ops::kAxis) == nullptr) {
137 return lite::RET_NOT_SUPPORT;
138 }
139 MS_CHECK_TRUE_MSG(prim->GetAttr(ops::kAxis) != nullptr, lite::RET_NULL_PTR, "GetAttr Failed.");
140 auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
141 if (axis < 0) {
142 axis += kInputSizeFour;
143 }
144 auto new_axis = kNH2NC[axis];
145 if (trans_type == kNHWC2NCHW) {
146 new_axis = kNC2NH[axis];
147 }
148 prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
149 return lite::RET_OK;
150 }
151
ChangeOpCrop(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)152 STATUS ChangeOpCrop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
153 NodeInferShape *node_infer_shape) {
154 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
155 if (trans_type == kNONE) {
156 MS_LOG(ERROR) << "trans_type is invalid.";
157 return lite::RET_ERROR;
158 }
159 auto crop_prim = GetValueNode<std::shared_ptr<ops::Crop>>(cnode->input(0));
160 if (crop_prim == nullptr) {
161 MS_LOG(ERROR) << "cnode is invalid.";
162 return lite::RET_ERROR;
163 }
164 MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kAxis) != nullptr, lite::RET_ERROR);
165 auto axis = crop_prim->get_axis();
166 if (axis < 0) {
167 axis += kInputSizeFour;
168 }
169 MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
170 MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kOffsets) != nullptr, lite::RET_ERROR);
171 auto offsets = crop_prim->get_offsets();
172 if (trans_type == kNCHW2NHWC) {
173 auto new_axis = kNH2NC[axis];
174 if (new_axis == 0) {
175 MS_CHECK_GE(offsets.size(), kInputIndexFour, lite::RET_ERROR);
176 offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
177 } else if (new_axis == kInputIndexThree) {
178 MS_CHECK_GE(offsets.size(), kInputIndexThree, lite::RET_ERROR);
179 offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
180 } else {
181 offsets.push_back(0);
182 }
183 crop_prim->set_axis(new_axis);
184 crop_prim->set_offsets(offsets);
185 } else {
186 auto new_axis = kNC2NH[axis];
187 if (new_axis == 0) {
188 offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]};
189 } else if (new_axis == kInputIndexThree) {
190 offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]};
191 } else {
192 offsets.pop_back();
193 }
194 crop_prim->set_axis(new_axis);
195 crop_prim->set_offsets(offsets);
196 }
197 return lite::RET_OK;
198 }
199
ChangeOpPad(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)200 STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
201 NodeInferShape *node_infer_shape) {
202 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
203 if (trans_type == kNONE) {
204 MS_LOG(ERROR) << "trans_type is invalid.";
205 return lite::RET_ERROR;
206 }
207 if (cnode->size() < kInputSizeThree) {
208 MS_LOG(ERROR) << "pad op need three inputs.";
209 return lite::RET_INPUT_TENSOR_ERROR;
210 }
211 auto second_input = cnode->input(kInputIndexTwo);
212 lite::DataInfo data_info;
213 int status;
214 if (utils::isa<Parameter>(second_input)) {
215 status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
216 } else if (utils::isa<ValueNode>(second_input)) {
217 status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info);
218 } else {
219 return lite::RET_NOT_SUPPORT;
220 }
221 if (status != lite::RET_OK) {
222 MS_LOG(ERROR) << "get paddings failed.";
223 return status;
224 }
225 if (std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<int>()) !=
226 kPaddingListLength) {
227 return lite::RET_OK;
228 }
229 std::vector<std::vector<int32_t>> padding_list(kInputSizeFour, std::vector<int32_t>(kInputSizeTwo));
230 auto data = reinterpret_cast<int32_t *>(data_info.data_.data());
231 for (int i = 0; i < kPaddingListLength; ++i) {
232 padding_list[i / kInputIndexTwo][i % kInputIndexTwo] = *data;
233 data += 1;
234 }
235 if (trans_type == kNCHW2NHWC) {
236 auto chanel_pad = padding_list[1];
237 padding_list.erase(padding_list.begin() + 1);
238 padding_list.push_back(chanel_pad);
239 } else {
240 auto chanel_pad = padding_list.back();
241 padding_list.pop_back();
242 padding_list.insert(padding_list.begin() + 1, chanel_pad);
243 }
244 auto param_node =
245 BuildIntVec2DParameterNode(func_graph, padding_list, cnode->input(kInputIndexTwo)->fullname_with_scope());
246 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_NULL_PTR, "BuildParameterNode Failed");
247 auto manager = func_graph->manager();
248 MS_ASSERT(manager != nullptr);
249 manager->Replace(cnode->input(kInputIndexTwo), param_node);
250 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
251 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
252 if (prim->GetAttr(ops::kPaddings) != nullptr) {
253 std::vector<std::vector<int64_t>> padding_attr;
254 (void)std::transform(padding_list.begin(), padding_list.end(), std::back_inserter(padding_attr),
255 [](const std::vector<int> &val) { return std::vector<int64_t>(val.begin(), val.end()); });
256 prim->AddAttr(ops::kPaddings, MakeValue(padding_attr));
257 }
258 return lite::RET_OK;
259 }
260
ChangeOpSlice(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)261 STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
262 NodeInferShape *node_infer_shape) {
263 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
264 if (trans_type == kNONE) {
265 MS_LOG(ERROR) << "trans_type is invalid.";
266 return lite::RET_ERROR;
267 }
268 for (size_t i = 2; i < cnode->size(); ++i) {
269 if (utils::isa<CNodePtr>(cnode->input(i))) {
270 return lite::RET_NOT_SUPPORT;
271 }
272 }
273 auto shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
274 if (shape.empty()) {
275 return lite::RET_NOT_SUPPORT;
276 }
277 int element_num = shape.front();
278 auto prim = GetValueNode<std::shared_ptr<ops::SliceFusion>>(cnode->input(0));
279 MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
280 std::vector<int> axes;
281 if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
282 for (int index = 0; index < element_num; ++index) {
283 axes.push_back(index);
284 }
285 } else {
286 auto origin_axes = prim->get_axes();
287 std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes),
288 [](int64_t v) { return static_cast<int>(v); });
289 }
290 for (size_t i = 2; i < cnode->size(); ++i) {
291 if (TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape) != RET_OK) {
292 MS_LOG(ERROR) << "Transform axes failed.";
293 return RET_ERROR;
294 }
295 }
296 auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
297 std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
298 prim->set_axes(new_axes);
299 return lite::RET_OK;
300 }
301
ChangeOpStrideSlice(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)302 STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
303 NodeInferShape *node_infer_shape) {
304 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
305 if (trans_type == kNONE) {
306 MS_LOG(ERROR) << "trans_type is invalid.";
307 return lite::RET_ERROR;
308 }
309 if (cnode->size() != kOnnxStridedSlice) {
310 return lite::RET_NOT_SUPPORT;
311 }
312 for (size_t i = 2; i < cnode->size(); ++i) {
313 if (utils::isa<CNodePtr>(cnode->input(i))) {
314 return lite::RET_NOT_SUPPORT;
315 }
316 }
317 std::vector<int> axes = node_infer_shape->GetIntVecInput(cnode, kInputIndexFour);
318 if (axes.empty()) {
319 MS_LOG(ERROR) << "strided slice input invalid.";
320 return lite::RET_ERROR;
321 }
322 for (size_t index = 2; index < cnode->size(); ++index) {
323 if (index == kInputIndexFour) {
324 continue;
325 }
326 if (TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape) != RET_OK) {
327 MS_LOG(ERROR) << "transform axes failed.";
328 return lite::RET_ERROR;
329 }
330 }
331 auto cur_axes = TransformOpAxesAttr(axes, trans_type);
332 auto param_node =
333 BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
334 MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed");
335 auto manager = func_graph->manager();
336 MS_ASSERT(manager != nullptr);
337 manager->Replace(cnode->input(kInputIndexFour), param_node);
338 return lite::RET_OK;
339 }
340 } // namespace
341
TransposePairFuseWhenInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)342 AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
343 const std::vector<int> &perm, bool before, size_t index) {
344 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
345 AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
346 // judge pair transpose after insert.
347 if (CheckPrimitiveType(trans_input_node, prim::kPrimTranspose)) {
348 std::vector<int> trans_perm;
349 auto input_cnode = trans_input_node->cast<CNodePtr>();
350 if (input_cnode == nullptr) {
351 MS_LOG(ERROR) << "input node is invalid.";
352 return nullptr;
353 }
354 if (GetTransposePerm(input_cnode, &trans_perm) != lite::RET_OK) {
355 MS_LOG(ERROR) << "transpose perm get failed.";
356 return nullptr;
357 }
358 if ((perm == kNH2NC && trans_perm == kNC2NH) || (perm == kNC2NH && trans_perm == kNH2NC)) {
359 return input_cnode->input(kFirstInput);
360 }
361 }
362 // insert depend on shape
363 return TransposeDependOnShape(func_graph, cnode, perm, before, index);
364 }
365
TransposeDependOnShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)366 AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
367 const std::vector<int> &perm, bool before, size_t index) {
368 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
369 AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
370 auto status = TransposeInsertDependOnShape(func_graph, cnode, before, index);
371 if (status == lite::RET_ERROR) {
372 return nullptr;
373 } else if (status == lite::RET_NO_CHANGE) {
374 return before ? cnode->input(index) : cnode;
375 }
376 // insert tranpsoe
377 std::string trans_name =
378 before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
379 auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
380 return trans_insert_node;
381 }
382
CanFusionIfInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,TransTypePair * trans_info,TransTypePair * trans_insert_info)383 bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
384 TransTypePair *trans_info, TransTypePair *trans_insert_info) {
385 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
386 MS_ASSERT(pre_type != nullptr && post_type != nullptr);
387 size_t trans_count = 0;
388 std::vector<AnfNodePtr> in_nodes;
389 auto graph_inputs = func_graph->get_inputs();
390 for (size_t i = 1; i < cnode->size(); ++i) {
391 if (utils::isa<CNodePtr>(cnode->input(i)) ||
392 std::find(graph_inputs.begin(), graph_inputs.end(), cnode->input(i)) != graph_inputs.end()) {
393 in_nodes.push_back(cnode->input(i));
394 }
395 }
396 if (!IsInOutCanFuison(in_nodes, &trans_count, &trans_info->pre_)) {
397 return false;
398 }
399 std::vector<AnfNodePtr> out_nodes;
400 if (GetPostNodes(func_graph, cnode, &out_nodes) != lite::RET_OK) {
401 return false;
402 }
403 if (!IsInOutCanFuison(out_nodes, &trans_count, &trans_info->post_)) {
404 return false;
405 }
406 if (trans_info->pre_ == trans_info->post_) {
407 return false;
408 }
409 auto total_node_count = in_nodes.size() + out_nodes.size();
410 bool can_insert = trans_count > total_node_count / kHalfDivisor;
411 if (CheckPrimitiveType(cnode, prim::kPrimActivation)) {
412 auto prim_act = GetValueNode<std::shared_ptr<ops::Activation>>(cnode->input(0));
413 MS_CHECK_TRUE_MSG(prim_act != nullptr, false, "GetValueNode Failed");
414 if (prim_act->get_activation_type() == mindspore::ActivationType::LEAKY_RELU) {
415 can_insert = trans_count >= total_node_count / kHalfDivisor;
416 }
417 }
418 if (CheckPrimitiveType(cnode, prim::kPrimSplit) || CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
419 can_insert = trans_count >= total_node_count / kHalfDivisor;
420 }
421 if (!can_insert) {
422 return can_insert;
423 }
424 DecidePreAndPostTransType(trans_info, trans_insert_info);
425 return can_insert;
426 }
427
CanChangeOpAxis(const CNodePtr & cnode)428 bool TransposeStrategy::CanChangeOpAxis(const CNodePtr &cnode) {
429 MS_ASSERT(cnode != nullptr);
430 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
431 MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode Failed");
432 if (!IsDynamicFormatOp(prim->name())) {
433 return false;
434 }
435 if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
436 return false;
437 }
438 if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice) ||
439 CheckPrimitiveType(cnode, prim::kPrimPadFusion)) {
440 for (size_t i = 2; i < cnode->size(); ++i) {
441 if (utils::isa<CNodePtr>(cnode->input(i))) {
442 return false;
443 }
444 if (utils::isa<Parameter>(cnode->input(i)) && !cnode->input(i)->cast<ParameterPtr>()->has_default()) {
445 return false;
446 }
447 }
448 if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) {
449 return false;
450 }
451 } else if (IsDynamicFormatOpWithAxis(prim->name())) {
452 if (prim->GetAttr(ops::kAxis) == nullptr) {
453 return false;
454 }
455 }
456 return true;
457 }
458
ChangeOpAxis(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type)459 STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
460 FormatTransNodeType trans_type) {
461 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
462 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
463 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
464 if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
465 return lite::RET_NOT_SUPPORT;
466 }
467 std::map<std::string,
468 std::function<STATUS(const FuncGraphPtr &, const CNodePtr &, FormatTransNodeType, NodeInferShape *)>>
469 process_funcs = {
470 {prim::kPrimConcat->name(), ChangeCommonOp}, {prim::kPrimSplit->name(), ChangeCommonOp},
471 {prim::kPrimCrop->name(), ChangeOpCrop}, {prim::kPrimPadFusion->name(), ChangeOpPad},
472 {prim::kPrimSliceFusion->name(), ChangeOpSlice}, {prim::kPrimStridedSlice->name(), ChangeOpStrideSlice}};
473 auto iter = process_funcs.find(prim->name());
474 if (iter != process_funcs.end()) {
475 return iter->second(func_graph, cnode, trans_type, &node_infer_shape_);
476 }
477 return lite::RET_OK;
478 }
479
TransposeInsertDependOnShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool before,size_t index)480 STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
481 bool before, size_t index) {
482 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
483 auto manager = func_graph->manager();
484 if (manager == nullptr) {
485 manager = Manage(func_graph, true);
486 }
487 if (manager == nullptr) {
488 MS_LOG(ERROR) << "manager is nullptr.";
489 return lite::RET_ERROR;
490 }
491 auto node_users = manager->node_users()[cnode];
492 if (node_users.empty()) {
493 MS_LOG(ERROR) << "cnode is isolated.";
494 return lite::RET_ERROR;
495 }
496 if (!utils::isa<CNodePtr>(node_users.front().first)) {
497 return lite::RET_ERROR;
498 }
499 CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
500 MS_ASSERT(base_node != nullptr);
501 size_t input_index = before ? index : node_users.front().second;
502 auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
503 if (!shape.empty() && shape.size() != kNH2NC.size()) {
504 return lite::RET_NO_CHANGE;
505 }
506 return lite::RET_OK;
507 }
508
IsInOutCanFuison(const std::vector<AnfNodePtr> & nodes,size_t * trans_count,FormatTransNodeType * trans_type)509 bool TransposeStrategy::IsInOutCanFuison(const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
510 FormatTransNodeType *trans_type) {
511 MS_ASSERT(trans_count != nullptr && trans_type != nullptr);
512 for (auto &node : nodes) {
513 if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
514 FormatTransNodeType cur_type;
515 std::vector<int> perm;
516 auto cnode = node->cast<CNodePtr>();
517 if (cnode == nullptr) {
518 return false;
519 }
520 if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
521 return false;
522 }
523 if (perm == kNH2NC) {
524 cur_type = kNHWC2NCHW;
525 } else if (perm == kNC2NH) {
526 cur_type = kNCHW2NHWC;
527 } else {
528 return false;
529 }
530 if (*trans_type == kNONE) {
531 *trans_type = cur_type;
532 } else if (*trans_type != cur_type) {
533 return false;
534 }
535 *trans_count += 1;
536 }
537 }
538 return true;
539 }
540
DecidePreAndPostTransType(TransTypePair * trans_info,TransTypePair * trans_insert_info) const541 void TransposeStrategy::DecidePreAndPostTransType(TransTypePair *trans_info, TransTypePair *trans_insert_info) const {
542 if (trans_info->pre_ == trans_info->post_) {
543 return;
544 }
545 if (trans_info->pre_ != kNONE && trans_info->post_ != kNONE) {
546 trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
547 trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
548 } else if (trans_info->pre_ == kNONE) {
549 trans_insert_info->pre_ = trans_info->post_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
550 trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
551 } else {
552 trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
553 trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
554 }
555 }
556 } // namespace opt
557 } // namespace mindspore
558