• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass/arithmetic_simplify.h"
18 
19 namespace mindspore {
20 namespace opt {
21 namespace irpass {
operator ()(const OptimizerPtr &,const AnfNodePtr & node)22 AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
23   PatternNode x, y, z;
24   PConstant one_(node, false, 1);
25   PConstant one_scalar_(node, false, 1, true);
26   PConstant zero_(node, false, 0);
27   PConstant zero_scalar_(node, false, 0, true);
28   PConstant const_(node);
29   PConstant const_2(node);
30   PConstant any_const(node);
31 
32   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
33     MATCH_REPLACE(node, x + zero_, x);                                                           // Add by zero
34     MATCH_REPLACE(node, x + zero_scalar_, x);                                                    // Add by zero
35     MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x);          // Scalar Add by zero
36     MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node));  // Multiply by one
37     MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x);           // Scalar Mul by one
38 
39     // Scalar Mul by zero
40     MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, zero_scalar_, true), zero_scalar_.NewValue());
41   }
42   // Prim Eliminate (identity)
43   MATCH_REPLACE(node, PPrimitive(prim::kPrimIdentity, x), x);
44   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
45     return nullptr;
46   }
47 
48   // ConstantDuplicateMul
49   auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr {
50     auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node));
51     auto mul_node = node->cast<CNodePtr>()->inputs()[0];
52     if (new_mul_tensor == nullptr) {
53       auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph());
54       return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph());
55     }
56     auto new_cnode = NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph());
57     new_cnode->set_abstract(node->abstract());
58     return new_cnode;
59   };
60   MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda);
61 
62   if (node->func_graph() == nullptr) {
63     return nullptr;
64   }
65 
66   // OptUpdateZeroTensor: {kPrimMomentum, {kPrimZerosLike, x}, y, z, xs} -> {kPrimMakeTuple, z, y}
67   MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z).MinExtraNodes(0),
68                 PPrimitive(prim::kPrimMakeTuple, z, y));
69 
70   // PowerOneEliminate
71   MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x,
72                    one_scalar_.CheckFunc(IsValueNode<Scalar>, node));
73 
74   return nullptr;
75 }
76 
operator ()(const OptimizerPtr &,const AnfNodePtr & node)77 AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
78   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
79     return nullptr;
80   }
81   PatternNode x, y;
82   PConstant zero_(node, false, 0);
83 
84   // Multiply by zero
85   MATCH_REPLACE_IF(node, x * zero_, zero_.WithShapeAs(node),
86                    !zero_.CheckFunc(IsParam, node) && !x.CheckFunc(IsLoad, node) &&
87                      x.GetNode(node)->func_graph() == node->func_graph());
88   auto zero_prim = PPrimitive(prim::kPrimZerosLike, y);
89   MATCH_REPLACE_IF(node, x * zero_prim, zero_.WithShapeAs(node),
90                    !zero_prim.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph());
91 
92   return nullptr;
93 }
94 
95 // grad = AllReduce(grad) / worker_number
96 // grad = grad + weight * decy
97 // ->
98 // grad = grad + weight * decy
99 // grad = AllReduce(grad) / worker_number
100 // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
101 // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
operator ()(const OptimizerPtr &,const AnfNodePtr & node)102 AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
103   PatternNode x, y, z;
104   auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x);
105   auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true);
106   auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true);
107   auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat);
108   auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr {
109     auto fg = all_reduce_pat.GetFuncGraph();
110     auto z_ = z.GetNode(node);
111     auto x_ = x.GetNode(node);
112 
113     // If addn inputs cross the graph, make the inputs same as allreduce node.
114     if (z_->isa<CNode>() && fg != z_->func_graph()) {
115       auto cnode_z = z_->cast<CNodePtr>();
116       z_ = NewCNode(cnode_z->inputs(), fg);
117     }
118 
119     auto addn_cnode = addn_pat.GetOriginalNode()->cast<CNodePtr>();
120     auto addn_op_node = addn_cnode->input(0);
121     auto make_tuple_op_node = addn_cnode->input(1)->cast<CNodePtr>()->input(0);
122     auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast<CNodePtr>()->input(0);
123     mul_cnode_ = mul_pat.GetOriginalNode();
124     auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0);
125     auto addn_maketuple = admktup_pat.GetOriginalNode();
126 
127     ShapeVector x_shape, z_shape;
128     if (!x_->isa<ValueNode>()) {
129       if ((x_->abstract() == nullptr) || !x_->abstract()->isa<abstract::AbstractTensor>()) {
130         return nullptr;
131       }
132       auto x_abstract = x_->abstract()->cast<abstract::AbstractTensorPtr>();
133       x_shape = x_abstract->shape()->shape();
134     } else {
135       ValuePtr x_value = x_->cast<ValueNodePtr>()->value();
136       if (!x_value->isa<tensor::Tensor>()) {
137         return nullptr;
138       }
139       auto x_tensor = GetValueNode<tensor::TensorPtr>(x_->cast<ValueNodePtr>());
140       x_shape = x_tensor->shape();
141     }
142     if (!z_->isa<ValueNode>()) {
143       if ((z_->abstract() == nullptr) || !z_->abstract()->isa<abstract::AbstractTensor>()) {
144         return nullptr;
145       }
146       auto z_abstract = z_->abstract()->cast<abstract::AbstractTensorPtr>();
147       z_shape = z_abstract->shape()->shape();
148     } else {
149       ValuePtr z_value = z_->cast<ValueNodePtr>()->value();
150       if (!z_value->isa<tensor::Tensor>()) {
151         return nullptr;
152       }
153       auto z_tensor = GetValueNode<tensor::TensorPtr>(z_->cast<ValueNodePtr>());
154       z_shape = z_tensor->shape();
155     }
156 
157     if (x_shape != z_shape) {
158       // AddN requires x_ and z_ have the same shape.
159       // If broadcasting TensorAdd is supported then can use this
160       return nullptr;
161     }
162     AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
163     AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
164     AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg);
165     AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg);
166     ProcessDependEdge(fg, addn_maketuple, all_reduce);
167     return mul;
168   };
169   MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda);
170   return nullptr;
171 }
172 
ProcessDependEdge(const FuncGraphPtr & fg,const AnfNodePtr & addn_maketuple,const AnfNodePtr & new_node)173 void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple,
174                                               const AnfNodePtr &new_node) {
175   // If has dynamic loss scale.
176   MS_EXCEPTION_IF_NULL(fg);
177   auto manager = fg->manager();
178   MS_EXCEPTION_IF_NULL(manager);
179   auto &users_map = manager->node_users();
180   auto it = users_map.find(mul_cnode_);
181   if (it != users_map.end()) {
182     auto users = it->second;
183     for (auto &user_pair : users) {
184       auto node = user_pair.first;
185       if (node != addn_maketuple && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
186         manager->SetEdge(node, user_pair.second, new_node);
187       }
188     }
189   }
190 }
191 }  // namespace irpass
192 }  // namespace opt
193 }  // namespace mindspore
194