1 /** 2 * Copyright 2020-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_ 19 20 #include <memory> 21 22 #include "frontend/optimizer/optimizer.h" 23 #include "mindspore/core/ops/structure_ops.h" 24 #include "mindspore/core/ops/framework_ops.h" 25 #include "frontend/optimizer/opt.h" 26 #include "frontend/optimizer/anf_visitor.h" 27 28 namespace mindspore { 29 namespace opt { 30 namespace irpass { 31 // the collection of irpass for optimie action 32 class OptimizeIRPassLib { 33 public: 34 OptimizeIRPassLib(); 35 ~OptimizeIRPassLib() = default; 36 37 SubstitutionPtr arithmetic_simplify_; 38 SubstitutionPtr special_op_eliminate_; 39 SubstitutionPtr ad_related_special_op_eliminate_; 40 SubstitutionPtr zero_like_fill_zero_; 41 SubstitutionPtr mutable_op_eliminate_; 42 SubstitutionPtr adjust_all_reduce_mul_add_; 43 SubstitutionPtr float_depend_g_call_; 44 // ops eliminate 45 SubstitutionPtr tuple_list_get_item_eliminator_; 46 SubstitutionPtr tuple_list_get_item_const_eliminator_; 47 SubstitutionPtr tuple_list_set_item_eliminator_; 48 SubstitutionPtr tuple_list_get_set_item_eliminator_; 49 SubstitutionPtr tuple_list_get_item_depend_reorder_; 50 SubstitutionPtr list_to_tuple_eliminator_; 51 SubstitutionPtr tuple_to_list_eliminator_; 52 SubstitutionPtr tuple_list_convert_item_index_to_positive_; 53 SubstitutionPtr make_slice_get_slice_eliminator_; 54 SubstitutionPtr slice_to_tuple_; 55 SubstitutionPtr dict_get_item_eliminator_; 56 SubstitutionPtr dict_get_item_const_eliminator_; 57 SubstitutionPtr dict_set_item_eliminator_; 58 59 SubstitutionPtr stack_unstack_eliminate_; 60 SubstitutionPtr tile_eliminate_; 61 SubstitutionPtr cast_eliminate_; 62 SubstitutionPtr reshape_eliminate_; 63 SubstitutionPtr transpose_eliminate_; 64 SubstitutionPtr reduce_eliminate_; 65 SubstitutionPtr partial_eliminate_; 66 SubstitutionPtr same_eliminate_; 67 SubstitutionPtr check_bprop_eliminate_; 68 SubstitutionPtr reset_defer_inline_; 69 SubstitutionPtr const_output_eliminate_; 70 SubstitutionPtr depend_value_elim_; 71 SubstitutionPtr all_reduce_const_elim_; 72 SubstitutionPtr mini_step_allgather_replace_; 73 SubstitutionPtr micro_step_allgather_replace_; 74 SubstitutionPtr real_op_eliminate_; 75 SubstitutionPtr convert_tensor_eliminate_; 76 SubstitutionPtr convert_tensor_all_eliminate_; 77 SubstitutionPtr get_grad_eliminate_; 78 79 // Env Item Eliminate 80 SubstitutionPtr environ_get_eliminate_; 81 SubstitutionPtr environ_get_add_eliminate_; 82 SubstitutionPtr environ_get_set_eliminate_; 83 SubstitutionPtr environ_get_depend_swap_; 84 SubstitutionPtr environ_add_const_eliminate_; 85 SubstitutionPtr split_environ_get_set_with_tuple_value_; 86 87 // Ref eliminate 88 SubstitutionPtr replace_old_param_; 89 90 // Branch culling 91 SubstitutionPtr switch_simplify_; 92 SubstitutionPtr compare_switch_simplify_; 93 SubstitutionPtr float_tuple_getitem_switch_; 94 SubstitutionPtr float_environ_get_switch_; 95 SubstitutionPtr exchange_switch_depend_value_; 96 97 SubstitutionPtr switch_partial_eliminater_; 98 SubstitutionPtr switch_layer_partial_eliminater_; 99 100 // AddN 101 SubstitutionPtr merge_addn_; 102 SubstitutionPtr addn_zero_filter_; 103 SubstitutionPtr addn_check_dump_; 104 105 // AccumulateNV2 106 SubstitutionPtr accumulaten_eliminater_; 107 108 // Accelerated Algorithm 109 SubstitutionPtr less_batch_normalization_; 110 111 // Gradient irpasses 112 SubstitutionPtr minmaximum_grad_; 113 SubstitutionPtr j_node_and_user_rematch_; 114 115 // inline 116 SubstitutionPtr inline_; 117 SubstitutionPtr halfway_inline_; 118 SubstitutionPtr inline_without_move_; 119 SubstitutionPtr replace_applicator_; 120 SubstitutionPtr specialize_transform_; 121 122 // Auto-monad related eliminaters. 123 SubstitutionPtr updatestate_useless_node_eliminater_; 124 SubstitutionPtr updatestate_pure_node_eliminater_; 125 SubstitutionPtr switch_call_monad_eliminater_; 126 SubstitutionPtr stopgrad_eliminater_; 127 SubstitutionPtr load_eliminater_; 128 129 // Incorporation 130 SubstitutionPtr incorporate_call_; 131 SubstitutionPtr incorporate_call_switch_; 132 133 // virtual dataset 134 SubstitutionPtr virtual_dataset_eliminate_; 135 136 // virtual output 137 SubstitutionPtr virtual_output_eliminate_; 138 139 // virtual shard identity 140 SubstitutionPtr virtual_shard_identity_; 141 142 // PipelineSplit 143 SubstitutionPtr parallel_virtual_node_; 144 145 // Convert 146 SubstitutionPtr print_tuple_wrapper_; 147 148 // Print const Convert string 149 SubstitutionPtr print_const_string_wrapper_; 150 151 // tuple parameter graph transform 152 SubstitutionPtr call_graph_tuple_transform_; 153 154 // RowTensor Eliminate 155 SubstitutionPtr row_tensor_eliminate_; 156 157 // RowTensorAddZerosLike Eliminate 158 SubstitutionPtr row_tensor_add_zeros_like_; 159 160 // SparseTensor Eliminate 161 SubstitutionPtr sparse_tensor_eliminate_; 162 163 // Value_Based Eliminate 164 SubstitutionPtr value_based_eliminate_; 165 166 // Partial defer inline 167 SubstitutionPtr partial_defer_inline_; 168 169 // Switch defer inline 170 SubstitutionPtr switch_defer_inline_; 171 172 // SwitchLayer defer inline 173 SubstitutionPtr switch_layer_defer_inline_; 174 175 // Pynative Eliminate 176 SubstitutionPtr pynative_eliminate_; 177 178 // Pynative no need grad eliminate 179 SubstitutionPtr pynative_no_grad_eliminate_; 180 181 // Recompute 182 SubstitutionPtr set_cell_output_no_recompute_; 183 SubstitutionPtr remove_not_recompute_node_; 184 185 // Optimize with SymbolEngine 186 SubstitutionPtr elim_not_effective_node_; 187 SubstitutionPtr elim_shapecalc_of_broadcastargs_; 188 SubstitutionPtr opt_reshape_; 189 SubstitutionPtr fold_const_symbol_; 190 }; 191 192 // the collection of irpass for resolve action 193 class ResolveIRPassLib { 194 public: 195 ResolveIRPassLib(); 196 ~ResolveIRPassLib() = default; 197 SubstitutionPtr resolver_; 198 }; 199 200 class GradPartialPassLib { 201 public: 202 GradPartialPassLib(); 203 ~GradPartialPassLib() = default; 204 SubstitutionPtr grad_partial_transform_; 205 }; 206 207 // Predicate functions IsNode(const AnfNodePtr &)208inline bool IsNode(const AnfNodePtr &) { return true; } 209 IsCNode(const AnfNodePtr & node)210inline bool IsCNode(const AnfNodePtr &node) { 211 if (node != nullptr) { 212 return node->isa<CNode>(); 213 } 214 return false; 215 } 216 IsVNode(const AnfNodePtr & node)217inline bool IsVNode(const AnfNodePtr &node) { 218 if (node != nullptr) { 219 return node->isa<ValueNode>(); 220 } 221 return false; 222 } 223 IsParam(const AnfNodePtr & node)224inline bool IsParam(const AnfNodePtr &node) { 225 if (node != nullptr) { 226 return node->isa<Parameter>(); 227 } 228 return false; 229 } 230 IsLoad(const AnfNodePtr & node)231inline bool IsLoad(const AnfNodePtr &node) { 232 if (node == nullptr || !node->isa<CNode>()) { 233 return false; 234 } 235 return IsPrimitiveCNode(node, prim::kPrimLoad); 236 } 237 238 // Check if CNode Input 0 is Func Graph IsCNodeGraph(const AnfNodePtr & node)239inline bool IsCNodeGraph(const AnfNodePtr &node) { 240 if (node == nullptr || !node->isa<CNode>()) { 241 return false; 242 } 243 244 auto inp0 = node->cast<CNodePtr>()->input(0); 245 return IsValueNode<FuncGraph>(inp0); 246 } 247 248 // Check if CNode Input 0 is CNode IsCNodeDup(const AnfNodePtr & node)249inline bool IsCNodeDup(const AnfNodePtr &node) { 250 if (node == nullptr || !node->isa<CNode>()) { 251 return false; 252 } 253 254 auto inp0 = node->cast<CNodePtr>()->input(0); 255 return (inp0 != nullptr) && inp0->isa<CNode>(); 256 } 257 258 // check if the cnode is a switch cnode IsCNodeSwitch(const AnfNodePtr & node)259inline bool IsCNodeSwitch(const AnfNodePtr &node) { 260 if (node != nullptr) { 261 if (node->isa<CNode>()) { 262 return IsPrimitiveCNode(node, prim::kPrimSwitch); 263 } 264 } 265 return false; 266 } 267 268 // check if the cnode is a do_signature cnode IsCNodeDoSignature(const AnfNodePtr & node)269inline bool IsCNodeDoSignature(const AnfNodePtr &node) { 270 auto cnode = dyn_cast_ptr<CNode>(node); 271 if (cnode == nullptr) { 272 return false; 273 } 274 return IsValueNode<prim::DoSignaturePrimitive>(cnode->input(0)); 275 } 276 } // namespace irpass 277 } // namespace opt 278 } // namespace mindspore 279 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_H_ 280