1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/reduce_stack_fusion.h"
19 #include <functional>
20 #include "mindspore/core/ops/lite_ops.h"
21 #include "mindspore/core/ops/array_ops.h"
22 #include "tools/optimizer/common/gllo_utils.h"
23 #include "tools/lite_exporter/fetch_content.h"
24 #include "ops/op_name.h"
25 #include "nnacl/op_base.h"
26
27 namespace mindspore {
28 namespace opt {
Run(const FuncGraphPtr & func_graph)29 bool ReduceStackFusion::Run(const FuncGraphPtr &func_graph) {
30 if (func_graph == nullptr) {
31 MS_LOG(ERROR) << "func_graph is a nullptr, cannot do ReduceStackFusion.";
32 return false;
33 }
34 auto node_list = TopoSort(func_graph->get_return());
35 for (auto &node : node_list) {
36 if (!utils::isa<CNode>(node) || !CheckPrimitiveType(node, prim::kPrimStack)) {
37 continue;
38 }
39 auto stack_cnode = node->cast<CNodePtr>();
40 if (Process(func_graph, stack_cnode) != lite::RET_OK) {
41 MS_LOG(ERROR) << "Do ReduceStackFusion failed.";
42 return false;
43 }
44 }
45 return true;
46 }
47
Process(const FuncGraphPtr & func_graph,const CNodePtr & stack)48 int ReduceStackFusion::Process(const FuncGraphPtr &func_graph, const CNodePtr &stack) {
49 MS_ASSERT(func_graph != nullptr && stack != nullptr);
50 if (!CheckCanFusion(func_graph, stack)) {
51 return lite::RET_OK;
52 }
53 reduce_prim_->AddAttr(ops::kKeepDims, MakeValue(true));
54 auto manager = func_graph->manager();
55 if (manager == nullptr) {
56 MS_LOG(ERROR) << "Manager is a nullptr.";
57 return lite::RET_NULL_PTR;
58 }
59 if (!manager->Replace(stack, stack->input(1))) {
60 MS_LOG(ERROR) << "do Manager-Replace failed.";
61 return lite::RET_ERROR;
62 }
63 return lite::RET_OK;
64 }
65
CheckCanFusion(const FuncGraphPtr & func_graph,const CNodePtr & stack)66 bool ReduceStackFusion::CheckCanFusion(const FuncGraphPtr &func_graph, const CNodePtr &stack) {
67 MS_ASSERT(func_graph != nullptr && stack != nullptr);
68 if (IsMarkedTrainOp(stack)) {
69 return false;
70 }
71 if (stack->size() != kInputSizeTwo) {
72 return false;
73 }
74 auto prim = GetCNodePrimitive(stack);
75 MS_CHECK_TRUE_RET(prim != nullptr, false);
76 if (IsQuantParameterNode(prim)) {
77 return false;
78 }
79 auto axis = prim->GetAttr(ops::kAxis) == nullptr ? 0 : GetValue<int64_t>(prim->GetAttr(ops::kAxis));
80 auto input_node = stack->input(1);
81 if (!utils::isa<CNode>(input_node) || !CheckPrimitiveType(input_node, prim::kPrimReduceFusion)) {
82 return false;
83 }
84 auto reduce = input_node->cast<CNodePtr>();
85 return CheckReduce(func_graph, reduce, axis);
86 }
87
CheckReduce(const FuncGraphPtr & func_graph,const CNodePtr & reduce,int stack_axis)88 bool ReduceStackFusion::CheckReduce(const FuncGraphPtr &func_graph, const CNodePtr &reduce, int stack_axis) {
89 MS_ASSERT(func_graph != nullptr && reduce != nullptr);
90 if (IsMarkedTrainOp(reduce)) {
91 return false;
92 }
93 if (IsMultiOutputTensors(func_graph, reduce)) {
94 return false;
95 }
96 if (reduce->size() < kInputSizeThree || reduce->input(ops::kInputIndex2) == nullptr ||
97 utils::isa<CNode>(reduce->input(ops::kInputIndex2))) {
98 return false;
99 }
100 reduce_prim_ = GetCNodePrimitive(reduce);
101 MS_CHECK_TRUE_RET(reduce_prim_ != nullptr, false);
102 if (IsQuantParameterNode(reduce_prim_)) {
103 return false;
104 }
105 bool keep_dim =
106 reduce_prim_->GetAttr(ops::kKeepDims) == nullptr ? false : GetValue<bool>(reduce_prim_->GetAttr(ops::kKeepDims));
107 if (keep_dim) {
108 return false;
109 }
110 lite::DataInfo data_info;
111 if (lite::FetchConstData(reduce, ops::kInputIndex2, converter::kFmkTypeMs, &data_info, false) != lite::RET_OK) {
112 return false;
113 }
114 if ((data_info.data_type_ != kNumberTypeInt && data_info.data_type_ != kNumberTypeInt32) ||
115 data_info.data_ptr_ == nullptr) {
116 return false;
117 }
118 auto num = std::accumulate(data_info.shape_.begin(), data_info.shape_.end(), 1, std::multiplies<>());
119 if (num > 1) {
120 return false;
121 }
122 return *(static_cast<int *>(data_info.data_ptr_)) == stack_axis;
123 }
124 } // namespace opt
125 } // namespace mindspore
126