1 /** 2 * Copyright 2020-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ 18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include <map> 24 #include "schema/inner/model_generated.h" 25 #include "tools/optimizer/common/pattern_process_pass_extends.h" 26 #include "include/common/utils/utils.h" 27 #include "tools/optimizer/common/gllo_utils.h" 28 29 namespace mindspore { 30 namespace opt { 31 32 /// fuse layer_norm or instance_norm into one operator 33 class NormFusion : public LitePatternProcessPass { 34 public: 35 explicit NormFusion(const std::string &name = "NormFusion", bool multigraph = true) LitePatternProcessPass(name,multigraph)36 : LitePatternProcessPass(name, multigraph) { 37 InitShapeSizeInferFuncMap(); 38 } 39 40 ~NormFusion() override = default; 41 42 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 43 44 protected: 45 bool Init() const; 46 47 private: 48 void InitShapeSizeInferFuncMap(); 49 bool GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode, 50 const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape, 51 schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const; 52 bool CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, 53 int *begin_norm_axis, int *begin_params_axis) const; 54 CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type, 55 float epsilon, int begin_norm_axis, int begin_params_axis) const; 56 CNodePtr CreateActivationNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; 57 std::map<string, int> ShapeSizeInfer(const FuncGraphPtr &func_graph) const; 58 59 protected: 60 mutable VarPtr input_ = nullptr; 61 mutable VarPtr mean1_ = nullptr; 62 mutable VarPtr mean1_axes_ = nullptr; 63 mutable VarPtr mean2_ = nullptr; 64 mutable VarPtr mean2_axes_ = nullptr; 65 mutable VarPtr gamma_ = nullptr; 66 mutable VarPtr beta_ = nullptr; 67 mutable VarPtr epsilon_ = nullptr; 68 mutable ActivationType add_act_type_{NO_ACTIVATION}; 69 std::map<schema::PrimitiveType, std::function<int(std::vector<int>, const schema::PrimitiveT &)>> 70 shape_size_infer_registry_; 71 }; 72 73 /// fuse tf layer_norm or instance_norm into one operator 74 class TfNormFusion : public NormFusion { 75 public: 76 explicit TfNormFusion(const std::string &name = "TfNormFusion", bool multigraph = true) NormFusion(name,multigraph)77 : NormFusion(name, multigraph) {} 78 79 ~TfNormFusion() override = default; 80 81 private: 82 const BaseRef DefinePattern() const override; 83 }; 84 85 /// fuse onnx layer_norm into one operator 86 class OnnxLayerNormFusion : public NormFusion { 87 public: 88 explicit OnnxLayerNormFusion(const std::string &name = "OnnxLayerNormFusion", bool multigraph = true) NormFusion(name,multigraph)89 : NormFusion(name, multigraph) {} 90 91 ~OnnxLayerNormFusion() override = default; 92 93 private: 94 const BaseRef DefinePattern() const override; 95 }; 96 97 /// fuse onnx layer_norm into one operator with little variance 98 class OnnxLayerNormFusion2 : public NormFusion { 99 public: 100 explicit OnnxLayerNormFusion2(const std::string &name = "OnnxLayerNormFusion2", bool multigraph = true) NormFusion(name,multigraph)101 : NormFusion(name, multigraph) {} 102 103 ~OnnxLayerNormFusion2() override = default; 104 105 private: 106 const BaseRef DefinePattern() const override; 107 }; 108 } // namespace opt 109 } // namespace mindspore 110 111 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ 112