• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_
18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_
19 
20 #include <vector>
21 #include <string>
22 #include <utility>
23 #include <unordered_map>
24 #include <memory>
25 #include "ir/anf.h"
26 #include "backend/optimizer/common/pattern_engine.h"
27 #include "backend/optimizer/common/helper.h"
28 #include "backend/optimizer/common/optimizer.h"
29 
30 namespace mindspore {
31 namespace opt {
32 class LambNextMVRule : public MultipleOutputPatternProcessPass {
33  public:
34   explicit LambNextMVRule(const std::string &name = "", bool multigraph = true)
MultipleOutputPatternProcessPass(name,multigraph)35       : MultipleOutputPatternProcessPass(name, multigraph) {
36     input0_ = std::make_shared<Var>();
37     input1_ = std::make_shared<Var>();
38     input2_ = std::make_shared<Var>();
39     input3_ = std::make_shared<Var>();
40     input4_ = std::make_shared<Var>();
41     input5_ = std::make_shared<Var>();
42     input6_ = std::make_shared<Var>();
43     mul0_x_ = std::make_shared<Var>();
44     mul1_sub_ = std::make_shared<Var>();
45     mul2_x_ = std::make_shared<Var>();
46     mul3_sub1_ = std::make_shared<Var>();
47     mul4_x_ = std::make_shared<Var>();
48     add2_y_ = std::make_shared<Var>();
49     real_div0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
50     real_div1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(kRealDivOpName));
51     real_div2_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMul->name()));
52     add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
53     add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name()));
54   }
55   ~LambNextMVRule() override = default;
56   const BaseRef DefinePattern() const override = 0;
57   BaseRef DefineAnotherPattern() const override = 0;
58   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
59   bool IsShareNodes(const EquivPtr &equiv1, const EquivPtr &equiv2) const override;
60 
61  protected:
62   bool IsRuleMatched(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
63                      std::vector<AnfNodePtr> *old_pattern_outputs) const;
64   AnfNodePtr CreateLambNextMVNode(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &old_pattern_outputs,
65                                   const EquivPtr &equiv) const;
66 
67   VarPtr input0_;
68   VarPtr input1_;
69   VarPtr input2_;
70   VarPtr input3_;
71   VarPtr input4_;
72   VarPtr input5_;
73   VarPtr input6_;
74   VarPtr mul0_x_;
75   VarPtr mul1_sub_;
76   VarPtr mul2_x_;
77   VarPtr mul3_sub1_;
78   VarPtr mul4_x_;
79   VarPtr add2_y_;
80   // nodes which two patterns share, and add2_y_ also.
81   VarPtr real_div0_var_;
82   VarPtr real_div1_var_;
83   // part of output nodes
84   VarPtr add0_var_;
85   VarPtr add1_var_;
86   // other node
87   VarPtr real_div2_var_;
88 };
89 
90 class LambNextMVRuleCond1 : public LambNextMVRule {
91  public:
92   explicit LambNextMVRuleCond1(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond1", multigraph) {}
93 
94   ~LambNextMVRuleCond1() override = default;
95   const BaseRef DefinePattern() const override;
96   BaseRef DefineAnotherPattern() const override;
97 };
98 
99 class LambNextMVRuleCond2 : public LambNextMVRule {
100  public:
101   explicit LambNextMVRuleCond2(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond2", multigraph) {}
102 
103   ~LambNextMVRuleCond2() override = default;
104   const BaseRef DefinePattern() const override;
105   BaseRef DefineAnotherPattern() const override;
106 };
107 
108 class LambNextMVRuleCond3 : public LambNextMVRule {
109  public:
110   explicit LambNextMVRuleCond3(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond3", multigraph) {}
111 
112   ~LambNextMVRuleCond3() override = default;
113   const BaseRef DefinePattern() const override;
114   BaseRef DefineAnotherPattern() const override;
115 };
116 
117 class LambNextMVRuleCond4 : public LambNextMVRule {
118  public:
119   explicit LambNextMVRuleCond4(bool multigraph = true) : LambNextMVRule("lamb_next_mv_rule_cond4", multigraph) {}
120 
121   ~LambNextMVRuleCond4() override = default;
122   const BaseRef DefinePattern() const override;
123   BaseRef DefineAnotherPattern() const override;
124 };
125 }  // namespace opt
126 }  // namespace mindspore
127 
128 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_LAMB_NEXT_MV_RULE_H_
129