• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_
18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include <map>
24 #include "schema/inner/model_generated.h"
25 #include "tools/optimizer/common/pattern_process_pass_extends.h"
26 #include "include/common/utils/utils.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 
29 namespace mindspore {
30 namespace opt {
31 
32 /// fuse layer_norm or instance_norm into one operator
33 class NormFusion : public LitePatternProcessPass {
34  public:
35   explicit NormFusion(const std::string &name = "NormFusion", bool multigraph = true)
LitePatternProcessPass(name,multigraph)36       : LitePatternProcessPass(name, multigraph) {
37     InitShapeSizeInferFuncMap();
38   }
39 
40   ~NormFusion() override = default;
41 
42   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
43 
44  protected:
45   bool Init() const;
46 
47  private:
48   void InitShapeSizeInferFuncMap();
49   bool GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode,
50                           const std::vector<int> &mean_axes, const std::vector<int> &params_shape,
51                           schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const;
52   bool CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon,
53                     int *begin_norm_axis, int *begin_params_axis) const;
54   CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type,
55                           float epsilon, int begin_norm_axis, int begin_params_axis) const;
56   CNodePtr CreateActivationNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
57   std::map<string, int> ShapeSizeInfer(const FuncGraphPtr &func_graph) const;
58 
59  protected:
60   mutable VarPtr input_ = nullptr;
61   mutable VarPtr mean1_ = nullptr;
62   mutable VarPtr mean1_axes_ = nullptr;
63   mutable VarPtr mean2_ = nullptr;
64   mutable VarPtr mean2_axes_ = nullptr;
65   mutable VarPtr gamma_ = nullptr;
66   mutable VarPtr beta_ = nullptr;
67   mutable VarPtr epsilon_ = nullptr;
68   mutable ActivationType add_act_type_{NO_ACTIVATION};
69   std::map<schema::PrimitiveType, std::function<int(std::vector<int>, const schema::PrimitiveT &)>>
70     shape_size_infer_registry_;
71 };
72 
73 /// fuse tf layer_norm or instance_norm into one operator
74 class TfNormFusion : public NormFusion {
75  public:
76   explicit TfNormFusion(const std::string &name = "TfNormFusion", bool multigraph = true)
NormFusion(name,multigraph)77       : NormFusion(name, multigraph) {}
78 
79   ~TfNormFusion() override = default;
80 
81  private:
82   const BaseRef DefinePattern() const override;
83 };
84 
85 /// fuse onnx layer_norm into one operator
86 class OnnxLayerNormFusion : public NormFusion {
87  public:
88   explicit OnnxLayerNormFusion(const std::string &name = "OnnxLayerNormFusion", bool multigraph = true)
NormFusion(name,multigraph)89       : NormFusion(name, multigraph) {}
90 
91   ~OnnxLayerNormFusion() override = default;
92 
93  private:
94   const BaseRef DefinePattern() const override;
95 };
96 
97 /// fuse onnx layer_norm into one operator with little variance
98 class OnnxLayerNormFusion2 : public NormFusion {
99  public:
100   explicit OnnxLayerNormFusion2(const std::string &name = "OnnxLayerNormFusion2", bool multigraph = true)
NormFusion(name,multigraph)101       : NormFusion(name, multigraph) {}
102 
103   ~OnnxLayerNormFusion2() override = default;
104 
105  private:
106   const BaseRef DefinePattern() const override;
107 };
108 }  // namespace opt
109 }  // namespace mindspore
110 
111 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_
112