1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONST_OUTPUT_ELIMINATE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONST_OUTPUT_ELIMINATE_H_ 19 20 #include <memory> 21 #include <vector> 22 #include "ir/anf.h" 23 #include "frontend/optimizer/optimizer.h" 24 #include "frontend/optimizer/anf_visitor.h" 25 #include "frontend/optimizer/irpass.h" 26 #include "mindspore/core/ops/array_ops.h" 27 #include "include/common/utils/anfalgo.h" 28 29 namespace mindspore::opt::irpass { 30 // {a=makeTule(0, 0, 0);return a;} --> {a=makeTuple(0,0,0); b=depend(0, a); return b;} 31 // {a=makeTule(0, 0, 0, grad);return a;} --> {a=makeTuple(0,0,0);b=depend(0, a); c=makeTuple(b, grad); return c;} 32 class ConstOutputEliminater : public AnfVisitor { 33 public: operator()34 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 35 Reset(); 36 auto flag = IsEliminate(node); 37 if (!flag) { 38 return nullptr; 39 } 40 41 MS_LOG(INFO) << "const output eliminater process"; 42 43 auto fg = GetValueNode<FuncGraphPtr>(node); 44 auto output = fg->output(); 45 const size_t min_input_size = 3; 46 const auto &inputs = output->cast<CNodePtr>()->inputs(); 47 if (inputs.size() < min_input_size) { 48 MS_LOG(INFO) << "maketuple input size small, size=" << inputs.size(); 49 return nullptr; 50 } 51 52 if (!grad_mode_) { 53 const auto const_data = Tensor0Builder(); 54 new_out_abstract_ = const_data->ToAbstract(); 55 auto new_value_node = NewValueNode(const_data); 56 new_value_node->set_abstract(new_out_abstract_); 57 58 auto depend = fg->NewCNode({NewValueNode(prim::kPrimDepend), new_value_node, output}); 59 MS_EXCEPTION_IF_NULL(depend); 60 depend->set_abstract(new_out_abstract_); 61 fg->set_output(depend); 62 } else { 63 // Zeros + grad 64 std::vector<AnfNodePtr> zero_inputs(inputs.begin() + 1, inputs.end() - 1); 65 auto grad_input = inputs.back(); 66 67 std::vector<AnfNodePtr> make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple)}; 68 make_tuple_inputs.insert(make_tuple_inputs.end(), zero_inputs.begin(), zero_inputs.end()); 69 auto tuple_zero_node_abstract = GetTupleAbstract(zero_inputs); 70 auto tuple_zero_node = fg->NewCNode(make_tuple_inputs); 71 tuple_zero_node->set_abstract(tuple_zero_node_abstract); 72 73 const auto const_data = Tensor0Builder(); 74 auto abstract_tensor = const_data->ToAbstract(); 75 auto new_value_node = NewValueNode(const_data); 76 new_value_node->set_abstract(abstract_tensor); 77 auto depend = fg->NewCNode({NewValueNode(prim::kPrimDepend), new_value_node, tuple_zero_node}); 78 depend->set_abstract(abstract_tensor); 79 80 new_out_abstract_ = GetTupleAbstract({new_value_node, grad_input}); 81 auto new_out = fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), depend, grad_input}); 82 new_out->set_abstract(new_out_abstract_); 83 fg->manager()->Replace(output, new_out); 84 } 85 fg->return_node()->set_abstract(new_out_abstract_); 86 87 (void)DoProcess(fg, true); 88 89 return nullptr; 90 } 91 92 private: 93 bool grad_mode_ = false; 94 size_t grad_index_ = 0; 95 AbstractBasePtr new_out_abstract_ = nullptr; 96 Reset()97 void Reset() { 98 grad_mode_ = false; 99 grad_index_ = 0; 100 new_out_abstract_ = nullptr; 101 } 102 GetTupleAbstract(const std::vector<AnfNodePtr> & inputs)103 AbstractBasePtr GetTupleAbstract(const std::vector<AnfNodePtr> &inputs) const { 104 AbstractBasePtrList new_sep_abstracts; 105 for (const auto &input : inputs) { 106 new_sep_abstracts.push_back(input->abstract()); 107 } 108 109 return std::make_shared<abstract::AbstractTuple>(new_sep_abstracts); 110 } 111 IsTupleAllZero(const AnfNodePtr & node)112 bool IsTupleAllZero(const AnfNodePtr &node) { 113 auto tuple = node->abstract()->cast<abstract::AbstractTuplePtr>(); 114 if (tuple == nullptr) { 115 return false; 116 } 117 size_t element_cnt = 0; 118 for (const auto &element : tuple->elements()) { 119 element_cnt++; 120 if (element->isa<abstract::AbstractTensor>()) { 121 const auto &tensor_abstract = element->cast<abstract::AbstractTensorPtr>(); 122 MS_EXCEPTION_IF_NULL(tensor_abstract); 123 auto dim_zero = tensor_abstract->BuildShape()->IsDimZero(); 124 auto value_any = tensor_abstract->BuildValue()->isa<ValueAny>(); 125 if (!value_any) { 126 return false; 127 } 128 129 if (element_cnt == tuple->elements().size()) { 130 grad_mode_ = dim_zero ? false : true; 131 grad_index_ = tuple->elements().size() - 1; 132 continue; 133 } 134 135 if (!dim_zero) { 136 return false; 137 } 138 139 continue; 140 } 141 142 if (!element->isa<abstract::AbstractScalar>()) { 143 return false; 144 } 145 const auto &scalar_abstract = element->cast<abstract::AbstractScalarPtr>(); 146 MS_EXCEPTION_IF_NULL(scalar_abstract); 147 auto abs_value = scalar_abstract->BuildValue(); 148 MS_EXCEPTION_IF_NULL(abs_value); 149 auto abs_int32 = dyn_cast<Int32Imm>(abs_value); 150 if (abs_int32 != nullptr) { 151 if (abs_int32->value() != 0) { 152 return false; 153 } 154 continue; 155 } 156 157 auto abs_int64 = dyn_cast<Int64Imm>(abs_value); 158 if (abs_int64 == nullptr) { 159 return false; 160 } 161 162 if (abs_int64->value() != 0) { 163 return false; 164 } 165 } 166 167 return true; 168 } 169 IsEliminate(const AnfNodePtr & node)170 bool IsEliminate(const AnfNodePtr &node) { 171 auto fg = GetValueNode<FuncGraphPtr>(node); 172 if (fg == nullptr) { 173 return false; 174 } 175 auto output = fg->output(); 176 if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { 177 return false; 178 } 179 180 // Check whether the output is 0 181 if (!IsTupleAllZero(output)) { 182 return false; 183 } 184 185 // Check output users 186 return DoProcess(fg); 187 } 188 189 bool DoProcess(const FuncGraphPtr &func, bool is_replace = false) const { 190 MS_EXCEPTION_IF_NULL(func); 191 auto &fg_use_map = func->func_graph_cnodes_index(); 192 if (fg_use_map.empty()) { 193 return false; 194 } 195 196 for (auto &fg_use : fg_use_map) { 197 auto use_node = fg_use.first->first->cast<CNodePtr>(); 198 if (!IsPrimitiveCNode(use_node, prim::kPrimMakeTuple)) { 199 return false; 200 } 201 auto use_node_graph = use_node->func_graph(); 202 auto &fg_use_map_sub = use_node_graph->func_graph_cnodes_index(); 203 auto mng_sub = use_node_graph->manager(); 204 for (auto &fg_use_sub : fg_use_map_sub) { 205 auto fg_use_node = fg_use_sub.first->first->cast<CNodePtr>(); 206 if (fg_use_node == nullptr) { 207 return false; 208 } 209 auto users_sub = mng_sub->node_users()[fg_use_node]; 210 211 auto ret = SubUsersProcess(users_sub, is_replace); 212 if (!ret) { 213 return false; 214 } 215 } 216 } 217 218 return true; 219 } 220 SubUsersProcess(const AnfNodeIndexSet & users,bool is_replace)221 bool SubUsersProcess(const AnfNodeIndexSet &users, bool is_replace) const { 222 for (auto &user : users) { 223 if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kDependAttachNodeIndex) { 224 continue; 225 } 226 227 if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { 228 return false; 229 } 230 231 auto index = common::AnfAlgo::GetTupleGetItemOutIndex(user.first->cast<CNodePtr>()); 232 if (index != kIndex1) { 233 continue; 234 } 235 236 auto mng_sub = user.first->func_graph()->manager(); 237 auto users_sub = mng_sub->node_users()[user.first]; 238 for (auto &user_sub : users_sub) { 239 if (is_replace) { 240 user_sub.first->set_abstract(new_out_abstract_); 241 } 242 243 auto ret = ConstNodeRealUserProcess(user_sub.first, user_sub.first->func_graph(), is_replace); 244 if (!ret) { 245 return false; 246 } 247 } 248 } 249 250 return true; 251 } 252 ConstNodeRealUserProcess(const AnfNodePtr & node,const FuncGraphPtr & func,bool is_replace)253 bool ConstNodeRealUserProcess(const AnfNodePtr &node, const FuncGraphPtr &func, bool is_replace) const { 254 MS_EXCEPTION_IF_NULL(node); 255 MS_EXCEPTION_IF_NULL(func); 256 257 auto mng = func->manager(); 258 auto users = mng->node_users()[node]; 259 if (users.empty()) { 260 return false; 261 } 262 263 for (auto &user : users) { 264 if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kDependAttachNodeIndex) { 265 continue; 266 } 267 268 if (!IsPrimitiveCNode(user.first, prim::kPrimTupleGetItem)) { 269 return false; 270 } 271 272 if (!is_replace) { 273 // Check 274 auto ret = RealUserCallerCheck(user.first, user.first->func_graph()); 275 if (!ret) { 276 return false; 277 } 278 } 279 280 if (is_replace) { 281 // Real caller 282 if (!grad_mode_) { 283 mng->Replace(user.first, node); 284 } else { 285 auto index = common::AnfAlgo::GetTupleGetItemOutIndex(user.first->cast<CNodePtr>()); 286 auto real_input = common::AnfAlgo::GetTupleGetItemRealInput(user.first->cast<CNodePtr>()); 287 size_t new_index = index == grad_index_ ? 1 : 0; 288 auto new_index_value = NewValueNode(MakeValue(SizeToLong(new_index))); 289 auto new_node = func->NewCNode({NewValueNode(prim::kPrimTupleGetItem), real_input, new_index_value}); 290 new_node->set_abstract(user.first->abstract()); 291 mng->Replace(user.first, new_node); 292 } 293 } 294 } 295 296 return true; 297 } 298 Tensor0Builder()299 tensor::TensorPtr Tensor0Builder() const { return std::make_shared<tensor::Tensor>(0.0); } 300 RealUserCallerCheck(const AnfNodePtr & node,const FuncGraphPtr & func)301 bool RealUserCallerCheck(const AnfNodePtr &node, const FuncGraphPtr &func) const { 302 MS_EXCEPTION_IF_NULL(node); 303 MS_EXCEPTION_IF_NULL(func); 304 305 auto mng = func->manager(); 306 auto &users = mng->node_users()[node]; 307 308 if (users.empty()) { 309 return false; 310 } 311 312 for (auto &user : users) { 313 if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kDependAttachNodeIndex) { 314 continue; 315 } 316 317 if (IsPrimitiveCNode(user.first, prim::kPrimDepend) && user.second == kRealInputIndexInDepend && grad_mode_) { 318 continue; 319 } 320 321 if (IsPrimitiveCNode(user.first, prim::kPrimSend) && grad_mode_) { 322 continue; 323 } 324 325 if (!IsPrimitiveCNode(user.first, prim::kPrimMakeTuple)) { 326 return false; 327 } 328 329 auto tuple = user.first->abstract()->cast<abstract::AbstractTuplePtr>(); 330 if (!tuple) { 331 return false; 332 } 333 334 // Check whether the element of tuple is empty tensor 335 for (const auto &element : tuple->elements()) { 336 if (!element->isa<abstract::AbstractTensor>()) { 337 return false; 338 } 339 340 const auto &tensor_abstract = element->cast<abstract::AbstractTensorPtr>(); 341 MS_EXCEPTION_IF_NULL(tensor_abstract); 342 if (!(tensor_abstract->BuildShape()->IsDimZero() && tensor_abstract->BuildValue()->isa<ValueAny>())) { 343 return false; 344 } 345 } 346 } 347 348 return true; 349 } 350 }; 351 } // namespace mindspore::opt::irpass 352 353 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CONST_OUTPUT_ELIMINATE_H_ 354