• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
19 
20 #include <securec.h>
21 #include <algorithm>
22 #include <memory>
23 #include <vector>
24 #include <string>
25 
26 #include "frontend/optimizer/optimizer_caller.h"
27 #include "ir/pattern_matcher.h"
28 #include "frontend/optimizer/anf_visitor.h"
29 #include "frontend/operator/ops.h"
30 #include "frontend/optimizer/irpass.h"
31 #include "frontend/optimizer/irpass/prim_eliminate.h"
32 #include "frontend/optimizer/optimizer.h"
33 #include "utils/comm_manager.h"
34 #include "frontend/parallel/context.h"
35 #include "pipeline/jit/parse/resolve.h"
36 #include "frontend/parallel/step_parallel.h"
37 
38 namespace mindspore {
39 namespace opt {
40 namespace irpass {
41 class SpecialOpEliminater : public OptimizerCaller {
42  public:
SpecialOpEliminater()43   SpecialOpEliminater()
44       : insert_gradient_of_(std::make_shared<PrimEliminater>(prim::kPrimInsertGradientOf)),
45         stop_gradient_(std::make_shared<PrimEliminater>(prim::kPrimStopGradient)),
46         hook_backward_(std::make_shared<PrimEliminater>(prim::kPrimHookBackward)),
47         print_shape_type_(std::make_shared<PrimEliminater>(prim::kPrimPrintShapeType)),
48         get_ref_value_(std::make_shared<PrimEliminater>(prim::kPrimGetRefValue)),
49         mirror_(std::make_shared<PrimEliminater>(prim::kPrimMirror)),
50         virtual_div_(std::make_shared<PrimEliminater>(prim::kPrimVirtualDiv)) {
51     eliminaters_.emplace_back(insert_gradient_of_);
52     eliminaters_.emplace_back(stop_gradient_);
53     eliminaters_.emplace_back(hook_backward_);
54     eliminaters_.emplace_back(print_shape_type_);
55     eliminaters_.emplace_back(get_ref_value_);
56     eliminaters_.emplace_back(mirror_);
57     eliminaters_.emplace_back(virtual_div_);
58   }
59   ~SpecialOpEliminater() = default;
60 
operator()61   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
62     AnfNodePtr new_node;
63     for (auto &eliminater : eliminaters_) {
64       new_node = (*eliminater)(optimizer, node);
65       if (new_node != nullptr) {
66         if (IsPrimitiveCNode(node, prim::kPrimHookBackward)) {
67           MS_LOG(WARNING)
68             << "Hook operation does not work in graph mode or ms_function, it will be eliminated during compilation.";
69         }
70         return new_node;
71       }
72     }
73     return nullptr;
74   }
75 
76  private:
77   OptimizerCallerPtr insert_gradient_of_, stop_gradient_, hook_backward_, print_shape_type_, get_ref_value_, mirror_,
78     virtual_div_;
79   std::vector<OptimizerCallerPtr> eliminaters_{};
80 };
81 
82 // {PrimVirtualDataset, X} -> X
83 // {PrimVirtualDataset, Xs} -> {prim::kPrimMakeTuple, Xs}
84 class VirtualDatasetEliminater : public AnfVisitor {
85  public:
operator()86   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
87     if (!IsPrimitiveCNode(node, prim::kPrimVirtualDataset) || node->func_graph() == nullptr) {
88       return nullptr;
89     }
90 
91     auto &inputs = node->cast<CNodePtr>()->inputs();
92     if (inputs.size() < 1) {
93       return nullptr;
94     }
95 
96     std::vector<AnfNodePtr> args;
97     (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
98     (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
99 
100     return node->func_graph()->NewCNode(args);
101   }
102 
Visit(const AnfNodePtr &)103   void Visit(const AnfNodePtr &) override {}
104 };
105 
106 // {prim::kPrimVirtualOutput, X} -> X
107 class VirtualOutputEliminater : public AnfVisitor {
108  public:
operator()109   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
110     if (!IsPrimitiveCNode(node, prim::kPrimVirtualOutput) || node->func_graph() == nullptr) {
111       return nullptr;
112     }
113     auto cnode = node->cast<CNodePtr>();
114     if (cnode->inputs().size() <= 1) {
115       return nullptr;
116     }
117     return cnode->input(1);
118   }
119 
Visit(const AnfNodePtr &)120   void Visit(const AnfNodePtr &) override {}
121 };
122 
123 // {prim::kPrimReceive, X} -> prim::kPrimReceive
124 class ReceiveEliminater : public AnfVisitor {
125  public:
operator()126   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
127     if (!IsPrimitiveCNode(node, prim::kPrimReceive) || node->func_graph() == nullptr) {
128       return nullptr;
129     }
130     auto cnode = node->cast<CNodePtr>();
131     if (cnode->inputs().size() == 1) {
132       return nullptr;
133     }
134     std::vector<AnfNodePtr> args = {cnode->input(0)};
135     return node->func_graph()->NewCNode(args);
136   }
137 
Visit(const AnfNodePtr &)138   void Visit(const AnfNodePtr &) override {}
139 };
140 
141 class VirtualAssignAddEliminater : public AnfVisitor {
142  public:
operator()143   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
144     if (!IsPrimitiveCNode(node, prim::kPrimVirtualAssignAdd) || node->func_graph() == nullptr) {
145       return nullptr;
146     }
147 
148     auto &inputs = node->cast<CNodePtr>()->inputs();
149     if (inputs.size() < 2) {
150       return nullptr;
151     }
152 
153     return inputs[1];
154   }
155 
156  private:
157   AnfNodePtr x_{nullptr};
158 };
159 
160 class VirtualAccuGradEliminater : public AnfVisitor {
161  public:
operator()162   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
163     if (!IsPrimitiveCNode(node, prim::kPrimVirtualAccuGrad) || node->func_graph() == nullptr) {
164       return nullptr;
165     }
166 
167     auto &inputs = node->cast<CNodePtr>()->inputs();
168     if (inputs.size() < 2) {
169       return nullptr;
170     }
171 
172     return inputs[1];
173   }
174 
175  private:
176   AnfNodePtr x_{nullptr};
177 };
178 
179 // {prim::kPrimMirrorMicroStep, X, Z} -> X
180 class MirrorMicroStepEliminater : public AnfVisitor {
181  public:
operator()182   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
183     if (!IsPrimitiveCNode(node, prim::kPrimMirrorMicroStep) || node->func_graph() == nullptr) {
184       return nullptr;
185     }
186 
187     auto &inputs = node->cast<CNodePtr>()->inputs();
188     if (inputs.size() < 2) {
189       return nullptr;
190     }
191 
192     return inputs[1];
193   }
194 
Visit(const AnfNodePtr &)195   void Visit(const AnfNodePtr &) override {}
196 };
197 
198 // {prim::kPrimSameTypeShape, X, Y} -> X
199 class SameEliminater : public AnfVisitor {
200  public:
operator()201   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
202     x_ = nullptr;
203     AnfVisitor::Match(prim::kPrimSameTypeShape, {IsNode, IsNode})(node);
204     return x_;
205   }
206 
Visit(const AnfNodePtr & node)207   void Visit(const AnfNodePtr &node) override {
208     if (x_ == nullptr) {
209       x_ = node;
210     }
211   }
212 
213  private:
214   AnfNodePtr x_{nullptr};
215 };
216 
217 // {prim::kPrimCheckBprop, X, Y} -> X
218 class CheckBpropEliminater : public AnfVisitor {
219  public:
operator()220   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
221     x_ = nullptr;
222     AnfVisitor::Match(prim::kPrimCheckBprop, {IsNode, IsNode})(node);
223     return x_;
224   }
225 
Visit(const AnfNodePtr & node)226   void Visit(const AnfNodePtr &node) override {
227     if (x_ == nullptr) {
228       x_ = node;
229     }
230   }
231 
232  private:
233   AnfNodePtr x_{nullptr};
234 };
235 
236 // {prim::kPrimMirrorMiniStep, X, Z} -> X
237 class MirrorMiniStepEliminater : public AnfVisitor {
238  public:
operator()239   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
240     if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) {
241       return nullptr;
242     }
243 
244     auto &inputs = node->cast<CNodePtr>()->inputs();
245     if (inputs.size() < 2) {
246       return nullptr;
247     }
248 
249     return inputs[1];
250   }
251 
Visit(const AnfNodePtr &)252   void Visit(const AnfNodePtr &) override {}
253 };
254 
255 // {prim::kPrimVirtualAdd, X, Z} -> X
256 class VirtualAddEliminater : public AnfVisitor {
257  public:
operator()258   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
259     if (!IsPrimitiveCNode(node, prim::kPrimVirtualAdd) || node->func_graph() == nullptr) {
260       return nullptr;
261     }
262 
263     auto &inputs = node->cast<CNodePtr>()->inputs();
264     if (inputs.size() < 2) {
265       return nullptr;
266     }
267 
268     return inputs[1];
269   }
270 
Visit(const AnfNodePtr &)271   void Visit(const AnfNodePtr &) override {}
272 };
273 
274 // {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
275 class MiniStepAllGatherPass : public AnfVisitor {
276  public:
operator()277   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
278     if (!IsPrimitiveCNode(node, prim::kPrimMiniStepAllGather) || node->func_graph() == nullptr) {
279       return nullptr;
280     }
281 
282     auto &inputs = node->cast<CNodePtr>()->inputs();
283     if (inputs.size() < 2) {
284       return nullptr;
285     }
286     auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0));
287     MS_EXCEPTION_IF_NULL(prim);
288     auto attrs = prim->attrs();
289     std::string group = attrs[parallel::GROUP]->ToString();
290     auto fusion = attrs[parallel::FUSION];
291     bool contain_recompute = prim->HasAttr(parallel::RECOMPUTE);
292     bool recompute = contain_recompute && GetValue<bool>(attrs[parallel::RECOMPUTE]);
293     parallel::Operator op = parallel::CreateAllGatherOp(group);
294     std::vector<AnfNodePtr> node_input =
295       parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE);
296     auto prim_anf_node = node_input[0]->cast<ValueNodePtr>();
297     prim = GetValueNode<PrimitivePtr>(prim_anf_node);
298     MS_EXCEPTION_IF_NULL(prim);
299     attrs = prim->attrs();
300     attrs[parallel::FUSION] = fusion;
301     if (contain_recompute) {
302       attrs[parallel::RECOMPUTE] = MakeValue(recompute);
303     }
304     prim->SetAttrs(attrs);
305     auto func_graph = inputs[1]->func_graph();
306     CNodePtr new_node = func_graph->NewCNode(node_input);
307     return new_node;
308   }
309 
Visit(const AnfNodePtr &)310   void Visit(const AnfNodePtr &) override {}
311 };
312 
313 // {prim::kPrimMicroStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
314 class MicroStepAllGatherPass : public AnfVisitor {
315  public:
operator()316   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
317     if (!IsPrimitiveCNode(node, prim::kPrimMicroStepAllGather) || node->func_graph() == nullptr) {
318       return nullptr;
319     }
320 
321     auto &inputs = node->cast<CNodePtr>()->inputs();
322     if (inputs.size() < 2) {
323       return nullptr;
324     }
325     auto prim = GetValueNode<PrimitivePtr>(node->cast<CNodePtr>()->input(0));
326     MS_EXCEPTION_IF_NULL(prim);
327     auto attrs = prim->attrs();
328     std::string group = attrs[parallel::GROUP]->ToString();
329     auto fusion = attrs[parallel::FUSION];
330     bool contain_recompute = prim->HasAttr(parallel::RECOMPUTE);
331     bool recompute = contain_recompute && GetValue<bool>(attrs[parallel::RECOMPUTE]);
332     parallel::Operator op = parallel::CreateAllGatherOp(group);
333     std::vector<AnfNodePtr> node_input =
334       parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER_NOT_COMPUTE);
335     auto prim_anf_node = node_input[0]->cast<ValueNodePtr>();
336     prim = GetValueNode<PrimitivePtr>(prim_anf_node);
337     MS_EXCEPTION_IF_NULL(prim);
338     attrs = prim->attrs();
339     attrs[parallel::FUSION] = fusion;
340     if (contain_recompute) {
341       attrs[parallel::RECOMPUTE] = MakeValue(recompute);
342     }
343     prim->SetAttrs(attrs);
344     auto func_graph = inputs[1]->func_graph();
345     CNodePtr new_node = func_graph->NewCNode(node_input);
346     return new_node;
347   }
348 
Visit(const AnfNodePtr &)349   void Visit(const AnfNodePtr &) override {}
350 };
351 
352 // Reset defer_inline flag
353 class ResetDeferInline : public AnfVisitor {
354  public:
operator()355   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
356     if (IsValueNode<FuncGraph>(node)) {
357       auto fg = GetValueNode<FuncGraphPtr>(node);
358       fg->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, false);
359     }
360     return nullptr;
361   }
362 };
363 
364 // {PrimZerosLike, Y} ->
365 // {PrimFill, {PrimDType, Y}, {PrimShape, Y}, 0}
366 class ZeroLikeFillZero : public AnfVisitor {
367  public:
ZeroLikeFillZero()368   ZeroLikeFillZero()
369       : PrimFill_(prim::GetPythonOps("fill", "mindspore.ops.functional")->cast<PrimitivePtr>()),
370         PrimShape_(prim::GetPythonOps("shape", "mindspore.ops.functional")->cast<PrimitivePtr>()),
371         PrimDType_(prim::GetPythonOps("dtype", "mindspore.ops.functional")->cast<PrimitivePtr>()) {}
372   ~ZeroLikeFillZero() override = default;
373 
operator()374   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
375     y_ = nullptr;
376     AnfVisitor::Match(prim::kPrimZerosLike, {IsNode})(node);
377     if (y_ == nullptr || node->func_graph() == nullptr) {
378       return nullptr;
379     }
380     if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) {
381       auto fg = node->func_graph();
382       auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
383       auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
384       return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(static_cast<int64_t>(0)))});
385     }
386 
387     abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>();
388 
389     TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
390     std::vector<int64_t> tensor_shape = tensor_abstract->shape()->shape();
391 
392     // if shape is unknown, don't optimize this operator away
393     for (const int64_t &dimension : tensor_shape) {
394       if (dimension < 0) {
395         return node;
396       }
397     }
398 
399     tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
400     size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
401     char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c());
402     (void)memset_s(data, mem_size, 0, mem_size);
403 
404     auto new_cnode = NewValueNode(new_tensor_ptr);
405     new_cnode->set_abstract(new_tensor_ptr->ToAbstract());
406 
407     return new_cnode;
408   }
409 
Visit(const AnfNodePtr & node)410   void Visit(const AnfNodePtr &node) override { y_ = node; }
411 
412  private:
413   AnfNodePtr y_{nullptr};
414   PrimitivePtr PrimFill_, PrimShape_, PrimDType_;
415 };
416 
417 // {prim::kPrimDepend, X, ValueCond}->X
418 class DependValueElim : public OptimizerCaller {
419  public:
operator()420   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
421     PatternNode<AnfNodePtr> x, cond;
422     MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimDepend, x, cond), x, IsVNode(cond.GetNode(node)));
423     return nullptr;
424   }
425 };
426 
427 // {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy))
428 // {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy))
429 // {{prim::resolve, CommonOPS, getitem}, (tensor0, tensor1,...), 0} -> tensor0
430 class PynativeEliminater : public OptimizerCaller {
CheckNameSpaceVNode(const AnfNodePtr & node,const std::string & str_value)431   bool CheckNameSpaceVNode(const AnfNodePtr &node, const std::string &str_value) {
432     ValueNodePtr value_node = node->cast<ValueNodePtr>();
433     if (value_node == nullptr) {
434       return false;
435     }
436     return GetValueNode<parse::NameSpacePtr>(value_node)->module() == str_value;
437   }
438 
CheckSymbolVNode(const AnfNodePtr & node,const std::string & str_value)439   bool CheckSymbolVNode(const AnfNodePtr &node, const std::string &str_value) {
440     ValueNodePtr value_node = node->cast<ValueNodePtr>();
441     if (value_node == nullptr) {
442       return false;
443     }
444     return GetValueNode<parse::SymbolPtr>(value_node)->symbol() == str_value;
445   }
CheckStrVNode(const AnfNodePtr & node,const std::string & str_value)446   bool CheckStrVNode(const AnfNodePtr &node, const std::string &str_value) {
447     ValueNodePtr value_node = node->cast<ValueNodePtr>();
448     if (value_node == nullptr) {
449       return false;
450     }
451     return GetValueNode<StringImmPtr>(value_node)->value() == str_value;
452   }
453 
FillGetItem(const ValuePtr & value,const ValuePtr & idx)454   ValuePtr FillGetItem(const ValuePtr &value, const ValuePtr &idx) {
455     MS_LOG(DEBUG) << "Start FillGetItem" << value->ToString() << idx->ToString();
456     if (!idx->isa<Int64Imm>()) {
457       MS_LOG(EXCEPTION) << "Getitem idx must int:" << idx->ToString();
458     }
459 
460     if (!value->isa<ValueTuple>()) {
461       MS_LOG(EXCEPTION) << "Getitem value must tuple:" << value->ToString();
462     }
463 
464     auto value_tuple = value->cast<ValueTuplePtr>();
465     int idx_t = idx->cast<Int64ImmPtr>()->value();
466     MS_LOG(DEBUG) << "Fill getitem" << idx_t << (*value_tuple)[idx_t]->ToString();
467     return (*value_tuple)[idx_t];
468   }
469 
FillZero(const ValuePtr & value)470   ValuePtr FillZero(const ValuePtr &value) {
471     MS_LOG(DEBUG) << "Start FillZero";
472     ValuePtr out = nullptr;
473     if (value->isa<Int64Imm>()) {
474       return MakeValue(value->cast<Int64ImmPtr>()->value());
475     }
476 
477     if (value->isa<tensor::Tensor>()) {
478       MS_LOG(DEBUG) << "Start FillZero Tensor";
479       auto tensor = value->cast<tensor::TensorPtr>();
480       tensor::TensorPtr out_t = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape());
481       char *data = reinterpret_cast<char *>(out_t->data_c());
482       std::fill(data, data + out_t->data().nbytes(), 0);
483       out = out_t;
484     }
485 
486     std::vector<ValuePtr> value_list;
487     if (value->isa<ValueTuple>()) {
488       MS_LOG(DEBUG) << "Start FillZero Tuple" << value->ToString();
489       auto value_tuple = value->cast<ValueTuplePtr>();
490       for (size_t i = 0; i < value_tuple->size(); i++) {
491         value_list.push_back(FillZero((*value_tuple)[i]));
492       }
493       out = std::make_shared<ValueTuple>(value_list);
494     }
495     if (out == nullptr) {
496       MS_LOG(EXCEPTION) << "FillZero failed:" << value->ToString();
497     }
498     MS_LOG(DEBUG) << "Result: " << out->ToString();
499     return out;
500   }
501 
502  private:
OperatorHandle1(const PatternNode<AnfNodePtr> & arg,const AnfNodePtr & node)503   AnfNodePtr OperatorHandle1(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) {
504     auto rep = (arg).GetNode(node);
505     if (rep != nullptr) {
506       if (rep->isa<ValueNode>()) {
507         auto value_node = rep->cast<ValueNodePtr>();
508         auto new_value_node = NewValueNode(FillZero(value_node->value()));
509         new_value_node->set_has_new_value(value_node->has_new_value());
510         MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4);
511         return new_value_node;
512       }
513     }
514     return nullptr;
515   }
516 
OperatorHandle2(const PatternNode<AnfNodePtr> & arg,const AnfNodePtr & node)517   AnfNodePtr OperatorHandle2(const PatternNode<AnfNodePtr> &arg, const AnfNodePtr &node) {
518     auto rep = (arg).GetNode(node);
519     if (rep != nullptr) {
520       if (rep->isa<ValueNode>() && !HasAbstractMonad(rep)) {
521         auto value_node = rep->cast<ValueNodePtr>();
522         auto new_value_node = NewValueNode(FillZero(value_node->value()));
523         new_value_node->set_has_new_value(value_node->has_new_value());
524         MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4);
525         return new_value_node;
526       }
527     }
528     return nullptr;
529   }
530 
OperatorHandle3(const std::vector<PatternNode<AnfNodePtr>> & args,const AnfNodePtr & node)531   void OperatorHandle3(const std::vector<PatternNode<AnfNodePtr>> &args, const AnfNodePtr &node) {
532     for (size_t i = 0; i < 2; i++) {
533       auto rep = (args[i]).GetNode(node);
534       if (rep != nullptr && rep->isa<ValueNode>()) {
535         auto value_node = rep->cast<ValueNodePtr>();
536         MS_EXCEPTION_IF_NULL(value_node);
537         auto &value = value_node->value();
538         MS_EXCEPTION_IF_NULL(value);
539         // when the use count of value node equals to one, it only used in binop_grad_common function
540         if (value->isa<tensor::Tensor>() && value_node->used_graph_count() == 1) {
541           auto tensor = value->cast<tensor::TensorPtr>();
542           MS_EXCEPTION_IF_NULL(tensor);
543           auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape());
544           value_node->set_value(new_tensor);
545         }
546       }
547     }
548   }
549 
OperatorHandle4(const PatternNode<AnfNodePtr> & arg,const PatternNode<AnfNodePtr> & arg1,const AnfNodePtr & node)550   AnfNodePtr OperatorHandle4(const PatternNode<AnfNodePtr> &arg, const PatternNode<AnfNodePtr> &arg1,
551                              const AnfNodePtr &node) {
552     auto rep = (arg).GetNode(node);
553     if (rep != nullptr) {
554       if (rep->isa<ValueNode>()) {
555         MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4);
556         ValueNodePtr new_node;
557         auto value_node = rep->cast<ValueNodePtr>();
558         auto rep1 = (arg1).GetNode(node);
559         if (rep1 != nullptr) {
560           if (rep1->isa<ValueNode>()) {
561             auto idx = rep1->cast<ValueNodePtr>();
562             if (!value_node->value()->isa<ValueTuple>()) {
563               return nullptr;
564             }
565             new_node = NewValueNode(FillGetItem(value_node->value(), idx->value()));
566             new_node->set_has_new_value(value_node->has_new_value());
567           }
568         }
569         MS_LOG(DEBUG) << "Fill getitem  replace ok " << new_node->DebugString(4);
570         return new_node;
571       }
572     }
573     return nullptr;
574   }
575 
576  public:
operator()577   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
578     MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4);
579     PatternNode<AnfNodePtr> symbol_str_vnode;
580     PatternNode<AnfNodePtr> c_vnode;
581     PatternNode<AnfNodePtr> zeros_like_vnode;
582     PatternNode<AnfNodePtr> arg;
583     auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode);
584     auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode);
585     auto pattern = PCNode(getattr, arg);
586     // {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy))
587     if ((pattern).TryCapture(node) &&
588         (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
589          CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
590       auto new_value_node = OperatorHandle1(arg, node);
591       if (new_value_node != nullptr) {
592         return new_value_node;
593       }
594     }
595     MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4);
596     // {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy))
597     auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode);
598     auto pattern1 = PCNode(resolve1, arg);
599 
600     if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
601                                         CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
602       auto new_value_node = OperatorHandle2(arg, node);
603       if (new_value_node != nullptr) {
604         return new_value_node;
605       }
606     }
607     // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout}
608     PatternNode<AnfNodePtr> binop_grad_common;
609     PatternNode<AnfNodePtr> getitem_vnode;
610     std::vector<PatternNode<AnfNodePtr>> args(4);
611     auto resolve_binop = PPrimitive(prim::kPrimResolve, symbol_str_vnode, binop_grad_common);
612     auto pattern_binop = PCNode(resolve_binop, args[0], args[1], args[2], args[3]);
613     if ((pattern_binop).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
614                                              CheckSymbolVNode(binop_grad_common.GetNode(node), "binop_grad_common"))) {
615       OperatorHandle3(args, node);
616       return nullptr;
617     }
618     // resolve(CommonOPS, getitem)((tensors), 3)
619     PatternNode<AnfNodePtr> arg1;
620     auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode);
621     auto pattern2 = PCNode(resolve2, arg, arg1);
622     if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") &&
623                                         CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) {
624       auto new_value_node = OperatorHandle4(arg, arg1, node);
625       if (new_value_node != nullptr) {
626         return new_value_node;
627       }
628     }
629 
630     MS_LOG(DEBUG) << "End Replace " << node->DebugString(4);
631     return nullptr;
632   }
633 };
634 
635 class AllReduceConstElim : public OptimizerCaller {
636  public:
operator()637   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
638     PatternNode<AnfNodePtr> x;
639     auto pattern = PPrimitive(prim::kPrimAllReduce, x);
640     // If AllReduce takes constant value as input and values across devices are all the same(ensured by parallel mode)
641     if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) &&
642         (pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) ||
643          pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) {
644       auto cur_func_graph = pattern.GetFuncGraph();
645       // If reduce operation is sum, then multiply constant by number of devices, otherwise just return the constant
646       auto prim_cnode = pattern.GetOriginalNode();
647       MS_EXCEPTION_IF_NULL(prim_cnode);
648       auto primitive = GetCNodePrimitive(prim_cnode);
649       auto reduce_op = primitive->GetAttr("op");
650       auto group = primitive->GetAttr("group")->ToString();
651       // For sum operation, multiply constant tensor by number of devices
652       if (reduce_op->ToString() == "sum") {
653         uint32_t num_of_devices;
654         // Get number of devices
655         if (!CommManager::GetInstance().GetRankSize(group, &num_of_devices)) {
656           MS_LOG(EXCEPTION) << "Failed to get num of devices for group [" + group + "]";
657         }
658         // Multiply constant by number of devices then return
659         std::vector<AnfNodePtr> mul_inputs;
660         auto constant_node = x.GetNode(node);
661         MS_EXCEPTION_IF_NULL(constant_node);
662         auto constant_value_node = constant_node->cast<ValueNodePtr>();
663         MS_EXCEPTION_IF_NULL(constant_value_node);
664         if (!constant_value_node->value()->isa<tensor::Tensor>()) {
665           MS_LOG(EXCEPTION) << "Expect the constant input for AllReduce to be a Tensor. Got " +
666                                  constant_value_node->value()->ToString();
667         }
668         auto constant_tensor = constant_value_node->value()->cast<tensor::TensorPtr>();
669         auto tensor_dtype = constant_tensor->Dtype();
670         auto num_of_device_node = NewValueNode(std::make_shared<tensor::Tensor>((int64_t)num_of_devices, tensor_dtype));
671         // Multiply nodes
672         auto mul_prim = prim::GetPythonOps("tensor_mul", "mindspore.ops.functional");
673         MS_EXCEPTION_IF_NULL(mul_prim);
674         mul_inputs.push_back(NewValueNode(mul_prim));
675         mul_inputs.push_back(constant_node);
676         mul_inputs.push_back(num_of_device_node);
677         return cur_func_graph->NewCNode(mul_inputs);
678       } else {
679         return x.GetNode(node);
680       }
681     }
682     return nullptr;
683   }
684 };
685 
686 // This pattern introduced by Depend(CollectCNodeWithIsolateNodes) in program_specialize.cc
687 // {{prim::kPrimDepend, X, Y}, Xs}->{prim::kPrimDepend, {X, Xs}, Y}
688 class FloatDependGCall : public AnfVisitor {
689  public:
operator()690   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
691     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
692       return nullptr;
693     }
694 
695     auto &inputs = node->cast<CNodePtr>()->inputs();
696     // as IsCNodeDup had checked the size of inputs must be greater or equal than 1, so no check here.
697     if (IsPrimitiveCNode(inputs[0], prim::kPrimDepend)) {
698       auto &depend_inputs = inputs[0]->cast<CNodePtr>()->inputs();
699       if (depend_inputs.size() != 3) {
700         return nullptr;
701       }
702       // put {Y, Xs} to new_inputs;
703       std::vector<AnfNodePtr> new_inputs({depend_inputs[1]});
704       new_inputs.insert(new_inputs.end(), inputs.cbegin() + 1, inputs.cend());
705       TraceGuard guard(std::make_shared<TraceCopy>(node->debug_info()));
706       ScopePtr scope = node->scope();
707       ScopeGuard scope_guard(scope);
708       auto new_call_node = node->func_graph()->NewCNode(new_inputs);
709       auto new_node = node->func_graph()->NewCNode({depend_inputs[0], new_call_node, depend_inputs[2]});
710       return new_node;
711     }
712     return nullptr;
713   }
714 };
715 
716 }  // namespace irpass
717 }  // namespace opt
718 }  // namespace mindspore
719 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SPECIAL_OP_ELIMINATE_H_
720