• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass/less_batch_normalization.h"
18 
19 #include <set>
20 #include <unordered_map>
21 
22 namespace mindspore {
23 namespace opt {
24 namespace irpass {
25 namespace {
26 enum class RemoveNodeType { kOtherNode = 0, kOptimizerNode };
27 const char kLessBatchNormalizationPassName[] = "less_bn";
28 constexpr auto kValidResidualStructureIndex = 1;
29 constexpr auto kBNParametersStartIndex = 2;
30 // Pattern 1
31 // Add -> BatchNorm -> Conv2D -> Relu ... -> End
32 //     ↘  BatchNorm -> Conv2D -> -> -> -> ↗
33 constexpr auto kFirstBranchPattern1 = 12;
34 constexpr auto kSecondBranchPattern1 = 3;
35 constexpr auto kFirstBranchStartIndexPattern1 = 4;
36 constexpr auto kFirstBranchEndIndexPattern1 = 11;
37 constexpr auto kSecondBranchStartIndexPattern1 = kFirstBranchPattern1;
38 constexpr auto kSecondBranchEndIndexPattern1 = 2 + kFirstBranchPattern1;
39 const std::vector<kStructureTuple> ResidualStructureBasePattern{
40   {kFirstBranchPattern1,
41    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
42    {kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}},
43   {kSecondBranchPattern1,
44    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
45    {kSecondBranchStartIndexPattern1, kSecondBranchEndIndexPattern1}}};
46 // Pattern 2
47 // Add -> BatchNorm -> Conv2D -> Relu ... -> End
48 //     ↘  -> ->     ...   ... ...    -> -> ↗
49 constexpr auto kFirstBranchPattern2 = 12;
50 constexpr auto kSecondBranchPattern2 = 1;
51 constexpr auto kFirstBranchStartIndexPattern2 = 4;
52 constexpr auto kFirstBranchEndIndexPattern2 = 11;
53 constexpr auto kSecondBranchStartIndexPattern2 = kFirstBranchPattern2;
54 constexpr auto kSecondBranchEndIndexPattern2 = 1 + kSecondBranchPattern2;
55 const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
56   {kFirstBranchPattern2,
57    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
58    {kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}},
59   {kSecondBranchPattern2, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern2, kSecondBranchEndIndexPattern2}}};
60 // Pattern 3
61 // Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End
62 //     ↘  BatchNorm -> Conv2D -> ->   ...   ...   ...   -> -> ↗
63 constexpr auto kFirstBranchPattern3 = 11;
64 constexpr auto kSecondBranchPattern3 = 3;
65 constexpr auto kFirstBranchStartIndexPattern3 = 4;
66 constexpr auto kFirstBranchEndIndexPattern3 = 10;
67 constexpr auto kSecondBranchStartIndexPattern3 = kFirstBranchPattern3;
68 constexpr auto kSecondBranchEndIndexPattern3 = 2 + kFirstBranchPattern3;
69 const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
70   {kFirstBranchPattern3,
71    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
72     prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
73     prim::kPrimConv2D},
74    {kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}},
75   {kSecondBranchPattern3,
76    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
77    {kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
78 // Pattern 4
79 constexpr auto kFirstBranchPattern4 = 8;
80 constexpr auto kSecondBranchPattern4 = 3;
81 constexpr auto kFirstBranchStartIndexPattern4 = 4;
82 constexpr auto kFirstBranchEndIndexPattern4 = 6;
83 constexpr auto kSecondBranchStartIndexPattern4 = kFirstBranchPattern4;
84 constexpr auto kSecondBranchEndIndexPattern4 = 3 + kFirstBranchPattern4;
85 const std::vector<kStructureTuple> BasicStructBasePattern{
86   {kFirstBranchPattern4,
87    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
88    {kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
89   {kSecondBranchPattern4,
90    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
91    {kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
92 // Pattern 5
93 constexpr auto kFirstBranchPattern5 = 7;
94 constexpr auto kSecondBranchPattern5 = 1;
95 constexpr auto kFirstBranchStartIndexPattern5 = 4;
96 constexpr auto kFirstBranchEndIndexPattern5 = 6;
97 constexpr auto kSecondBranchStartIndexPattern5 = kFirstBranchPattern5;
98 constexpr auto kSecondBranchEndIndexPattern5 = 3 + kFirstBranchPattern5;
99 const std::vector<kStructureTuple> BasicStructFirstStepPattern{
100   {kFirstBranchPattern5,
101    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
102     prim::kPrimBatchNorm, prim::kPrimConv2D},
103    {kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
104   {kSecondBranchPattern5, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
105 // Pattern 6
106 constexpr auto kFirstBranchPattern6 = 8;
107 constexpr auto kSecondBranchPattern6 = 1;
108 constexpr auto kFirstBranchStartIndexPattern6 = 4;
109 constexpr auto kFirstBranchEndIndexPattern6 = 6;
110 constexpr auto kSecondBranchStartIndexPattern6 = kFirstBranchPattern6;
111 constexpr auto kSecondBranchEndIndexPattern6 = 3 + kFirstBranchPattern6;
112 const std::vector<kStructureTuple> BasicStructShortCutPattern{
113   {kFirstBranchPattern6,
114    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
115    {kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
116   {kSecondBranchPattern6, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
117 // Pattern 7
118 constexpr auto kFirstBranchPattern7 = 1;
119 constexpr auto kSecondBranchPattern7 = 13;
120 constexpr auto kFirstBranchStartIndexPattern7 = SIZE_MAX;
121 constexpr auto kFirstBranchEndIndexPattern7 = SIZE_MAX;
122 constexpr auto kSecondBranchStartIndexPattern7 = 7;
123 constexpr auto kSecondBranchEndIndexPattern7 = 10;
124 const std::vector<kStructureTuple> InvertedResidualShortCutPattern{
125   {kFirstBranchPattern7,
126    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
127    {kFirstBranchStartIndexPattern7, kFirstBranchEndIndexPattern7}},
128   {kSecondBranchPattern7,
129    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
130     prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
131     prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm},
132    {kSecondBranchStartIndexPattern7, kSecondBranchEndIndexPattern7}}};
133 // Pattern 8
134 constexpr auto kFirstBranchPattern8 = 4;
135 constexpr auto kFirstBranchStartIndexPattern8 = 0;
136 constexpr auto kFirstBranchEndIndexPattern8 = 3;
137 const std::vector<kStructureTuple> InvertedResidualPattern{
138   {kFirstBranchPattern8,
139    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimAdd},
140    {kFirstBranchStartIndexPattern8, kFirstBranchEndIndexPattern8}}};
141 // Pattern 9
142 constexpr auto kFirstBranchPattern9 = 1;
143 constexpr auto kSecondBranchPattern9 = 12;
144 constexpr auto kFirstBranchStartIndexPattern9 = SIZE_MAX;
145 constexpr auto kFirstBranchEndIndexPattern9 = SIZE_MAX;
146 constexpr auto kSecondBranchStartIndexPattern9 = 7;
147 constexpr auto kSecondBranchEndIndexPattern9 = 10;
148 const std::vector<kStructureTuple> InvertedResidualShortCutPattern2{
149   {kFirstBranchPattern9, {prim::kPrimAdd}, {kFirstBranchStartIndexPattern9, kFirstBranchEndIndexPattern9}},
150   {kSecondBranchPattern9,
151    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
152     prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
153     prim::kPrimConv2D, prim::kPrimAdd},
154    {kSecondBranchStartIndexPattern9, kSecondBranchEndIndexPattern9}}};
155 // Pattern 10
156 constexpr auto kFirstBranchPattern10 = 5;
157 constexpr auto kFirstBranchStartIndexPattern10 = 0;
158 constexpr auto kFirstBranchEndIndexPattern10 = 4;
159 const std::vector<kStructureTuple> InvertedResidualPattern2{
160   {kFirstBranchPattern10,
161    {prim::kPrimReduceMean, prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
162    {kFirstBranchStartIndexPattern10, kFirstBranchEndIndexPattern10}}};
163 // Pattern 11
164 constexpr auto kFirstBranchPattern11 = 17;
165 constexpr auto kFirstBranchStartIndexPattern11 = 3;
166 constexpr auto kFirstBranchEndIndexPattern11 = 6;
167 const std::vector<kStructureTuple> InvertedResidualPattern3{
168   {kFirstBranchPattern11,
169    {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6, prim::kPrimTupleGetItem,
170     prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
171     prim::kPrimRelu6, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu6,
172     prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
173    {kFirstBranchStartIndexPattern11, kFirstBranchEndIndexPattern11}}};
174 // Pattern 12
175 constexpr auto kFirstBranchPattern12 = 1;
176 constexpr auto kSecondBranchPattern12 = 9;
177 constexpr auto kFirstBranchStartIndexPattern12 = SIZE_MAX;
178 constexpr auto kFirstBranchEndIndexPattern12 = SIZE_MAX;
179 constexpr auto kSecondBranchStartIndexPattern12 = kFirstBranchPattern12 + 5;
180 constexpr auto kSecondBranchEndIndexPattern12 = kFirstBranchPattern12 + 8;
181 const std::vector<kStructureTuple> DenseBlockShortCutPattern{
182   {kFirstBranchPattern12, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern12, kFirstBranchEndIndexPattern12}},
183   {kSecondBranchPattern12,
184    {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
185     prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
186    {kSecondBranchStartIndexPattern12, kSecondBranchEndIndexPattern12}}};
187 // Pattern 13
188 constexpr auto kFirstBranchPattern13 = 5;
189 constexpr auto kFirstBranchStartIndexPattern13 = 0;
190 constexpr auto kFirstBranchEndIndexPattern13 = 4;
191 const std::vector<kStructureTuple> DenseBlockPattern{
192   {kFirstBranchPattern13,
193    {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
194    {kFirstBranchStartIndexPattern13, kFirstBranchEndIndexPattern13}}};
195 // Pattern 14
196 constexpr auto kFirstBranchPattern14 = 9;
197 constexpr auto kSecondBranchPattern14 = 1;
198 constexpr auto kFirstBranchStartIndexPattern14 = 5;
199 constexpr auto kFirstBranchEndIndexPattern14 = 8;
200 constexpr auto kSecondBranchStartIndexPattern14 = SIZE_MAX;
201 constexpr auto kSecondBranchEndIndexPattern14 = SIZE_MAX;
202 const std::vector<kStructureTuple> DenseBlockShortCutPattern2{
203   {kFirstBranchPattern14,
204    {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
205     prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConcat},
206    {kFirstBranchStartIndexPattern14, kFirstBranchEndIndexPattern14}},
207   {kSecondBranchPattern14, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern14, kSecondBranchEndIndexPattern14}}};
208 // Pattern 15
209 constexpr auto kFirstBranchPattern15 = 9;
210 constexpr auto kSecondBranchPattern15 = 1;
211 constexpr auto kFirstBranchStartIndexPattern15 = 0;
212 constexpr auto kFirstBranchEndIndexPattern15 = 4;
213 constexpr auto kSecondBranchStartIndexPattern15 = SIZE_MAX;
214 constexpr auto kSecondBranchEndIndexPattern15 = SIZE_MAX;
215 const std::vector<kStructureTuple> DenseBlockPoolPattern{
216   {kFirstBranchPattern15,
217    {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
218     prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
219    {kFirstBranchStartIndexPattern15, kFirstBranchEndIndexPattern15}},
220   {kSecondBranchPattern15, {prim::kPrimConcat}, {kSecondBranchStartIndexPattern15, kSecondBranchEndIndexPattern15}}};
221 // Pattern 16
222 constexpr auto kFirstBranchPattern16 = 1;
223 constexpr auto kSecondBranchPattern16 = 9;
224 constexpr auto kFirstBranchStartIndexPattern16 = SIZE_MAX;
225 constexpr auto kFirstBranchEndIndexPattern16 = SIZE_MAX;
226 constexpr auto kSecondBranchStartIndexPattern16 = kFirstBranchPattern16;
227 constexpr auto kSecondBranchEndIndexPattern16 = kFirstBranchPattern16 + 4;
228 const std::vector<kStructureTuple> DenseBlockPoolPatter2{
229   {kFirstBranchPattern16, {prim::kPrimConcat}, {kFirstBranchStartIndexPattern16, kFirstBranchEndIndexPattern16}},
230   {kSecondBranchPattern16,
231    {prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D,
232     prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimMaxPool},
233    {kSecondBranchStartIndexPattern16, kSecondBranchEndIndexPattern16}}};
234 static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {ResidualStructureBasePattern,
235                                                                             ResidualStructureShortCutPattern,
236                                                                             ResidualStructureFirstStepPattern,
237                                                                             BasicStructBasePattern,
238                                                                             BasicStructFirstStepPattern,
239                                                                             BasicStructShortCutPattern,
240                                                                             InvertedResidualShortCutPattern,
241                                                                             InvertedResidualPattern,
242                                                                             InvertedResidualShortCutPattern2,
243                                                                             InvertedResidualPattern2,
244                                                                             InvertedResidualPattern3,
245                                                                             DenseBlockShortCutPattern,
246                                                                             DenseBlockPattern,
247                                                                             DenseBlockShortCutPattern2,
248                                                                             DenseBlockPoolPattern,
249                                                                             DenseBlockPoolPatter2};
250 const std::set<PrimitivePtr> kNeedRemoveNodeSet{
251   prim::kPrimLoad,      prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
252   prim::kPrimApplyFtrl, prim::kPrimSGD,        prim::kPrimApplyRMSProp,  prim::kPrimAdam};
253 static std::unordered_map<RemoveNodeType, std::unordered_set<size_t>> kRemoveIndex{
254   {RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}};
255 
NeedRemove(const ParameterPtr & a,const std::vector<AnfNodePtr> & parameter_list)256 bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> &parameter_list) {
257   if (a == nullptr) {
258     return false;
259   }
260   return std::any_of(parameter_list.begin(), parameter_list.end(), [&a](const AnfNodePtr &b) {
261     return (b->isa<Parameter>() && a->name() == b->cast<ParameterPtr>()->name());
262   });
263 }
264 
IsNotRealUseNode(const AnfNodePtr & node)265 bool IsNotRealUseNode(const AnfNodePtr &node) {
266   for (const auto &prim : kNeedRemoveNodeSet) {
267     if (IsPrimitiveCNode(node, prim)) {
268       return true;
269     }
270   }
271   return false;
272 }
273 
ConvertRemoveNodeToVirtualNode(const CNodePtr & cnode)274 CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) {
275   MS_EXCEPTION_IF_NULL(cnode);
276   std::vector<AnfNodePtr> args;
277   size_t index = 0;
278   const auto &inputs = cnode->inputs();
279   auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode];
280   if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) {
281     remove_index = kRemoveIndex[RemoveNodeType::kOtherNode];
282   }
283 
284   (void)std::copy_if(
285     inputs.begin(), inputs.end(), std::back_inserter(args),
286     [&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); });
287 
288   (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
289   const auto &fg = cnode->func_graph();
290   MS_EXCEPTION_IF_NULL(fg);
291   auto new_make_tuple = fg->NewCNode(args);
292   return new_make_tuple;
293 }
294 
IsRealRemoveParameterNode(const FuncGraphManagerPtr & manager,const AnfNodePtr & parameter)295 bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &parameter) {
296   auto param_output = manager->node_users().find(parameter);
297   if (param_output == manager->node_users().end()) {
298     return true;
299   }
300 
301   bool need_remove = true;
302   auto output_info_list = param_output->second;
303   for (const auto &output_info : output_info_list) {
304     const auto &node = output_info.first;
305     if (IsNotRealUseNode(node)) {
306       const auto &cnode = node->cast<CNodePtr>();
307       const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
308       (void)manager->Replace(cnode, new_cnode);
309       continue;
310     }
311     need_remove = false;
312   }
313 
314   return need_remove;
315 }
316 
RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr & manager,const std::vector<AnfNodePtr> & remove_parameter_list)317 void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
318                                               const std::vector<AnfNodePtr> &remove_parameter_list) {
319   auto roots = manager->roots();
320   if (roots.size() != 1) {
321     MS_LOG(ERROR) << "The size of roots " << roots.size() << " is not valid.";
322     return;
323   }
324   auto root_graph = *(roots.begin());
325   MS_EXCEPTION_IF_NULL(root_graph);
326 
327   std::vector<AnfNodePtr> real_remove_parameter_list;
328   (void)std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(),
329                      std::back_inserter(real_remove_parameter_list),
330                      [&manager](const AnfNodePtr &param) { return IsRealRemoveParameterNode(manager, param); });
331 
332   auto root_parameters = root_graph->parameters();
333   size_t origin_param_count = root_parameters.size();
334   (void)root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
335                                              [&real_remove_parameter_list](const AnfNodePtr &node) {
336                                                return NeedRemove(node->cast<ParameterPtr>(),
337                                                                  real_remove_parameter_list);
338                                              }),
339                               root_parameters.end());
340   size_t remove_param_count = origin_param_count - root_parameters.size();
341   size_t hyper_param_count = root_graph->hyper_param_count();
342   if (remove_param_count > hyper_param_count) {
343     MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
344     return;
345   }
346   hyper_param_count = hyper_param_count - remove_param_count;
347   root_graph->set_hyper_param_count(hyper_param_count);
348   manager->SetParameters(root_graph, root_parameters);
349 }
350 }  // namespace
351 
MatchStructureNode(const CNodePtr & cnode,const int32_t index,const kStructureTuple & patternTuple) const352 bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index,
353                                                 const kStructureTuple &patternTuple) const {
354   if (index < 0) {
355     return false;
356   }
357   const auto &use_pattern = std::get<1>(patternTuple);
358   int32_t use_index = index % static_cast<int32_t>(use_pattern.size());
359   if (!IsPrimitiveCNode(cnode, use_pattern[IntToSize(use_index)]) &&
360       use_pattern[IntToSize(use_index)] != prim::kPrimTupleGetItem) {
361     return false;
362   }
363   return true;
364 }
365 
MatchGraphStructure(const CNodePtr & cnode,const std::vector<kStructureTuple> & match_pattern)366 bool LessBatchNormalization::MatchGraphStructure(const CNodePtr &cnode,
367                                                  const std::vector<kStructureTuple> &match_pattern) {
368   if ((match_branch_ + 1 >= total_match_node_.size()) || (match_branch_ >= match_pattern.size())) {
369     return false;
370   }
371 
372   int32_t index = static_cast<int32_t>(match_node_) - static_cast<int32_t>(total_match_node_[match_branch_]);
373   const auto &pattern = match_pattern[match_branch_];
374   if (!MatchStructureNode(cnode, index, pattern)) {
375     return false;
376   }
377 
378   match_node_++;
379   if (match_node_ == total_match_node_.back()) {
380     is_match_ = true;
381     return false;
382   }
383   if (match_node_ == total_match_node_[match_branch_ + 1]) {
384     match_branch_++;
385     return false;
386   }
387   return true;
388 }
389 
IsRemoveNode(const CNodePtr & cnode,const std::vector<kStructureTuple> & match_pattern)390 void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vector<kStructureTuple> &match_pattern) {
391   if (!IsPrimitiveCNode(cnode, prim::kPrimBatchNorm) && !IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem) &&
392       !IsValueNode<FuncGraph>(cnode->input(0))) {
393     return;
394   }
395   if (match_pattern.empty()) {
396     return;
397   }
398   const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_));
399   if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) {
400     (void)remove_node_list_.insert(cnode);
401   }
402 }
403 
operator ()(const OptimizerPtr & optimizer,const AnfNodePtr & node)404 AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) {
405   const auto &fg = node->func_graph();
406   MS_EXCEPTION_IF_NULL(fg);
407   if (!fg->has_attr(kLessBatchNormalizationPassName)) {
408     return nullptr;
409   }
410   match_pattern_ = 0;
411   while (match_pattern_ < kNeedMatchPattern.size()) {
412     Reset();
413     const auto &current_pattern = kNeedMatchPattern.at(match_pattern_);
414     size_t sum_match_node = 0;
415     (void)std::for_each(current_pattern.begin(), current_pattern.end(), [&, this](const kStructureTuple &t) {
416       sum_match_node += std::get<0>(t);
417       (void)this->total_match_node_.emplace_back(sum_match_node);
418     });
419     auto cnode = node->cast<CNodePtr>();
420     if (cnode == nullptr || cnode->inputs().empty()) {
421       return nullptr;
422     }
423     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
424     std::vector<PredicateFuncType> funcs(cnode->inputs().size() - 1, IsCNode);
425     AnfVisitor::Match(prim, funcs)(node);
426     if (is_match_) {
427       break;
428     }
429     match_pattern_++;
430   }
431 
432   if (!is_match_ || remove_node_list_.empty()) {
433     return nullptr;
434   }
435 
436   auto manager = optimizer->manager();
437   MS_EXCEPTION_IF_NULL(manager);
438   std::vector<AnfNodePtr> remove_load_list;
439   std::vector<AnfNodePtr> remove_parameter_list;
440   for (auto &iter : remove_node_list_) {
441     // Need to remove batchnorm's parameter input.
442     if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) {
443       (void)std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(),
444                          std::back_inserter(remove_load_list),
445                          [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); });
446       (void)std::transform(
447         remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list),
448         [](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); });
449     }
450     // Remove useless node.
451     auto input_cnode = iter->input(kValidResidualStructureIndex);
452     (void)manager->Replace(iter, input_cnode);
453   }
454   RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list);
455 
456   return node;
457 }
458 
Visit(const CNodePtr & cnode)459 void LessBatchNormalization::Visit(const CNodePtr &cnode) {
460   if (cnode == nullptr) {
461     return;
462   }
463 
464   const auto &current_pattern = kNeedMatchPattern.at(match_pattern_);
465   IsRemoveNode(cnode, current_pattern);
466   if (!MatchGraphStructure(cnode, current_pattern)) {
467     return;
468   }
469 
470   auto search_input = cnode->input(kValidResidualStructureIndex);
471   if (search_input != nullptr && search_input->isa<CNode>()) {
472     this->Visit(search_input->cast<CNodePtr>());
473   }
474   return;
475 }
476 
Reset()477 void LessBatchNormalization::Reset() {
478   remove_node_list_.clear();
479   total_match_node_.clear();
480   (void)total_match_node_.emplace_back(0);
481   match_node_ = 0;
482   match_branch_ = 0;
483   is_match_ = false;
484 }
485 }  // namespace irpass
486 }  // namespace opt
487 }  // namespace mindspore
488