1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/scale_activation_fusion.h"
19 #include <memory>
20 #include "mindspore/core/ops/lite_ops.h"
21 #include "ops/fusion/activation.h"
22 #include "ops/fusion/scale_fusion.h"
23 #include "ops/op_utils.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "nnacl/op_base.h"
26
27 namespace mindspore::opt {
DefinePattern() const28 const BaseRef ScaleActivationFusion::DefinePattern() const {
29 auto is_scale = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimScaleFusion>);
30 MS_CHECK_TRUE_RET(is_scale != nullptr, {});
31 auto is_activation = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimActivation>);
32 MS_CHECK_TRUE_RET(is_activation != nullptr, {});
33 return VectorRef({is_activation, is_scale});
34 }
35
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const36 const AnfNodePtr ScaleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
37 const EquivPtr &) const {
38 if (func_graph == nullptr || node == nullptr) {
39 lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
40 return nullptr;
41 }
42 auto act_node = node->cast<CNodePtr>();
43 MS_CHECK_TRUE_RET(act_node != nullptr, nullptr);
44 if (!CheckPrimitiveType(act_node, prim::kPrimActivation) || IsMarkedTrainOp(act_node)) {
45 return nullptr;
46 }
47 MS_CHECK_TRUE_RET(act_node->size() == kInputSizeTwo, nullptr);
48 auto act_prim = ops::GetOperator<mindspore::ops::Activation>(act_node->input(FIRST_INPUT));
49 MS_CHECK_TRUE_RET(act_prim != nullptr, nullptr);
50 auto act_prim_c = act_prim->GetPrim();
51 MS_CHECK_TRUE_RET(act_prim_c != nullptr && act_prim_c->GetAttr(ops::kActivationType) != nullptr, nullptr);
52 if (act_prim->get_activation_type() != mindspore::RELU && act_prim->get_activation_type() != mindspore::RELU6) {
53 return nullptr;
54 }
55
56 auto scale_node = act_node->input(SECOND_INPUT);
57 MS_CHECK_TRUE_RET(scale_node != nullptr, nullptr);
58 auto scale_cnode = scale_node->cast<CNodePtr>();
59 MS_CHECK_TRUE_RET(scale_cnode != nullptr, nullptr);
60 if (IsMarkedTrainOp(scale_cnode) || IsMultiOutputTensors(func_graph, scale_cnode)) {
61 return nullptr;
62 }
63 auto scale_prim = ops::GetOperator<ops::ScaleFusion>(scale_cnode->input(FIRST_INPUT));
64 MS_ASSERT(scale_prim != nullptr);
65 auto scale_prim_c = scale_prim->GetPrim();
66 MS_CHECK_TRUE_RET(scale_prim_c != nullptr, nullptr);
67 ActivationType act_type = act_prim->get_activation_type();
68 if (scale_prim_c->GetAttr(ops::kActivationType) != nullptr && scale_prim->get_activation_type() != NO_ACTIVATION) {
69 auto scale_act = scale_prim->get_activation_type();
70 MS_CHECK_TRUE_RET(scale_act == RELU || scale_act == RELU6, nullptr);
71 act_type = scale_act == RELU6 ? RELU6 : act_type;
72 }
73 (void)scale_prim_c->AddAttr(ops::kActivationType, MakeValue<int64_t>(static_cast<int64_t>(act_type)));
74 return scale_node;
75 }
76 } // namespace mindspore::opt
77