• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass/arithmetic_simplify.h"
18 
19 #include "include/common/utils/parallel_context.h"
20 #include "mindspore/core/ops/sequence_ops.h"
21 #include "mindspore/core/ops/other_ops.h"
22 #include "mindspore/core/ops/nn_optimizer_ops.h"
23 #include "mindspore/core/ops/math_ops.h"
24 #include "mindspore/core/ops/array_ops.h"
25 #include "mindspore/core/ops/arithmetic_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 namespace mindspore {
28 namespace opt {
29 namespace irpass {
operator ()(const OptimizerPtr &,const AnfNodePtr & node)30 AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
31   PatternNode x;
32   PatternNode y;
33   PatternNode z;
34   PConstant one_(node, false, 1);
35   PConstant one_scalar_(node, false, 1, true);
36   PConstant zero_(node, false, 0);
37   PConstant zero_scalar_(node, false, 0, true);
38   PConstant const_(node);
39   PConstant const_2(node);
40   PConstant any_const(node);
41   // if node has keep_alive attr, it would not be eliminated.
42   if (node->isa<CNode>()) {
43     auto cnode = node->cast<CNodePtr>();
44     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
45     if (prim->HasAttr("keep_alive") && GetValue<bool>(prim->GetAttr("keep_alive"))) {
46       MS_LOG(INFO) << "keep node " << node->fullname_with_scope() << " alive";
47       return nullptr;
48     }
49   }
50   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
51     auto IsAddByZeroSimplifiable = [node](const AnfNodePtr &real_x) {
52       // If real_x is Load CNode, We should not simplify it as Load is a no-op at backend, after simplification, the
53       // result of the Load may be incorrect.
54       if (IsPrimitiveCNode(real_x, prim::kPrimLoad)) {
55         MS_LOG(DEBUG) << "Cannot simplify as real_x is CNode Load: " << real_x->ToString();
56         return false;
57       }
58 
59       if (real_x->abstract() != nullptr && real_x->abstract()->GetShapeTrack() != nullptr &&
60           node->abstract() != nullptr && node->abstract()->GetShapeTrack() != nullptr &&
61           *real_x->abstract()->GetShapeTrack() == *node->abstract()->GetShapeTrack()) {
62         MS_LOG(DEBUG) << "Can simplify when their shapes are same: real_x shape:"
63                       << real_x->abstract()->GetShapeTrack()->ToString()
64                       << ", node shape: " << node->abstract()->GetShapeTrack()->ToString();
65         return true;
66       }
67       MS_LOG(DEBUG) << "Cannot simplify when their shapes are not same: real_x shape:"
68                     << real_x->abstract()->GetShapeTrack()->ToString()
69                     << ", node shape: " << node->abstract()->GetShapeTrack()->ToString();
70       return false;
71     };
72     MATCH_REPLACE_IF(node, x + zero_, x, x.CheckFunc(IsAddByZeroSimplifiable, node));  // Add by zero
73 
74     MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarAdd, x, zero_scalar_, true), x);          // Scalar Add by zero
75     MATCH_REPLACE_IF(node, x * one_, any_const.WithValueOf(x), !one_.CheckFunc(IsParam, node));  // Multiply by one
76     MATCH_REPLACE(node, PBinOperation(prim::kPrimScalarMul, x, one_scalar_, true), x);           // Scalar Mul by one
77   }
78   // Prim Eliminate (identity)
79   MATCH_REPLACE(node, PPrimitive(prim::kPrimidentity, x), x);
80   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
81     return nullptr;
82   }
83 
84   // ConstantDuplicateMul
85   auto const_dup_lambda = [&node, &x, &const_, &const_2]() -> AnfNodePtr {
86     auto new_mul_tensor = const_.MulByPatternConst(const_2, x.GetNode(node));
87     auto mul_node = node->cast<CNodePtr>()->inputs()[0];
88     if (new_mul_tensor == nullptr) {
89       auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
90       if (parallel_mode == parallel::kAutoParallel || parallel_mode == parallel::kSemiAutoParallel) {
91         return nullptr;
92       }
93       auto ttmul = NewCNode({mul_node, const_.GetNode(node), const_2.GetNode(node)}, node->func_graph());
94       return NewCNode({mul_node, x.GetNode(node), ttmul}, node->func_graph());
95     }
96     auto new_cnode = NewCNode({mul_node, x.GetNode(node), new_mul_tensor}, node->func_graph());
97     new_cnode->set_abstract(node->abstract());
98     return new_cnode;
99   };
100   MATCH_REPLACE_LAMBDA(node, const_ * (const_2 * x), const_dup_lambda);
101 
102   if (node->func_graph() == nullptr) {
103     return nullptr;
104   }
105 
106   // OptUpdateZeroTensor: {kPrimMomentum, {kPrimZerosLike, x}, y, z, xs} -> {kPrimMakeTuple, z, y}
107   MATCH_REPLACE(node, PPrimitive(prim::kPrimMomentum, PPrimitive(prim::kPrimZerosLike, x), y, z).MinExtraNodes(0),
108                 PPrimitive(prim::kPrimMakeTuple, z, y));
109 
110   // PowerOneEliminate
111   MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimPow, x, one_scalar_), x,
112                    one_scalar_.CheckFunc(IsValueNode<Scalar>, node));
113 
114   return nullptr;
115 }
116 
117 // grad = AllReduce(grad) / worker_number
118 // grad = grad + weight * decy
119 // ->
120 // grad = grad + weight * decy
121 // grad = AllReduce(grad) / worker_number
122 // {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} ->
123 // {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y}
operator ()(const OptimizerPtr &,const AnfNodePtr & node)124 AnfNodePtr AdjustAllReduceMulAdd::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
125   PatternNode x;
126   PatternNode y;
127   PatternNode z;
128   auto all_reduce_pat = PPrimitive(prim::kPrimAllReduce, x);
129   auto mul_pat = PBinOperation(prim::kPrimMul, all_reduce_pat, y, true);
130   auto admktup_pat = PBinOperation(prim::kPrimMakeTuple, mul_pat, z, true);
131   auto addn_pat = PPrimitive(prim::kPrimAddN, admktup_pat);
132   auto adjust_lambda = [&node, &x, &y, &z, &addn_pat, &all_reduce_pat, &admktup_pat, &mul_pat, this]() -> AnfNodePtr {
133     auto fg = all_reduce_pat.GetFuncGraph();
134     auto z_ = z.GetNode(node);
135     auto x_ = x.GetNode(node);
136 
137     // If addn inputs cross the graph, make the inputs same as allreduce node.
138     if (z_->isa<CNode>() && fg != z_->func_graph()) {
139       auto cnode_z = z_->cast<CNodePtr>();
140       z_ = NewCNode(cnode_z->inputs(), fg);
141     }
142 
143     auto addn_cnode = addn_pat.GetOriginalNode()->cast<CNodePtr>();
144     auto addn_op_node = addn_cnode->input(0);
145     auto make_tuple_op_node = addn_cnode->input(1)->cast<CNodePtr>()->input(0);
146     auto all_reduce_prim = all_reduce_pat.GetOriginalNode()->cast<CNodePtr>()->input(0);
147     mul_cnode_ = mul_pat.GetOriginalNode();
148     auto mul_prim = mul_cnode_->cast<CNodePtr>()->input(0);
149     auto addn_maketuple = admktup_pat.GetOriginalNode();
150 
151     ShapeVector x_shape;
152     ShapeVector z_shape;
153     if (!x_->isa<ValueNode>()) {
154       if ((x_->abstract() == nullptr) || !x_->abstract()->isa<abstract::AbstractTensor>()) {
155         return nullptr;
156       }
157       auto x_abstract = x_->abstract()->cast<abstract::AbstractTensorPtr>();
158       x_shape = x_abstract->shape()->shape();
159     } else {
160       ValuePtr x_value = x_->cast<ValueNodePtr>()->value();
161       if (!x_value->isa<tensor::Tensor>()) {
162         return nullptr;
163       }
164       auto x_tensor = GetValueNode<tensor::TensorPtr>(x_->cast<ValueNodePtr>());
165       x_shape = x_tensor->shape();
166     }
167     if (!z_->isa<ValueNode>()) {
168       if ((z_->abstract() == nullptr) || !z_->abstract()->isa<abstract::AbstractTensor>()) {
169         return nullptr;
170       }
171       auto z_abstract = z_->abstract()->cast<abstract::AbstractTensorPtr>();
172       z_shape = z_abstract->shape()->shape();
173     } else {
174       ValuePtr z_value = z_->cast<ValueNodePtr>()->value();
175       if (!z_value->isa<tensor::Tensor>()) {
176         return nullptr;
177       }
178       auto z_tensor = GetValueNode<tensor::TensorPtr>(z_->cast<ValueNodePtr>());
179       z_shape = z_tensor->shape();
180     }
181 
182     if (x_shape != z_shape) {
183       // AddN requires x_ and z_ have the same shape.
184       // If broadcasting TensorAdd is supported then can use this
185       return nullptr;
186     }
187     AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg);
188     AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg);
189     AnfNodePtr all_reduce = NewCNode({all_reduce_prim, add}, fg);
190     AnfNodePtr mul = NewCNode({mul_prim, all_reduce, y.GetNode(node)}, fg);
191     ProcessDependEdge(fg, addn_maketuple, all_reduce);
192     return mul;
193   };
194   MATCH_REPLACE_LAMBDA(node, addn_pat, adjust_lambda);
195   return nullptr;
196 }
197 
ProcessDependEdge(const FuncGraphPtr & fg,const AnfNodePtr & addn_maketuple,const AnfNodePtr & new_node)198 void AdjustAllReduceMulAdd::ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple,
199                                               const AnfNodePtr &new_node) {
200   // If has dynamic loss scale.
201   MS_EXCEPTION_IF_NULL(fg);
202   auto manager = fg->manager();
203   MS_EXCEPTION_IF_NULL(manager);
204   auto &users_map = manager->node_users();
205   auto it = users_map.find(mul_cnode_);
206   if (it != users_map.end()) {
207     auto users = it->second;
208     for (auto &user_pair : users) {
209       auto node = user_pair.first;
210       if (node != addn_maketuple && IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
211         manager->SetEdge(node, user_pair.second, new_node);
212       }
213     }
214   }
215 }
216 }  // namespace irpass
217 }  // namespace opt
218 }  // namespace mindspore
219