• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_
18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include <map>
24 #include "schema/inner/model_generated.h"
25 #include "backend/optimizer/common/optimizer.h"
26 #include "utils/utils.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 
29 namespace mindspore {
30 namespace opt {
31 
32 /// fuse layer_norm or instance_norm into one operator
33 class NormFusion : public PatternProcessPass {
34  public:
35   explicit NormFusion(const std::string &name = "NormFusion", bool multigraph = true)
PatternProcessPass(name,multigraph)36       : PatternProcessPass(name, multigraph) {
37     InitShapeSizeInferFuncMap();
38   }
39 
40   ~NormFusion() override = default;
41 
42  protected:
43   bool Init() const;
44 
45   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
46 
47  private:
48   void InitShapeSizeInferFuncMap();
49   bool GetNormTypeAndAxis(const FuncGraphPtr &func_graph, const CNodePtr &input_cnode,
50                           const std::vector<int> &mean_axes, const std::vector<int> &params_shape,
51                           schema::PrimitiveType *type, int *begin_norm_axis, int *begin_params_axis) const;
52   bool CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, schema::PrimitiveType *type, float *epsilon,
53                     int *begin_norm_axis, int *begin_params_axis) const;
54   CNodePtr CreateNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const schema::PrimitiveType type,
55                           float epsilon, int begin_norm_axis, int begin_params_axis) const;
56   std::map<string, int> ShapeSizeInfer(const FuncGraphPtr &func_graph) const;
57 
58  protected:
59   mutable VarPtr input_ = nullptr;
60   mutable VarPtr mean1_ = nullptr;
61   mutable VarPtr mean1_axes_ = nullptr;
62   mutable VarPtr mean2_ = nullptr;
63   mutable VarPtr mean2_axes_ = nullptr;
64   mutable VarPtr gamma_ = nullptr;
65   mutable VarPtr beta_ = nullptr;
66   mutable VarPtr epsilon_ = nullptr;
67   std::map<schema::PrimitiveType, std::function<int(std::vector<int>, const schema::PrimitiveT &)>>
68     shape_size_infer_registry_;
69 };
70 
71 /// fuse tf layer_norm or instance_norm into one operator
72 class TfNormFusion : public NormFusion {
73  public:
74   explicit TfNormFusion(const std::string &name = "TfNormFusion", bool multigraph = true)
NormFusion(name,multigraph)75       : NormFusion(name, multigraph) {}
76 
77   ~TfNormFusion() override = default;
78 
79  private:
80   const BaseRef DefinePattern() const override;
81 };
82 
83 /// fuse onnx layer_norm into one operator
84 class OnnxLayerNormFusion : public NormFusion {
85  public:
86   explicit OnnxLayerNormFusion(const std::string &name = "OnnxLayerNormFusion", bool multigraph = true)
NormFusion(name,multigraph)87       : NormFusion(name, multigraph) {}
88 
89   ~OnnxLayerNormFusion() override = default;
90 
91  private:
92   const BaseRef DefinePattern() const override;
93 };
94 }  // namespace opt
95 }  // namespace mindspore
96 
97 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_NORM_FUSION_H_
98