• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_
19 
20 #include <memory>
21 
22 #include "frontend/optimizer/optimizer.h"
23 #include "mindspore/core/ops/structure_ops.h"
24 #include "mindspore/core/ops/framework_ops.h"
25 #include "frontend/optimizer/opt.h"
26 #include "frontend/optimizer/anf_visitor.h"
27 
28 namespace mindspore {
29 namespace opt {
30 namespace irpass {
31 // the collection of irpass for optimie action
32 class OptimizeIRPassLib {
33  public:
34   OptimizeIRPassLib();
35   ~OptimizeIRPassLib() = default;
36 
37   SubstitutionPtr arithmetic_simplify_;
38   SubstitutionPtr special_op_eliminate_;
39   SubstitutionPtr ad_related_special_op_eliminate_;
40   SubstitutionPtr zero_like_fill_zero_;
41   SubstitutionPtr mutable_op_eliminate_;
42   SubstitutionPtr adjust_all_reduce_mul_add_;
43   SubstitutionPtr float_depend_g_call_;
44   //  ops eliminate
45   SubstitutionPtr tuple_list_get_item_eliminator_;
46   SubstitutionPtr tuple_list_get_item_const_eliminator_;
47   SubstitutionPtr tuple_list_set_item_eliminator_;
48   SubstitutionPtr tuple_list_get_set_item_eliminator_;
49   SubstitutionPtr tuple_list_get_item_depend_reorder_;
50   SubstitutionPtr list_to_tuple_eliminator_;
51   SubstitutionPtr tuple_to_list_eliminator_;
52   SubstitutionPtr tuple_list_convert_item_index_to_positive_;
53   SubstitutionPtr make_slice_get_slice_eliminator_;
54   SubstitutionPtr slice_to_tuple_;
55   SubstitutionPtr dict_get_item_eliminator_;
56   SubstitutionPtr dict_get_item_const_eliminator_;
57   SubstitutionPtr dict_set_item_eliminator_;
58 
59   SubstitutionPtr stack_unstack_eliminate_;
60   SubstitutionPtr tile_eliminate_;
61   SubstitutionPtr cast_eliminate_;
62   SubstitutionPtr reshape_eliminate_;
63   SubstitutionPtr transpose_eliminate_;
64   SubstitutionPtr reduce_eliminate_;
65   SubstitutionPtr partial_eliminate_;
66   SubstitutionPtr same_eliminate_;
67   SubstitutionPtr check_bprop_eliminate_;
68   SubstitutionPtr reset_defer_inline_;
69   SubstitutionPtr const_output_eliminate_;
70   SubstitutionPtr depend_value_elim_;
71   SubstitutionPtr all_reduce_const_elim_;
72   SubstitutionPtr mini_step_allgather_replace_;
73   SubstitutionPtr micro_step_allgather_replace_;
74   SubstitutionPtr real_op_eliminate_;
75   SubstitutionPtr convert_tensor_eliminate_;
76   SubstitutionPtr convert_tensor_all_eliminate_;
77   SubstitutionPtr get_grad_eliminate_;
78 
79   // Env Item Eliminate
80   SubstitutionPtr environ_get_eliminate_;
81   SubstitutionPtr environ_get_add_eliminate_;
82   SubstitutionPtr environ_get_set_eliminate_;
83   SubstitutionPtr environ_get_depend_swap_;
84   SubstitutionPtr environ_add_const_eliminate_;
85   SubstitutionPtr split_environ_get_set_with_tuple_value_;
86 
87   // Ref eliminate
88   SubstitutionPtr replace_old_param_;
89 
90   // Branch culling
91   SubstitutionPtr switch_simplify_;
92   SubstitutionPtr compare_switch_simplify_;
93   SubstitutionPtr float_tuple_getitem_switch_;
94   SubstitutionPtr float_environ_get_switch_;
95   SubstitutionPtr exchange_switch_depend_value_;
96 
97   SubstitutionPtr switch_partial_eliminater_;
98   SubstitutionPtr switch_layer_partial_eliminater_;
99 
100   // AddN
101   SubstitutionPtr merge_addn_;
102   SubstitutionPtr addn_zero_filter_;
103   SubstitutionPtr addn_check_dump_;
104 
105   // AccumulateNV2
106   SubstitutionPtr accumulaten_eliminater_;
107 
108   // Accelerated Algorithm
109   SubstitutionPtr less_batch_normalization_;
110 
111   // Gradient irpasses
112   SubstitutionPtr minmaximum_grad_;
113   SubstitutionPtr j_node_and_user_rematch_;
114 
115   // inline
116   SubstitutionPtr inline_;
117   SubstitutionPtr halfway_inline_;
118   SubstitutionPtr inline_without_move_;
119   SubstitutionPtr replace_applicator_;
120   SubstitutionPtr specialize_transform_;
121 
122   // Auto-monad related eliminaters.
123   SubstitutionPtr updatestate_useless_node_eliminater_;
124   SubstitutionPtr updatestate_pure_node_eliminater_;
125   SubstitutionPtr switch_call_monad_eliminater_;
126   SubstitutionPtr stopgrad_eliminater_;
127   SubstitutionPtr load_eliminater_;
128 
129   // Incorporation
130   SubstitutionPtr incorporate_call_;
131   SubstitutionPtr incorporate_call_switch_;
132 
133   // virtual dataset
134   SubstitutionPtr virtual_dataset_eliminate_;
135 
136   // virtual output
137   SubstitutionPtr virtual_output_eliminate_;
138 
139   // virtual shard identity
140   SubstitutionPtr virtual_shard_identity_;
141 
142   // PipelineSplit
143   SubstitutionPtr parallel_virtual_node_;
144 
145   // Convert
146   SubstitutionPtr print_tuple_wrapper_;
147 
148   // Print const Convert string
149   SubstitutionPtr print_const_string_wrapper_;
150 
151   // tuple parameter graph transform
152   SubstitutionPtr call_graph_tuple_transform_;
153 
154   // RowTensor Eliminate
155   SubstitutionPtr row_tensor_eliminate_;
156 
157   // RowTensorAddZerosLike Eliminate
158   SubstitutionPtr row_tensor_add_zeros_like_;
159 
160   // SparseTensor Eliminate
161   SubstitutionPtr sparse_tensor_eliminate_;
162 
163   // Value_Based Eliminate
164   SubstitutionPtr value_based_eliminate_;
165 
166   // Partial defer inline
167   SubstitutionPtr partial_defer_inline_;
168 
169   // Switch defer inline
170   SubstitutionPtr switch_defer_inline_;
171 
172   // SwitchLayer defer inline
173   SubstitutionPtr switch_layer_defer_inline_;
174 
175   // Pynative Eliminate
176   SubstitutionPtr pynative_eliminate_;
177 
178   // Pynative no need grad eliminate
179   SubstitutionPtr pynative_no_grad_eliminate_;
180 
181   // Recompute
182   SubstitutionPtr set_cell_output_no_recompute_;
183   SubstitutionPtr remove_not_recompute_node_;
184 
185   // Optimize with SymbolEngine
186   SubstitutionPtr elim_not_effective_node_;
187   SubstitutionPtr elim_shapecalc_of_broadcastargs_;
188   SubstitutionPtr opt_reshape_;
189   SubstitutionPtr fold_const_symbol_;
190 };
191 
192 // the collection of irpass for resolve action
193 class ResolveIRPassLib {
194  public:
195   ResolveIRPassLib();
196   ~ResolveIRPassLib() = default;
197   SubstitutionPtr resolver_;
198 };
199 
200 class GradPartialPassLib {
201  public:
202   GradPartialPassLib();
203   ~GradPartialPassLib() = default;
204   SubstitutionPtr grad_partial_transform_;
205 };
206 
207 // Predicate functions
IsNode(const AnfNodePtr &)208 inline bool IsNode(const AnfNodePtr &) { return true; }
209 
IsCNode(const AnfNodePtr & node)210 inline bool IsCNode(const AnfNodePtr &node) {
211   if (node != nullptr) {
212     return node->isa<CNode>();
213   }
214   return false;
215 }
216 
IsVNode(const AnfNodePtr & node)217 inline bool IsVNode(const AnfNodePtr &node) {
218   if (node != nullptr) {
219     return node->isa<ValueNode>();
220   }
221   return false;
222 }
223 
IsParam(const AnfNodePtr & node)224 inline bool IsParam(const AnfNodePtr &node) {
225   if (node != nullptr) {
226     return node->isa<Parameter>();
227   }
228   return false;
229 }
230 
IsLoad(const AnfNodePtr & node)231 inline bool IsLoad(const AnfNodePtr &node) {
232   if (node == nullptr || !node->isa<CNode>()) {
233     return false;
234   }
235   return IsPrimitiveCNode(node, prim::kPrimLoad);
236 }
237 
238 // Check if CNode Input 0 is Func Graph
IsCNodeGraph(const AnfNodePtr & node)239 inline bool IsCNodeGraph(const AnfNodePtr &node) {
240   if (node == nullptr || !node->isa<CNode>()) {
241     return false;
242   }
243 
244   auto inp0 = node->cast<CNodePtr>()->input(0);
245   return IsValueNode<FuncGraph>(inp0);
246 }
247 
248 // Check if CNode Input 0 is CNode
IsCNodeDup(const AnfNodePtr & node)249 inline bool IsCNodeDup(const AnfNodePtr &node) {
250   if (node == nullptr || !node->isa<CNode>()) {
251     return false;
252   }
253 
254   auto inp0 = node->cast<CNodePtr>()->input(0);
255   return (inp0 != nullptr) && inp0->isa<CNode>();
256 }
257 
258 // check if the cnode is a switch cnode
IsCNodeSwitch(const AnfNodePtr & node)259 inline bool IsCNodeSwitch(const AnfNodePtr &node) {
260   if (node != nullptr) {
261     if (node->isa<CNode>()) {
262       return IsPrimitiveCNode(node, prim::kPrimSwitch);
263     }
264   }
265   return false;
266 }
267 
268 // check if the cnode is a do_signature cnode
IsCNodeDoSignature(const AnfNodePtr & node)269 inline bool IsCNodeDoSignature(const AnfNodePtr &node) {
270   auto cnode = dyn_cast_ptr<CNode>(node);
271   if (cnode == nullptr) {
272     return false;
273   }
274   return IsValueNode<prim::DoSignaturePrimitive>(cnode->input(0));
275 }
276 }  // namespace irpass
277 }  // namespace opt
278 }  // namespace mindspore
279 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_
280