• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/optimizer/irpass/symbol_engine_optimizer.h"
18 
19 #include <vector>
20 #include <memory>
21 #include <utility>
22 #include "ir/pattern_matcher.h"
23 #include "ir/functor.h"
24 #include "ops/array_ops.h"
25 #include "ops/math_ops.h"
26 #include "ops/op_def.h"
27 #include "include/common/utils/utils.h"
28 #include "mindspore/core/symbolic_shape/symbol.h"
29 #include "mindspore/core/symbolic_shape/utils.h"
30 #include "include/common/symbol_engine/symbol_engine_impl.h"
31 
32 namespace mindspore {
33 namespace opt {
34 namespace irpass {
GetSymbolEngine(const AnfNodePtr & node)35 inline SymbolEnginePtr GetSymbolEngine(const AnfNodePtr &node) { return node->func_graph()->symbol_engine(); }
36 
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & opt)37 bool SymbolEngineBuilder::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &opt) {
38   if (only_dynshape_graph_ && !HasDynamicShapeNode(opt)) {
39     MS_LOG(INFO) << "There is no dynamic shape node, the SymbolEngineBuilder is disabled.";
40     return false;
41   }
42   try {
43     MS_LOG_TRY_CATCH_SCOPE;
44     symshape::SymbolEngineImpl::Build(func_graph);
45     MS_LOG(INFO) << "Build symbol engine successfully.";
46   } catch (std::exception &e) {
47     MS_LOG(WARNING) << "Build symbol engine failed. message: " << e.what();
48   }
49   return true;
50 }
51 
HasDynamicShapeNode(const OptimizerPtr & opt) const52 bool SymbolEngineBuilder::HasDynamicShapeNode(const OptimizerPtr &opt) const {
53   auto mng = opt->manager();
54   if (mng == nullptr) {
55     return false;
56   }
57   auto &nodes = mng->all_nodes();
58   for (auto &node : nodes) {
59     if (!node->isa<CNode>()) {
60       continue;
61     }
62     auto abs = node->abstract();
63     if (abs != nullptr && abs->GetShape()->IsDynamic()) {
64       return true;
65     }
66   }
67   return false;
68 }
69 
operator ()(const OptimizerPtr & opt,const AnfNodePtr & node)70 AnfNodePtr ElimShapeCalcOnBroadcastArgsGrad::operator()(const OptimizerPtr &opt, const AnfNodePtr &node) {
71   if (GetSymbolEngine(node) == nullptr) {
72     return nullptr;
73   }
74   PatternNode<AnfNodePtr> dout;
75   PatternNode<AnfNodePtr> shape_calc;
76   PatternNode<AnfNodePtr> shape;
77   PatternNode<AnfNodePtr> keepdims;
78   PatternNode<AnfNodePtr> skipmode;
79   PConstant idx0(node, false, 0, true);
80   PConstant idx1(node, false, 1, true);
81   MATCH_REPLACE_IF(
82     node,
83     PPrimitive(prim::kPrimReduceSum, dout, PPrimitive(prim::kPrimTupleGetItem, shape_calc, idx0), keepdims, skipmode),
84     dout, Check(opt, shape_calc.GetNode(node), kIndex1));
85   MATCH_REPLACE_IF(
86     node,
87     PPrimitive(prim::kPrimReduceSum, dout, PPrimitive(prim::kPrimTupleGetItem, shape_calc, idx1), keepdims, skipmode),
88     dout, Check(opt, shape_calc.GetNode(node), kIndex2));
89   return nullptr;
90 }
91 
Check(const OptimizerPtr & opt,const AnfNodePtr & shape_calc,size_t input_index)92 bool ElimShapeCalcOnBroadcastArgsGrad::Check(const OptimizerPtr &opt, const AnfNodePtr &shape_calc,
93                                              size_t input_index) {
94   auto mng = opt->manager();
95   MS_EXCEPTION_IF_NULL(mng);
96   auto &users = mng->node_users();
97 
98   auto shapecalc_node = shape_calc->cast<CNodePtr>();
99   constexpr const size_t shapecalc_size = 3;
100   if (shapecalc_node == nullptr || !IsPrimitiveCNode(shapecalc_node, prim::kPrimShapeCalc) ||
101       shapecalc_node->size() != shapecalc_size) {
102     return false;
103   }
104   auto input_node = shapecalc_node->input(input_index);
105   auto shapecalc_functor = common::AnfAlgo::GetNodeAttr<ShapeCalcBaseFunctorPtr>(shapecalc_node, kAttrFunctor);
106   MS_EXCEPTION_IF_NULL(shapecalc_functor);
107   if (shapecalc_functor->name() != "ShapeCalc_BroadcastGradientArgs") {
108     // only support the broadcast gradient condition
109     return false;
110   }
111   auto fwd_unique_id = shapecalc_node->primal_attrs().find(kPrimalAttrForwardUniqueId);
112   if (fwd_unique_id == shapecalc_node->primal_attrs().end()) {
113     // only support bprop node
114     return false;
115   }
116   AnfNodePtr fwd_node = nullptr;
117   for (auto &user : users[input_node]) {
118     auto user_cnode = user.first->cast<CNodePtr>();
119     if (user_cnode == nullptr) {
120       continue;
121     }
122     if (auto uniq_id = user_cnode->primal_attrs().find(kPrimalAttrUniqueId);
123         uniq_id != user_cnode->primal_attrs().end()) {
124       if (*uniq_id->second == *fwd_unique_id->second) {
125         fwd_node = user.first;
126         break;
127       }
128     }
129   }
130   if (fwd_node == nullptr) {
131     return false;
132   }
133 
134   auto input_shape = input_node->abstract()->GetSymbolicShape();
135   auto output_shape = fwd_node->abstract()->GetSymbolicShape();
136   auto ret = CheckSymbolEqual(input_shape, output_shape, GetValue<size_t>(shapecalc_functor->ToValue()));
137   if (ret) {
138     MS_LOG(INFO) << "For " << shape_calc->DebugString() << " (" << shape_calc->fullname_with_scope() << ")"
139                  << " generated by BroadcastGradientArgs. The gradient for input " << input_index
140                  << " is unnecessary, which can be eliminated. grad symbol: " << input_shape->ToString()
141                  << ". out symbol: " << output_shape->ToString();
142   }
143   return ret;
144 }
145 
CheckSymbolEqual(const ListSymbolPtr & input_shape,const ListSymbolPtr & output_shape,size_t shift)146 bool ElimShapeCalcOnBroadcastArgsGrad::CheckSymbolEqual(const ListSymbolPtr &input_shape,
147                                                         const ListSymbolPtr &output_shape, size_t shift) {
148   if (input_shape == nullptr || output_shape == nullptr) {
149     return false;
150   }
151   if (input_shape->size() < output_shape->size()) {
152     return false;
153   }
154   if (input_shape->is_dyn_len() || output_shape->is_dyn_len()) {
155     return input_shape->EqualsTo(output_shape);
156   }
157   for (size_t i = input_shape->size(); i > shift; i--) {
158     auto inp = input_shape->symbols()[input_shape->size() - i];
159     if (i <= output_shape->size() && !inp->EqualsTo(output_shape->symbols()[output_shape->size() - i])) {
160       return false;
161     }
162   }
163   return true;
164 }
165 
operator ()(const OptimizerPtr &,const AnfNodePtr & node)166 AnfNodePtr ElimNotEffectiveNode::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
167   if (GetSymbolEngine(node) == nullptr) {
168     return nullptr;
169   }
170   static const PrimitiveSet supports_op = {prim::kPrimReshape, prim::kPrimReduceSum, prim::kPrimReduceMax,
171                                            prim::kPrimReduceMin};
172   if (!IsOneOfPrimitiveCNode(node, supports_op)) {
173     return nullptr;
174   }
175   auto cnode = node->cast<CNodePtr>();
176   MS_EXCEPTION_IF_NULL(cnode);
177   auto input_node = node->cast<CNodePtr>()->input(1);
178   auto input_shape = input_node->abstract()->GetSymbolicShape();
179   auto output_shape = node->abstract()->GetSymbolicShape();
180   if (input_shape != nullptr && input_shape->EqualsTo(output_shape)) {
181     MS_LOG(INFO) << "For node " << node->DebugString() << " (" << node->fullname_with_scope()
182                  << "), the input shape and output shape is same, which can be eliminated.";
183     return input_node;
184   }
185   return nullptr;
186 }
187 
operator ()(const OptimizerPtr &,const AnfNodePtr & node)188 AnfNodePtr OptReshape::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
189   if (GetSymbolEngine(node) == nullptr) {
190     return nullptr;
191   }
192   PatternNode<AnfNodePtr> input;
193   PatternNode<AnfNodePtr> shape;
194   ShapeVector shape_vec;
195 
196   auto MakeReshape = [&shape_vec, &node]() -> AnfNodePtr {
197     auto shape_val = MakeValue(shape_vec);
198     auto shape = NewValueNode(shape_val);
199     auto cnode = node->cast<CNodePtr>();
200     MS_EXCEPTION_IF_NULL(cnode);
201     MS_LOG(INFO) << "For node " << cnode->DebugString()
202                  << ", the symbolic value of \"shape\" is static or has only one dynamic dim, "
203                  << "replace the \"shape\" to a value node: " << shape_val->ToString();
204     shape->set_abstract(shape_val->ToAbstract());
205     auto reshape = NewCNode({cnode->input(0), cnode->input(1), shape}, node->func_graph());
206     reshape->set_abstract(node->abstract());
207     return reshape;
208   };
209   auto CheckShape = [&shape_vec](const AnfNodePtr &shape) {
210     if (!shape->isa<CNode>()) {
211       return false;
212     }
213     auto symshape = shape->abstract()->GetSymbolicValue();
214     if (symshape == nullptr || !symshape->HasData()) {
215       return false;
216     }
217     shape_vec = symshape::ToShape(symshape);
218     return std::count(shape_vec.cbegin(), shape_vec.cend(), abstract::Shape::kShapeDimAny) <= 1;
219   };
220   MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimReshape, input, shape), MakeReshape,
221                           CheckShape(shape.GetNode(node)));
222   return nullptr;
223 }
224 
operator ()(const OptimizerPtr &,const AnfNodePtr & node)225 AnfNodePtr FoldConstSymbol::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
226   auto symbol_engine = GetSymbolEngine(node);
227   if (symbol_engine == nullptr) {
228     return nullptr;
229   }
230   auto op_def = mindspore::ops::GetOpDef(AnfUtils::GetCNodeName(node));
231   if (op_def == nullptr) {
232     return nullptr;
233   }
234   if (node->abstract() != nullptr && !symshape::QueryValue(node->abstract())->isa<ValueAny>()) {
235     return nullptr;
236   }
237   auto cnode = node->cast<CNodePtr>();
238   MS_EXCEPTION_IF_NULL(cnode);
239   AnfNodePtrList new_inputs;
240   bool need_replace = false;
241   for (size_t i = 1; i < cnode->size(); i++) {
242     auto inp = cnode->input(i);
243     if (!inp->isa<CNode>() || inp->abstract() == nullptr || i - 1 >= op_def->args_.size()) {
244       continue;
245     }
246     auto v = symshape::QueryValue(inp->abstract());
247     if (v->isa<ValueAny>()) {
248       continue;
249     }
250     if (new_inputs.empty()) {
251       new_inputs = cnode->inputs();
252     }
253     if (v->isa<ValueSequence>() && op_def->args_[i - 1].arg_dtype_ == ops::OP_DTYPE::DT_TUPLE_INT) {
254       new_inputs[i] = NewValueNode(v);
255       need_replace = true;
256     } else {
257       MS_LOG(INFO) << "For node " << node->DebugString() << ", the input[" << i
258                    << "]'s value does not match the op_def type(" << op_def->args_[i - 1].arg_dtype_
259                    << "). value = :" << v->ToString();
260       continue;
261     }
262     MS_LOG(INFO) << "For node " << node->DebugString() << ", the input[" << i
263                  << "]'s symbolic value is constant, fold the input value: " << v->ToString();
264     auto new_abs = v->ToAbstract();
265     MS_EXCEPTION_IF_NULL(new_abs);
266     new_abs->SetSymbolicValue(inp->abstract()->GetSymbolicValue());
267     new_inputs[i]->set_abstract(new_abs);
268   }
269   if (!need_replace) {
270     return nullptr;
271   }
272   auto new_node = NewCNode(new_inputs, node->func_graph());
273   new_node->set_abstract(node->abstract());
274   return new_node;
275 }
276 
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer)277 bool ShapeOpCse::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
278   if (func_graph->symbol_engine() == nullptr) {
279     return false;
280   }
281   auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple, AlwaysInclude);
282   auto mng = optimizer->manager();
283   MS_EXCEPTION_IF_NULL(mng);
284   std::vector<std::pair<AnfNodePtr, SymbolPtr>> shape_values;
285   bool changed = false;
286   for (auto &node : nodes) {
287     if (IsPrimitiveCNode(node, prim::kPrimShape)) {
288       auto v = node->abstract()->GetSymbolicValue();
289       if (v == nullptr) {
290         continue;
291       }
292       bool matched = false;
293       for (auto &prev : shape_values) {
294         if (node->func_graph() == prev.first->func_graph() && v->EqualsTo(prev.second)) {
295           MS_LOG(INFO) << "The symbolic value of " << node->DebugString() << " (" << node->fullname_with_scope()
296                        << ") is same as previous node " << prev.first->DebugString() << " ("
297                        << prev.first->fullname_with_scope() << "), eliminated it. Value:" << v->ToString();
298           mng->Replace(node, prev.first);
299           changed = true;
300           matched = true;
301           break;
302         }
303       }
304       if (!matched) {
305         shape_values.emplace_back(std::make_pair(node, v));
306       }
307     }
308   }
309   return changed;
310 }
311 }  // namespace irpass
312 }  // namespace opt
313 }  // namespace mindspore
314