• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_
18 
19 #include <vector>
20 #include <memory>
21 #include <string>
22 #include "backend/optimizer/common/optimizer.h"
23 #include "utils/utils.h"
24 
25 namespace mindspore {
26 namespace opt {
27 class FusedBatchNormFusion : public PatternProcessPass {
28  public:
29   explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true)
PatternProcessPass(name,multigraph)30       : PatternProcessPass(name, multigraph) {
31     data_input0_var_ = std::make_shared<Var>();
32     data_input1_var_ = std::make_shared<Var>();
33     data_input2_var_ = std::make_shared<Var>();
34     variable_input0_var_ = std::make_shared<Var>();
35     variable_input1_var_ = std::make_shared<Var>();
36     constant_input0_var_ = std::make_shared<Var>();
37     constant_input1_var_ = std::make_shared<Var>();
38     batch_norm_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimBatchNorm->name()));
39     assign_sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAssignSub->name()));
40     assign_sub1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAssignSub->name()));
41     monad0_var_ = std::make_shared<Var>();
42     monad1_var_ = std::make_shared<Var>();
43   }
44   ~FusedBatchNormFusion() override = default;
45   const BaseRef DefinePattern() const override;
46   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
47 
48  protected:
49   AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
50                                     const EquivPtr &equiv) const;
51   void GetBNTrainingUpdateInputs(const EquivPtr &equiv, const std::vector<AnfNodePtr> &bn_training_reduce_outputs,
52                                  std::vector<AnfNodePtr> *bn_training_update_inputs) const;
53   void GetBNTrainingUpdateAbstractList(const EquivPtr &equiv, const AnfNodePtr &bn,
54                                        std::vector<AbstractBasePtr> *abstract_list) const;
55   AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
56                                     const std::vector<AnfNodePtr> &bn_training_reduce_outputs) const;
57   ValuePtr GetFactor(const EquivPtr &equiv) const;
58   void EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const;
59 
60   VarPtr data_input0_var_;
61   VarPtr data_input1_var_;
62   VarPtr data_input2_var_;
63   VarPtr variable_input0_var_;
64   VarPtr variable_input1_var_;
65   VarPtr constant_input0_var_;
66   VarPtr constant_input1_var_;
67   VarPtr batch_norm_var_;
68   VarPtr assign_sub0_var_;
69   VarPtr assign_sub1_var_;
70   VarPtr monad0_var_;
71   VarPtr monad1_var_;
72 };
73 
74 class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion {
75  public:
76   explicit FusedBatchNormMixPrecisionFusion0(bool multigraph = true)
77       : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {}
78 
79   ~FusedBatchNormMixPrecisionFusion0() override = default;
80   const BaseRef DefinePattern() const override;
81 };
82 
83 class FusedBatchNormMixPrecisionFusion1 : public FusedBatchNormFusion {
84  public:
85   explicit FusedBatchNormMixPrecisionFusion1(bool multigraph = true)
86       : FusedBatchNormFusion("fused_batch_norm_mix_precision_fusion", multigraph) {}
87 
88   ~FusedBatchNormMixPrecisionFusion1() override = default;
89   const BaseRef DefinePattern() const override;
90 };
91 }  // namespace opt
92 }  // namespace mindspore
93 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FUSED_BATCH_NORM_FUSION_H_
94