• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
19 #include <memory>
20 #include "mindspore/core/ops/lite_ops.h"
21 #include "ops/fusion/activation.h"
22 #include "ops/op_utils.h"
23 #include "include/common/utils/utils.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "nnacl/op_base.h"
26 
27 namespace mindspore::opt {
DefineSigmoidMulFirstPattern() const28 VectorRef SigmoidMulFusion::DefineSigmoidMulFirstPattern() const {
29   auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
30   MS_CHECK_TRUE_RET(is_activation != nullptr, {});
31   auto is_var = std::make_shared<Var>();
32   MS_CHECK_TRUE_RET(is_var != nullptr, {});
33   auto activation_input = VectorRef({is_activation, is_var});
34   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
35   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
36   auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
37   MS_CHECK_TRUE_RET(is_const != nullptr, {});
38   return VectorRef({is_mul, activation_input, is_const});
39 }
40 
DefineSigmoidMulSecondPattern() const41 VectorRef SigmoidMulFusion::DefineSigmoidMulSecondPattern() const {
42   auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
43   MS_CHECK_TRUE_RET(is_activation != nullptr, {});
44   auto is_var = std::make_shared<Var>();
45   MS_CHECK_TRUE_RET(is_var != nullptr, {});
46   auto activation_input = VectorRef({is_activation, is_var});
47   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
48   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
49   return VectorRef({is_mul, is_var, activation_input});
50 }
51 
DefinePatterns() const52 std::unordered_map<std::string, VectorRef> SigmoidMulFusion::DefinePatterns() const {
53   std::unordered_map<std::string, VectorRef> patterns;
54   patterns["SigmoidMulFirstPatternName"] = DefineSigmoidMulFirstPattern();
55   patterns["SigmoidMulSecondPatternName"] = DefineSigmoidMulSecondPattern();
56   return patterns;
57 }
58 
59 // x * sigmoid(x) ->swish(x)
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr &) const60 AnfNodePtr SigmoidMulFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
61                                      const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &) const {
62   if (func_graph == nullptr || node == nullptr) {
63     return nullptr;
64   }
65   auto mul_cnode = node->cast<CNodePtr>();
66   MS_CHECK_TRUE_RET(mul_cnode != nullptr, nullptr);
67   if (IsMarkedTrainOp(mul_cnode)) {
68     return nullptr;
69   }
70   auto activation_cnode = mul_cnode->input(kInputIndexTwo)->cast<CNodePtr>();
71   MS_CHECK_TRUE_RET(activation_cnode != nullptr, nullptr);
72   if (IsMarkedTrainOp(activation_cnode)) {
73     return nullptr;
74   }
75 
76   if (!CheckPattern(pattern_name, func_graph, activation_cnode, mul_cnode)) {
77     return nullptr;
78   }
79   auto activation_prim = ops::GetOperator<mindspore::ops::Activation>(activation_cnode->input(0));
80   MS_CHECK_TRUE_RET(activation_prim != nullptr, nullptr);
81   activation_prim->set_activation_type(mindspore::SWISH);
82   return activation_cnode;
83 }
84 
CheckPattern(const std::string & pattern_name,const FuncGraphPtr & func_graph,const CNodePtr & act_cnode,const CNodePtr & mul_cnode) const85 bool SigmoidMulFusion::CheckPattern(const std::string &pattern_name, const FuncGraphPtr &func_graph,
86                                     const CNodePtr &act_cnode, const CNodePtr &mul_cnode) const {
87   // activation must sigmoid
88   auto activation_prim = ops::GetOperator<mindspore::ops::Activation>(act_cnode->input(0));
89   MS_CHECK_TRUE_RET(activation_prim != nullptr, false);
90   if (activation_prim == nullptr || (activation_prim->GetAttr(ops::kActivationType) != nullptr &&
91                                      activation_prim->get_activation_type() != mindspore::SIGMOID)) {
92     MS_LOG(ERROR) << "activation type is not sigmoid.";
93     return false;
94   }
95   MS_CHECK_TRUE_RET(mul_cnode->input(kInputIndexOne) != nullptr, false);
96   if (pattern_name == "SigmoidMulFirstPatternName") {
97     return true;
98   } else {
99     MS_CHECK_TRUE_RET(act_cnode->input(kInputIndexOne) != nullptr, false);
100     if (act_cnode->input(kInputIndexOne) != mul_cnode->input(kInputIndexOne)) {
101       return false;
102     }
103   }
104   return true;
105 }
106 }  // namespace mindspore::opt
107