1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
19 #include <memory>
20 #include "mindspore/core/ops/lite_ops.h"
21 #include "ops/fusion/activation.h"
22 #include "ops/op_utils.h"
23 #include "include/common/utils/utils.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "nnacl/op_base.h"
26
27 namespace mindspore::opt {
DefineSigmoidMulFirstPattern() const28 VectorRef SigmoidMulFusion::DefineSigmoidMulFirstPattern() const {
29 auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
30 MS_CHECK_TRUE_RET(is_activation != nullptr, {});
31 auto is_var = std::make_shared<Var>();
32 MS_CHECK_TRUE_RET(is_var != nullptr, {});
33 auto activation_input = VectorRef({is_activation, is_var});
34 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
35 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
36 auto is_const = std::make_shared<CondVar>(IsParamOrValueNodeWithData);
37 MS_CHECK_TRUE_RET(is_const != nullptr, {});
38 return VectorRef({is_mul, activation_input, is_const});
39 }
40
DefineSigmoidMulSecondPattern() const41 VectorRef SigmoidMulFusion::DefineSigmoidMulSecondPattern() const {
42 auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
43 MS_CHECK_TRUE_RET(is_activation != nullptr, {});
44 auto is_var = std::make_shared<Var>();
45 MS_CHECK_TRUE_RET(is_var != nullptr, {});
46 auto activation_input = VectorRef({is_activation, is_var});
47 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
48 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
49 return VectorRef({is_mul, is_var, activation_input});
50 }
51
DefinePatterns() const52 std::unordered_map<std::string, VectorRef> SigmoidMulFusion::DefinePatterns() const {
53 std::unordered_map<std::string, VectorRef> patterns;
54 patterns["SigmoidMulFirstPatternName"] = DefineSigmoidMulFirstPattern();
55 patterns["SigmoidMulSecondPatternName"] = DefineSigmoidMulSecondPattern();
56 return patterns;
57 }
58
59 // x * sigmoid(x) ->swish(x)
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr &) const60 AnfNodePtr SigmoidMulFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
61 const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &) const {
62 if (func_graph == nullptr || node == nullptr) {
63 return nullptr;
64 }
65 auto mul_cnode = node->cast<CNodePtr>();
66 MS_CHECK_TRUE_RET(mul_cnode != nullptr, nullptr);
67 if (IsMarkedTrainOp(mul_cnode)) {
68 return nullptr;
69 }
70 auto activation_cnode = mul_cnode->input(kInputIndexTwo)->cast<CNodePtr>();
71 MS_CHECK_TRUE_RET(activation_cnode != nullptr, nullptr);
72 if (IsMarkedTrainOp(activation_cnode)) {
73 return nullptr;
74 }
75
76 if (!CheckPattern(pattern_name, func_graph, activation_cnode, mul_cnode)) {
77 return nullptr;
78 }
79 auto activation_prim = ops::GetOperator<mindspore::ops::Activation>(activation_cnode->input(0));
80 MS_CHECK_TRUE_RET(activation_prim != nullptr, nullptr);
81 activation_prim->set_activation_type(mindspore::SWISH);
82 return activation_cnode;
83 }
84
CheckPattern(const std::string & pattern_name,const FuncGraphPtr & func_graph,const CNodePtr & act_cnode,const CNodePtr & mul_cnode) const85 bool SigmoidMulFusion::CheckPattern(const std::string &pattern_name, const FuncGraphPtr &func_graph,
86 const CNodePtr &act_cnode, const CNodePtr &mul_cnode) const {
87 // activation must sigmoid
88 auto activation_prim = ops::GetOperator<mindspore::ops::Activation>(act_cnode->input(0));
89 MS_CHECK_TRUE_RET(activation_prim != nullptr, false);
90 if (activation_prim == nullptr || (activation_prim->GetAttr(ops::kActivationType) != nullptr &&
91 activation_prim->get_activation_type() != mindspore::SIGMOID)) {
92 MS_LOG(ERROR) << "activation type is not sigmoid.";
93 return false;
94 }
95 MS_CHECK_TRUE_RET(mul_cnode->input(kInputIndexOne) != nullptr, false);
96 if (pattern_name == "SigmoidMulFirstPatternName") {
97 return true;
98 } else {
99 MS_CHECK_TRUE_RET(act_cnode->input(kInputIndexOne) != nullptr, false);
100 if (act_cnode->input(kInputIndexOne) != mul_cnode->input(kInputIndexOne)) {
101 return false;
102 }
103 }
104 return true;
105 }
106 } // namespace mindspore::opt
107