1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/irpass.h"
18 #include "mindspore/core/ops/structure_ops.h"
19 #include "mindspore/core/ops/sparse_tensor_ops.h"
20 #include "mindspore/core/ops/sequence_ops.h"
21 #include "mindspore/core/ops/conv_pool_ops.h"
22 #include "mindspore/core/ops/other_ops.h"
23 #include "mindspore/core/ops/nn_optimizer_ops.h"
24 #include "mindspore/core/ops/math_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/arithmetic_ops.h"
27 #include "mindspore/core/ops/framework_ops.h"
28 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
29 #include "frontend/optimizer/irpass/arithmetic_simplify.h"
30 #include "frontend/optimizer/irpass/branch_culling.h"
31 #include "frontend/optimizer/irpass/cast_eliminate.h"
32 #include "frontend/optimizer/irpass/get_grad_eliminate.h"
33 #include "frontend/optimizer/irpass/print_converter.h"
34 #include "frontend/optimizer/irpass/environ_eliminate.h"
35 #include "frontend/optimizer/irpass/inline.h"
36 #include "frontend/optimizer/irpass/updatestate_eliminate.h"
37 #include "frontend/optimizer/irpass/load_eliminate.h"
38 #include "frontend/optimizer/irpass/stopgrad_eliminate.h"
39 #include "frontend/optimizer/irpass/incorporate_call.h"
40 #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h"
41 #include "frontend/optimizer/irpass/seqence_to_sequence_op_eliminate.h"
42 #include "frontend/optimizer/irpass/item_dict_eliminate.h"
43 #include "frontend/optimizer/irpass/merge_addn.h"
44 #include "frontend/optimizer/irpass/accumulaten_eliminate.h"
45 #include "frontend/optimizer/irpass/less_batch_normalization.h"
46 #include "frontend/optimizer/irpass/minmax_grad.h"
47 #include "frontend/optimizer/irpass/param_replace.h"
48 #include "frontend/optimizer/irpass/partial_eliminate.h"
49 #include "frontend/optimizer/irpass/reduce_eliminate.h"
50 #include "frontend/optimizer/irpass/reshape_eliminate.h"
51 #include "frontend/optimizer/irpass/special_op_eliminate.h"
52 #include "frontend/optimizer/irpass/specialize_transform.h"
53 #include "frontend/optimizer/irpass/symbol_resolver.h"
54 #include "frontend/optimizer/irpass/tile_eliminate.h"
55 #include "frontend/optimizer/irpass/transpose_eliminate.h"
56 #include "frontend/optimizer/irpass/value_based_eliminate.h"
57 #include "frontend/optimizer/irpass/pynative_no_grad_eliminate.h"
58 #include "frontend/optimizer/opt.h"
59 #include "frontend/optimizer/irpass/row_tensor_eliminate.h"
60 #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
61 #include "frontend/optimizer/irpass/stack_unstack_eliminate.h"
62 #include "frontend/optimizer/irpass/mutable_eliminate.h"
63 #include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
64 #include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
65 #include "frontend/optimizer/irpass/recompute_prepare.h"
66 #include "frontend/optimizer/irpass/real_op_eliminate.h"
67 #include "frontend/optimizer/irpass/convert_tensor_eliminate.h"
68 #include "frontend/optimizer/irpass/recompute.h"
69 #include "frontend/optimizer/irpass/grad_partial_transform.h"
70 #include "frontend/optimizer/irpass/symbol_engine_optimizer.h"
71 #include "frontend/optimizer/irpass/const_output_eliminate.h"
72 #include "frontend/optimizer/irpass/slice_to_tuple.h"
73 #include "frontend/optimizer/irpass/j_node_and_user_rematch.h"
74
75 namespace mindspore {
76 namespace opt {
77 namespace irpass {
OptimizeIRPassLib()78 OptimizeIRPassLib::OptimizeIRPassLib() {
79 arithmetic_simplify_ = MakeSubstitution(std::make_shared<ArithmeticSimplify>(), "arithmetic_simplify",
80 {prim::kPrimScalarAdd, prim::kPrimScalarMul, prim::kPrimAdd,
81 prim::kPrimidentity, prim::kPrimMomentum, prim::kPrimMul, prim::kPrimPow});
82 special_op_eliminate_ = MakeSubstitution(
83 std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
84 {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimCellBackwardHook, prim::kPrimPrintShapeType});
85 mutable_op_eliminate_ =
86 MakeSubstitution(std::make_shared<MutableEliminater>(), "mutable_eliminate", prim::kPrimMutable);
87 ad_related_special_op_eliminate_ =
88 MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "ad_related_special_op_eliminate",
89 {prim::kPrimMirror, prim::kPrimVirtualDiv, prim::kPrimStopGradient});
90 pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup);
91 pynative_no_grad_eliminate_ =
92 MakeSubstitution(std::make_shared<PynativeNoGradEliminater>(), "pynative_no_grad_eliminate", prim::kPrimMakeTuple);
93 zero_like_fill_zero_ =
94 MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
95 adjust_all_reduce_mul_add_ =
96 MakeSubstitution(std::make_shared<AdjustAllReduceMulAdd>(), "adjust_all_reduce_mul_add", prim::kPrimAddN);
97 float_depend_g_call_ = MakeSubstitution(std::make_shared<FloatDependGCall>(), "float_depend_g_call", IsCNodeDup);
98
99 // ops eliminate
100 tuple_list_get_item_eliminator_ =
101 MakeSubstitution(std::make_shared<TupleListGetitemEliminator>(), "tuple_list_get_item_eliminator",
102 {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
103 tuple_list_get_item_const_eliminator_ =
104 MakeSubstitution(std::make_shared<TupleListGetitemConstEliminator>(), "tuple_list_get_item_const_eliminator",
105 {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
106 tuple_list_set_item_eliminator_ =
107 MakeSubstitution(std::make_shared<TupleListSetitemEliminator>(), "tuple_list_set_item_eliminator",
108 {prim::kPrimTupleSetItem, prim::kPrimListSetItem});
109 tuple_list_get_set_item_eliminator_ =
110 MakeSubstitution(std::make_shared<TupleListGetSetitemEliminator>(), "tuple_list_get_set_item_eliminator",
111 {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
112 tuple_list_get_item_depend_reorder_ =
113 MakeSubstitution(std::make_shared<TupleListGetitemDependReorder>(), "tuple_list_get_item_depend_reorder",
114 {prim::kPrimTupleGetItem, prim::kPrimListGetItem});
115 list_to_tuple_eliminator_ =
116 MakeSubstitution(std::make_shared<ListToTupleEliminator>(), "list_to_tuple_eliminator_", {prim::kPrimListToTuple});
117 tuple_to_list_eliminator_ =
118 MakeSubstitution(std::make_shared<TupleToListEliminator>(), "tuple_to_list_eliminator_", {prim::kPrimTupleToList});
119 tuple_list_convert_item_index_to_positive_ = MakeSubstitution(
120 std::make_shared<TupleListConvertItemIndexToPositive>(), "tuple_list_convert_item_index_to_positive",
121 {prim::kPrimTupleGetItem, prim::kPrimTupleSetItem, prim::kPrimListGetItem, prim::kPrimListSetItem});
122 make_slice_get_slice_eliminator_ = MakeSubstitution(std::make_shared<MakeSliceSliceGetItemEliminator>(),
123 "make_slice_get_slice_eliminator", {prim::kPrimSliceGetItem});
124 slice_to_tuple_ = MakeSubstitution(std::make_shared<SliceToTuple>(), "make_slice_get_slice_eliminator",
125 {prim::kPrimSliceGetItem, prim::kPrimMakeSlice});
126 dict_get_item_eliminator_ =
127 MakeSubstitution(std::make_shared<DictGetitemEliminator>(), "dict_get_item_eliminator", prim::kPrimDictGetItem);
128 dict_get_item_const_eliminator_ = MakeSubstitution(std::make_shared<DictGetitemConstEliminator>(),
129 "dict_get_item_const_eliminator", prim::kPrimDictGetItem);
130 dict_set_item_eliminator_ =
131 MakeSubstitution(std::make_shared<DictSetitemEliminator>(), "dict_set_item_eliminator", prim::kPrimDictSetItem);
132 stack_unstack_eliminate_ =
133 MakeSubstitution(std::make_shared<StackUnstackEliminator>(), "stack_unstack_eliminate", prim::kPrimUnstack);
134 tile_eliminate_ = MakeSubstitution(std::make_shared<TileEliminater>(), "tile_eliminate", prim::kPrimTile);
135 cast_eliminate_ = MakeSubstitution(std::make_shared<CastEliminater>(), "cast_eliminate", prim::kPrimCast);
136 get_grad_eliminate_ =
137 MakeSubstitution(std::make_shared<GetGradEliminater>(), "get_grad_eliminate", prim::kPrimGetGrad);
138 reshape_eliminate_ = MakeSubstitution(std::make_shared<ReshapeEliminater>(), "reshape_eliminate", prim::kPrimReshape);
139 transpose_eliminate_ =
140 MakeSubstitution(std::make_shared<TransposeSameIOEliminater>(), "transpose_eliminate", prim::kPrimTranspose);
141 reduce_eliminate_ = MakeSubstitution(
142 std::make_shared<ReduceOneEliminater>(), "reduce_eliminate",
143 {prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
144 partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
145 same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
146 mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
147 "mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
148 micro_step_allgather_replace_ = MakeSubstitution(std::make_shared<MicroStepAllGatherPass>(),
149 "micro_step_allgather_replace", prim::kPrimMicroStepAllGather);
150 check_bprop_eliminate_ =
151 MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
152 reset_defer_inline_ =
153 MakeSubstitution(std::make_shared<ResetDeferInline>(), "reset_defer_inline", IsValueNode<FuncGraph>);
154 const_output_eliminate_ =
155 MakeSubstitution(std::make_shared<ConstOutputEliminater>(), "const_output_eliminate", IsValueNode<FuncGraph>);
156 depend_value_elim_ = MakeSubstitution(std::make_shared<DependValueElim>(), "depend_value_elim", prim::kPrimDepend);
157 all_reduce_const_elim_ =
158 MakeSubstitution(std::make_shared<AllReduceConstElim>(), "reduce_all_const_elim", prim::kPrimAllReduce);
159 real_op_eliminate_ = MakeSubstitution(std::make_shared<RealOpEliminate>(), "real_op_eliminate", prim::kPrimRealInner);
160 convert_tensor_eliminate_ = MakeSubstitution(std::make_shared<ConvertTensorEliminate>(), "convert_tensor_eliminate",
161 {prim::kPrimConvertToAdapterTensor, prim::kPrimConvertToMsTensor});
162 convert_tensor_all_eliminate_ =
163 MakeSubstitution(std::make_shared<ConvertTensorAllEliminate>(), "convert_tensor_all_eliminate",
164 {prim::kPrimConvertToAdapterTensor, prim::kPrimConvertToMsTensor});
165
166 // Environ Item Eliminate
167 environ_get_eliminate_ =
168 MakeSubstitution(std::make_shared<EnvironGetEliminater>(), "environ_get_eliminate", prim::kPrimEnvironGet);
169 environ_get_add_eliminate_ =
170 MakeSubstitution(std::make_shared<EnvironGetAddEliminater>(), "environ_get_add_eliminate", prim::kPrimEnvironGet);
171 environ_get_set_eliminate_ =
172 MakeSubstitution(std::make_shared<EnvironGetSetEliminater>(), "environ_get_set_eliminate", prim::kPrimEnvironGet);
173 environ_get_depend_swap_ =
174 MakeSubstitution(std::make_shared<EnvironGetDependSwap>(), "environ_get_depend_swap", prim::kPrimEnvironGet);
175 environ_add_const_eliminate_ = MakeSubstitution(std::make_shared<EnvironAddConstEliminater>(),
176 "environ_add_const_eliminate_", prim::kPrimEnvironAdd);
177 split_environ_get_set_with_tuple_value_ =
178 MakeSubstitution(std::make_shared<SplitEnvironGetSetWithTupleValue>(), "split_environ_get_set_with_tuple_value",
179 {prim::kPrimEnvironGet, prim::kPrimEnvironSet});
180
181 // Ref eliminate
182 replace_old_param_ = MakeSubstitution(std::make_shared<ReplaceOldParam>(), "replace_old_param", IsParam);
183
184 // Gradient
185 minmaximum_grad_ = MakeSubstitution(std::make_shared<MinMaximumGrad>(), "minmaximum_grad", prim::kPrimTupleGetItem);
186 j_node_and_user_rematch_ =
187 MakeSubstitution(std::make_shared<JNodeAndUserRematch>(), "j_node_and_user_rematch", IsCNode);
188
189 // branch culling
190 switch_simplify_ = MakeSubstitution(std::make_shared<SwitchSimplify>(), "switch_simplify", prim::kPrimSwitch);
191 compare_switch_simplify_ =
192 MakeSubstitution(std::make_shared<CompareSwitchSimplify>(), "compare_switch_simplify", prim::kPrimSwitch);
193 float_tuple_getitem_switch_ = MakeSubstitution(std::make_shared<FloatTupleGetItemSwitch>(),
194 "float_tuple_getitem_switch", prim::kPrimTupleGetItem);
195 float_environ_get_switch_ =
196 MakeSubstitution(std::make_shared<FloatEnvironGetSwitch>(), "float_environ_get_switch", prim::kPrimEnvironGet);
197 exchange_switch_depend_value_ =
198 MakeSubstitution(std::make_shared<ExchangeSwitchDependValue>(), "exchange_switch_depend_value", prim::kPrimSwitch);
199
200 switch_partial_eliminater_ =
201 MakeSubstitution(std::make_shared<SwitchPartialEliminater>(), "eliminate_switch_partial_", IsCNodeDup);
202 switch_layer_partial_eliminater_ =
203 MakeSubstitution(std::make_shared<SwitchLayerPartialEliminater>(), "eliminate_switch_layer_partial_", IsCNodeDup);
204
205 // Addn
206 merge_addn_ = MakeSubstitution(std::make_shared<MergeAddN>(), "merge_addn", prim::kPrimAddN);
207 addn_zero_filter_ = MakeSubstitution(std::make_shared<AddNZeroFilter>(), "addn_zero_filter", prim::kPrimAddN);
208 addn_check_dump_ = MakeSubstitution(std::make_shared<AddNCheckDump>(), "addn_check_dump", prim::kPrimAddN);
209
210 // AccumulateNV2
211 accumulaten_eliminater_ =
212 MakeSubstitution(std::make_shared<AccumulateNV2Eliminater>(), "accumulaten_eliminater", prim::kPrimAccumulateNV2);
213
214 // Accelerated Algorithm
215 less_batch_normalization_ =
216 MakeSubstitution(std::make_shared<LessBatchNormalization>(), "less_batch_normalization",
217 {prim::kPrimAdd, prim::kPrimReLU6, prim::kPrimMatMul, prim::kPrimMakeTuple, prim::kPrimMaxPool});
218
219 // inline
220 inline_ = MakeSubstitution(std::make_shared<Inliner>(), "inline", IsCNodeGraph);
221 inline_without_move_ = MakeSubstitution(std::make_shared<DirectInliner>(false), "inline", IsCNodeGraph);
222 replace_applicator_ =
223 MakeSubstitution(std::make_shared<ReplaceApplicator>(), "replace_applicator", IsValueNode<FuncGraph>);
224 specialize_transform_ =
225 MakeSubstitution(std::make_shared<SpecializeOnGraphArguments>(), "specialize_transform", IsCNodeGraph);
226
227 // UpdateState eliminate
228 updatestate_useless_node_eliminater_ =
229 MakeSubstitution(std::make_shared<UpdatestateUselessNodeEliminater>(), "updatestate_useless_node_eliminater",
230 prim::kPrimUpdateState);
231 updatestate_pure_node_eliminater_ = MakeSubstitution(std::make_shared<UpdatestatePureNodeEliminater>(),
232 "updatestate_pure_node_eliminater", prim::kPrimUpdateState);
233 switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared<SwitchCallMonadParameterEliminater>(),
234 "switch_call_monad_eliminater", IsCNodeDup);
235
236 // Load eliminate
237 load_eliminater_ = MakeSubstitution(std::make_shared<LoadEliminater>(), "load_eliminater", prim::kPrimLoad);
238
239 // StopGradient eliminate
240 stopgrad_eliminater_ =
241 MakeSubstitution(std::make_shared<StopGradientEliminater>(), "stopgrad_eliminater", prim::kPrimStopGradient);
242
243 // Incorporation
244 incorporate_call_ = MakeSubstitution(std::make_shared<IncorporateCall>(), "incorporate_call", IsCNodeDup);
245 incorporate_call_switch_ =
246 MakeSubstitution(std::make_shared<IncorporateCallSwitch>(), "incorporate_call_switch", IsCNodeDup);
247
248 // Virtual Dataset
249 virtual_dataset_eliminate_ = MakeSubstitution(std::make_shared<VirtualDatasetEliminater>(),
250 "virtual_dataset_eliminate", prim::kPrimVirtualDataset);
251
252 // Virtual Output
253 virtual_output_eliminate_ =
254 MakeSubstitution(std::make_shared<VirtualOutputEliminater>(), "virtual_output_eliminate", prim::kPrimVirtualOutput);
255
256 // Virtual Shard Identity
257 virtual_shard_identity_ = MakeSubstitution(std::make_shared<AShardIdentityEliminater>(), "shard_identity_eliminate",
258 prim::kPrimAShardIdentity);
259
260 // PipelineSplit
261 parallel_virtual_node_ = MakeSubstitution(
262 std::make_shared<ParallelVirtualNodeEliminater>(), "parallel_virtual_node",
263 {prim::kPrimVirtualAssignAdd, prim::kPrimVirtualPipelineEnd, prim::kPrimVirtualAccuGrad, prim::kPrimMirrorMicroStep,
264 prim::kPrimVirtualAdd, prim::kPrimMirrorMiniStep, prim::kPrimMirrorSilentCheck});
265
266 // Convert
267 print_tuple_wrapper_ =
268 MakeSubstitution(std::make_shared<PrintTupleWrapper>(), "print_tuple_wrapper", prim::kPrimPrint);
269
270 print_const_string_wrapper_ =
271 MakeSubstitution(std::make_shared<PrintConstStringWrapper>(), "print_const_string_wrapper", prim::kPrimPrint);
272
273 // tuple parameter graph transform
274 call_graph_tuple_transform_ =
275 MakeSubstitution(std::make_shared<CallGraphSequenceTransform>(), "graph_param_transform", IsNode);
276
277 // RowTensor Eliminate
278 row_tensor_eliminate_ = MakeSubstitution(
279 std::make_shared<RowTensorEliminater>(), "row_tensor_eliminate",
280 {prim::kPrimRowTensorGetIndices, prim::kPrimRowTensorGetValues, prim::kPrimRowTensorGetDenseShape});
281
282 // RowTensorAddZerosLike Eliminate
283 row_tensor_add_zeros_like_ =
284 MakeSubstitution(std::make_shared<RowTensorAddZerosLike>(), "row_tensor_add_zeros_like", prim::kPrimRowTensorAdd);
285
286 // SparseTensor Eliminate
287 sparse_tensor_eliminate_ = MakeSubstitution(
288 std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
289 {prim::kPrimCOOTensorGetIndices, prim::kPrimCOOTensorGetValues, prim::kPrimCOOTensorGetDenseShape});
290
291 // Value_Based Eliminate
292 value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
293 {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
294 // Partial func graph input defer inline
295 partial_defer_inline_ =
296 MakeSubstitution(std::make_shared<PartialDeferInline>(), "partial_defer_inline", prim::kPrimPartial);
297
298 // Switch func graph input defer inline
299 switch_defer_inline_ =
300 MakeSubstitution(std::make_shared<SwitchDeferInline>(), "switch_defer_inline", prim::kPrimSwitch);
301
302 // SwitchLayer func graph input defer inline
303 switch_layer_defer_inline_ =
304 MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
305
306 // Recompute
307 set_cell_output_no_recompute_ = MakeSubstitution(std::make_shared<SetCellOutputNoRecompute>(),
308 "set_cell_output_no_recompute", IsValueNode<FuncGraph>);
309 remove_not_recompute_node_ =
310 MakeSubstitution(std::make_shared<RemoveNotRecomputeNode>(), "remove_not_recompute_node", IsCNode);
311
312 // Optimize with SymbolEngine
313 elim_shapecalc_of_broadcastargs_ = MakeSubstitution(std::make_shared<ElimShapeCalcOnBroadcastArgsGrad>(),
314 "elim_shapecalc_of_broadcastargs", prim::kPrimReduceSum);
315 elim_not_effective_node_ = MakeSubstitution(std::make_shared<ElimNotEffectiveNode>(), "elim_not_effective", IsCNode);
316 opt_reshape_ = MakeSubstitution(std::make_shared<OptReshape>(), "opt_reshape", prim::kPrimReshape);
317 fold_const_symbol_ = MakeSubstitution(std::make_shared<FoldConstSymbol>(), "fold_const_symbol", IsCNode);
318 }
319
ResolveIRPassLib()320 ResolveIRPassLib::ResolveIRPassLib() {
321 // In resolver_, some patterns have priority over others.
322 resolver_ = MakeSubstitution(std::make_shared<Resolver>(), "getattr_setattr_resolve",
323 {prim::kPrimGetAttr, prim::kPrimSetAttr, prim::kPrimResolve}, opt::CHECK_RENORM, true);
324 }
325
GradPartialPassLib()326 GradPartialPassLib::GradPartialPassLib() {
327 grad_partial_transform_ =
328 MakeSubstitution(std::make_shared<GradPartialTransform>(), "grad_partial_transform", IsCNode);
329 }
330 } // namespace irpass
331 } // namespace opt
332 } // namespace mindspore
333