• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_
19 
20 #include "ir/func_graph.h"
21 #include "frontend/optimizer/anf_visitor.h"
22 #include "frontend/optimizer/optimizer.h"
23 #include "mindspore/core/symbolic_shape/symbol.h"
24 
25 namespace mindspore {
26 namespace opt {
27 namespace irpass {
28 class SymbolEngineBuilder {
29  public:
only_dynshape_graph_(only_dynshape_graph)30   explicit SymbolEngineBuilder(bool only_dynshape_graph = true) : only_dynshape_graph_(only_dynshape_graph) {}
31   ~SymbolEngineBuilder() = default;
32   bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &opt);
33 
34  protected:
35   bool HasDynamicShapeNode(const OptimizerPtr &opt) const;
36   bool only_dynshape_graph_{true};  // If true, only build SymbolEngine when dynamic shape node exists.
37 };
38 
39 /**
40  * Eliminate the ShapeCalc-Reduce-Reshape pattern generated by BroadcastGradientArgs.
41  *
42  * %5 = Add(a, b)  // when shape of "a" is equal to shape of "%5"
43  * ...
44  * %10 = ShapeCalc(a, b)   // backward op of "%5-Add".
45  * %11 = TupleGetItem(%10, 0)
46  * %12 = ReduceSum(dout, %11)
47  * %13 = Shape(a)
48  * %14 = Reshape(%12, %13)
49  * %15 = op(%14)
50  * --->
51  * %5 = Add(a, b)
52  * ...
53  * %10 = op(dout)
54  *
55  * There may be another `TupleGetItem(%10, 1)` branch. when both branches are eliminated together, the "ShapeCalc"
56  * is eliminated.
57  */
58 class ElimShapeCalcOnBroadcastArgsGrad : public AnfVisitor {
59  public:
60   AnfNodePtr operator()(const OptimizerPtr &opt, const AnfNodePtr &node) override;
61 
62  protected:
63   bool Check(const OptimizerPtr &opt, const AnfNodePtr &shape_calc, size_t input_index);
64   bool CheckSymbolEqual(const ListSymbolPtr &input_shape, const ListSymbolPtr &output_shape, size_t shift);
65 };
66 
67 // Some ops like ReduceSum or Reshape, if the input shape and output shape are the same (in symbolic shape), it means
68 // that this op is not effective in running, so we can eliminate it.
69 class ElimNotEffectiveNode : public AnfVisitor {
70  public:
71   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
72 };
73 
74 // the symbolic value of "shape" is static or has only one "-1", replace the "shape" to a const tensor.
75 class OptReshape : public AnfVisitor {
76  public:
77   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
78 };
79 
80 /**
81  * Fold the input cnode when the symbolic value is constant value.
82  *
83  * example:
84  * %0 = ShapeCalc(p, ()) // the ShapeCalc has two output
85  * %1 = TupleGetItem(%0, 1) // the symbolic value of item 1 is const.
86  * %2 = Tile(p, %1)
87  * -->
88  * %2 = Tile(p, const_value)
89  */
90 class FoldConstSymbol : public AnfVisitor {
91  public:
92   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
93 };
94 
95 class ShapeOpCse {
96  public:
97   ShapeOpCse() = default;
98   ~ShapeOpCse() = default;
99   bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer);
100 };
101 }  // namespace irpass
102 }  // namespace opt
103 }  // namespace mindspore
104 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_
105