1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ 19 20 #include <securec.h> 21 #include <algorithm> 22 #include <memory> 23 #include <vector> 24 #include <string> 25 26 #include "frontend/optimizer/optimizer_caller.h" 27 #include "ir/pattern_matcher.h" 28 #include "frontend/optimizer/anf_visitor.h" 29 #include "frontend/operator/ops.h" 30 #include "frontend/optimizer/irpass.h" 31 #include "frontend/optimizer/irpass/prim_eliminate.h" 32 #include "frontend/optimizer/optimizer.h" 33 #include "utils/comm_manager.h" 34 #include "frontend/parallel/context.h" 35 #include "pipeline/jit/parse/resolve.h" 36 #include "frontend/parallel/step_parallel.h" 37 38 namespace mindspore { 39 namespace opt { 40 namespace irpass { 41 class SpecialOpEliminater : public OptimizerCaller { 42 public: SpecialOpEliminater()43 SpecialOpEliminater() 44 : insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)), 45 stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)), 46 hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)), 47 print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)), 48 get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)), 49 mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)), 50 virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) { 51 eliminaters_.emplace_back(insert_gradient_of_); 52 eliminaters_.emplace_back(stop_gradient_); 53 eliminaters_.emplace_back(hook_backward_); 54 eliminaters_.emplace_back(print_shape_type_); 55 eliminaters_.emplace_back(get_ref_value_); 56 eliminaters_.emplace_back(mirror_); 57 eliminaters_.emplace_back(virtual_div_); 58 } 59 ~SpecialOpEliminater() = default; 60 operator()61 AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { 62 AnfNodePtr new_node; 63 for (auto &eliminater : eliminaters_) { 64 new_node = (*eliminater)(optimizer, node); 65 if (new_node != nullptr) { 66 if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) { 67 MS_LOG(WARNING) 68 << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation."; 69 } 70 return new_node; 71 } 72 } 73 return nullptr; 74 } 75 76 private: 77 OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_, 78 virtual_div_; 79 std::vector<OptimizerCallerPtr> eliminaters_{}; 80 }; 81 82 // {PrimVirtualDataset, X} -> X 83 // {PrimVirtualDataset, Xs} -> {prim::kPrimMakeTuple, Xs} 84 class VirtualDatasetEliminater : public AnfVisitor { 85 public: operator()86 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 87 if (!IsPrimitiveCNode(node, prim::kPrimVirtualDataset) || node->func_graph() == nullptr) { 88 return nullptr; 89 } 90 91 auto &inputs = node->cast<CNodePtr>()->inputs(); 92 if (inputs.size() < 1) { 93 return nullptr; 94 } 95 96 std::vector<AnfNodePtr> args; 97 (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args)); 98 (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); 99 100 return node->func_graph()->NewCNode(args); 101 } 102 Visit(const AnfNodePtr &)103 void Visit(const AnfNodePtr &) override {} 104 }; 105 106 // {prim::kPrimVirtualOutput, X} -> X 107 class VirtualOutputEliminater : public AnfVisitor { 108 public: operator()109 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 110 if (!IsPrimitiveCNode(node, prim::kPrimVirtualOutput) || node->func_graph() == nullptr) { 111 return nullptr; 112 } 113 auto cnode = node->cast<CNodePtr>(); 114 if (cnode->inputs().size() <= 1) { 115 return nullptr; 116 } 117 return cnode->input(1); 118 } 119 Visit(const AnfNodePtr &)120 void Visit(const AnfNodePtr &) override {} 121 }; 122 123 // {prim::kPrimReceive, X} -> prim::kPrimReceive 124 class ReceiveEliminater : public AnfVisitor { 125 public: operator()126 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 127 if (!IsPrimitiveCNode(node, prim::kPrimReceive) || node->func_graph() == nullptr) { 128 return nullptr; 129 } 130 auto cnode = node->cast<CNodePtr>(); 131 if (cnode->inputs().size() == 1) { 132 return nullptr; 133 } 134 std::vector<AnfNodePtr> args = {cnode->input(0)}; 135 return node->func_graph()->NewCNode(args); 136 } 137 Visit(const AnfNodePtr &)138 void Visit(const AnfNodePtr &) override {} 139 }; 140 141 class VirtualAssignAddEliminater : public AnfVisitor { 142 public: operator()143 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 144 if (!IsPrimitiveCNode(node, prim::kPrimVirtualAssignAdd) || node->func_graph() == nullptr) { 145 return nullptr; 146 } 147 148 auto &inputs = node->cast<CNodePtr>()->inputs(); 149 if (inputs.size() < 2) { 150 return nullptr; 151 } 152 153 return inputs[1]; 154 } 155 156 private: 157 AnfNodePtr x_{nullptr}; 158 }; 159 160 class VirtualAccuGradEliminater : public AnfVisitor { 161 public: operator()162 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 163 if (!IsPrimitiveCNode(node, prim::kPrimVirtualAccuGrad) || node->func_graph() == nullptr) { 164 return nullptr; 165 } 166 167 auto &inputs = node->cast<CNodePtr>()->inputs(); 168 if (inputs.size() < 2) { 169 return nullptr; 170 } 171 172 return inputs[1]; 173 } 174 175 private: 176 AnfNodePtr x_{nullptr}; 177 }; 178 179 // {prim::kPrimMirrorMicroStep, X, Z} -> X 180 class MirrorMicroStepEliminater : public AnfVisitor { 181 public: operator()182 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 183 if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) || node->func_graph() == nullptr) { 184 return nullptr; 185 } 186 187 auto &inputs = node->cast<CNodePtr>()->inputs(); 188 if (inputs.size() < 2) { 189 return nullptr; 190 } 191 192 return inputs[1]; 193 } 194 Visit(const AnfNodePtr &)195 void Visit(const AnfNodePtr &) override {} 196 }; 197 198 // {prim::kPrimSameTypeShape, X, Y} -> X 199 class SameEliminater : public AnfVisitor { 200 public: operator()201 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 202 x_ = nullptr; 203 AnfVisitor::Match(prim::kPrimSameTypeShape, {IsNode, IsNode})(node); 204 return x_; 205 } 206 Visit(const AnfNodePtr & node)207 void Visit(const AnfNodePtr &node) override { 208 if (x_ == nullptr) { 209 x_ = node; 210 } 211 } 212 213 private: 214 AnfNodePtr x_{nullptr}; 215 }; 216 217 // {prim::kPrimCheckBprop, X, Y} -> X 218 class CheckBpropEliminater : public AnfVisitor { 219 public: operator()220 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 221 x_ = nullptr; 222 AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node); 223 return x_; 224 } 225 Visit(const AnfNodePtr & node)226 void Visit(const AnfNodePtr &node) override { 227 if (x_ == nullptr) { 228 x_ = node; 229 } 230 } 231 232 private: 233 AnfNodePtr x_{nullptr}; 234 }; 235 236 // {prim::kPrimMirrorMiniStep, X, Z} -> X 237 class MirrorMiniStepEliminater : public AnfVisitor { 238 public: operator()239 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 240 if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) { 241 return nullptr; 242 } 243 244 auto &inputs = node->cast<CNodePtr>()->inputs(); 245 if (inputs.size() < 2) { 246 return nullptr; 247 } 248 249 return inputs[1]; 250 } 251 Visit(const AnfNodePtr &)252 void Visit(const AnfNodePtr &) override {} 253 }; 254 255 // {prim::kPrimVirtualAdd, X, Z} -> X 256 class VirtualAddEliminater : public AnfVisitor { 257 public: operator()258 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 259 if (!IsPrimitiveCNode(node, prim::kPrimVirtualAdd) || node->func_graph() == nullptr) { 260 return nullptr; 261 } 262 263 auto &inputs = node->cast<CNodePtr>()->inputs(); 264 if (inputs.size() < 2) { 265 return nullptr; 266 } 267 268 return inputs[1]; 269 } 270 Visit(const AnfNodePtr &)271 void Visit(const AnfNodePtr &) override {} 272 }; 273 274 // {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X} 275 class MiniStepAllGatherPass : public AnfVisitor { 276 public: operator()277 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 278 if (!IsPrimitiveCNode(node, prim::kPrimMiniStepAllGather) || node->func_graph() == nullptr) { 279 return nullptr; 280 } 281 282 auto &inputs = node->cast<CNodePtr>()->inputs(); 283 if (inputs.size() < 2) { 284 return nullptr; 285 } 286 auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0)); 287 MS_EXCEPTION_IF_NULL(prim); 288 auto attrs = prim->attrs(); 289 std::string group = attrs[parallel::GROUP]->ToString(); 290 auto fusion = attrs[parallel::FUSION]; 291 bool contain_recompute = prim->HasAttr(parallel::RECOMPUTE); 292 bool recompute = contain_recompute && GetValue<bool>(attrs[parallel::RECOMPUTE]); 293 parallel::Operator op = parallel::CreateAllGatherOp(group); 294 std::vector<AnfNodePtr> node_input = 295 parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE); 296 auto prim_anf_node = node_input[0]->cast<ValueNodePtr>(); 297 prim = GetValueNode<PrimitivePtr>(prim_anf_node); 298 MS_EXCEPTION_IF_NULL(prim); 299 attrs = prim->attrs(); 300 attrs[parallel::FUSION] = fusion; 301 if (contain_recompute) { 302 attrs[parallel::RECOMPUTE] = MakeValue(recompute); 303 } 304 prim->SetAttrs(attrs); 305 auto func_graph = inputs[1]->func_graph(); 306 CNodePtr new_node = func_graph->NewCNode(node_input); 307 return new_node; 308 } 309 Visit(const AnfNodePtr &)310 void Visit(const AnfNodePtr &) override {} 311 }; 312 313 // {prim::kPrimMicroStepAllGather, X, Z} -> {prim::kPrimAllGather, X} 314 class MicroStepAllGatherPass : public AnfVisitor { 315 public: operator()316 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 317 if (!IsPrimitiveCNode(node, prim::kPrimMicroStepAllGather) || node->func_graph() == nullptr) { 318 return nullptr; 319 } 320 321 auto &inputs = node->cast<CNodePtr>()->inputs(); 322 if (inputs.size() < 2) { 323 return nullptr; 324 } 325 auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0)); 326 MS_EXCEPTION_IF_NULL(prim); 327 auto attrs = prim->attrs(); 328 std::string group = attrs[parallel::GROUP]->ToString(); 329 auto fusion = attrs[parallel::FUSION]; 330 bool contain_recompute = prim->HasAttr(parallel::RECOMPUTE); 331 bool recompute = contain_recompute && GetValue<bool>(attrs[parallel::RECOMPUTE]); 332 parallel::Operator op = parallel::CreateAllGatherOp(group); 333 std::vector<AnfNodePtr> node_input = 334 parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE); 335 auto prim_anf_node = node_input[0]->cast<ValueNodePtr>(); 336 prim = GetValueNode<PrimitivePtr>(prim_anf_node); 337 MS_EXCEPTION_IF_NULL(prim); 338 attrs = prim->attrs(); 339 attrs[parallel::FUSION] = fusion; 340 if (contain_recompute) { 341 attrs[parallel::RECOMPUTE] = MakeValue(recompute); 342 } 343 prim->SetAttrs(attrs); 344 auto func_graph = inputs[1]->func_graph(); 345 CNodePtr new_node = func_graph->NewCNode(node_input); 346 return new_node; 347 } 348 Visit(const AnfNodePtr &)349 void Visit(const AnfNodePtr &) override {} 350 }; 351 352 // Reset defer_inline flag 353 class ResetDeferInline : public AnfVisitor { 354 public: operator()355 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 356 if (IsValueNode<FuncGraph>(node)) { 357 auto fg = GetValueNode<FuncGraphPtr>(node); 358 fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false); 359 } 360 return nullptr; 361 } 362 }; 363 364 // {PrimZerosLike, Y} -> 365 // {PrimFill, {PrimDType, Y}, {PrimShape, Y}, 0} 366 class ZeroLikeFillZero : public AnfVisitor { 367 public: ZeroLikeFillZero()368 ZeroLikeFillZero() 369 : PrimFill_(prim::GetPythonOps("fill", "mindspore.ops.functional")->cast<PrimitivePtr>()), 370 PrimShape_(prim::GetPythonOps("shape", "mindspore.ops.functional")->cast<PrimitivePtr>()), 371 PrimDType_(prim::GetPythonOps("dtype", "mindspore.ops.functional")->cast<PrimitivePtr>()) {} 372 ~ZeroLikeFillZero() override = default; 373 operator()374 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 375 y_ = nullptr; 376 AnfVisitor::Match(prim::kPrimZerosLike, {IsNode})(node); 377 if (y_ == nullptr || node->func_graph() == nullptr) { 378 return nullptr; 379 } 380 if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) { 381 auto fg = node->func_graph(); 382 auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_}); 383 auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_}); 384 return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(static_cast<int64_t>(0)))}); 385 } 386 387 abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>(); 388 389 TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); 390 std::vector<int64_t> tensor_shape = tensor_abstract->shape()->shape(); 391 392 // if shape is unknown, don't optimize this operator away 393 for (const int64_t &dimension : tensor_shape) { 394 if (dimension < 0) { 395 return node; 396 } 397 } 398 399 tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape); 400 size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); 401 char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c()); 402 (void)memset_s(data, mem_size, 0, mem_size); 403 404 auto new_cnode = NewValueNode(new_tensor_ptr); 405 new_cnode->set_abstract(new_tensor_ptr->ToAbstract()); 406 407 return new_cnode; 408 } 409 Visit(const AnfNodePtr & node)410 void Visit(const AnfNodePtr &node) override { y_ = node; } 411 412 private: 413 AnfNodePtr y_{nullptr}; 414 PrimitivePtr PrimFill_, PrimShape_, PrimDType_; 415 }; 416 417 // {prim::kPrimDepend, X, ValueCond}->X 418 class DependValueElim : public OptimizerCaller { 419 public: operator()420 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 421 PatternNode<AnfNodePtr> x, cond; 422 MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node))); 423 return nullptr; 424 } 425 }; 426 427 // {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy)) 428 // {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy)) 429 // {{prim::resolve, CommonOPS, getitem}, (tensor0, tensor1,...), 0} -> tensor0 430 class PynativeEliminater : public OptimizerCaller { CheckNameSpaceVNode(const AnfNodePtr & node,const std::string & str_value)431 bool CheckNameSpaceVNode(const AnfNodePtr &node, const std::string &str_value) { 432 ValueNodePtr value_node = node->cast<ValueNodePtr>(); 433 if (value_node == nullptr) { 434 return false; 435 } 436 return GetValueNode<parse::NameSpacePtr>(value_node)->module() == str_value; 437 } 438 CheckSymbolVNode(const AnfNodePtr & node,const std::string & str_value)439 bool CheckSymbolVNode(const AnfNodePtr &node, const std::string &str_value) { 440 ValueNodePtr value_node = node->cast<ValueNodePtr>(); 441 if (value_node == nullptr) { 442 return false; 443 } 444 return GetValueNode<parse::SymbolPtr>(value_node)->symbol() == str_value; 445 } CheckStrVNode(const AnfNodePtr & node,const std::string & str_value)446 bool CheckStrVNode(const AnfNodePtr &node, const std::string &str_value) { 447 ValueNodePtr value_node = node->cast<ValueNodePtr>(); 448 if (value_node == nullptr) { 449 return false; 450 } 451 return GetValueNode<StringImmPtr>(value_node)->value() == str_value; 452 } 453 FillGetItem(const ValuePtr & value,const ValuePtr & idx)454 ValuePtr FillGetItem(const ValuePtr &value, const ValuePtr &idx) { 455 MS_LOG(DEBUG) << "Start FillGetItem" << value->ToString() << idx->ToString(); 456 if (!idx->isa<Int64Imm>()) { 457 MS_LOG(EXCEPTION) << "Getitem idx must int:" << idx->ToString(); 458 } 459 460 if (!value->isa<ValueTuple>()) { 461 MS_LOG(EXCEPTION) << "Getitem value must tuple:" << value->ToString(); 462 } 463 464 auto value_tuple = value->cast<ValueTuplePtr>(); 465 int idx_t = idx->cast<Int64ImmPtr>()->value(); 466 MS_LOG(DEBUG) << "Fill getitem" << idx_t << (*value_tuple)[idx_t]->ToString(); 467 return (*value_tuple)[idx_t]; 468 } 469 FillZero(const ValuePtr & value)470 ValuePtr FillZero(const ValuePtr &value) { 471 MS_LOG(DEBUG) << "Start FillZero"; 472 ValuePtr out = nullptr; 473 if (value->isa<Int64Imm>()) { 474 return MakeValue(value->cast<Int64ImmPtr>()->value()); 475 } 476 477 if (value->isa<tensor::Tensor>()) { 478 MS_LOG(DEBUG) << "Start FillZero Tensor"; 479 auto tensor = value->cast<tensor::TensorPtr>(); 480 tensor::TensorPtr out_t = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); 481 char *data = reinterpret_cast<char *>(out_t->data_c()); 482 std::fill(data, data + out_t->data().nbytes(), 0); 483 out = out_t; 484 } 485 486 std::vector<ValuePtr> value_list; 487 if (value->isa<ValueTuple>()) { 488 MS_LOG(DEBUG) << "Start FillZero Tuple" << value->ToString(); 489 auto value_tuple = value->cast<ValueTuplePtr>(); 490 for (size_t i = 0; i < value_tuple->size(); i++) { 491 value_list.push_back(FillZero((*value_tuple)[i])); 492 } 493 out = std::make_shared<ValueTuple>(value_list); 494 } 495 if (out == nullptr) { 496 MS_LOG(EXCEPTION) << "FillZero failed:" << value->ToString(); 497 } 498 MS_LOG(DEBUG) << "Result: " << out->ToString(); 499 return out; 500 } 501 502 private: OperatorHandle1(const PatternNode<AnfNodePtr> & arg,const AnfNodePtr & node)503 AnfNodePtr OperatorHandle1(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) { 504 auto rep = (arg).GetNode(node); 505 if (rep != nullptr) { 506 if (rep->isa<ValueNode>()) { 507 auto value_node = rep->cast<ValueNodePtr>(); 508 auto new_value_node = NewValueNode(FillZero(value_node->value())); 509 new_value_node->set_has_new_value(value_node->has_new_value()); 510 MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4); 511 return new_value_node; 512 } 513 } 514 return nullptr; 515 } 516 OperatorHandle2(const PatternNode<AnfNodePtr> & arg,const AnfNodePtr & node)517 AnfNodePtr OperatorHandle2(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) { 518 auto rep = (arg).GetNode(node); 519 if (rep != nullptr) { 520 if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) { 521 auto value_node = rep->cast<ValueNodePtr>(); 522 auto new_value_node = NewValueNode(FillZero(value_node->value())); 523 new_value_node->set_has_new_value(value_node->has_new_value()); 524 MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4); 525 return new_value_node; 526 } 527 } 528 return nullptr; 529 } 530 OperatorHandle3(const std::vector<PatternNode<AnfNodePtr>> & args,const AnfNodePtr & node)531 void OperatorHandle3(const std::vector<PatternNode<AnfNodePtr>> &args, const AnfNodePtr &node) { 532 for (size_t i = 0; i < 2; i++) { 533 auto rep = (args[i]).GetNode(node); 534 if (rep != nullptr && rep->isa<ValueNode>()) { 535 auto value_node = rep->cast<ValueNodePtr>(); 536 MS_EXCEPTION_IF_NULL(value_node); 537 auto &value = value_node->value(); 538 MS_EXCEPTION_IF_NULL(value); 539 // when the use count of value node equals to one, it only used in binop_grad_common function 540 if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) { 541 auto tensor = value->cast<tensor::TensorPtr>(); 542 MS_EXCEPTION_IF_NULL(tensor); 543 auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); 544 value_node->set_value(new_tensor); 545 } 546 } 547 } 548 } 549 OperatorHandle4(const PatternNode<AnfNodePtr> & arg,const PatternNode<AnfNodePtr> & arg1,const AnfNodePtr & node)550 AnfNodePtr OperatorHandle4(const PatternNode<AnfNodePtr> &arg, const PatternNode<AnfNodePtr> &arg1, 551 const AnfNodePtr &node) { 552 auto rep = (arg).GetNode(node); 553 if (rep != nullptr) { 554 if (rep->isa<ValueNode>()) { 555 MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4); 556 ValueNodePtr new_node; 557 auto value_node = rep->cast<ValueNodePtr>(); 558 auto rep1 = (arg1).GetNode(node); 559 if (rep1 != nullptr) { 560 if (rep1->isa<ValueNode>()) { 561 auto idx = rep1->cast<ValueNodePtr>(); 562 if (!value_node->value()->isa<ValueTuple>()) { 563 return nullptr; 564 } 565 new_node = NewValueNode(FillGetItem(value_node->value(), idx->value())); 566 new_node->set_has_new_value(value_node->has_new_value()); 567 } 568 } 569 MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4); 570 return new_node; 571 } 572 } 573 return nullptr; 574 } 575 576 public: operator()577 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 578 MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); 579 PatternNode<AnfNodePtr> symbol_str_vnode; 580 PatternNode<AnfNodePtr> c_vnode; 581 PatternNode<AnfNodePtr> zeros_like_vnode; 582 PatternNode<AnfNodePtr> arg; 583 auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode); 584 auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode); 585 auto pattern = PCNode(getattr, arg); 586 // {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy)) 587 if ((pattern).TryCapture(node) && 588 (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && 589 CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { 590 auto new_value_node = OperatorHandle1(arg, node); 591 if (new_value_node != nullptr) { 592 return new_value_node; 593 } 594 } 595 MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); 596 // {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy)) 597 auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode); 598 auto pattern1 = PCNode(resolve1, arg); 599 600 if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && 601 CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { 602 auto new_value_node = OperatorHandle2(arg, node); 603 if (new_value_node != nullptr) { 604 return new_value_node; 605 } 606 } 607 // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} 608 PatternNode<AnfNodePtr> binop_grad_common; 609 PatternNode<AnfNodePtr> getitem_vnode; 610 std::vector<PatternNode<AnfNodePtr>> args(4); 611 auto resolve_binop = PPrimitive(prim::kPrimResolve, symbol_str_vnode, binop_grad_common); 612 auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]); 613 if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && 614 CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) { 615 OperatorHandle3(args, node); 616 return nullptr; 617 } 618 // resolve(CommonOPS, getitem)((tensors), 3) 619 PatternNode<AnfNodePtr> arg1; 620 auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); 621 auto pattern2 = PCNode(resolve2, arg, arg1); 622 if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") && 623 CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) { 624 auto new_value_node = OperatorHandle4(arg, arg1, node); 625 if (new_value_node != nullptr) { 626 return new_value_node; 627 } 628 } 629 630 MS_LOG(DEBUG) << "End Replace " << node->DebugString(4); 631 return nullptr; 632 } 633 }; 634 635 class AllReduceConstElim : public OptimizerCaller { 636 public: operator()637 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 638 PatternNode<AnfNodePtr> x; 639 auto pattern = PPrimitive(prim::kPrimAllReduce, x); 640 // If AllReduce takes constant value as input and values across devices are all the same(ensured by parallel mode) 641 if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) && 642 (pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) || 643 pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) { 644 auto cur_func_graph = pattern.GetFuncGraph(); 645 // If reduce operation is sum, then multiply constant by number of devices, otherwise just return the constant 646 auto prim_cnode = pattern.GetOriginalNode(); 647 MS_EXCEPTION_IF_NULL(prim_cnode); 648 auto primitive = GetCNodePrimitive(prim_cnode); 649 auto reduce_op = primitive->GetAttr("op"); 650 auto group = primitive->GetAttr("group")->ToString(); 651 // For sum operation, multiply constant tensor by number of devices 652 if (reduce_op->ToString() == "sum") { 653 uint32_t num_of_devices; 654 // Get number of devices 655 if (!CommManager::GetInstance().GetRankSize(group, &num_of_devices)) { 656 MS_LOG(EXCEPTION) << "Failed to get num of devices for group [" + group + "]"; 657 } 658 // Multiply constant by number of devices then return 659 std::vector<AnfNodePtr> mul_inputs; 660 auto constant_node = x.GetNode(node); 661 MS_EXCEPTION_IF_NULL(constant_node); 662 auto constant_value_node = constant_node->cast<ValueNodePtr>(); 663 MS_EXCEPTION_IF_NULL(constant_value_node); 664 if (!constant_value_node->value()->isa<tensor::Tensor>()) { 665 MS_LOG(EXCEPTION) << "Expect the constant input for AllReduce to be a Tensor. Got " + 666 constant_value_node->value()->ToString(); 667 } 668 auto constant_tensor = constant_value_node->value()->cast<tensor::TensorPtr>(); 669 auto tensor_dtype = constant_tensor->Dtype(); 670 auto num_of_device_node = NewValueNode(std::make_shared<tensor::Tensor>((int64_t)num_of_devices, tensor_dtype)); 671 // Multiply nodes 672 auto mul_prim = prim::GetPythonOps("tensor_mul", "mindspore.ops.functional"); 673 MS_EXCEPTION_IF_NULL(mul_prim); 674 mul_inputs.push_back(NewValueNode(mul_prim)); 675 mul_inputs.push_back(constant_node); 676 mul_inputs.push_back(num_of_device_node); 677 return cur_func_graph->NewCNode(mul_inputs); 678 } else { 679 return x.GetNode(node); 680 } 681 } 682 return nullptr; 683 } 684 }; 685 686 // This pattern introduced by Depend(CollectCNodeWithIsolateNodes) in program_specialize.cc 687 // {{prim::kPrimDepend, X, Y}, Xs}->{prim::kPrimDepend, {X, Xs}, Y} 688 class FloatDependGCall : public AnfVisitor { 689 public: operator()690 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { 691 if (!node->isa<CNode>() || node->func_graph() == nullptr) { 692 return nullptr; 693 } 694 695 auto &inputs = node->cast<CNodePtr>()->inputs(); 696 // as IsCNodeDup had checked the size of inputs must be greater or equal than 1, so no check here. 697 if (IsPrimitiveCNode(inputs[0], prim::kPrimDepend)) { 698 auto &depend_inputs = inputs[0]->cast<CNodePtr>()->inputs(); 699 if (depend_inputs.size() != 3) { 700 return nullptr; 701 } 702 // put {Y, Xs} to new_inputs; 703 std::vector<AnfNodePtr> new_inputs({depend_inputs[1]}); 704 new_inputs.insert(new_inputs.end(), inputs.cbegin() + 1, inputs.cend()); 705 TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info())); 706 ScopePtr scope = node->scope(); 707 ScopeGuard scope_guard(scope); 708 auto new_call_node = node->func_graph()->NewCNode(new_inputs); 709 auto new_node = node->func_graph()->NewCNode({depend_inputs[0], new_call_node, depend_inputs[2]}); 710 return new_node; 711 } 712 return nullptr; 713 } 714 }; 715 716 } // namespace irpass 717 } // namespace opt 718 } // namespace mindspore 719 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_ 720