• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/strided_slice_fusion.h"
19 #include <memory>
20 #include <vector>
21 #include "mindspore/core/ops/array_ops.h"
22 #include "tools/optimizer/fusion/strided_slice_checker.h"
23 #include "tools/optimizer/common/gllo_utils.h"
24 #include "tools/lite_exporter/fetch_content.h"
25 #include "ir/func_graph.h"
26 #include "nnacl/op_base.h"
27 #include "ops/op_name.h"
28 
29 namespace mindspore {
30 namespace opt {
31 namespace {
CheckContinuity(const std::vector<CNodePtr> & nodes,int axis)32 bool CheckContinuity(const std::vector<CNodePtr> &nodes, int axis) {
33   MS_ASSERT(!nodes.empty());
34   for (const auto &node : nodes) {
35     if (!StridedSliceChecker::CheckCommonInfo(node)) {
36       return false;
37     }
38   }
39   std::vector<int> first_begin;
40   if (StridedSliceChecker::GetBegin(nodes.front(), &first_begin) != lite::RET_OK) {
41     return false;
42   }
43   std::vector<int> first_end;
44   if (StridedSliceChecker::GetEnd(nodes.front(), &first_end) != lite::RET_OK) {
45     return false;
46   }
47   MS_CHECK_TRUE_RET(first_begin.size() == first_end.size(), false);
48   if (axis >= static_cast<int>(first_begin.size())) {
49     return false;
50   }
51   for (size_t i = 1; i < nodes.size(); ++i) {
52     std::vector<int> second_begin;
53     if (StridedSliceChecker::GetBegin(nodes[i], &second_begin) != lite::RET_OK) {
54       return false;
55     }
56     std::vector<int> second_end;
57     if (StridedSliceChecker::GetEnd(nodes[i], &second_end) != lite::RET_OK) {
58       return false;
59     }
60     MS_CHECK_TRUE_RET(second_begin.size() == second_end.size(), false);
61     if (second_begin.size() != first_begin.size()) {
62       return false;
63     }
64     for (int j = 0; j < static_cast<int>(first_begin.size()); ++j) {
65       if (j == axis) {
66         continue;
67       }
68       if (second_begin[j] != first_begin[j] || second_end[j] != first_end[j]) {
69         return false;
70       }
71     }
72     if (second_begin[axis] != first_end[axis]) {
73       return false;
74     }
75     first_begin = second_begin;
76     first_end = second_end;
77   }
78   return true;
79 }
80 }  // namespace
81 
Run(const FuncGraphPtr & func_graph)82 bool StridedSliceFusion::Run(const FuncGraphPtr &func_graph) {
83   MS_CHECK_TRUE_MSG(func_graph != nullptr, false, "FuncGraph is a nullptr.");
84   auto manager = func_graph->manager();
85   MS_CHECK_TRUE_MSG(manager != nullptr, false, "The manager of this graph is a nullptr.");
86   auto nodes_list = TopoSort(func_graph->get_return());
87   for (auto &node : nodes_list) {
88     MS_CHECK_TRUE_RET(node != nullptr, false);
89     if (!utils::isa<CNode>(node)) {
90       continue;
91     }
92     if (!CheckPrimitiveType(node, prim::kPrimConcat)) {
93       continue;
94     }
95     auto cnode = node->cast<CNodePtr>();
96     if (IsMarkedTrainOp(cnode)) {
97       continue;
98     }
99     auto prim = GetCNodePrimitive(cnode);
100     MS_CHECK_TRUE_MSG(prim != nullptr, false, "Concat's prim is a nullptr.");
101     axis_ = prim->GetAttr(ops::kAxis) == nullptr ? 0 : static_cast<int>(GetValue<int64_t>(prim->GetAttr(ops::kAxis)));
102     if (axis_ < 0) {
103       continue;
104     }
105     if (Process(func_graph, cnode) != lite::RET_OK) {
106       MS_LOG(ERROR) << "Do StridedSliceFusion failed.";
107       return false;
108     }
109   }
110   UpdateManager(func_graph);
111   return true;
112 }
113 
Process(const FuncGraphPtr & func_graph,const CNodePtr & cnode)114 int StridedSliceFusion::Process(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
115   MS_ASSERT(func_graph != nullptr && cnode != nullptr);
116   FindStridedSliceOp(func_graph, cnode);
117   if (!CheckCanFusion()) {
118     return lite::RET_OK;
119   }
120   auto manager = func_graph->manager();
121   MS_ASSERT(manager != nullptr);
122   for (const auto &nodes : strided_slice_ops_) {
123     auto first_node = nodes.front();
124     auto end_node = nodes.back();
125     manager->SetEdge(first_node, kInputIndexThree, end_node->input(kInputIndexThree));
126     for (size_t i = 1; i < nodes.size(); ++i) {
127       if (!manager->Replace(nodes[i], NewValueNode(std::make_shared<UMonad>()))) {
128         MS_LOG(ERROR) << "Manager Replace strided_slice op with Mond failed.";
129         return lite::RET_ERROR;
130       }
131     }
132     auto first_prim = GetCNodePrimitive(first_node);
133     MS_ASSERT(first_prim != nullptr);
134     auto end_prim = GetCNodePrimitive(end_node);
135     MS_ASSERT(end_prim != nullptr);
136     first_prim->set_attr(ops::kEndMask, end_prim->GetAttr(ops::kEndMask));
137   }
138   auto inputs = cnode->inputs();
139   std::vector<AnfNodePtr> new_inputs;
140   for (const auto &input : inputs) {
141     if (utils::isa<ValueNode>(input) && utils::isa<Monad>(input->cast<ValueNodePtr>()->value())) {
142       continue;
143     }
144     new_inputs.push_back(input);
145   }
146   cnode->set_inputs(new_inputs);
147   return lite::RET_OK;
148 }
149 
FindStridedSliceOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)150 void StridedSliceFusion::FindStridedSliceOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
151   strided_slice_ops_.clear();
152   AnfNodePtr input{nullptr};
153   size_t index = 0;
154   for (size_t i = 1; i < cnode->size(); ++i) {
155     if (!utils::isa<CNode>(cnode->input(i)) || !CheckPrimitiveType(cnode->input(i), prim::kPrimStridedSlice)) {
156       continue;
157     }
158     auto pre_cnode = cnode->input(i)->cast<CNodePtr>();
159     if (pre_cnode->size() != kInputSizeFive) {
160       strided_slice_ops_.clear();
161       return;
162     }
163     if (IsMultiOutputTensors(func_graph, pre_cnode)) {
164       continue;
165     }
166     auto input_cur = pre_cnode->input(1);
167     if (input_cur == nullptr) {
168       strided_slice_ops_.clear();
169       return;
170     }
171     if (input_cur == input && i - index == 1) {
172       strided_slice_ops_[strided_slice_ops_.size() - 1].push_back(pre_cnode);
173     } else {
174       strided_slice_ops_.push_back({pre_cnode});
175       input = input_cur;
176     }
177     index = i;
178   }
179 }
180 
CheckCanFusion()181 bool StridedSliceFusion::CheckCanFusion() {
182   std::vector<std::vector<CNodePtr>> strided_slice_ops = strided_slice_ops_;
183   strided_slice_ops_.clear();
184   for (auto &nodes : strided_slice_ops) {
185     if (nodes.size() <= 1) {
186       continue;
187     }
188     if (CheckContinuity(nodes, axis_)) {
189       strided_slice_ops_.push_back(nodes);
190     }
191   }
192   if (strided_slice_ops_.empty()) {
193     return false;
194   }
195   return true;
196 }
197 }  // namespace opt
198 }  // namespace mindspore
199