1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_ 19 20 #include "ir/func_graph.h" 21 #include "frontend/optimizer/anf_visitor.h" 22 #include "frontend/optimizer/optimizer.h" 23 #include "mindspore/core/symbolic_shape/symbol.h" 24 25 namespace mindspore { 26 namespace opt { 27 namespace irpass { 28 class SymbolEngineBuilder { 29 public: only_dynshape_graph_(only_dynshape_graph)30 explicit SymbolEngineBuilder(bool only_dynshape_graph = true) : only_dynshape_graph_(only_dynshape_graph) {} 31 ~SymbolEngineBuilder() = default; 32 bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &opt); 33 34 protected: 35 bool HasDynamicShapeNode(const OptimizerPtr &opt) const; 36 bool only_dynshape_graph_{true}; // If true, only build SymbolEngine when dynamic shape node exists. 37 }; 38 39 /** 40 * Eliminate the ShapeCalc-Reduce-Reshape pattern generated by BroadcastGradientArgs. 41 * 42 * %5 = Add(a, b) // when shape of "a" is equal to shape of "%5" 43 * ... 44 * %10 = ShapeCalc(a, b) // backward op of "%5-Add". 45 * %11 = TupleGetItem(%10, 0) 46 * %12 = ReduceSum(dout, %11) 47 * %13 = Shape(a) 48 * %14 = Reshape(%12, %13) 49 * %15 = op(%14) 50 * ---> 51 * %5 = Add(a, b) 52 * ... 53 * %10 = op(dout) 54 * 55 * There may be another `TupleGetItem(%10, 1)` branch. when both branches are eliminated together, the "ShapeCalc" 56 * is eliminated. 57 */ 58 class ElimShapeCalcOnBroadcastArgsGrad : public AnfVisitor { 59 public: 60 AnfNodePtr operator()(const OptimizerPtr &opt, const AnfNodePtr &node) override; 61 62 protected: 63 bool Check(const OptimizerPtr &opt, const AnfNodePtr &shape_calc, size_t input_index); 64 bool CheckSymbolEqual(const ListSymbolPtr &input_shape, const ListSymbolPtr &output_shape, size_t shift); 65 }; 66 67 // Some ops like ReduceSum or Reshape, if the input shape and output shape are the same (in symbolic shape), it means 68 // that this op is not effective in running, so we can eliminate it. 69 class ElimNotEffectiveNode : public AnfVisitor { 70 public: 71 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; 72 }; 73 74 // the symbolic value of "shape" is static or has only one "-1", replace the "shape" to a const tensor. 75 class OptReshape : public AnfVisitor { 76 public: 77 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; 78 }; 79 80 /** 81 * Fold the input cnode when the symbolic value is constant value. 82 * 83 * example: 84 * %0 = ShapeCalc(p, ()) // the ShapeCalc has two output 85 * %1 = TupleGetItem(%0, 1) // the symbolic value of item 1 is const. 86 * %2 = Tile(p, %1) 87 * --> 88 * %2 = Tile(p, const_value) 89 */ 90 class FoldConstSymbol : public AnfVisitor { 91 public: 92 AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; 93 }; 94 95 class ShapeOpCse { 96 public: 97 ShapeOpCse() = default; 98 ~ShapeOpCse() = default; 99 bool operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer); 100 }; 101 } // namespace irpass 102 } // namespace opt 103 } // namespace mindspore 104 #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SYMBOL_ENGINE_OPTIMIZER_H_ 105