1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ 18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include <map> 24 #include "schema/inner/model_generated.h" 25 #include "backend/optimizer/common/optimizer.h" 26 #include "utils/utils.h" 27 #include "tools/optimizer/common/gllo_utils.h" 28 29 namespace mindspore { 30 namespace opt { 31 32 /// fuse layer_norm or instance_norm into one operator 33 class NormFusion : public PatternProcessPass { 34 public: 35 explicit NormFusion(const std::string &name = "NormFusion", bool multigraph = true) PatternProcessPass(name,multigraph)36 : PatternProcessPass(name, multigraph) { 37 InitShapeSizeInferFuncMap(); 38 } 39 40 ~NormFusion() override = default; 41 42 protected: 43 bool Init() const; 44 45 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 46 47 private: 48 void InitShapeSizeInferFuncMap(); 49 bool GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode, 50 const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape, 51 schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const; 52 bool CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon, 53 int *begin_norm_axis, int *begin_params_axis) const; 54 CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type, 55 float epsilon, int begin_norm_axis, int begin_params_axis) const; 56 std::map<string, int> ShapeSizeInfer(const FuncGraphPtr &func_graph) const; 57 58 protected: 59 mutable VarPtr input_ = nullptr; 60 mutable VarPtr mean1_ = nullptr; 61 mutable VarPtr mean1_axes_ = nullptr; 62 mutable VarPtr mean2_ = nullptr; 63 mutable VarPtr mean2_axes_ = nullptr; 64 mutable VarPtr gamma_ = nullptr; 65 mutable VarPtr beta_ = nullptr; 66 mutable VarPtr epsilon_ = nullptr; 67 std::map<schema::PrimitiveType, std::function<int(std::vector<int>, const schema::PrimitiveT &)>> 68 shape_size_infer_registry_; 69 }; 70 71 /// fuse tf layer_norm or instance_norm into one operator 72 class TfNormFusion : public NormFusion { 73 public: 74 explicit TfNormFusion(const std::string &name = "TfNormFusion", bool multigraph = true) NormFusion(name,multigraph)75 : NormFusion(name, multigraph) {} 76 77 ~TfNormFusion() override = default; 78 79 private: 80 const BaseRef DefinePattern() const override; 81 }; 82 83 /// fuse onnx layer_norm into one operator 84 class OnnxLayerNormFusion : public NormFusion { 85 public: 86 explicit OnnxLayerNormFusion(const std::string &name = "OnnxLayerNormFusion", bool multigraph = true) NormFusion(name,multigraph)87 : NormFusion(name, multigraph) {} 88 89 ~OnnxLayerNormFusion() override = default; 90 91 private: 92 const BaseRef DefinePattern() const override; 93 }; 94 } // namespace opt 95 } // namespace mindspore 96 97 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_ 98