• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass.h"
18 #include "frontend/optimizer/irpass/arithmetic_simplify.h"
19 #include "frontend/optimizer/irpass/branch_culling.h"
20 #include "frontend/optimizer/irpass/cast_eliminate.h"
21 #include "frontend/optimizer/irpass/convert.h"
22 #include "frontend/optimizer/irpass/env_item_eliminate.h"
23 #include "frontend/optimizer/irpass/grad_var_prepare.h"
24 #include "frontend/optimizer/irpass/gradient_eliminate.h"
25 #include "frontend/optimizer/irpass/inline.h"
26 #include "frontend/optimizer/irpass/updatestate_eliminate.h"
27 #include "frontend/optimizer/irpass/load_eliminate.h"
28 #include "frontend/optimizer/irpass/stopgrad_eliminate.h"
29 #include "frontend/optimizer/irpass/incorporate_call.h"
30 #include "frontend/optimizer/irpass/incorporate_getitem.h"
31 #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
32 #include "frontend/optimizer/irpass/merge_addn.h"
33 #include "frontend/optimizer/irpass/accumulaten_eliminate.h"
34 #include "frontend/optimizer/irpass/less_batch_normalization.h"
35 #include "frontend/optimizer/irpass/minmax_grad.h"
36 #include "frontend/optimizer/irpass/param_replace.h"
37 #include "frontend/optimizer/irpass/partial_eliminate.h"
38 #include "frontend/optimizer/irpass/reduce_eliminate.h"
39 #include "frontend/optimizer/irpass/ref_eliminate.h"
40 #include "frontend/optimizer/irpass/reshape_eliminate.h"
41 #include "frontend/optimizer/irpass/special_op_eliminate.h"
42 #include "frontend/optimizer/irpass/specialize_transform.h"
43 #include "frontend/optimizer/irpass/symbol_resolver.h"
44 #include "frontend/optimizer/irpass/tile_eliminate.h"
45 #include "frontend/optimizer/irpass/transpose_eliminate.h"
46 #include "frontend/optimizer/irpass/value_based_eliminate.h"
47 #include "frontend/optimizer/opt.h"
48 #include "frontend/optimizer/irpass/row_tensor_eliminate.h"
49 #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
50 #include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
51 #include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
52 #include "frontend/optimizer/irpass/recompute_prepare.h"
53 
54 namespace mindspore {
55 namespace opt {
56 namespace irpass {
OptimizeIRPassLib()57 OptimizeIRPassLib::OptimizeIRPassLib() {
58   arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
59                                           {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimAdd,
60                                            prim::kPrimIdentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
61   arithmetic_simplify2_ =
62     MakeSubstitution(std::make_shared<ArithmeticSimplify2>(), "arithmetic_simplify2", {prim::kPrimMul});
63   special_op_eliminate_ =
64     MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
65                      {prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
66                       prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv});
67   pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup);
68   zero_like_fill_zero_ =
69     MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
70   adjust_all_reduce_mul_add_ =
71     MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
72   float_depend_g_call_ = MakeSubstitution(std::make_shared<FloatDependGCall>(), "float_depend_g_call", IsCNodeDup);
73 
74   // ops eliminate
75   tuple_list_get_item_eliminator_ =
76     MakeSubstitution(std::make_shared<TupleListGetitemEliminator>(), "tuple_list_get_item_eliminator",
77                      {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
78   tuple_list_get_item_const_eliminator_ =
79     MakeSubstitution(std::make_shared<TupleListGetitemConstEliminator>(), "tuple_list_get_item_const_eliminator",
80                      {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
81   tuple_list_set_item_eliminator_ =
82     MakeSubstitution(std::make_shared<TupleListSetitemEliminator>(), "tuple_list_set_item_eliminator",
83                      {prim::kPrimTupleSetItem, prim::kPrimListSetItem});
84   tuple_list_get_set_item_eliminator_ =
85     MakeSubstitution(std::make_shared<TupleListGetSetitemEliminator>(), "tuple_list_get_set_item_eliminator",
86                      {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
87   tuple_list_get_item_depend_reorder_ =
88     MakeSubstitution(std::make_shared<TupleListGetitemDependReorder>(), "tuple_list_get_item_depend_reorder",
89                      {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
90   tuple_list_convert_item_index_to_positive_ = MakeSubstitution(
91     std::make_shared<TupleListConvertItemIndexToPositive>(), "tuple_list_convert_item_index_to_positive",
92     {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
93 
94   tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
95   cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
96   reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
97   transpose_eliminate_ =
98     MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
99   reduce_eliminate_ = MakeSubstitution(
100     std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
101     {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
102   partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
103   same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
104   mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
105                                             prim::kPrimMirrorMiniStep);
106   mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
107                                                   "mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
108   micro_step_allgather_replace_ = MakeSubstitution(std::make_shared<MicroStepAllGatherPass>(),
109                                                    "micro_step_allgather_replace", prim::kPrimMicroStepAllGather);
110   virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd);
111   check_bprop_eliminate_ =
112     MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
113   reset_defer_inline_ =
114     MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
115   depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
116   all_reduce_const_elim_ =
117     MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);
118 
119   // Env Item Eliminate
120   env_get_item_eliminate_ =
121     MakeSubstitution(std::make_shared<EnvGetItemEliminater>(), "env_get_item_eliminate", prim::kPrimEnvGetItem);
122   env_get_item_add_eliminate_ =
123     MakeSubstitution(std::make_shared<EnvGetItemAddEliminater>(), "env_get_item_add_eliminate_", prim::kPrimEnvGetItem);
124   env_get_set_item_eliminate_ =
125     MakeSubstitution(std::make_shared<EnvGetSetItemEliminater>(), "env_get_set_item_eliminate", prim::kPrimEnvGetItem);
126   env_get_item_depend_swap_ =
127     MakeSubstitution(std::make_shared<EnvGetItemDependSwap>(), "env_get_item_depend_swap", prim::kPrimEnvGetItem);
128 
129   incorporate_env_getitem_bypass_recursive_ =
130     MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(true), "incorporate_env_get_item", prim::kPrimEnvGetItem);
131   incorporate_env_getitem_switch_ = MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitch>(),
132                                                      "incorporate_env_getitem_switch", prim::kPrimEnvGetItem);
133   incorporate_env_getitem_ =
134     MakeSubstitution(std::make_shared<IncorporateEnvGetitem>(), "incorporate_env_get_item", prim::kPrimEnvGetItem);
135 
136   incorporate_env_getitem_switch_layer_ =
137     MakeSubstitution(std::make_shared<IncorporateEnvGetitemSwitchLayer>(), "incorporate_env_getitem_switch_layer",
138                      prim::kPrimEnvGetItem);
139 
140   // Ref eliminate
141   make_ref_eliminate_ =
142     MakeSubstitution(std::make_shared<MakeRefEliminater>(), "make_ref_eliminate", prim::kPrimMakeRef);
143   get_ref_param_eliminate_ =
144     MakeSubstitution(std::make_shared<GetRefParamEliminater>(), "get_ref_param_eliminate", {prim::kPrimGetRefValue});
145   get_make_ref_eliminate_ = MakeSubstitution(std::make_shared<GetMakeRefEliminater>(), "get_make_ref_eliminate",
146                                              {prim::kPrimGetRefKey, prim::kPrimGetRefValue});
147 
148   replace_refkey_by_param_ = MakeSubstitution(std::make_shared<ReplaceRefkeyByParam>(), "replace_refkey_by_param",
149                                               IsValueNode<RefKey>, opt::FORCE_RENORM);
150   replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
151   minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
152 
153   // branch culling
154   switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
155   float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
156                                                  "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
157   float_env_getitem_switch_ =
158     MakeSubstitution(std::make_shared<FloatEnvGetItemSwitch>(), "float_env_getitem_switch", prim::kPrimEnvGetItem);
159   exchange_switch_depend_value_ =
160     MakeSubstitution(std::make_shared<ExchangeSwitchDependValue>(), "exchange_switch_depend_value", prim::kPrimSwitch);
161 
162   switch_partial_eliminater_ =
163     MakeSubstitution(std::make_shared<SwitchPartialEliminater>(), "eliminate_switch_partial_", IsCNodeDup);
164   switch_layer_partial_eliminater_ =
165     MakeSubstitution(std::make_shared<SwitchLayerPartialEliminater>(), "eliminate_switch_layer_partial_", IsCNodeDup);
166 
167   // Addn
168   merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
169   addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
170 
171   // AccumulateNV2
172   accumulaten_eliminater_ =
173     MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2);
174 
175   // Accelerated Algorithm
176   less_batch_normalization_ =
177     MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization",
178                      {prim::kPrimAdd, prim::kPrimRelu6, prim::kPrimMatMul, prim::kPrimMakeTuple, prim::kPrimMaxPool});
179 
180   // inline
181   inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
182   inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
183   replace_applicator_ =
184     MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
185   specialize_transform_ =
186     MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
187 
188   // UpdateState eliminate
189   updatestate_useless_node_eliminater_ =
190     MakeSubstitution(std::make_shared<UpdatestateUselessNodeEliminater>(), "updatestate_useless_node_eliminater",
191                      prim::kPrimUpdateState);
192   updatestate_pure_node_eliminater_ = MakeSubstitution(std::make_shared<UpdatestatePureNodeEliminater>(),
193                                                        "updatestate_pure_node_eliminater", prim::kPrimUpdateState);
194   switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
195                                                    "switch_call_monad_eliminater", IsCNodeDup);
196 
197   // Load eliminate
198   load_eliminater_ = MakeSubstitution(std::make_shared<LoadEliminater>(), "load_eliminater", prim::kPrimLoad);
199 
200   // StopGradient eliminate
201   stopgrad_eliminater_ =
202     MakeSubstitution(std::make_shared<StopGradientEliminater>(), "stopgrad_eliminater", prim::kPrimStopGradient);
203 
204   // Incorporation
205   incorporate_getitem_set_ =
206     MakeSubstitution(std::make_shared<IncorporateGetitemSet>(), "incorporate_getitem_set", prim::kPrimTupleGetItem);
207   incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
208   incorporate_call_switch_ =
209     MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
210 
211   // Virtual Dataset
212   virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
213                                                 "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
214   // Virtual Dataset
215   virtual_output_eliminate_ =
216     MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
217 
218   // PipelineSplit
219   receive_eliminate_ = MakeSubstitution(std::make_shared<ReceiveEliminater>(), "receive_eliminate", prim::kPrimReceive);
220   virtual_accu_grad_ =
221     MakeSubstitution(std::make_shared<VirtualAccuGradEliminater>(), "virtual_accu_grad", prim::kPrimVirtualAccuGrad);
222   virtual_assign_add_ =
223     MakeSubstitution(std::make_shared<VirtualAssignAddEliminater>(), "virtual_assign_add", prim::kPrimVirtualAssignAdd);
224   mirror_micro_step_ =
225     MakeSubstitution(std::make_shared<MirrorMicroStepEliminater>(), "mirror_micro_step", prim::kPrimMirrorMicroStep);
226 
227   // Convert
228   print_tuple_wrapper_ =
229     MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
230 
231   // tuple parameter graph transform
232   call_graph_tuple_transform_ =
233     MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
234 
235   // RowTensor Eliminate
236   row_tensor_eliminate_ = MakeSubstitution(
237     std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",
238     {prim::kPrimRowTensorGetIndices, prim::kPrimRowTensorGetValues, prim::kPrimRowTensorGetDenseShape});
239 
240   // RowTensorAddZerosLike Eliminate
241   row_tensor_add_zeros_like_ =
242     MakeSubstitution(std::make_shared<RowTensorAddZerosLike>(), "row_tensor_add_zeros_like", prim::kPrimRowTensorAdd);
243 
244   // SparseTensor Eliminate
245   sparse_tensor_eliminate_ = MakeSubstitution(
246     std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
247     {prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
248 
249   // Value_Based Eliminate
250   value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
251                                             {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
252 
253   // switch defer inline
254   switch_defer_inline_ =
255     MakeSubstitution(std::make_shared<SwitchDeferInline>(), "switch_defer_inline", prim::kPrimSwitch);
256 
257   // switch_layer defer inline
258   switch_layer_defer_inline_ =
259     MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
260 
261   // recompute
262   set_cell_output_no_recompute_ = MakeSubstitution(std::make_shared<SetCellOutputNoRecompute>(),
263                                                    "set_cell_output_no_recompute", IsValueNode<FuncGraph>);
264 }
265 
ResolveIRPassLib()266 ResolveIRPassLib::ResolveIRPassLib() {
267   // In resolver_getattr_resolve_, some patterns have priority over others.
268   resolver_getattr_resolve_ = MakeSubstitution(std::make_shared<ResolverGetAttrResolve>(), "getattr_resolve",
269                                                {prim::kPrimGetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
270 }
271 
InferenceOptPrepareLib()272 InferenceOptPrepareLib::InferenceOptPrepareLib() {
273   grad_var_prepare_ = MakeSubstitution(std::make_shared<GradVarPrepare>(), "grad_var_prepare", IsCNode);
274 }
275 }  // namespace irpass
276 }  // namespace opt
277 }  // namespace mindspore
278