1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/irpass/less_batch_normalization.h"
18
19 #include <set>
20 #include "mindspore/core/ops/sequence_ops.h"
21 #include "mindspore/core/ops/conv_pool_ops.h"
22 #include "mindspore/core/ops/nn_optimizer_ops.h"
23 #include "mindspore/core/ops/nn_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "utils/hash_map.h"
28
29 namespace mindspore {
30 namespace opt {
31 namespace irpass {
32 namespace {
33 enum class RemoveNodeType { kOtherNode = 0, kOptimizerNode };
34 const char kLessBatchNormalizationPassName[] = "less_bn";
35 const auto kBoostType = "boost_type";
36 constexpr auto kValidResidualStructureIndex = 1;
37 constexpr auto kBNParametersStartIndex = 2;
38 // Pattern 1
39 // Add -> BatchNorm -> Conv2D -> Relu ... -> End
40 // BatchNorm -> Conv2D -> -> -> ->
41 constexpr auto kFirstBranchPattern1 = 12;
42 constexpr auto kSecondBranchPattern1 = 3;
43 constexpr auto kFirstBranchStartIndexPattern1 = 4;
44 constexpr auto kFirstBranchEndIndexPattern1 = 11;
45 constexpr auto kSecondBranchStartIndexPattern1 = kFirstBranchPattern1;
46 constexpr auto kSecondBranchEndIndexPattern1 = 2 + kFirstBranchPattern1;
47 const std::vector<kStructureTuple> ResidualStructureBasePattern{
48 {kFirstBranchPattern1,
49 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU},
50 {kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}},
51 {kSecondBranchPattern1,
52 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
53 {kSecondBranchStartIndexPattern1, kSecondBranchEndIndexPattern1}}};
54 // Pattern 2
55 // Add -> BatchNorm -> Conv2D -> Relu ... -> End
56 // -> -> ... ... ... -> ->
57 constexpr auto kFirstBranchPattern2 = 12;
58 constexpr auto kSecondBranchPattern2 = 1;
59 constexpr auto kFirstBranchStartIndexPattern2 = 4;
60 constexpr auto kFirstBranchEndIndexPattern2 = 11;
61 constexpr auto kSecondBranchStartIndexPattern2 = kFirstBranchPattern2;
62 constexpr auto kSecondBranchEndIndexPattern2 = 1 + kSecondBranchPattern2;
63 const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
64 {kFirstBranchPattern2,
65 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU},
66 {kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}},
67 {kSecondBranchPattern2, {prim::kPrimReLU}, {kSecondBranchStartIndexPattern2, kSecondBranchEndIndexPattern2}}};
68 // Pattern 3
69 // Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End
70 // BatchNorm -> Conv2D -> -> ... ... ... -> ->
71 constexpr auto kFirstBranchPattern3 = 11;
72 constexpr auto kSecondBranchPattern3 = 3;
73 constexpr auto kFirstBranchStartIndexPattern3 = 4;
74 constexpr auto kFirstBranchEndIndexPattern3 = 10;
75 constexpr auto kSecondBranchStartIndexPattern3 = kFirstBranchPattern3;
76 constexpr auto kSecondBranchEndIndexPattern3 = 2 + kFirstBranchPattern3;
77 const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
78 {kFirstBranchPattern3,
79 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem,
80 prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
81 prim::kPrimConv2D},
82 {kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}},
83 {kSecondBranchPattern3,
84 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
85 {kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
86 // Pattern 4
87 constexpr auto kFirstBranchPattern4 = 8;
88 constexpr auto kSecondBranchPattern4 = 3;
89 constexpr auto kFirstBranchStartIndexPattern4 = 4;
90 constexpr auto kFirstBranchEndIndexPattern4 = 6;
91 constexpr auto kSecondBranchStartIndexPattern4 = kFirstBranchPattern4;
92 constexpr auto kSecondBranchEndIndexPattern4 = 3 + kFirstBranchPattern4;
93 const std::vector<kStructureTuple> BasicStructBasePattern{
94 {kFirstBranchPattern4,
95 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU},
96 {kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
97 {kSecondBranchPattern4,
98 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
99 {kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
100 // Pattern 5
101 constexpr auto kFirstBranchPattern5 = 7;
102 constexpr auto kSecondBranchPattern5 = 1;
103 constexpr auto kFirstBranchStartIndexPattern5 = 4;
104 constexpr auto kFirstBranchEndIndexPattern5 = 6;
105 constexpr auto kSecondBranchStartIndexPattern5 = kFirstBranchPattern5;
106 constexpr auto kSecondBranchEndIndexPattern5 = 3 + kFirstBranchPattern5;
107 const std::vector<kStructureTuple> BasicStructFirstStepPattern{
108 {kFirstBranchPattern5,
109 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem,
110 prim::kPrimBatchNorm, prim::kPrimConv2D},
111 {kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
112 {kSecondBranchPattern5, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
113 // Pattern 6
114 constexpr auto kFirstBranchPattern6 = 8;
115 constexpr auto kSecondBranchPattern6 = 1;
116 constexpr auto kFirstBranchStartIndexPattern6 = 4;
117 constexpr auto kFirstBranchEndIndexPattern6 = 6;
118 constexpr auto kSecondBranchStartIndexPattern6 = kFirstBranchPattern6;
119 constexpr auto kSecondBranchEndIndexPattern6 = 3 + kFirstBranchPattern6;
120 const std::vector<kStructureTuple> BasicStructShortCutPattern{
121 {kFirstBranchPattern6,
122 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU},
123 {kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
124 {kSecondBranchPattern6, {prim::kPrimReLU}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
125 // Pattern 7
126 constexpr auto kFirstBranchPattern7 = 1;
127 constexpr auto kSecondBranchPattern7 = 13;
128 constexpr auto kFirstBranchStartIndexPattern7 = SIZE_MAX;
129 constexpr auto kFirstBranchEndIndexPattern7 = SIZE_MAX;
130 constexpr auto kSecondBranchStartIndexPattern7 = 7;
131 constexpr auto kSecondBranchEndIndexPattern7 = 10;
132 const std::vector<kStructureTuple> InvertedResidualShortCutPattern{
133 {kFirstBranchPattern7,
134 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
135 {kFirstBranchStartIndexPattern7, kFirstBranchEndIndexPattern7}},
136 {kSecondBranchPattern7,
137 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem,
138 prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
139 prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
140 {kSecondBranchStartIndexPattern7, kSecondBranchEndIndexPattern7}}};
141 // Pattern 8
142 constexpr auto kFirstBranchPattern8 = 4;
143 constexpr auto kFirstBranchStartIndexPattern8 = 0;
144 constexpr auto kFirstBranchEndIndexPattern8 = 3;
145 const std::vector<kStructureTuple> InvertedResidualPattern{
146 {kFirstBranchPattern8,
147 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimAdd},
148 {kFirstBranchStartIndexPattern8, kFirstBranchEndIndexPattern8}}};
149 // Pattern 9
150 constexpr auto kFirstBranchPattern9 = 1;
151 constexpr auto kSecondBranchPattern9 = 12;
152 constexpr auto kFirstBranchStartIndexPattern9 = SIZE_MAX;
153 constexpr auto kFirstBranchEndIndexPattern9 = SIZE_MAX;
154 constexpr auto kSecondBranchStartIndexPattern9 = 7;
155 constexpr auto kSecondBranchEndIndexPattern9 = 10;
156 const std::vector<kStructureTuple> InvertedResidualShortCutPattern2{
157 {kFirstBranchPattern9, {prim::kPrimAdd}, {kFirstBranchStartIndexPattern9, kFirstBranchEndIndexPattern9}},
158 {kSecondBranchPattern9,
159 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem,
160 prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
161 prim::kPrimConv2D, prim::kPrimAdd},
162 {kSecondBranchStartIndexPattern9, kSecondBranchEndIndexPattern9}}};
163 // Pattern 10
164 constexpr auto kFirstBranchPattern10 = 5;
165 constexpr auto kFirstBranchStartIndexPattern10 = 0;
166 constexpr auto kFirstBranchEndIndexPattern10 = 4;
167 const std::vector<kStructureTuple> InvertedResidualPattern2{
168 {kFirstBranchPattern10,
169 {prim::kPrimReduceMean, prim::kPrimReLU6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
170 {kFirstBranchStartIndexPattern10, kFirstBranchEndIndexPattern10}}};
171 // Pattern 11
172 constexpr auto kFirstBranchPattern11 = 17;
173 constexpr auto kFirstBranchStartIndexPattern11 = 3;
174 constexpr auto kFirstBranchEndIndexPattern11 = 6;
175 const std::vector<kStructureTuple> InvertedResidualPattern3{
176 {kFirstBranchPattern11,
177 {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem,
178 prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
179 prim::kPrimReLU6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6,
180 prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
181 {kFirstBranchStartIndexPattern11, kFirstBranchEndIndexPattern11}}};
182 // Pattern 12
183 constexpr auto kFirstBranchPattern12 = 1;
184 constexpr auto kSecondBranchPattern12 = 9;
185 constexpr auto kFirstBranchStartIndexPattern12 = SIZE_MAX;
186 constexpr auto kFirstBranchEndIndexPattern12 = SIZE_MAX;
187 constexpr auto kSecondBranchStartIndexPattern12 = kFirstBranchPattern12 + 5;
188 constexpr auto kSecondBranchEndIndexPattern12 = kFirstBranchPattern12 + 8;
189 const std::vector<kStructureTuple> DenseBlockShortCutPattern{
190 {kFirstBranchPattern12, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern12, kFirstBranchEndIndexPattern12}},
191 {kSecondBranchPattern12,
192 {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
193 prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
194 {kSecondBranchStartIndexPattern12, kSecondBranchEndIndexPattern12}}};
195 // Pattern 13
196 constexpr auto kFirstBranchPattern13 = 5;
197 constexpr auto kFirstBranchStartIndexPattern13 = 0;
198 constexpr auto kFirstBranchEndIndexPattern13 = 4;
199 const std::vector<kStructureTuple> DenseBlockPattern{
200 {kFirstBranchPattern13,
201 {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
202 {kFirstBranchStartIndexPattern13, kFirstBranchEndIndexPattern13}}};
203 // Pattern 14
204 constexpr auto kFirstBranchPattern14 = 9;
205 constexpr auto kSecondBranchPattern14 = 1;
206 constexpr auto kFirstBranchStartIndexPattern14 = 5;
207 constexpr auto kFirstBranchEndIndexPattern14 = 8;
208 constexpr auto kSecondBranchStartIndexPattern14 = SIZE_MAX;
209 constexpr auto kSecondBranchEndIndexPattern14 = SIZE_MAX;
210 const std::vector<kStructureTuple> DenseBlockShortCutPattern2{
211 {kFirstBranchPattern14,
212 {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
213 prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
214 {kFirstBranchStartIndexPattern14, kFirstBranchEndIndexPattern14}},
215 {kSecondBranchPattern14, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern14, kSecondBranchEndIndexPattern14}}};
216 // Pattern 15
217 constexpr auto kFirstBranchPattern15 = 9;
218 constexpr auto kSecondBranchPattern15 = 1;
219 constexpr auto kFirstBranchStartIndexPattern15 = 0;
220 constexpr auto kFirstBranchEndIndexPattern15 = 4;
221 constexpr auto kSecondBranchStartIndexPattern15 = SIZE_MAX;
222 constexpr auto kSecondBranchEndIndexPattern15 = SIZE_MAX;
223 const std::vector<kStructureTuple> DenseBlockPoolPattern{
224 {kFirstBranchPattern15,
225 {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
226 prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
227 {kFirstBranchStartIndexPattern15, kFirstBranchEndIndexPattern15}},
228 {kSecondBranchPattern15, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern15, kSecondBranchEndIndexPattern15}}};
229 // Pattern 16
230 constexpr auto kFirstBranchPattern16 = 1;
231 constexpr auto kSecondBranchPattern16 = 9;
232 constexpr auto kFirstBranchStartIndexPattern16 = SIZE_MAX;
233 constexpr auto kFirstBranchEndIndexPattern16 = SIZE_MAX;
234 constexpr auto kSecondBranchStartIndexPattern16 = kFirstBranchPattern16;
235 constexpr auto kSecondBranchEndIndexPattern16 = kFirstBranchPattern16 + 4;
236 const std::vector<kStructureTuple> DenseBlockPoolPatter2{
237 {kFirstBranchPattern16, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern16, kFirstBranchEndIndexPattern16}},
238 {kSecondBranchPattern16,
239 {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
240 prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
241 {kSecondBranchStartIndexPattern16, kSecondBranchEndIndexPattern16}}};
242 static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {ResidualStructureBasePattern,
243 ResidualStructureShortCutPattern,
244 ResidualStructureFirstStepPattern,
245 BasicStructBasePattern,
246 BasicStructFirstStepPattern,
247 BasicStructShortCutPattern,
248 InvertedResidualShortCutPattern,
249 InvertedResidualPattern,
250 InvertedResidualShortCutPattern2,
251 InvertedResidualPattern2,
252 InvertedResidualPattern3,
253 DenseBlockShortCutPattern,
254 DenseBlockPattern,
255 DenseBlockShortCutPattern2,
256 DenseBlockPoolPattern,
257 DenseBlockPoolPatter2};
258 const std::set<PrimitivePtr> kNeedRemoveNodeSet{
259 prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
260 prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
261 static mindspore::HashMap<RemoveNodeType, mindspore::HashSet<size_t>> kRemoveIndex{
262 {RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}};
263
NeedRemove(const ParameterPtr & a,const std::vector<AnfNodePtr> & parameter_list)264 bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) {
265 if (a == nullptr) {
266 return false;
267 }
268 return std::any_of(parameter_list.begin(), parameter_list.end(), [&a](const AnfNodePtr &b) {
269 return (b->isa<Parameter>() && a->name() == b->cast<ParameterPtr>()->name());
270 });
271 }
272
IsNotRealUseNode(const AnfNodePtr & node)273 bool IsNotRealUseNode(const AnfNodePtr &node) {
274 for (const auto &prim : kNeedRemoveNodeSet) {
275 if (IsPrimitiveCNode(node, prim)) {
276 return true;
277 }
278 }
279 return false;
280 }
281
ConvertRemoveNodeToVirtualNode(const CNodePtr & cnode)282 CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) {
283 MS_EXCEPTION_IF_NULL(cnode);
284 std::vector<AnfNodePtr> args;
285 size_t index = 0;
286 const auto &inputs = cnode->inputs();
287 auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode];
288 if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) {
289 remove_index = kRemoveIndex[RemoveNodeType::kOtherNode];
290 }
291
292 (void)std::copy_if(
293 inputs.cbegin(), inputs.cend(), std::back_inserter(args),
294 [&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); });
295
296 (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
297 const auto &fg = cnode->func_graph();
298 MS_EXCEPTION_IF_NULL(fg);
299 auto new_make_tuple = fg->NewCNode(args);
300 return new_make_tuple;
301 }
302
IsRealRemoveParameterNode(const FuncGraphManagerPtr & manager,const AnfNodePtr & parameter)303 bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) {
304 auto param_output = manager->node_users().find(parameter);
305 if (param_output == manager->node_users().end()) {
306 return true;
307 }
308
309 bool need_remove = true;
310 auto output_info_list = param_output->second;
311 for (const auto &output_info : output_info_list) {
312 const auto &node = output_info.first;
313 if (IsNotRealUseNode(node)) {
314 const auto &cnode = node->cast<CNodePtr>();
315 const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
316 (void)manager->Replace(cnode, new_cnode);
317 continue;
318 }
319 need_remove = false;
320 }
321
322 return need_remove;
323 }
324
RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & remove_parameter_list)325 void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
326 const std::vector<AnfNodePtr> &remove_parameter_list) {
327 auto roots = manager->roots();
328 if (roots.size() != 1) {
329 MS_LOG(ERROR) << "The size of roots " << roots.size() << " is not valid.";
330 return;
331 }
332 auto root_graph = *(roots.begin());
333 MS_EXCEPTION_IF_NULL(root_graph);
334
335 std::vector<AnfNodePtr> real_remove_parameter_list;
336 (void)std::copy_if(remove_parameter_list.cbegin(), remove_parameter_list.cend(),
337 std::back_inserter(real_remove_parameter_list),
338 [&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
339
340 auto root_parameters = root_graph->parameters();
341 size_t origin_param_count = root_parameters.size();
342 (void)root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
343 [&real_remove_parameter_list](const AnfNodePtr &node) {
344 return NeedRemove(node->cast<ParameterPtr>(),
345 real_remove_parameter_list);
346 }),
347 root_parameters.cend());
348 size_t remove_param_count = origin_param_count - root_parameters.size();
349 size_t fv_param_count = root_graph->fv_param_count();
350 if (remove_param_count > fv_param_count) {
351 MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
352 return;
353 }
354 fv_param_count = fv_param_count - remove_param_count;
355 root_graph->set_fv_param_count(fv_param_count);
356 manager->SetParameters(root_graph, root_parameters);
357 }
358
CheckNodeLessBNAttr(const FuncGraphPtr & func_graph)359 bool CheckNodeLessBNAttr(const FuncGraphPtr &func_graph) {
360 MS_EXCEPTION_IF_NULL(func_graph);
361 auto nodes = TopoSort(func_graph->return_node());
362 auto result = std::find_if(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) {
363 auto cnode = node->cast<CNodePtr>();
364 return cnode != nullptr && GetCNodeFuncName(cnode) == kBatchNormOpName &&
365 common::AnfAlgo::HasNodeAttr(kBoostType, cnode) &&
366 common::AnfAlgo::GetNodeAttr<std::string>(cnode, kBoostType) == kLessBatchNormalizationPassName;
367 });
368 if (result != nodes.end()) {
369 func_graph->set_attr(kLessBatchNormalizationPassName, MakeValue(true));
370 return true;
371 } else {
372 func_graph->set_attr(kLessBatchNormalizationPassName, MakeValue(false));
373 return false;
374 }
375 }
376 } // namespace
377
MatchStructureNode(const CNodePtr & cnode,const int32_t index,const kStructureTuple & patternTuple) const378 bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index,
379 const kStructureTuple &patternTuple) const {
380 if (index < 0) {
381 return false;
382 }
383 const auto &use_pattern = std::get<1>(patternTuple);
384 int32_t use_index = index % static_cast<int32_t>(use_pattern.size());
385 if (!IsPrimitiveCNode(cnode, use_pattern[IntToSize(use_index)]) &&
386 use_pattern[IntToSize(use_index)] != prim::kPrimTupleGetItem) {
387 return false;
388 }
389 return true;
390 }
391
MatchGraphStructure(const CNodePtr & cnode,const std::vector<kStructureTuple> & match_pattern)392 bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode,
393 const std::vector<kStructureTuple> &match_pattern) {
394 if ((match_branch_ + 1 >= total_match_node_.size()) || (match_branch_ >= match_pattern.size())) {
395 return false;
396 }
397
398 int32_t index = static_cast<int32_t>(match_node_) - static_cast<int32_t>(total_match_node_[match_branch_]);
399 const auto &pattern = match_pattern[match_branch_];
400 if (!MatchStructureNode(cnode, index, pattern)) {
401 return false;
402 }
403
404 match_node_++;
405 if (match_node_ == total_match_node_.back()) {
406 is_match_ = true;
407 return false;
408 }
409 if (match_node_ == total_match_node_[match_branch_ + 1]) {
410 match_branch_++;
411 return false;
412 }
413 return true;
414 }
415
IsRemoveNode(const CNodePtr & cnode,const std::vector<kStructureTuple> & match_pattern)416 void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern) {
417 if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) &&
418 !IsValueNode<FuncGraph>(cnode->input(0))) {
419 return;
420 }
421 if (match_pattern.empty()) {
422 return;
423 }
424 const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_));
425 if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) {
426 (void)remove_node_list_.insert(cnode);
427 }
428 }
429
operator ()(const OptimizerPtr & optimizer,const AnfNodePtr & node)430 AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
431 const auto &fg = node->func_graph();
432 MS_EXCEPTION_IF_NULL(fg);
433 auto flag = fg->get_attr(kLessBatchNormalizationPassName);
434 if (flag != nullptr) {
435 if (!GetValue<bool>(flag)) {
436 return nullptr;
437 }
438 } else {
439 if (!CheckNodeLessBNAttr(fg)) {
440 return nullptr;
441 }
442 }
443
444 match_pattern_ = 0;
445 while (match_pattern_ < kNeedMatchPattern.size()) {
446 Reset();
447 const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_);
448 size_t sum_match_node = 0;
449 (void)std::for_each(current_pattern.begin(), current_pattern.end(), [&, this](const kStructureTuple &t) {
450 sum_match_node += std::get<0>(t);
451 (void)this->total_match_node_.emplace_back(sum_match_node);
452 });
453 auto cnode = node->cast<CNodePtr>();
454 if (cnode == nullptr || cnode->inputs().empty()) {
455 return nullptr;
456 }
457 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
458 std::vector<PredicateFuncType> funcs(cnode->size() - 1, IsCNode);
459 AnfVisitor::Match(prim, funcs)(node);
460 if (is_match_) {
461 break;
462 }
463 match_pattern_++;
464 }
465
466 if (!is_match_ || remove_node_list_.empty()) {
467 return nullptr;
468 }
469
470 auto manager = optimizer->manager();
471 MS_EXCEPTION_IF_NULL(manager);
472 std::vector<AnfNodePtr> remove_load_list;
473 std::vector<AnfNodePtr> remove_parameter_list;
474 for (auto &iter : remove_node_list_) {
475 // Need to remove batchnorm's parameter input.
476 if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) {
477 (void)std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
478 std::back_inserter(remove_load_list),
479 [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
480 (void)std::transform(
481 remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list),
482 [](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); });
483 }
484 // Remove useless node.
485 auto input_cnode = iter->input(kValidResidualStructureIndex);
486 (void)manager->Replace(iter, input_cnode);
487 }
488 RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list);
489
490 return node;
491 }
492
Visit(const CNodePtr & cnode)493 void LessBatchNormalization::Visit(const CNodePtr &cnode) {
494 if (cnode == nullptr) {
495 return;
496 }
497
498 const auto ¤t_pattern = kNeedMatchPattern.at(match_pattern_);
499 IsRemoveNode(cnode, current_pattern);
500 if (!MatchGraphStructure(cnode, current_pattern)) {
501 return;
502 }
503
504 auto search_input = cnode->input(kValidResidualStructureIndex);
505 if (search_input != nullptr && search_input->isa<CNode>()) {
506 this->Visit(search_input->cast<CNodePtr>());
507 }
508 return;
509 }
510
Reset()511 void LessBatchNormalization::Reset() {
512 remove_node_list_.clear();
513 total_match_node_.clear();
514 (void)total_match_node_.emplace_back(0);
515 match_node_ = 0;
516 match_branch_ = 0;
517 is_match_ = false;
518 }
519 } // namespace irpass
520 } // namespace opt
521 } // namespace mindspore
522