1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/optimizer/irpass/symbol_engine_optimizer.h"
18
19 #include <vector>
20 #include <memory>
21 #include <utility>
22 #include "ir/pattern_matcher.h"
23 #include "ir/functor.h"
24 #include "ops/array_ops.h"
25 #include "ops/math_ops.h"
26 #include "ops/op_def.h"
27 #include "include/common/utils/utils.h"
28 #include "mindspore/core/symbolic_shape/symbol.h"
29 #include "mindspore/core/symbolic_shape/utils.h"
30 #include "include/common/symbol_engine/symbol_engine_impl.h"
31
32 namespace mindspore {
33 namespace opt {
34 namespace irpass {
GetSymbolEngine(const AnfNodePtr & node)35 inline SymbolEnginePtr GetSymbolEngine(const AnfNodePtr &node) { return node->func_graph()->symbol_engine(); }
36
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & opt)37 bool SymbolEngineBuilder::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &opt) {
38 if (only_dynshape_graph_ && !HasDynamicShapeNode(opt)) {
39 MS_LOG(INFO) << "There is no dynamic shape node, the SymbolEngineBuilder is disabled.";
40 return false;
41 }
42 try {
43 MS_LOG_TRY_CATCH_SCOPE;
44 symshape::SymbolEngineImpl::Build(func_graph);
45 MS_LOG(INFO) << "Build symbol engine successfully.";
46 } catch (std::exception &e) {
47 MS_LOG(WARNING) << "Build symbol engine failed. message: " << e.what();
48 }
49 return true;
50 }
51
HasDynamicShapeNode(const OptimizerPtr & opt) const52 bool SymbolEngineBuilder::HasDynamicShapeNode(const OptimizerPtr &opt) const {
53 auto mng = opt->manager();
54 if (mng == nullptr) {
55 return false;
56 }
57 auto &nodes = mng->all_nodes();
58 for (auto &node : nodes) {
59 if (!node->isa<CNode>()) {
60 continue;
61 }
62 auto abs = node->abstract();
63 if (abs != nullptr && abs->GetShape()->IsDynamic()) {
64 return true;
65 }
66 }
67 return false;
68 }
69
operator ()(const OptimizerPtr & opt,const AnfNodePtr & node)70 AnfNodePtr ElimShapeCalcOnBroadcastArgsGrad::operator()(const OptimizerPtr &opt, const AnfNodePtr &node) {
71 if (GetSymbolEngine(node) == nullptr) {
72 return nullptr;
73 }
74 PatternNode<AnfNodePtr> dout;
75 PatternNode<AnfNodePtr> shape_calc;
76 PatternNode<AnfNodePtr> shape;
77 PatternNode<AnfNodePtr> keepdims;
78 PatternNode<AnfNodePtr> skipmode;
79 PConstant idx0(node, false, 0, true);
80 PConstant idx1(node, false, 1, true);
81 MATCH_REPLACE_IF(
82 node,
83 PPrimitive(prim::kPrimReduceSum, dout, PPrimitive(prim::kPrimTupleGetItem, shape_calc, idx0), keepdims, skipmode),
84 dout, Check(opt, shape_calc.GetNode(node), kIndex1));
85 MATCH_REPLACE_IF(
86 node,
87 PPrimitive(prim::kPrimReduceSum, dout, PPrimitive(prim::kPrimTupleGetItem, shape_calc, idx1), keepdims, skipmode),
88 dout, Check(opt, shape_calc.GetNode(node), kIndex2));
89 return nullptr;
90 }
91
Check(const OptimizerPtr & opt,const AnfNodePtr & shape_calc,size_t input_index)92 bool ElimShapeCalcOnBroadcastArgsGrad::Check(const OptimizerPtr &opt, const AnfNodePtr &shape_calc,
93 size_t input_index) {
94 auto mng = opt->manager();
95 MS_EXCEPTION_IF_NULL(mng);
96 auto &users = mng->node_users();
97
98 auto shapecalc_node = shape_calc->cast<CNodePtr>();
99 constexpr const size_t shapecalc_size = 3;
100 if (shapecalc_node == nullptr || !IsPrimitiveCNode(shapecalc_node, prim::kPrimShapeCalc) ||
101 shapecalc_node->size() != shapecalc_size) {
102 return false;
103 }
104 auto input_node = shapecalc_node->input(input_index);
105 auto shapecalc_functor = common::AnfAlgo::GetNodeAttr<ShapeCalcBaseFunctorPtr>(shapecalc_node, kAttrFunctor);
106 MS_EXCEPTION_IF_NULL(shapecalc_functor);
107 if (shapecalc_functor->name() != "ShapeCalc_BroadcastGradientArgs") {
108 // only support the broadcast gradient condition
109 return false;
110 }
111 auto fwd_unique_id = shapecalc_node->primal_attrs().find(kPrimalAttrForwardUniqueId);
112 if (fwd_unique_id == shapecalc_node->primal_attrs().end()) {
113 // only support bprop node
114 return false;
115 }
116 AnfNodePtr fwd_node = nullptr;
117 for (auto &user : users[input_node]) {
118 auto user_cnode = user.first->cast<CNodePtr>();
119 if (user_cnode == nullptr) {
120 continue;
121 }
122 if (auto uniq_id = user_cnode->primal_attrs().find(kPrimalAttrUniqueId);
123 uniq_id != user_cnode->primal_attrs().end()) {
124 if (*uniq_id->second == *fwd_unique_id->second) {
125 fwd_node = user.first;
126 break;
127 }
128 }
129 }
130 if (fwd_node == nullptr) {
131 return false;
132 }
133
134 auto input_shape = input_node->abstract()->GetSymbolicShape();
135 auto output_shape = fwd_node->abstract()->GetSymbolicShape();
136 auto ret = CheckSymbolEqual(input_shape, output_shape, GetValue<size_t>(shapecalc_functor->ToValue()));
137 if (ret) {
138 MS_LOG(INFO) << "For " << shape_calc->DebugString() << " (" << shape_calc->fullname_with_scope() << ")"
139 << " generated by BroadcastGradientArgs. The gradient for input " << input_index
140 << " is unnecessary, which can be eliminated. grad symbol: " << input_shape->ToString()
141 << ". out symbol: " << output_shape->ToString();
142 }
143 return ret;
144 }
145
CheckSymbolEqual(const ListSymbolPtr & input_shape,const ListSymbolPtr & output_shape,size_t shift)146 bool ElimShapeCalcOnBroadcastArgsGrad::CheckSymbolEqual(const ListSymbolPtr &input_shape,
147 const ListSymbolPtr &output_shape, size_t shift) {
148 if (input_shape == nullptr || output_shape == nullptr) {
149 return false;
150 }
151 if (input_shape->size() < output_shape->size()) {
152 return false;
153 }
154 if (input_shape->is_dyn_len() || output_shape->is_dyn_len()) {
155 return input_shape->EqualsTo(output_shape);
156 }
157 for (size_t i = input_shape->size(); i > shift; i--) {
158 auto inp = input_shape->symbols()[input_shape->size() - i];
159 if (i <= output_shape->size() && !inp->EqualsTo(output_shape->symbols()[output_shape->size() - i])) {
160 return false;
161 }
162 }
163 return true;
164 }
165
operator ()(const OptimizerPtr &,const AnfNodePtr & node)166 AnfNodePtr ElimNotEffectiveNode::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
167 if (GetSymbolEngine(node) == nullptr) {
168 return nullptr;
169 }
170 static const PrimitiveSet supports_op = {prim::kPrimReshape, prim::kPrimReduceSum, prim::kPrimReduceMax,
171 prim::kPrimReduceMin};
172 if (!IsOneOfPrimitiveCNode(node, supports_op)) {
173 return nullptr;
174 }
175 auto cnode = node->cast<CNodePtr>();
176 MS_EXCEPTION_IF_NULL(cnode);
177 auto input_node = node->cast<CNodePtr>()->input(1);
178 auto input_shape = input_node->abstract()->GetSymbolicShape();
179 auto output_shape = node->abstract()->GetSymbolicShape();
180 if (input_shape != nullptr && input_shape->EqualsTo(output_shape)) {
181 MS_LOG(INFO) << "For node " << node->DebugString() << " (" << node->fullname_with_scope()
182 << "), the input shape and output shape is same, which can be eliminated.";
183 return input_node;
184 }
185 return nullptr;
186 }
187
operator ()(const OptimizerPtr &,const AnfNodePtr & node)188 AnfNodePtr OptReshape::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
189 if (GetSymbolEngine(node) == nullptr) {
190 return nullptr;
191 }
192 PatternNode<AnfNodePtr> input;
193 PatternNode<AnfNodePtr> shape;
194 ShapeVector shape_vec;
195
196 auto MakeReshape = [&shape_vec, &node]() -> AnfNodePtr {
197 auto shape_val = MakeValue(shape_vec);
198 auto shape = NewValueNode(shape_val);
199 auto cnode = node->cast<CNodePtr>();
200 MS_EXCEPTION_IF_NULL(cnode);
201 MS_LOG(INFO) << "For node " << cnode->DebugString()
202 << ", the symbolic value of \"shape\" is static or has only one dynamic dim, "
203 << "replace the \"shape\" to a value node: " << shape_val->ToString();
204 shape->set_abstract(shape_val->ToAbstract());
205 auto reshape = NewCNode({cnode->input(0), cnode->input(1), shape}, node->func_graph());
206 reshape->set_abstract(node->abstract());
207 return reshape;
208 };
209 auto CheckShape = [&shape_vec](const AnfNodePtr &shape) {
210 if (!shape->isa<CNode>()) {
211 return false;
212 }
213 auto symshape = shape->abstract()->GetSymbolicValue();
214 if (symshape == nullptr || !symshape->HasData()) {
215 return false;
216 }
217 shape_vec = symshape::ToShape(symshape);
218 return std::count(shape_vec.cbegin(), shape_vec.cend(), abstract::Shape::kShapeDimAny) <= 1;
219 };
220 MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimReshape, input, shape), MakeReshape,
221 CheckShape(shape.GetNode(node)));
222 return nullptr;
223 }
224
operator ()(const OptimizerPtr &,const AnfNodePtr & node)225 AnfNodePtr FoldConstSymbol::operator()(const OptimizerPtr &, const AnfNodePtr &node) {
226 auto symbol_engine = GetSymbolEngine(node);
227 if (symbol_engine == nullptr) {
228 return nullptr;
229 }
230 auto op_def = mindspore::ops::GetOpDef(AnfUtils::GetCNodeName(node));
231 if (op_def == nullptr) {
232 return nullptr;
233 }
234 if (node->abstract() != nullptr && !symshape::QueryValue(node->abstract())->isa<ValueAny>()) {
235 return nullptr;
236 }
237 auto cnode = node->cast<CNodePtr>();
238 MS_EXCEPTION_IF_NULL(cnode);
239 AnfNodePtrList new_inputs;
240 bool need_replace = false;
241 for (size_t i = 1; i < cnode->size(); i++) {
242 auto inp = cnode->input(i);
243 if (!inp->isa<CNode>() || inp->abstract() == nullptr || i - 1 >= op_def->args_.size()) {
244 continue;
245 }
246 auto v = symshape::QueryValue(inp->abstract());
247 if (v->isa<ValueAny>()) {
248 continue;
249 }
250 if (new_inputs.empty()) {
251 new_inputs = cnode->inputs();
252 }
253 if (v->isa<ValueSequence>() && op_def->args_[i - 1].arg_dtype_ == ops::OP_DTYPE::DT_TUPLE_INT) {
254 new_inputs[i] = NewValueNode(v);
255 need_replace = true;
256 } else {
257 MS_LOG(INFO) << "For node " << node->DebugString() << ", the input[" << i
258 << "]'s value does not match the op_def type(" << op_def->args_[i - 1].arg_dtype_
259 << "). value = :" << v->ToString();
260 continue;
261 }
262 MS_LOG(INFO) << "For node " << node->DebugString() << ", the input[" << i
263 << "]'s symbolic value is constant, fold the input value: " << v->ToString();
264 auto new_abs = v->ToAbstract();
265 MS_EXCEPTION_IF_NULL(new_abs);
266 new_abs->SetSymbolicValue(inp->abstract()->GetSymbolicValue());
267 new_inputs[i]->set_abstract(new_abs);
268 }
269 if (!need_replace) {
270 return nullptr;
271 }
272 auto new_node = NewCNode(new_inputs, node->func_graph());
273 new_node->set_abstract(node->abstract());
274 return new_node;
275 }
276
operator ()(const FuncGraphPtr & func_graph,const OptimizerPtr & optimizer)277 bool ShapeOpCse::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) {
278 if (func_graph->symbol_engine() == nullptr) {
279 return false;
280 }
281 auto nodes = TopoSort(func_graph->get_return(), SuccDeeperSimple, AlwaysInclude);
282 auto mng = optimizer->manager();
283 MS_EXCEPTION_IF_NULL(mng);
284 std::vector<std::pair<AnfNodePtr, SymbolPtr>> shape_values;
285 bool changed = false;
286 for (auto &node : nodes) {
287 if (IsPrimitiveCNode(node, prim::kPrimShape)) {
288 auto v = node->abstract()->GetSymbolicValue();
289 if (v == nullptr) {
290 continue;
291 }
292 bool matched = false;
293 for (auto &prev : shape_values) {
294 if (node->func_graph() == prev.first->func_graph() && v->EqualsTo(prev.second)) {
295 MS_LOG(INFO) << "The symbolic value of " << node->DebugString() << " (" << node->fullname_with_scope()
296 << ") is same as previous node " << prev.first->DebugString() << " ("
297 << prev.first->fullname_with_scope() << "), eliminated it. Value:" << v->ToString();
298 mng->Replace(node, prev.first);
299 changed = true;
300 matched = true;
301 break;
302 }
303 }
304 if (!matched) {
305 shape_values.emplace_back(std::make_pair(node, v));
306 }
307 }
308 }
309 return changed;
310 }
311 } // namespace irpass
312 } // namespace opt
313 } // namespace mindspore
314