1 /**
2 * Copyright 2021-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/transpose_strategy.h"
19 #include <algorithm>
20 #include <functional>
21 #include <map>
22 #include <memory>
23 #include <vector>
24 #include <string>
25 #include <utility>
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/lite_ops.h"
28 #include "mindspore/core/ops/array_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "ops/crop.h"
31 #include "src/common/utils.h"
32 #include "ops/fusion/activation.h"
33 #include "ops/fusion/slice_fusion.h"
34 #include "ops/op_utils.h"
35 #include "tools/lite_exporter/fetch_content.h"
36 #include "nnacl/op_base.h"
37
38 namespace mindspore {
39 namespace opt {
40 namespace {
41 constexpr size_t kFirstInput = 1;
42 constexpr size_t kHalfDivisor = 2;
43 constexpr size_t kOnnxStridedSlice = 6;
44 constexpr int kPaddingListLength = 8;
GetPostNodes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,std::vector<AnfNodePtr> * out_nodes)45 STATUS GetPostNodes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, std::vector<AnfNodePtr> *out_nodes) {
46 MS_ASSERT(func_graph != nullptr && cnode != nullptr && out_nodes != nullptr);
47 auto manager = func_graph->manager();
48 if (manager == nullptr) {
49 manager = Manage(func_graph, true);
50 }
51 if (manager == nullptr) {
52 MS_LOG(ERROR) << "manager is nullptr.";
53 return lite::RET_ERROR;
54 }
55 auto node_users = manager->node_users()[cnode];
56 if (node_users.empty()) {
57 MS_LOG(ERROR) << "cnode is isolated.";
58 return lite::RET_ERROR;
59 }
60 std::transform(node_users.begin(), node_users.end(), std::back_inserter(*out_nodes),
61 [](const std::pair<AnfNodePtr, int> &node_user) { return node_user.first; });
62 return lite::RET_OK;
63 }
64
JudgeIs4DInput(NodeInferShape * node_infer_shape,const CNodePtr & cnode)65 bool JudgeIs4DInput(NodeInferShape *node_infer_shape, const CNodePtr &cnode) {
66 MS_ASSERT(node_infer_shape != nullptr && cnode != nullptr);
67 auto shape = node_infer_shape->GetInputShape(cnode, 1);
68 if (shape.size() != kInputSizeFour) {
69 if (cnode->size() > kInputSizeTwo) {
70 shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
71 if (shape.size() != kInputSizeFour && !lite::JudgeDynamicShape(shape)) {
72 return false;
73 }
74 } else {
75 return false;
76 }
77 }
78 return true;
79 }
80
TransformOpAxesAttr(const std::vector<int> & origin_axes,FormatTransNodeType trans_type)81 std::vector<int> TransformOpAxesAttr(const std::vector<int> &origin_axes, FormatTransNodeType trans_type) {
82 std::vector<int> cur_axes;
83 for (size_t i = 0; i < origin_axes.size(); ++i) {
84 int axis = origin_axes[i];
85 if (axis < 0) {
86 axis += kInputSizeFour;
87 }
88 MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
89 int cur_axis = kNH2NC[axis];
90 if (trans_type == kNHWC2NCHW) {
91 cur_axis = kNC2NH[axis];
92 }
93 cur_axes.push_back(cur_axis);
94 }
95 std::sort(cur_axes.begin(), cur_axes.end());
96 return cur_axes;
97 }
98
TransformAttrByAxes(const FuncGraphPtr & func_graph,const CNodePtr & cnode,size_t input_index,const std::vector<int> & axes,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)99 int TransformAttrByAxes(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t input_index,
100 const std::vector<int> &axes, FormatTransNodeType trans_type,
101 NodeInferShape *node_infer_shape) {
102 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
103 if (input_index >= cnode->size() || axes.empty()) {
104 return lite::RET_ERROR;
105 }
106 auto origin_input = node_infer_shape->GetIntVecInput(cnode, input_index);
107 if (origin_input.size() != axes.size()) {
108 return lite::RET_ERROR;
109 }
110 std::vector<int> cur_input;
111 for (int dim = 0; dim < static_cast<int>(kInputSizeFour); ++dim) {
112 for (size_t index = 0; index < axes.size(); ++index) {
113 int axis = axes[index];
114 if (axis < 0) {
115 axis += kInputSizeFour;
116 }
117 MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
118 int cur_axis = kNH2NC[axis];
119 if (trans_type == kNHWC2NCHW) {
120 cur_axis = kNC2NH[axis];
121 }
122 if (cur_axis == dim) {
123 cur_input.push_back(origin_input[index]);
124 }
125 }
126 }
127 auto param_node = BuildIntVecParameterNode(func_graph, cur_input, cnode->input(input_index)->fullname_with_scope());
128 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_ERROR, "BuildIntVecParameterNode failed");
129 func_graph->manager()->SetEdge(cnode, input_index, param_node);
130 return lite::RET_OK;
131 }
132
ChangeCommonOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)133 STATUS ChangeCommonOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
134 NodeInferShape *node_infer_shape) {
135 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
136 if (trans_type == kNONE) {
137 MS_LOG(ERROR) << "trans_type is invalid.";
138 return lite::RET_ERROR;
139 }
140 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
141 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
142 if (prim->GetAttr(ops::kAxis) == nullptr) {
143 return lite::RET_NOT_SUPPORT;
144 }
145 MS_CHECK_TRUE_MSG(prim->GetAttr(ops::kAxis) != nullptr, lite::RET_NULL_PTR, "GetAttr Failed.");
146 auto axis = GetValue<int64_t>(prim->GetAttr(ops::kAxis));
147 if (axis < 0) {
148 axis += kInputSizeFour;
149 }
150 auto new_axis = kNH2NC[axis];
151 if (trans_type == kNHWC2NCHW) {
152 new_axis = kNC2NH[axis];
153 }
154 prim->AddAttr(ops::kAxis, MakeValue<int64_t>(new_axis));
155 return lite::RET_OK;
156 }
157
ChangeOpCrop(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)158 STATUS ChangeOpCrop(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
159 NodeInferShape *node_infer_shape) {
160 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
161 if (trans_type == kNONE) {
162 MS_LOG(ERROR) << "trans_type is invalid.";
163 return lite::RET_ERROR;
164 }
165 auto crop_prim = ops::GetOperator<ops::Crop>(cnode->input(0));
166 if (crop_prim == nullptr) {
167 MS_LOG(ERROR) << "cnode is invalid.";
168 return lite::RET_ERROR;
169 }
170 MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kAxis) != nullptr, lite::RET_ERROR);
171 auto axis = crop_prim->get_axis();
172 if (axis < 0) {
173 axis += kInputSizeFour;
174 }
175 MS_ASSERT(axis >= 0 && axis < kInputSizeFour);
176 MS_CHECK_TRUE_RET(crop_prim->GetAttr(ops::kOffsets) != nullptr, lite::RET_ERROR);
177 auto offsets = crop_prim->get_offsets();
178 if (trans_type == kNCHW2NHWC) {
179 auto new_axis = kNH2NC[axis];
180 if (new_axis == 0) {
181 MS_CHECK_GE(offsets.size(), kInputIndexFour, lite::RET_ERROR);
182 offsets = {offsets[0], offsets[kInputIndexTwo], offsets[kInputIndexThree], offsets[1]};
183 } else if (new_axis == kInputIndexThree) {
184 MS_CHECK_GE(offsets.size(), kInputIndexThree, lite::RET_ERROR);
185 offsets = {offsets[1], offsets[kInputIndexTwo], offsets[0]};
186 } else {
187 offsets.push_back(0);
188 }
189 crop_prim->set_axis(new_axis);
190 crop_prim->set_offsets(offsets);
191 } else {
192 auto new_axis = kNC2NH[axis];
193 if (new_axis == 0) {
194 offsets = {offsets[0], offsets[kInputIndexThree], offsets[1], offsets[kInputIndexTwo]};
195 } else if (new_axis == kInputIndexThree) {
196 offsets = {offsets[kInputIndexTwo], offsets[0], offsets[1]};
197 } else {
198 offsets.pop_back();
199 }
200 crop_prim->set_axis(new_axis);
201 crop_prim->set_offsets(offsets);
202 }
203 return lite::RET_OK;
204 }
205
ChangeOpPad(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)206 STATUS ChangeOpPad(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
207 NodeInferShape *node_infer_shape) {
208 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
209 if (trans_type == kNONE) {
210 MS_LOG(ERROR) << "trans_type is invalid.";
211 return lite::RET_ERROR;
212 }
213 if (cnode->size() < kInputSizeThree) {
214 MS_LOG(ERROR) << "pad op need three inputs.";
215 return lite::RET_INPUT_TENSOR_ERROR;
216 }
217 auto second_input = cnode->input(kInputIndexTwo);
218 lite::DataInfo data_info;
219 int status;
220 if (utils::isa<Parameter>(second_input)) {
221 status = lite::FetchDataFromParameterNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, &data_info, true);
222 } else if (utils::isa<ValueNode>(second_input)) {
223 status = lite::FetchDataFromValueNode(cnode, kInputIndexTwo, converter::kFmkTypeMs, false, &data_info, true);
224 } else {
225 return lite::RET_NOT_SUPPORT;
226 }
227 if (status != lite::RET_OK) {
228 MS_LOG(ERROR) << "get paddings failed.";
229 return status;
230 }
231 if (std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<int>()) !=
232 kPaddingListLength) {
233 return lite::RET_OK;
234 }
235 std::vector<std::vector<int32_t>> padding_list(kInputSizeFour, std::vector<int32_t>(kInputSizeTwo));
236 auto data = reinterpret_cast<int32_t *>(data_info.data_.data());
237 for (int i = 0; i < kPaddingListLength; ++i) {
238 padding_list[i / kInputIndexTwo][i % kInputIndexTwo] = *data;
239 data += 1;
240 }
241 if (trans_type == kNCHW2NHWC) {
242 auto chanel_pad = padding_list[1];
243 padding_list.erase(padding_list.begin() + 1);
244 padding_list.push_back(chanel_pad);
245 } else {
246 auto chanel_pad = padding_list.back();
247 padding_list.pop_back();
248 padding_list.insert(padding_list.begin() + 1, chanel_pad);
249 }
250 auto param_node =
251 BuildIntVec2DParameterNode(func_graph, padding_list, cnode->input(kInputIndexTwo)->fullname_with_scope());
252 MS_CHECK_TRUE_MSG(param_node != nullptr, lite::RET_NULL_PTR, "BuildParameterNode Failed");
253 auto manager = func_graph->manager();
254 MS_ASSERT(manager != nullptr);
255 manager->Replace(cnode->input(kInputIndexTwo), param_node);
256 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
257 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
258 if (prim->GetAttr(ops::kPaddings) != nullptr) {
259 std::vector<std::vector<int64_t>> padding_attr;
260 (void)std::transform(padding_list.begin(), padding_list.end(), std::back_inserter(padding_attr),
261 [](const std::vector<int> &val) { return std::vector<int64_t>(val.begin(), val.end()); });
262 prim->AddAttr(ops::kPaddings, MakeValue(padding_attr));
263 }
264 return lite::RET_OK;
265 }
266
ChangeOpSlice(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)267 STATUS ChangeOpSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
268 NodeInferShape *node_infer_shape) {
269 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
270 if (trans_type == kNONE) {
271 MS_LOG(ERROR) << "trans_type is invalid.";
272 return lite::RET_ERROR;
273 }
274 for (size_t i = 2; i < cnode->size(); ++i) {
275 if (utils::isa<CNodePtr>(cnode->input(i))) {
276 return lite::RET_NOT_SUPPORT;
277 }
278 }
279 auto shape = node_infer_shape->GetInputShape(cnode, kInputIndexTwo);
280 if (lite::JudgeDynamicShape(shape)) {
281 return lite::RET_NOT_SUPPORT;
282 }
283 int element_num = shape.front();
284 auto prim = ops::GetOperator<ops::SliceFusion>(cnode->input(0));
285 MS_CHECK_TRUE_MSG(prim != nullptr, RET_ERROR, "GetValueNode failed");
286 std::vector<int> axes;
287 if (prim->GetAttr(ops::kAxes) == nullptr || prim->get_axes().empty()) {
288 for (int index = 0; index < element_num; ++index) {
289 axes.push_back(index);
290 }
291 } else {
292 auto origin_axes = prim->get_axes();
293 std::transform(origin_axes.begin(), origin_axes.end(), std::back_inserter(axes),
294 [](int64_t v) { return static_cast<int>(v); });
295 }
296 for (size_t i = 2; i < cnode->size(); ++i) {
297 if (TransformAttrByAxes(func_graph, cnode, i, axes, trans_type, node_infer_shape) != RET_OK) {
298 MS_LOG(ERROR) << "Transform axes failed.";
299 return RET_ERROR;
300 }
301 }
302 auto tmp_axes = TransformOpAxesAttr(axes, trans_type);
303 std::vector<int64_t> new_axes(tmp_axes.begin(), tmp_axes.end());
304 prim->set_axes(new_axes);
305 return lite::RET_OK;
306 }
307
ChangeOpStrideSlice(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type,NodeInferShape * node_infer_shape)308 STATUS ChangeOpStrideSlice(const FuncGraphPtr &func_graph, const CNodePtr &cnode, FormatTransNodeType trans_type,
309 NodeInferShape *node_infer_shape) {
310 MS_ASSERT(func_graph != nullptr && cnode != nullptr && node_infer_shape != nullptr);
311 if (trans_type == kNONE) {
312 MS_LOG(ERROR) << "trans_type is invalid.";
313 return lite::RET_ERROR;
314 }
315 if (cnode->size() != kOnnxStridedSlice) {
316 return lite::RET_NOT_SUPPORT;
317 }
318 for (size_t i = 2; i < cnode->size(); ++i) {
319 if (utils::isa<CNodePtr>(cnode->input(i))) {
320 return lite::RET_NOT_SUPPORT;
321 }
322 }
323 std::vector<int> axes = node_infer_shape->GetIntVecInput(cnode, kInputIndexFour);
324 if (axes.empty()) {
325 MS_LOG(ERROR) << "strided slice input invalid.";
326 return lite::RET_ERROR;
327 }
328 for (size_t index = 2; index < cnode->size(); ++index) {
329 if (index == kInputIndexFour) {
330 continue;
331 }
332 if (TransformAttrByAxes(func_graph, cnode, index, axes, trans_type, node_infer_shape) != RET_OK) {
333 MS_LOG(ERROR) << "transform axes failed.";
334 return lite::RET_ERROR;
335 }
336 }
337 auto cur_axes = TransformOpAxesAttr(axes, trans_type);
338 auto param_node =
339 BuildIntVecParameterNode(func_graph, cur_axes, cnode->input(kInputIndexFour)->fullname_with_scope());
340 MS_CHECK_TRUE_MSG(param_node != nullptr, RET_ERROR, "BuildIntVecParameterNode failed");
341 auto manager = func_graph->manager();
342 MS_ASSERT(manager != nullptr);
343 manager->SetEdge(cnode, kInputIndexFour, param_node);
344 return lite::RET_OK;
345 }
346 } // namespace
347
TransposePairFuseWhenInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)348 AnfNodePtr TransposeStrategy::TransposePairFuseWhenInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
349 const std::vector<int> &perm, bool before, size_t index) {
350 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
351 AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
352 // judge pair transpose after insert.
353 if (CheckPrimitiveType(trans_input_node, prim::kPrimTranspose)) {
354 std::vector<int> trans_perm;
355 auto input_cnode = trans_input_node->cast<CNodePtr>();
356 if (input_cnode == nullptr) {
357 MS_LOG(ERROR) << "input node is invalid.";
358 return nullptr;
359 }
360 if (GetTransposePerm(input_cnode, &trans_perm) != lite::RET_OK) {
361 MS_LOG(ERROR) << "transpose perm get failed.";
362 return nullptr;
363 }
364 if ((perm == kNH2NC && trans_perm == kNC2NH) || (perm == kNC2NH && trans_perm == kNH2NC)) {
365 return input_cnode->input(kFirstInput);
366 }
367 }
368 // insert depend on shape
369 return TransposeDependOnShape(func_graph, cnode, perm, before, index);
370 }
371
TransposeDependOnShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::vector<int> & perm,bool before,size_t index)372 AnfNodePtr TransposeStrategy::TransposeDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
373 const std::vector<int> &perm, bool before, size_t index) {
374 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
375 AnfNodePtr trans_input_node = before ? cnode->input(index) : cnode;
376 auto status = TransposeInsertDependOnShape(func_graph, cnode, before, index);
377 if (status == lite::RET_ERROR) {
378 return nullptr;
379 } else if (status == lite::RET_NO_CHANGE) {
380 return before ? cnode->input(index) : cnode;
381 }
382 // insert tranpsoe
383 std::string trans_name =
384 before ? cnode->fullname_with_scope() + "_pre" + std::to_string(index - 1) : cnode->fullname_with_scope() + "_post";
385 auto trans_insert_node = GenTransposeNode(func_graph, trans_input_node, perm, trans_name);
386 return trans_insert_node;
387 }
388
CanFusionIfInsert(const FuncGraphPtr & func_graph,const CNodePtr & cnode,TransTypePair * trans_info,TransTypePair * trans_insert_info)389 bool TransposeStrategy::CanFusionIfInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
390 TransTypePair *trans_info, TransTypePair *trans_insert_info) {
391 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
392 MS_ASSERT(pre_type != nullptr && post_type != nullptr);
393 size_t trans_count = 0;
394 std::vector<AnfNodePtr> in_nodes;
395 auto graph_inputs = func_graph->get_inputs();
396 for (size_t i = 1; i < cnode->size(); ++i) {
397 if (utils::isa<CNodePtr>(cnode->input(i)) ||
398 std::find(graph_inputs.begin(), graph_inputs.end(), cnode->input(i)) != graph_inputs.end()) {
399 in_nodes.push_back(cnode->input(i));
400 }
401 }
402 if (!IsInOutCanFuison(in_nodes, &trans_count, &trans_info->pre_)) {
403 return false;
404 }
405 std::vector<AnfNodePtr> out_nodes;
406 if (GetPostNodes(func_graph, cnode, &out_nodes) != lite::RET_OK) {
407 return false;
408 }
409 if (!IsInOutCanFuison(out_nodes, &trans_count, &trans_info->post_)) {
410 return false;
411 }
412 if (trans_info->pre_ == trans_info->post_) {
413 return false;
414 }
415 auto total_node_count = in_nodes.size() + out_nodes.size();
416 bool can_insert = trans_count > total_node_count / kHalfDivisor;
417 if (CheckPrimitiveType(cnode, prim::kPrimActivation)) {
418 auto prim_act = ops::GetOperator<ops::Activation>(cnode->input(0));
419 MS_CHECK_TRUE_MSG(prim_act != nullptr, false, "GetValueNode Failed");
420 if (prim_act->get_activation_type() == mindspore::ActivationType::LEAKY_RELU) {
421 can_insert = trans_count >= total_node_count / kHalfDivisor;
422 }
423 }
424 if (CheckPrimitiveType(cnode, prim::kPrimSplit) || CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
425 can_insert = trans_count >= total_node_count / kHalfDivisor;
426 }
427 if (!can_insert) {
428 return can_insert;
429 }
430 DecidePreAndPostTransType(trans_info, trans_insert_info);
431 return can_insert;
432 }
433
CanChangeOpAxis(const CNodePtr & cnode)434 bool TransposeStrategy::CanChangeOpAxis(const CNodePtr &cnode) {
435 MS_ASSERT(cnode != nullptr);
436 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
437 MS_CHECK_TRUE_MSG(prim != nullptr, false, "GetValueNode Failed");
438 if (!IsDynamicFormatOp(prim->name())) {
439 return false;
440 }
441 if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
442 return false;
443 }
444 if (CheckPrimitiveType(cnode, prim::kPrimSliceFusion) || CheckPrimitiveType(cnode, prim::kPrimStridedSlice) ||
445 CheckPrimitiveType(cnode, prim::kPrimPadFusion)) {
446 for (size_t i = 2; i < cnode->size(); ++i) {
447 if (utils::isa<CNodePtr>(cnode->input(i))) {
448 return false;
449 }
450 if (utils::isa<Parameter>(cnode->input(i)) && !cnode->input(i)->cast<ParameterPtr>()->has_default()) {
451 return false;
452 }
453 }
454 if (CheckPrimitiveType(cnode, prim::kPrimStridedSlice) && cnode->size() != kOnnxStridedSlice) {
455 return false;
456 }
457 } else if (CheckPrimitiveType(cnode, prim::kPrimScaleFusion)) {
458 MS_CHECK_TRUE_RET(cnode->size() >= kInputSizeThree, false);
459 auto weight_param = cnode->input(kInputIndexTwo);
460 MS_CHECK_TRUE_RET(weight_param != nullptr, false);
461 std::vector<int64_t> weight_shape;
462 if (FetchShapeFromAbstract(weight_param->abstract(), &weight_shape) != lite::RET_OK) {
463 MS_LOG(ERROR) << "Get shape from abstract failed.";
464 return false;
465 }
466 if (weight_shape.size() != 1) {
467 return false;
468 }
469 } else if (IsDynamicFormatOpWithAxis(prim->name())) {
470 if (prim->GetAttr(ops::kAxis) == nullptr) {
471 return false;
472 }
473 }
474 return true;
475 }
476
ChangeOpAxis(const FuncGraphPtr & func_graph,const CNodePtr & cnode,FormatTransNodeType trans_type)477 STATUS TransposeStrategy::ChangeOpAxis(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
478 FormatTransNodeType trans_type) {
479 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
480 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
481 MS_CHECK_TRUE_MSG(prim != nullptr, lite::RET_NULL_PTR, "GetValueNode Failed");
482 if (IsDynamicFormatOpWithAxis(prim->name()) && !JudgeIs4DInput(&node_infer_shape_, cnode)) {
483 return lite::RET_NOT_SUPPORT;
484 }
485 std::map<std::string,
486 std::function<STATUS(const FuncGraphPtr &, const CNodePtr &, FormatTransNodeType, NodeInferShape *)>>
487 process_funcs = {
488 {prim::kPrimConcat->name(), ChangeCommonOp}, {prim::kPrimSplit->name(), ChangeCommonOp},
489 {prim::kPrimCrop->name(), ChangeOpCrop}, {prim::kPrimPadFusion->name(), ChangeOpPad},
490 {prim::kPrimSliceFusion->name(), ChangeOpSlice}, {prim::kPrimStridedSlice->name(), ChangeOpStrideSlice},
491 {prim::kPrimScaleFusion->name(), ChangeCommonOp}};
492 auto iter = process_funcs.find(prim->name());
493 if (iter != process_funcs.end()) {
494 return iter->second(func_graph, cnode, trans_type, &node_infer_shape_);
495 }
496 return lite::RET_OK;
497 }
498
TransposeInsertDependOnShape(const FuncGraphPtr & func_graph,const CNodePtr & cnode,bool before,size_t index)499 STATUS TransposeStrategy::TransposeInsertDependOnShape(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
500 bool before, size_t index) {
501 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
502 auto manager = func_graph->manager();
503 if (manager == nullptr) {
504 manager = Manage(func_graph, true);
505 }
506 if (manager == nullptr) {
507 MS_LOG(ERROR) << "manager is nullptr.";
508 return lite::RET_ERROR;
509 }
510 auto node_users = manager->node_users()[cnode];
511 if (node_users.empty()) {
512 MS_LOG(ERROR) << "cnode is isolated.";
513 return lite::RET_ERROR;
514 }
515 if (!utils::isa<CNodePtr>(node_users.front().first)) {
516 return lite::RET_ERROR;
517 }
518 CNodePtr base_node = before ? cnode : node_users.front().first->cast<CNodePtr>();
519 MS_ASSERT(base_node != nullptr);
520 size_t input_index = before ? index : static_cast<size_t>(node_users.front().second);
521 auto shape = node_infer_shape_.GetInputShape(base_node, input_index);
522 if (!lite::JudgeDynamicShape(shape) && shape.size() != kNH2NC.size()) {
523 return lite::RET_NO_CHANGE;
524 }
525 return lite::RET_OK;
526 }
527
IsInOutCanFuison(const std::vector<AnfNodePtr> & nodes,size_t * trans_count,FormatTransNodeType * trans_type)528 bool TransposeStrategy::IsInOutCanFuison(const std::vector<AnfNodePtr> &nodes, size_t *trans_count,
529 FormatTransNodeType *trans_type) {
530 MS_ASSERT(trans_count != nullptr && trans_type != nullptr);
531 for (auto &node : nodes) {
532 if (CheckPrimitiveType(node, prim::kPrimTranspose)) {
533 FormatTransNodeType cur_type;
534 std::vector<int> perm;
535 auto cnode = node->cast<CNodePtr>();
536 if (cnode == nullptr) {
537 return false;
538 }
539 if (GetTransposePerm(cnode, &perm) != lite::RET_OK) {
540 return false;
541 }
542 if (perm == kNH2NC) {
543 cur_type = kNHWC2NCHW;
544 } else if (perm == kNC2NH) {
545 cur_type = kNCHW2NHWC;
546 } else {
547 return false;
548 }
549 if (*trans_type == kNONE) {
550 *trans_type = cur_type;
551 } else if (*trans_type != cur_type) {
552 return false;
553 }
554 *trans_count += 1;
555 }
556 }
557 return true;
558 }
559
DecidePreAndPostTransType(const TransTypePair * trans_info,TransTypePair * trans_insert_info) const560 void TransposeStrategy::DecidePreAndPostTransType(const TransTypePair *trans_info,
561 TransTypePair *trans_insert_info) const {
562 if (trans_info->pre_ == trans_info->post_) {
563 return;
564 }
565 if (trans_info->pre_ != kNONE && trans_info->post_ != kNONE) {
566 trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
567 trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
568 } else if (trans_info->pre_ == kNONE) {
569 trans_insert_info->pre_ = trans_info->post_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
570 trans_insert_info->post_ = trans_info->post_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
571 } else {
572 trans_insert_info->pre_ = trans_info->pre_ == kNHWC2NCHW ? kNCHW2NHWC : kNHWC2NCHW;
573 trans_insert_info->post_ = trans_info->pre_ == kNHWC2NCHW ? kNHWC2NCHW : kNCHW2NHWC;
574 }
575 }
576 } // namespace opt
577 } // namespace mindspore
578