• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass/less_batch_normalization.h"
18 
19 #include <set>
20 #include "mindspore/core/ops/sequence_ops.h"
21 #include "mindspore/core/ops/conv_pool_ops.h"
22 #include "mindspore/core/ops/nn_optimizer_ops.h"
23 #include "mindspore/core/ops/nn_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "utils/hash_map.h"
28 
29 namespace mindspore {
30 namespace opt {
31 namespace irpass {
32 namespace {
33 enum class RemoveNodeType { kOtherNode = 0, kOptimizerNode };
34 const char kLessBatchNormalizationPassName[] = "less_bn";
35 const auto kBoostType = "boost_type";
36 constexpr auto kValidResidualStructureIndex = 1;
37 constexpr auto kBNParametersStartIndex = 2;
38 // Pattern 1
39 // Add -> BatchNorm -> Conv2D -> Relu ... -> End
40 //     BatchNorm -> Conv2D -> -> -> ->
41 constexpr auto kFirstBranchPattern1 = 12;
42 constexpr auto kSecondBranchPattern1 = 3;
43 constexpr auto kFirstBranchStartIndexPattern1 = 4;
44 constexpr auto kFirstBranchEndIndexPattern1 = 11;
45 constexpr auto kSecondBranchStartIndexPattern1 = kFirstBranchPattern1;
46 constexpr auto kSecondBranchEndIndexPattern1 = 2 + kFirstBranchPattern1;
47 const std::vector<kStructureTuple> ResidualStructureBasePattern{
48   {kFirstBranchPattern1,
49    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU},
50    {kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}},
51   {kSecondBranchPattern1,
52    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
53    {kSecondBranchStartIndexPattern1, kSecondBranchEndIndexPattern1}}};
54 // Pattern 2
55 // Add -> BatchNorm -> Conv2D -> Relu ... -> End
56 //       -> ->     ...   ... ...    -> ->
57 constexpr auto kFirstBranchPattern2 = 12;
58 constexpr auto kSecondBranchPattern2 = 1;
59 constexpr auto kFirstBranchStartIndexPattern2 = 4;
60 constexpr auto kFirstBranchEndIndexPattern2 = 11;
61 constexpr auto kSecondBranchStartIndexPattern2 = kFirstBranchPattern2;
62 constexpr auto kSecondBranchEndIndexPattern2 = 1 + kSecondBranchPattern2;
63 const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
64   {kFirstBranchPattern2,
65    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU},
66    {kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}},
67   {kSecondBranchPattern2, {prim::kPrimReLU}, {kSecondBranchStartIndexPattern2, kSecondBranchEndIndexPattern2}}};
68 // Pattern 3
69 // Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End
70 //       BatchNorm -> Conv2D -> ->   ...   ...   ...   -> ->
71 constexpr auto kFirstBranchPattern3 = 11;
72 constexpr auto kSecondBranchPattern3 = 3;
73 constexpr auto kFirstBranchStartIndexPattern3 = 4;
74 constexpr auto kFirstBranchEndIndexPattern3 = 10;
75 constexpr auto kSecondBranchStartIndexPattern3 = kFirstBranchPattern3;
76 constexpr auto kSecondBranchEndIndexPattern3 = 2 + kFirstBranchPattern3;
77 const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
78   {kFirstBranchPattern3,
79    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem,
80     prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
81     prim::kPrimConv2D},
82    {kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}},
83   {kSecondBranchPattern3,
84    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
85    {kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
86 // Pattern 4
87 constexpr auto kFirstBranchPattern4 = 8;
88 constexpr auto kSecondBranchPattern4 = 3;
89 constexpr auto kFirstBranchStartIndexPattern4 = 4;
90 constexpr auto kFirstBranchEndIndexPattern4 = 6;
91 constexpr auto kSecondBranchStartIndexPattern4 = kFirstBranchPattern4;
92 constexpr auto kSecondBranchEndIndexPattern4 = 3 + kFirstBranchPattern4;
93 const std::vector<kStructureTuple> BasicStructBasePattern{
94   {kFirstBranchPattern4,
95    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU},
96    {kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
97   {kSecondBranchPattern4,
98    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
99    {kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
100 // Pattern 5
101 constexpr auto kFirstBranchPattern5 = 7;
102 constexpr auto kSecondBranchPattern5 = 1;
103 constexpr auto kFirstBranchStartIndexPattern5 = 4;
104 constexpr auto kFirstBranchEndIndexPattern5 = 6;
105 constexpr auto kSecondBranchStartIndexPattern5 = kFirstBranchPattern5;
106 constexpr auto kSecondBranchEndIndexPattern5 = 3 + kFirstBranchPattern5;
107 const std::vector<kStructureTuple> BasicStructFirstStepPattern{
108   {kFirstBranchPattern5,
109    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem,
110     prim::kPrimBatchNorm, prim::kPrimConv2D},
111    {kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
112   {kSecondBranchPattern5, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
113 // Pattern 6
114 constexpr auto kFirstBranchPattern6 = 8;
115 constexpr auto kSecondBranchPattern6 = 1;
116 constexpr auto kFirstBranchStartIndexPattern6 = 4;
117 constexpr auto kFirstBranchEndIndexPattern6 = 6;
118 constexpr auto kSecondBranchStartIndexPattern6 = kFirstBranchPattern6;
119 constexpr auto kSecondBranchEndIndexPattern6 = 3 + kFirstBranchPattern6;
120 const std::vector<kStructureTuple> BasicStructShortCutPattern{
121   {kFirstBranchPattern6,
122    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU},
123    {kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
124   {kSecondBranchPattern6, {prim::kPrimReLU}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
125 // Pattern 7
126 constexpr auto kFirstBranchPattern7 = 1;
127 constexpr auto kSecondBranchPattern7 = 13;
128 constexpr auto kFirstBranchStartIndexPattern7 = SIZE_MAX;
129 constexpr auto kFirstBranchEndIndexPattern7 = SIZE_MAX;
130 constexpr auto kSecondBranchStartIndexPattern7 = 7;
131 constexpr auto kSecondBranchEndIndexPattern7 = 10;
132 const std::vector<kStructureTuple> InvertedResidualShortCutPattern{
133   {kFirstBranchPattern7,
134    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
135    {kFirstBranchStartIndexPattern7, kFirstBranchEndIndexPattern7}},
136   {kSecondBranchPattern7,
137    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem,
138     prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
139     prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
140    {kSecondBranchStartIndexPattern7, kSecondBranchEndIndexPattern7}}};
141 // Pattern 8
142 constexpr auto kFirstBranchPattern8 = 4;
143 constexpr auto kFirstBranchStartIndexPattern8 = 0;
144 constexpr auto kFirstBranchEndIndexPattern8 = 3;
145 const std::vector<kStructureTuple> InvertedResidualPattern{
146   {kFirstBranchPattern8,
147    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimAdd},
148    {kFirstBranchStartIndexPattern8, kFirstBranchEndIndexPattern8}}};
149 // Pattern 9
150 constexpr auto kFirstBranchPattern9 = 1;
151 constexpr auto kSecondBranchPattern9 = 12;
152 constexpr auto kFirstBranchStartIndexPattern9 = SIZE_MAX;
153 constexpr auto kFirstBranchEndIndexPattern9 = SIZE_MAX;
154 constexpr auto kSecondBranchStartIndexPattern9 = 7;
155 constexpr auto kSecondBranchEndIndexPattern9 = 10;
156 const std::vector<kStructureTuple> InvertedResidualShortCutPattern2{
157   {kFirstBranchPattern9, {prim::kPrimAdd}, {kFirstBranchStartIndexPattern9, kFirstBranchEndIndexPattern9}},
158   {kSecondBranchPattern9,
159    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem,
160     prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
161     prim::kPrimConv2D, prim::kPrimAdd},
162    {kSecondBranchStartIndexPattern9, kSecondBranchEndIndexPattern9}}};
163 // Pattern 10
164 constexpr auto kFirstBranchPattern10 = 5;
165 constexpr auto kFirstBranchStartIndexPattern10 = 0;
166 constexpr auto kFirstBranchEndIndexPattern10 = 4;
167 const std::vector<kStructureTuple> InvertedResidualPattern2{
168   {kFirstBranchPattern10,
169    {prim::kPrimReduceMean, prim::kPrimReLU6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
170    {kFirstBranchStartIndexPattern10, kFirstBranchEndIndexPattern10}}};
171 // Pattern 11
172 constexpr auto kFirstBranchPattern11 = 17;
173 constexpr auto kFirstBranchStartIndexPattern11 = 3;
174 constexpr auto kFirstBranchEndIndexPattern11 = 6;
175 const std::vector<kStructureTuple> InvertedResidualPattern3{
176   {kFirstBranchPattern11,
177    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6, prim::kPrimTupleGetItem,
178     prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
179     prim::kPrimReLU6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimReLU6,
180     prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
181    {kFirstBranchStartIndexPattern11, kFirstBranchEndIndexPattern11}}};
182 // Pattern 12
183 constexpr auto kFirstBranchPattern12 = 1;
184 constexpr auto kSecondBranchPattern12 = 9;
185 constexpr auto kFirstBranchStartIndexPattern12 = SIZE_MAX;
186 constexpr auto kFirstBranchEndIndexPattern12 = SIZE_MAX;
187 constexpr auto kSecondBranchStartIndexPattern12 = kFirstBranchPattern12 + 5;
188 constexpr auto kSecondBranchEndIndexPattern12 = kFirstBranchPattern12 + 8;
189 const std::vector<kStructureTuple> DenseBlockShortCutPattern{
190   {kFirstBranchPattern12, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern12, kFirstBranchEndIndexPattern12}},
191   {kSecondBranchPattern12,
192    {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
193     prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
194    {kSecondBranchStartIndexPattern12, kSecondBranchEndIndexPattern12}}};
195 // Pattern 13
196 constexpr auto kFirstBranchPattern13 = 5;
197 constexpr auto kFirstBranchStartIndexPattern13 = 0;
198 constexpr auto kFirstBranchEndIndexPattern13 = 4;
199 const std::vector<kStructureTuple> DenseBlockPattern{
200   {kFirstBranchPattern13,
201    {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
202    {kFirstBranchStartIndexPattern13, kFirstBranchEndIndexPattern13}}};
203 // Pattern 14
204 constexpr auto kFirstBranchPattern14 = 9;
205 constexpr auto kSecondBranchPattern14 = 1;
206 constexpr auto kFirstBranchStartIndexPattern14 = 5;
207 constexpr auto kFirstBranchEndIndexPattern14 = 8;
208 constexpr auto kSecondBranchStartIndexPattern14 = SIZE_MAX;
209 constexpr auto kSecondBranchEndIndexPattern14 = SIZE_MAX;
210 const std::vector<kStructureTuple> DenseBlockShortCutPattern2{
211   {kFirstBranchPattern14,
212    {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
213     prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
214    {kFirstBranchStartIndexPattern14, kFirstBranchEndIndexPattern14}},
215   {kSecondBranchPattern14, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern14, kSecondBranchEndIndexPattern14}}};
216 // Pattern 15
217 constexpr auto kFirstBranchPattern15 = 9;
218 constexpr auto kSecondBranchPattern15 = 1;
219 constexpr auto kFirstBranchStartIndexPattern15 = 0;
220 constexpr auto kFirstBranchEndIndexPattern15 = 4;
221 constexpr auto kSecondBranchStartIndexPattern15 = SIZE_MAX;
222 constexpr auto kSecondBranchEndIndexPattern15 = SIZE_MAX;
223 const std::vector<kStructureTuple> DenseBlockPoolPattern{
224   {kFirstBranchPattern15,
225    {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
226     prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
227    {kFirstBranchStartIndexPattern15, kFirstBranchEndIndexPattern15}},
228   {kSecondBranchPattern15, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern15, kSecondBranchEndIndexPattern15}}};
229 // Pattern 16
230 constexpr auto kFirstBranchPattern16 = 1;
231 constexpr auto kSecondBranchPattern16 = 9;
232 constexpr auto kFirstBranchStartIndexPattern16 = SIZE_MAX;
233 constexpr auto kFirstBranchEndIndexPattern16 = SIZE_MAX;
234 constexpr auto kSecondBranchStartIndexPattern16 = kFirstBranchPattern16;
235 constexpr auto kSecondBranchEndIndexPattern16 = kFirstBranchPattern16 + 4;
236 const std::vector<kStructureTuple> DenseBlockPoolPatter2{
237   {kFirstBranchPattern16, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern16, kFirstBranchEndIndexPattern16}},
238   {kSecondBranchPattern16,
239    {prim::kPrimConv2D, prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
240     prim::kPrimReLU, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
241    {kSecondBranchStartIndexPattern16, kSecondBranchEndIndexPattern16}}};
242 static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {ResidualStructureBasePattern,
243                                                                             ResidualStructureShortCutPattern,
244                                                                             ResidualStructureFirstStepPattern,
245                                                                             BasicStructBasePattern,
246                                                                             BasicStructFirstStepPattern,
247                                                                             BasicStructShortCutPattern,
248                                                                             InvertedResidualShortCutPattern,
249                                                                             InvertedResidualPattern,
250                                                                             InvertedResidualShortCutPattern2,
251                                                                             InvertedResidualPattern2,
252                                                                             InvertedResidualPattern3,
253                                                                             DenseBlockShortCutPattern,
254                                                                             DenseBlockPattern,
255                                                                             DenseBlockShortCutPattern2,
256                                                                             DenseBlockPoolPattern,
257                                                                             DenseBlockPoolPatter2};
258 const std::set<PrimitivePtr> kNeedRemoveNodeSet{
259   prim::kPrimLoad,      prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
260   prim::kPrimApplyFtrl, prim::kPrimSGD,        prim::kPrimApplyRMSProp,  prim::kPrimAdam};
261 static mindspore::HashMap<RemoveNodeType, mindspore::HashSet<size_t>> kRemoveIndex{
262   {RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}};
263 
NeedRemove(const ParameterPtr & a,const std::vector<AnfNodePtr> & parameter_list)264 bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> &parameter_list) {
265   if (a == nullptr) {
266     return false;
267   }
268   return std::any_of(parameter_list.begin(), parameter_list.end(), [&a](const AnfNodePtr &b) {
269     return (b->isa<Parameter>() && a->name() == b->cast<ParameterPtr>()->name());
270   });
271 }
272 
IsNotRealUseNode(const AnfNodePtr & node)273 bool IsNotRealUseNode(const AnfNodePtr &node) {
274   for (const auto &prim : kNeedRemoveNodeSet) {
275     if (IsPrimitiveCNode(node, prim)) {
276       return true;
277     }
278   }
279   return false;
280 }
281 
ConvertRemoveNodeToVirtualNode(const CNodePtr & cnode)282 CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) {
283   MS_EXCEPTION_IF_NULL(cnode);
284   std::vector<AnfNodePtr> args;
285   size_t index = 0;
286   const auto &inputs = cnode->inputs();
287   auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode];
288   if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) {
289     remove_index = kRemoveIndex[RemoveNodeType::kOtherNode];
290   }
291 
292   (void)std::copy_if(
293     inputs.cbegin(), inputs.cend(), std::back_inserter(args),
294     [&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); });
295 
296   (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
297   const auto &fg = cnode->func_graph();
298   MS_EXCEPTION_IF_NULL(fg);
299   auto new_make_tuple = fg->NewCNode(args);
300   return new_make_tuple;
301 }
302 
IsRealRemoveParameterNode(const FuncGraphManagerPtr & manager,const AnfNodePtr & parameter)303 bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &parameter) {
304   auto param_output = manager->node_users().find(parameter);
305   if (param_output == manager->node_users().end()) {
306     return true;
307   }
308 
309   bool need_remove = true;
310   auto output_info_list = param_output->second;
311   for (const auto &output_info : output_info_list) {
312     const auto &node = output_info.first;
313     if (IsNotRealUseNode(node)) {
314       const auto &cnode = node->cast<CNodePtr>();
315       const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
316       (void)manager->Replace(cnode, new_cnode);
317       continue;
318     }
319     need_remove = false;
320   }
321 
322   return need_remove;
323 }
324 
RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & remove_parameter_list)325 void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
326                                               const std::vector<AnfNodePtr> &remove_parameter_list) {
327   auto roots = manager->roots();
328   if (roots.size() != 1) {
329     MS_LOG(ERROR) << "The size of roots " << roots.size() << " is not valid.";
330     return;
331   }
332   auto root_graph = *(roots.begin());
333   MS_EXCEPTION_IF_NULL(root_graph);
334 
335   std::vector<AnfNodePtr> real_remove_parameter_list;
336   (void)std::copy_if(remove_parameter_list.cbegin(), remove_parameter_list.cend(),
337                      std::back_inserter(real_remove_parameter_list),
338                      [&manager](const AnfNodePtr &param) { return IsRealRemoveParameterNode(manager, param); });
339 
340   auto root_parameters = root_graph->parameters();
341   size_t origin_param_count = root_parameters.size();
342   (void)root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
343                                              [&real_remove_parameter_list](const AnfNodePtr &node) {
344                                                return NeedRemove(node->cast<ParameterPtr>(),
345                                                                  real_remove_parameter_list);
346                                              }),
347                               root_parameters.cend());
348   size_t remove_param_count = origin_param_count - root_parameters.size();
349   size_t fv_param_count = root_graph->fv_param_count();
350   if (remove_param_count > fv_param_count) {
351     MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
352     return;
353   }
354   fv_param_count = fv_param_count - remove_param_count;
355   root_graph->set_fv_param_count(fv_param_count);
356   manager->SetParameters(root_graph, root_parameters);
357 }
358 
CheckNodeLessBNAttr(const FuncGraphPtr & func_graph)359 bool CheckNodeLessBNAttr(const FuncGraphPtr &func_graph) {
360   MS_EXCEPTION_IF_NULL(func_graph);
361   auto nodes = TopoSort(func_graph->return_node());
362   auto result = std::find_if(nodes.begin(), nodes.end(), [](const AnfNodePtr &node) {
363     auto cnode = node->cast<CNodePtr>();
364     return cnode != nullptr && GetCNodeFuncName(cnode) == kBatchNormOpName &&
365            common::AnfAlgo::HasNodeAttr(kBoostType, cnode) &&
366            common::AnfAlgo::GetNodeAttr<std::string>(cnode, kBoostType) == kLessBatchNormalizationPassName;
367   });
368   if (result != nodes.end()) {
369     func_graph->set_attr(kLessBatchNormalizationPassName, MakeValue(true));
370     return true;
371   } else {
372     func_graph->set_attr(kLessBatchNormalizationPassName, MakeValue(false));
373     return false;
374   }
375 }
376 }  // namespace
377 
MatchStructureNode(const CNodePtr & cnode,const int32_t index,const kStructureTuple & patternTuple) const378 bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index,
379                                                 const kStructureTuple &patternTuple) const {
380   if (index < 0) {
381     return false;
382   }
383   const auto &use_pattern = std::get<1>(patternTuple);
384   int32_t use_index = index % static_cast<int32_t>(use_pattern.size());
385   if (!IsPrimitiveCNode(cnode, use_pattern[IntToSize(use_index)]) &&
386       use_pattern[IntToSize(use_index)] != prim::kPrimTupleGetItem) {
387     return false;
388   }
389   return true;
390 }
391 
MatchGraphStructure(const CNodePtr & cnode,const std::vector<kStructureTuple> & match_pattern)392 bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode,
393                                                  const std::vector<kStructureTuple> &match_pattern) {
394   if ((match_branch_ + 1 >= total_match_node_.size()) || (match_branch_ >= match_pattern.size())) {
395     return false;
396   }
397 
398   int32_t index = static_cast<int32_t>(match_node_) - static_cast<int32_t>(total_match_node_[match_branch_]);
399   const auto &pattern = match_pattern[match_branch_];
400   if (!MatchStructureNode(cnode, index, pattern)) {
401     return false;
402   }
403 
404   match_node_++;
405   if (match_node_ == total_match_node_.back()) {
406     is_match_ = true;
407     return false;
408   }
409   if (match_node_ == total_match_node_[match_branch_ + 1]) {
410     match_branch_++;
411     return false;
412   }
413   return true;
414 }
415 
IsRemoveNode(const CNodePtr & cnode,const std::vector<kStructureTuple> & match_pattern)416 void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern) {
417   if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) &&
418       !IsValueNode<FuncGraph>(cnode->input(0))) {
419     return;
420   }
421   if (match_pattern.empty()) {
422     return;
423   }
424   const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_));
425   if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) {
426     (void)remove_node_list_.insert(cnode);
427   }
428 }
429 
operator ()(const OptimizerPtr & optimizer,const AnfNodePtr & node)430 AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
431   const auto &fg = node->func_graph();
432   MS_EXCEPTION_IF_NULL(fg);
433   auto flag = fg->get_attr(kLessBatchNormalizationPassName);
434   if (flag != nullptr) {
435     if (!GetValue<bool>(flag)) {
436       return nullptr;
437     }
438   } else {
439     if (!CheckNodeLessBNAttr(fg)) {
440       return nullptr;
441     }
442   }
443 
444   match_pattern_ = 0;
445   while (match_pattern_ < kNeedMatchPattern.size()) {
446     Reset();
447     const auto &current_pattern = kNeedMatchPattern.at(match_pattern_);
448     size_t sum_match_node = 0;
449     (void)std::for_each(current_pattern.begin(), current_pattern.end(), [&, this](const kStructureTuple &t) {
450       sum_match_node += std::get<0>(t);
451       (void)this->total_match_node_.emplace_back(sum_match_node);
452     });
453     auto cnode = node->cast<CNodePtr>();
454     if (cnode == nullptr || cnode->inputs().empty()) {
455       return nullptr;
456     }
457     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
458     std::vector<PredicateFuncType> funcs(cnode->size() - 1, IsCNode);
459     AnfVisitor::Match(prim, funcs)(node);
460     if (is_match_) {
461       break;
462     }
463     match_pattern_++;
464   }
465 
466   if (!is_match_ || remove_node_list_.empty()) {
467     return nullptr;
468   }
469 
470   auto manager = optimizer->manager();
471   MS_EXCEPTION_IF_NULL(manager);
472   std::vector<AnfNodePtr> remove_load_list;
473   std::vector<AnfNodePtr> remove_parameter_list;
474   for (auto &iter : remove_node_list_) {
475     // Need to remove batchnorm's parameter input.
476     if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) {
477       (void)std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
478                          std::back_inserter(remove_load_list),
479                          [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
480       (void)std::transform(
481         remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list),
482         [](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); });
483     }
484     // Remove useless node.
485     auto input_cnode = iter->input(kValidResidualStructureIndex);
486     (void)manager->Replace(iter, input_cnode);
487   }
488   RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list);
489 
490   return node;
491 }
492 
Visit(const CNodePtr & cnode)493 void LessBatchNormalization::Visit(const CNodePtr &cnode) {
494   if (cnode == nullptr) {
495     return;
496   }
497 
498   const auto &current_pattern = kNeedMatchPattern.at(match_pattern_);
499   IsRemoveNode(cnode, current_pattern);
500   if (!MatchGraphStructure(cnode, current_pattern)) {
501     return;
502   }
503 
504   auto search_input = cnode->input(kValidResidualStructureIndex);
505   if (search_input != nullptr && search_input->isa<CNode>()) {
506     this->Visit(search_input->cast<CNodePtr>());
507   }
508   return;
509 }
510 
Reset()511 void LessBatchNormalization::Reset() {
512   remove_node_list_.clear();
513   total_match_node_.clear();
514   (void)total_match_node_.emplace_back(0);
515   match_node_ = 0;
516   match_branch_ = 0;
517   is_match_ = false;
518 }
519 }  // namespace irpass
520 }  // namespace opt
521 }  // namespace mindspore
522