1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/strided_slice_fusion.h"
19 #include <memory>
20 #include <vector>
21 #include "mindspore/core/ops/array_ops.h"
22 #include "tools/optimizer/fusion/strided_slice_checker.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "tools/lite_exporter/fetch_content.h"
25 #include "ir/func_graph.h"
26 #include "nnacl/op_base.h"
27 #include "ops/op_name.h"
28
29 namespace mindspore {
30 namespace opt {
31 namespace {
CheckContinuity(const std::vector<CNodePtr> & nodes,int axis)32 bool CheckContinuity(const std::vector<CNodePtr> &nodes, int axis) {
33 MS_ASSERT(!nodes.empty());
34 for (const auto &node : nodes) {
35 if (!StridedSliceChecker::CheckCommonInfo(node)) {
36 return false;
37 }
38 }
39 std::vector<int> first_begin;
40 if (StridedSliceChecker::GetBegin(nodes.front(), &first_begin) != lite::RET_OK) {
41 return false;
42 }
43 std::vector<int> first_end;
44 if (StridedSliceChecker::GetEnd(nodes.front(), &first_end) != lite::RET_OK) {
45 return false;
46 }
47 MS_CHECK_TRUE_RET(first_begin.size() == first_end.size(), false);
48 if (axis >= static_cast<int>(first_begin.size())) {
49 return false;
50 }
51 for (size_t i = 1; i < nodes.size(); ++i) {
52 std::vector<int> second_begin;
53 if (StridedSliceChecker::GetBegin(nodes[i], &second_begin) != lite::RET_OK) {
54 return false;
55 }
56 std::vector<int> second_end;
57 if (StridedSliceChecker::GetEnd(nodes[i], &second_end) != lite::RET_OK) {
58 return false;
59 }
60 MS_CHECK_TRUE_RET(second_begin.size() == second_end.size(), false);
61 if (second_begin.size() != first_begin.size()) {
62 return false;
63 }
64 for (int j = 0; j < static_cast<int>(first_begin.size()); ++j) {
65 if (j == axis) {
66 continue;
67 }
68 if (second_begin[j] != first_begin[j] || second_end[j] != first_end[j]) {
69 return false;
70 }
71 }
72 if (second_begin[axis] != first_end[axis]) {
73 return false;
74 }
75 first_begin = second_begin;
76 first_end = second_end;
77 }
78 return true;
79 }
80 } // namespace
81
Run(const FuncGraphPtr & func_graph)82 bool StridedSliceFusion::Run(const FuncGraphPtr &func_graph) {
83 MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "FuncGraph is a nullptr.");
84 auto manager = func_graph->manager();
85 MS_CHECK_TRUE_MSG(manager != nullptr, false, "The manager of this graph is a nullptr.");
86 auto nodes_list = TopoSort(func_graph->get_return());
87 for (auto &node : nodes_list) {
88 MS_CHECK_TRUE_RET(node != nullptr, false);
89 if (!utils::isa<CNode>(node)) {
90 continue;
91 }
92 if (!CheckPrimitiveType(node, prim::kPrimConcat)) {
93 continue;
94 }
95 auto cnode = node->cast<CNodePtr>();
96 if (IsMarkedTrainOp(cnode)) {
97 continue;
98 }
99 auto prim = GetCNodePrimitive(cnode);
100 MS_CHECK_TRUE_MSG(prim != nullptr, false, "Concat's prim is a nullptr.");
101 axis_ = prim->GetAttr(ops::kAxis) == nullptr ? 0 : static_cast<int>(GetValue<int64_t>(prim->GetAttr(ops::kAxis)));
102 if (axis_ < 0) {
103 continue;
104 }
105 if (Process(func_graph, cnode) != lite::RET_OK) {
106 MS_LOG(ERROR) << "Do StridedSliceFusion failed.";
107 return false;
108 }
109 }
110 UpdateManager(func_graph);
111 return true;
112 }
113
Process(const FuncGraphPtr & func_graph,const CNodePtr & cnode)114 int StridedSliceFusion::Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
115 MS_ASSERT(func_graph != nullptr && cnode != nullptr);
116 FindStridedSliceOp(func_graph, cnode);
117 if (!CheckCanFusion()) {
118 return lite::RET_OK;
119 }
120 auto manager = func_graph->manager();
121 MS_ASSERT(manager != nullptr);
122 for (const auto &nodes : strided_slice_ops_) {
123 auto first_node = nodes.front();
124 auto end_node = nodes.back();
125 manager->SetEdge(first_node, kInputIndexThree, end_node->input(kInputIndexThree));
126 for (size_t i = 1; i < nodes.size(); ++i) {
127 if (!manager->Replace(nodes[i], NewValueNode(std::make_shared<UMonad>()))) {
128 MS_LOG(ERROR) << "Manager Replace strided_slice op with Mond failed.";
129 return lite::RET_ERROR;
130 }
131 }
132 auto first_prim = GetCNodePrimitive(first_node);
133 MS_ASSERT(first_prim != nullptr);
134 auto end_prim = GetCNodePrimitive(end_node);
135 MS_ASSERT(end_prim != nullptr);
136 first_prim->set_attr(ops::kEndMask, end_prim->GetAttr(ops::kEndMask));
137 }
138 auto inputs = cnode->inputs();
139 std::vector<AnfNodePtr> new_inputs;
140 for (const auto &input : inputs) {
141 if (utils::isa<ValueNode>(input) && utils::isa<Monad>(input->cast<ValueNodePtr>()->value())) {
142 continue;
143 }
144 new_inputs.push_back(input);
145 }
146 cnode->set_inputs(new_inputs);
147 return lite::RET_OK;
148 }
149
FindStridedSliceOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)150 void StridedSliceFusion::FindStridedSliceOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
151 strided_slice_ops_.clear();
152 AnfNodePtr input{nullptr};
153 size_t index = 0;
154 for (size_t i = 1; i < cnode->size(); ++i) {
155 if (!utils::isa<CNode>(cnode->input(i)) || !CheckPrimitiveType(cnode->input(i), prim::kPrimStridedSlice)) {
156 continue;
157 }
158 auto pre_cnode = cnode->input(i)->cast<CNodePtr>();
159 if (pre_cnode->size() != kInputSizeFive) {
160 strided_slice_ops_.clear();
161 return;
162 }
163 if (IsMultiOutputTensors(func_graph, pre_cnode)) {
164 continue;
165 }
166 auto input_cur = pre_cnode->input(1);
167 if (input_cur == nullptr) {
168 strided_slice_ops_.clear();
169 return;
170 }
171 if (input_cur == input && i - index == 1) {
172 strided_slice_ops_[strided_slice_ops_.size() - 1].push_back(pre_cnode);
173 } else {
174 strided_slice_ops_.push_back({pre_cnode});
175 input = input_cur;
176 }
177 index = i;
178 }
179 }
180
CheckCanFusion()181 bool StridedSliceFusion::CheckCanFusion() {
182 std::vector<std::vector<CNodePtr>> strided_slice_ops = strided_slice_ops_;
183 strided_slice_ops_.clear();
184 for (auto &nodes : strided_slice_ops) {
185 if (nodes.size() <= 1) {
186 continue;
187 }
188 if (CheckContinuity(nodes, axis_)) {
189 strided_slice_ops_.push_back(nodes);
190 }
191 }
192 if (strided_slice_ops_.empty()) {
193 return false;
194 }
195 return true;
196 }
197 } // namespace opt
198 } // namespace mindspore
199