1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/irpass/less_batch_normalization.h"
18
19 #include <set>
20 #include <unordered_map>
21
22 namespace mindspore {
23 namespace opt {
24 namespace irpass {
25 namespace {
26 enum class RemoveNodeType { kOtherNode = 0, kOptimizerNode };
27 const char kLessBatchNormalizationPassName[] = "less_bn";
28 constexpr auto kValidResidualStructureIndex = 1;
29 constexpr auto kBNParametersStartIndex = 2;
30 // Pattern 1
31 // Add -> BatchNorm -> Conv2D -> Relu ... -> End
32 // ↘ BatchNorm -> Conv2D -> -> -> -> ↗
33 constexpr auto kFirstBranchPattern1 = 12;
34 constexpr auto kSecondBranchPattern1 = 3;
35 constexpr auto kFirstBranchStartIndexPattern1 = 4;
36 constexpr auto kFirstBranchEndIndexPattern1 = 11;
37 constexpr auto kSecondBranchStartIndexPattern1 = kFirstBranchPattern1;
38 constexpr auto kSecondBranchEndIndexPattern1 = 2 + kFirstBranchPattern1;
39 const std::vector<kStructureTuple> ResidualStructureBasePattern{
40 {kFirstBranchPattern1,
41 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
42 {kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}},
43 {kSecondBranchPattern1,
44 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
45 {kSecondBranchStartIndexPattern1, kSecondBranchEndIndexPattern1}}};
46 // Pattern 2
47 // Add -> BatchNorm -> Conv2D -> Relu ... -> End
48 // ↘ -> -> ... ... ... -> -> ↗
49 constexpr auto kFirstBranchPattern2 = 12;
50 constexpr auto kSecondBranchPattern2 = 1;
51 constexpr auto kFirstBranchStartIndexPattern2 = 4;
52 constexpr auto kFirstBranchEndIndexPattern2 = 11;
53 constexpr auto kSecondBranchStartIndexPattern2 = kFirstBranchPattern2;
54 constexpr auto kSecondBranchEndIndexPattern2 = 1 + kSecondBranchPattern2;
55 const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
56 {kFirstBranchPattern2,
57 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
58 {kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}},
59 {kSecondBranchPattern2, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern2, kSecondBranchEndIndexPattern2}}};
60 // Pattern 3
61 // Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End
62 // ↘ BatchNorm -> Conv2D -> -> ... ... ... -> -> ↗
63 constexpr auto kFirstBranchPattern3 = 11;
64 constexpr auto kSecondBranchPattern3 = 3;
65 constexpr auto kFirstBranchStartIndexPattern3 = 4;
66 constexpr auto kFirstBranchEndIndexPattern3 = 10;
67 constexpr auto kSecondBranchStartIndexPattern3 = kFirstBranchPattern3;
68 constexpr auto kSecondBranchEndIndexPattern3 = 2 + kFirstBranchPattern3;
69 const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
70 {kFirstBranchPattern3,
71 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
72 prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
73 prim::kPrimConv2D},
74 {kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}},
75 {kSecondBranchPattern3,
76 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
77 {kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
78 // Pattern 4
79 constexpr auto kFirstBranchPattern4 = 8;
80 constexpr auto kSecondBranchPattern4 = 3;
81 constexpr auto kFirstBranchStartIndexPattern4 = 4;
82 constexpr auto kFirstBranchEndIndexPattern4 = 6;
83 constexpr auto kSecondBranchStartIndexPattern4 = kFirstBranchPattern4;
84 constexpr auto kSecondBranchEndIndexPattern4 = 3 + kFirstBranchPattern4;
85 const std::vector<kStructureTuple> BasicStructBasePattern{
86 {kFirstBranchPattern4,
87 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
88 {kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
89 {kSecondBranchPattern4,
90 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
91 {kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
92 // Pattern 5
93 constexpr auto kFirstBranchPattern5 = 7;
94 constexpr auto kSecondBranchPattern5 = 1;
95 constexpr auto kFirstBranchStartIndexPattern5 = 4;
96 constexpr auto kFirstBranchEndIndexPattern5 = 6;
97 constexpr auto kSecondBranchStartIndexPattern5 = kFirstBranchPattern5;
98 constexpr auto kSecondBranchEndIndexPattern5 = 3 + kFirstBranchPattern5;
99 const std::vector<kStructureTuple> BasicStructFirstStepPattern{
100 {kFirstBranchPattern5,
101 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
102 prim::kPrimBatchNorm, prim::kPrimConv2D},
103 {kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
104 {kSecondBranchPattern5, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
105 // Pattern 6
106 constexpr auto kFirstBranchPattern6 = 8;
107 constexpr auto kSecondBranchPattern6 = 1;
108 constexpr auto kFirstBranchStartIndexPattern6 = 4;
109 constexpr auto kFirstBranchEndIndexPattern6 = 6;
110 constexpr auto kSecondBranchStartIndexPattern6 = kFirstBranchPattern6;
111 constexpr auto kSecondBranchEndIndexPattern6 = 3 + kFirstBranchPattern6;
112 const std::vector<kStructureTuple> BasicStructShortCutPattern{
113 {kFirstBranchPattern6,
114 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
115 {kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
116 {kSecondBranchPattern6, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
117 // Pattern 7
118 constexpr auto kFirstBranchPattern7 = 1;
119 constexpr auto kSecondBranchPattern7 = 13;
120 constexpr auto kFirstBranchStartIndexPattern7 = SIZE_MAX;
121 constexpr auto kFirstBranchEndIndexPattern7 = SIZE_MAX;
122 constexpr auto kSecondBranchStartIndexPattern7 = 7;
123 constexpr auto kSecondBranchEndIndexPattern7 = 10;
124 const std::vector<kStructureTuple> InvertedResidualShortCutPattern{
125 {kFirstBranchPattern7,
126 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
127 {kFirstBranchStartIndexPattern7, kFirstBranchEndIndexPattern7}},
128 {kSecondBranchPattern7,
129 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
130 prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
131 prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
132 {kSecondBranchStartIndexPattern7, kSecondBranchEndIndexPattern7}}};
133 // Pattern 8
134 constexpr auto kFirstBranchPattern8 = 4;
135 constexpr auto kFirstBranchStartIndexPattern8 = 0;
136 constexpr auto kFirstBranchEndIndexPattern8 = 3;
137 const std::vector<kStructureTuple> InvertedResidualPattern{
138 {kFirstBranchPattern8,
139 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimAdd},
140 {kFirstBranchStartIndexPattern8, kFirstBranchEndIndexPattern8}}};
141 // Pattern 9
142 constexpr auto kFirstBranchPattern9 = 1;
143 constexpr auto kSecondBranchPattern9 = 12;
144 constexpr auto kFirstBranchStartIndexPattern9 = SIZE_MAX;
145 constexpr auto kFirstBranchEndIndexPattern9 = SIZE_MAX;
146 constexpr auto kSecondBranchStartIndexPattern9 = 7;
147 constexpr auto kSecondBranchEndIndexPattern9 = 10;
148 const std::vector<kStructureTuple> InvertedResidualShortCutPattern2{
149 {kFirstBranchPattern9, {prim::kPrimAdd}, {kFirstBranchStartIndexPattern9, kFirstBranchEndIndexPattern9}},
150 {kSecondBranchPattern9,
151 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
152 prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
153 prim::kPrimConv2D, prim::kPrimAdd},
154 {kSecondBranchStartIndexPattern9, kSecondBranchEndIndexPattern9}}};
155 // Pattern 10
156 constexpr auto kFirstBranchPattern10 = 5;
157 constexpr auto kFirstBranchStartIndexPattern10 = 0;
158 constexpr auto kFirstBranchEndIndexPattern10 = 4;
159 const std::vector<kStructureTuple> InvertedResidualPattern2{
160 {kFirstBranchPattern10,
161 {prim::kPrimReduceMean, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
162 {kFirstBranchStartIndexPattern10, kFirstBranchEndIndexPattern10}}};
163 // Pattern 11
164 constexpr auto kFirstBranchPattern11 = 17;
165 constexpr auto kFirstBranchStartIndexPattern11 = 3;
166 constexpr auto kFirstBranchEndIndexPattern11 = 6;
167 const std::vector<kStructureTuple> InvertedResidualPattern3{
168 {kFirstBranchPattern11,
169 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
170 prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
171 prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6,
172 prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
173 {kFirstBranchStartIndexPattern11, kFirstBranchEndIndexPattern11}}};
174 // Pattern 12
175 constexpr auto kFirstBranchPattern12 = 1;
176 constexpr auto kSecondBranchPattern12 = 9;
177 constexpr auto kFirstBranchStartIndexPattern12 = SIZE_MAX;
178 constexpr auto kFirstBranchEndIndexPattern12 = SIZE_MAX;
179 constexpr auto kSecondBranchStartIndexPattern12 = kFirstBranchPattern12 + 5;
180 constexpr auto kSecondBranchEndIndexPattern12 = kFirstBranchPattern12 + 8;
181 const std::vector<kStructureTuple> DenseBlockShortCutPattern{
182 {kFirstBranchPattern12, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern12, kFirstBranchEndIndexPattern12}},
183 {kSecondBranchPattern12,
184 {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
185 prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
186 {kSecondBranchStartIndexPattern12, kSecondBranchEndIndexPattern12}}};
187 // Pattern 13
188 constexpr auto kFirstBranchPattern13 = 5;
189 constexpr auto kFirstBranchStartIndexPattern13 = 0;
190 constexpr auto kFirstBranchEndIndexPattern13 = 4;
191 const std::vector<kStructureTuple> DenseBlockPattern{
192 {kFirstBranchPattern13,
193 {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
194 {kFirstBranchStartIndexPattern13, kFirstBranchEndIndexPattern13}}};
195 // Pattern 14
196 constexpr auto kFirstBranchPattern14 = 9;
197 constexpr auto kSecondBranchPattern14 = 1;
198 constexpr auto kFirstBranchStartIndexPattern14 = 5;
199 constexpr auto kFirstBranchEndIndexPattern14 = 8;
200 constexpr auto kSecondBranchStartIndexPattern14 = SIZE_MAX;
201 constexpr auto kSecondBranchEndIndexPattern14 = SIZE_MAX;
202 const std::vector<kStructureTuple> DenseBlockShortCutPattern2{
203 {kFirstBranchPattern14,
204 {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
205 prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
206 {kFirstBranchStartIndexPattern14, kFirstBranchEndIndexPattern14}},
207 {kSecondBranchPattern14, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern14, kSecondBranchEndIndexPattern14}}};
208 // Pattern 15
209 constexpr auto kFirstBranchPattern15 = 9;
210 constexpr auto kSecondBranchPattern15 = 1;
211 constexpr auto kFirstBranchStartIndexPattern15 = 0;
212 constexpr auto kFirstBranchEndIndexPattern15 = 4;
213 constexpr auto kSecondBranchStartIndexPattern15 = SIZE_MAX;
214 constexpr auto kSecondBranchEndIndexPattern15 = SIZE_MAX;
215 const std::vector<kStructureTuple> DenseBlockPoolPattern{
216 {kFirstBranchPattern15,
217 {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
218 prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
219 {kFirstBranchStartIndexPattern15, kFirstBranchEndIndexPattern15}},
220 {kSecondBranchPattern15, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern15, kSecondBranchEndIndexPattern15}}};
221 // Pattern 16
222 constexpr auto kFirstBranchPattern16 = 1;
223 constexpr auto kSecondBranchPattern16 = 9;
224 constexpr auto kFirstBranchStartIndexPattern16 = SIZE_MAX;
225 constexpr auto kFirstBranchEndIndexPattern16 = SIZE_MAX;
226 constexpr auto kSecondBranchStartIndexPattern16 = kFirstBranchPattern16;
227 constexpr auto kSecondBranchEndIndexPattern16 = kFirstBranchPattern16 + 4;
228 const std::vector<kStructureTuple> DenseBlockPoolPatter2{
229 {kFirstBranchPattern16, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern16, kFirstBranchEndIndexPattern16}},
230 {kSecondBranchPattern16,
231 {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
232 prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
233 {kSecondBranchStartIndexPattern16, kSecondBranchEndIndexPattern16}}};
234 static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {ResidualStructureBasePattern,
235 ResidualStructureShortCutPattern,
236 ResidualStructureFirstStepPattern,
237 BasicStructBasePattern,
238 BasicStructFirstStepPattern,
239 BasicStructShortCutPattern,
240 InvertedResidualShortCutPattern,
241 InvertedResidualPattern,
242 InvertedResidualShortCutPattern2,
243 InvertedResidualPattern2,
244 InvertedResidualPattern3,
245 DenseBlockShortCutPattern,
246 DenseBlockPattern,
247 DenseBlockShortCutPattern2,
248 DenseBlockPoolPattern,
249 DenseBlockPoolPatter2};
250 const std::set<PrimitivePtr> kNeedRemoveNodeSet{
251 prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
252 prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
253 static std::unordered_map<RemoveNodeType, std::unordered_set<size_t>> kRemoveIndex{
254 {RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}};
255
NeedRemove(const ParameterPtr & a,const std::vector<AnfNodePtr> & parameter_list)256 bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) {
257 if (a == nullptr) {
258 return false;
259 }
260 return std::any_of(parameter_list.begin(), parameter_list.end(), [&a](const AnfNodePtr &b) {
261 return (b->isa<Parameter>() && a->name() == b->cast<ParameterPtr>()->name());
262 });
263 }
264
IsNotRealUseNode(const AnfNodePtr & node)265 bool IsNotRealUseNode(const AnfNodePtr &node) {
266 for (const auto &prim : kNeedRemoveNodeSet) {
267 if (IsPrimitiveCNode(node, prim)) {
268 return true;
269 }
270 }
271 return false;
272 }
273
ConvertRemoveNodeToVirtualNode(const CNodePtr & cnode)274 CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) {
275 MS_EXCEPTION_IF_NULL(cnode);
276 std::vector<AnfNodePtr> args;
277 size_t index = 0;
278 const auto &inputs = cnode->inputs();
279 auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode];
280 if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) {
281 remove_index = kRemoveIndex[RemoveNodeType::kOtherNode];
282 }
283
284 (void)std::copy_if(
285 inputs.begin(), inputs.end(), std::back_inserter(args),
286 [&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); });
287
288 (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
289 const auto &fg = cnode->func_graph();
290 MS_EXCEPTION_IF_NULL(fg);
291 auto new_make_tuple = fg->NewCNode(args);
292 return new_make_tuple;
293 }
294
IsRealRemoveParameterNode(const FuncGraphManagerPtr & manager,const AnfNodePtr & parameter)295 bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) {
296 auto param_output = manager->node_users().find(parameter);
297 if (param_output == manager->node_users().end()) {
298 return true;
299 }
300
301 bool need_remove = true;
302 auto output_info_list = param_output->second;
303 for (const auto &output_info : output_info_list) {
304 const auto &node = output_info.first;
305 if (IsNotRealUseNode(node)) {
306 const auto &cnode = node->cast<CNodePtr>();
307 const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
308 (void)manager->Replace(cnode, new_cnode);
309 continue;
310 }
311 need_remove = false;
312 }
313
314 return need_remove;
315 }
316
RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & remove_parameter_list)317 void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
318 const std::vector<AnfNodePtr> &remove_parameter_list) {
319 auto roots = manager->roots();
320 if (roots.size() != 1) {
321 MS_LOG(ERROR) << "The size of roots " << roots.size() << " is not valid.";
322 return;
323 }
324 auto root_graph = *(roots.begin());
325 MS_EXCEPTION_IF_NULL(root_graph);
326
327 std::vector<AnfNodePtr> real_remove_parameter_list;
328 (void)std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(),
329 std::back_inserter(real_remove_parameter_list),
330 [&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
331
332 auto root_parameters = root_graph->parameters();
333 size_t origin_param_count = root_parameters.size();
334 (void)root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
335 [&real_remove_parameter_list](const AnfNodePtr &node) {
336 return NeedRemove(node->cast<ParameterPtr>(),
337 real_remove_parameter_list);
338 }),
339 root_parameters.end());
340 size_t remove_param_count = origin_param_count - root_parameters.size();
341 size_t hyper_param_count = root_graph->hyper_param_count();
342 if (remove_param_count > hyper_param_count) {
343 MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
344 return;
345 }
346 hyper_param_count = hyper_param_count - remove_param_count;
347 root_graph->set_hyper_param_count(hyper_param_count);
348 manager->SetParameters(root_graph, root_parameters);
349 }
350 } // namespace
351
MatchStructureNode(const CNodePtr & cnode,const int32_t index,const kStructureTuple & patternTuple) const352 bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index,
353 const kStructureTuple &patternTuple) const {
354 if (index < 0) {
355 return false;
356 }
357 const auto &use_pattern = std::get<1>(patternTuple);
358 int32_t use_index = index % static_cast<int32_t>(use_pattern.size());
359 if (!IsPrimitiveCNode(cnode, use_pattern[IntToSize(use_index)]) &&
360 use_pattern[IntToSize(use_index)] != prim::kPrimTupleGetItem) {
361 return false;
362 }
363 return true;
364 }
365
MatchGraphStructure(const CNodePtr & cnode,const std::vector<kStructureTuple> & match_pattern)366 bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode,
367 const std::vector<kStructureTuple> &match_pattern) {
368 if ((match_branch_ + 1 >= total_match_node_.size()) || (match_branch_ >= match_pattern.size())) {
369 return false;
370 }
371
372 int32_t index = static_cast<int32_t>(match_node_) - static_cast<int32_t>(total_match_node_[match_branch_]);
373 const auto &pattern = match_pattern[match_branch_];
374 if (!MatchStructureNode(cnode, index, pattern)) {
375 return false;
376 }
377
378 match_node_++;
379 if (match_node_ == total_match_node_.back()) {
380 is_match_ = true;
381 return false;
382 }
383 if (match_node_ == total_match_node_[match_branch_ + 1]) {
384 match_branch_++;
385 return false;
386 }
387 return true;
388 }
389
IsRemoveNode(const CNodePtr & cnode,const std::vector<kStructureTuple> & match_pattern)390 void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern) {
391 if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) &&
392 !IsValueNode<FuncGraph>(cnode->input(0))) {
393 return;
394 }
395 if (match_pattern.empty()) {
396 return;
397 }
398 const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_));
399 if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) {
400 (void)remove_node_list_.insert(cnode);
401 }
402 }
403
operator ()(const OptimizerPtr & optimizer,const AnfNodePtr & node)404 AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
405 const auto &fg = node->func_graph();
406 MS_EXCEPTION_IF_NULL(fg);
407 if (!fg->has_attr(kLessBatchNormalizationPassName)) {
408 return nullptr;
409 }
410 match_pattern_ = 0;
411 while (match_pattern_ < kNeedMatchPattern.size()) {
412 Reset();
413 const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_);
414 size_t sum_match_node = 0;
415 (void)std::for_each(current_pattern.begin(), current_pattern.end(), [&, this](const kStructureTuple &t) {
416 sum_match_node += std::get<0>(t);
417 (void)this->total_match_node_.emplace_back(sum_match_node);
418 });
419 auto cnode = node->cast<CNodePtr>();
420 if (cnode == nullptr || cnode->inputs().empty()) {
421 return nullptr;
422 }
423 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
424 std::vector<PredicateFuncType> funcs(cnode->inputs().size() - 1, IsCNode);
425 AnfVisitor::Match(prim, funcs)(node);
426 if (is_match_) {
427 break;
428 }
429 match_pattern_++;
430 }
431
432 if (!is_match_ || remove_node_list_.empty()) {
433 return nullptr;
434 }
435
436 auto manager = optimizer->manager();
437 MS_EXCEPTION_IF_NULL(manager);
438 std::vector<AnfNodePtr> remove_load_list;
439 std::vector<AnfNodePtr> remove_parameter_list;
440 for (auto &iter : remove_node_list_) {
441 // Need to remove batchnorm's parameter input.
442 if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) {
443 (void)std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
444 std::back_inserter(remove_load_list),
445 [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
446 (void)std::transform(
447 remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list),
448 [](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); });
449 }
450 // Remove useless node.
451 auto input_cnode = iter->input(kValidResidualStructureIndex);
452 (void)manager->Replace(iter, input_cnode);
453 }
454 RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list);
455
456 return node;
457 }
458
Visit(const CNodePtr & cnode)459 void LessBatchNormalization::Visit(const CNodePtr &cnode) {
460 if (cnode == nullptr) {
461 return;
462 }
463
464 const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_);
465 IsRemoveNode(cnode, current_pattern);
466 if (!MatchGraphStructure(cnode, current_pattern)) {
467 return;
468 }
469
470 auto search_input = cnode->input(kValidResidualStructureIndex);
471 if (search_input != nullptr && search_input->isa<CNode>()) {
472 this->Visit(search_input->cast<CNodePtr>());
473 }
474 return;
475 }
476
Reset()477 void LessBatchNormalization::Reset() {
478 remove_node_list_.clear();
479 total_match_node_.clear();
480 (void)total_match_node_.emplace_back(0);
481 match_node_ = 0;
482 match_branch_ = 0;
483 is_match_ = false;
484 }
485 } // namespace irpass
486 } // namespace opt
487 } // namespace mindspore
488